In [275]:
#!wget https://dataworks.indianapolis.iu.edu/bitstream/handle/11243/41/data.zip
#!unzip -q data.zip
#!rm data.zip

In [276]:
import scipy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from functools import reduce

torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


# Dataset import

In [277]:
class ImageDNATrainDataset():
    def __init__(self, train=True):
        splits_mat = scipy.io.loadmat("data/INSECTS/splits.mat")
        train_loc = splits_mat["train_loc"]-1

        TRAINING_SAMPLES_NUMBER = 12481
        TRAINING_LABELS_NUMBER = 652

        assert len(train_loc[0]) == TRAINING_SAMPLES_NUMBER

        indeces = train_loc
        # indeces.shape is (1, |indeces|), so we extract the whole list using [0]
        indeces = indeces[0]

        data_mat = scipy.io.loadmat("data/INSECTS/data.mat")
        self.embeddings_img = torch.from_numpy(
            data_mat["embeddings_img"][indeces]
        ).float()
        self.embeddings_dna = torch.from_numpy(
            data_mat["embeddings_dna"][indeces]
        ).float()

        # Remap seen species in [0, 651]
        seen_species = data_mat["labels"][train_loc][0]
        seen_species_mapping = {label: i for i, label in enumerate(np.unique(seen_species))}

        species_mapping = seen_species_mapping
        assert len(species_mapping) == TRAINING_LABELS_NUMBER

        species = data_mat["labels"][indeces]
        remapped_species = np.array([species_mapping[label.item()] for label in species])
        self.remapped_species = torch.from_numpy(remapped_species).long()

        assert len(torch.unique(self.remapped_species)) == TRAINING_LABELS_NUMBER

        # data_mat['G'] returns a ndarray of type uint16, therefore we convert into int16 before invoking from_numpy
        self.G = torch.from_numpy(data_mat["G"].astype(np.int16)).long()
        self.genera = torch.empty(species.shape).long()
        for i in range(indeces.size):
            self.genera[i][0] = self.G[species[i][0] - 1][0] - 1041

        assert len(self.genera) == TRAINING_SAMPLES_NUMBER

        self.species_names = data_mat["species"][indeces]
        self.ids = data_mat["ids"][indeces]

    def __len__(self):
        return len(self.embeddings_dna)

    def __getitem__(self, idx):
        embedding_img = self.embeddings_img[idx]
        embedding_dna = self.embeddings_dna[idx]
        label = self.remapped_species[idx].item()
        genera = self.genera[idx].item()

        return embedding_img.view(1, -1), embedding_dna.view(1, -1), label, genera

In [278]:
class ImageDNAValidationDataset():
    def __init__(self, train=True):
        splits_mat = scipy.io.loadmat("data/INSECTS/splits.mat")
        train_loc = splits_mat["train_loc"]-1
        val_seen_loc = splits_mat["val_seen_loc"]-1
        val_unseen_loc = splits_mat["val_unseen_loc"]-1

        TRAINING_LABELS_NUMBER = 652
        VALIDATION_SAMPLES_NUMBER = 6939
        VALIDATION_SPECIES_NUMBER = 774
        TRAINING_VALIDATION_SPECIES_NUMBER = 797
        VALIDATION_SEEN_SPECIES_NUMBER = 629
        VALIDATION_UNSEEN_SPECIES_GENERA_NUMBER = 97

        indeces = np.concatenate((val_seen_loc, val_unseen_loc), axis=1)
        # indeces.shape is (1, |indeces|), so we extract the whole list using [0]
        indeces = indeces[0]
        assert len(indeces) == VALIDATION_SAMPLES_NUMBER

        data_mat = scipy.io.loadmat("data/INSECTS/data.mat")
        self.embeddings_img = torch.from_numpy(data_mat["embeddings_img"][indeces]).float()
        self.embeddings_dna = torch.from_numpy(data_mat["embeddings_dna"][indeces]).float()

        # Remap seen species in [0, 651]
        seen_species = data_mat["labels"][train_loc][0]
        seen_species_mapping = {label: i for i, label in enumerate(np.unique(seen_species))}

        # Remap unseen species during validation in [652, 796]
        unseen_species = data_mat["labels"][val_unseen_loc][0]
        unseen_species_mapping = {label: i + TRAINING_LABELS_NUMBER for i, label in enumerate(np.unique(unseen_species))}

        # Union of the two mappings, allows to fully remap all the labels
        species_mapping = seen_species_mapping | unseen_species_mapping
        assert len(species_mapping) == TRAINING_VALIDATION_SPECIES_NUMBER

        species = data_mat["labels"][indeces]
        remapped_species = np.array([species_mapping[label.item()] for label in species])
        self.remapped_species = torch.from_numpy(remapped_species).long()
        assert len(torch.unique(self.remapped_species)) == VALIDATION_SPECIES_NUMBER

        # data_mat['G'] returns a ndarray of type uint16, therefore we convert into int16 before invoking from_numpy
        self.G = torch.from_numpy(data_mat["G"].astype(np.int16)).long()
        self.genera = torch.empty(species.shape).long()
        for i in range(indeces.size):
            self.genera[i][0] = self.G[species[i][0] - 1][0] - 1041

        # Compute genera of unseen species in the validation set
        unseen_species_genera = []
        for i in val_unseen_loc[0]:
            unseen_species_genera.append(data_mat["G"][data_mat["labels"][i][0] - 1][0] - 1041)
        self.unseen_species_genera = np.array(unseen_species_genera)
        assert len(np.unique(self.unseen_species_genera)) == VALIDATION_UNSEEN_SPECIES_GENERA_NUMBER

        # Compute seen species number in the validation set
        seen_species = []
        for i in val_seen_loc[0]:
            seen_species.append(species_mapping[data_mat["labels"][i].item()])
        self.seen_species = np.array(seen_species)
        assert len(np.unique(self.seen_species)) == VALIDATION_SEEN_SPECIES_NUMBER

        self.species_names = data_mat["species"][indeces]
        self.ids = data_mat["ids"][indeces]

    def __len__(self):
        return len(self.embeddings_dna)

    def __getitem__(self, idx):
        embedding_img = self.embeddings_img[idx]
        embedding_dna = self.embeddings_dna[idx]
        label = self.remapped_species[idx].item()
        genera = self.genera[idx].item()

        return embedding_img.view(1, -1), embedding_dna.view(1, -1), label, genera

In [279]:
class ImageDNATestDataset(Dataset):
    def __init__(self, train=True):
        splits_mat = scipy.io.loadmat("data/INSECTS/splits.mat")
        train_loc = splits_mat["train_loc"]-1
        trainval_loc = splits_mat["trainval_loc"]-1
        test_seen_loc = splits_mat["test_seen_loc"]-1
        test_unseen_loc = splits_mat["test_unseen_loc"]-1
        val_seen_loc = splits_mat["val_seen_loc"]-1
        val_unseen_loc = splits_mat["val_unseen_loc"]-1

        TRAINING_SPECIES_NUMBER = 652
        TRAINING_VALIDATION_SPECIES = 797
        NUMBER_OF_SPECIES = 1040
        TEST_SEEN_SPECIES_NUMBER = 770
        TEST_UNSEEN_SPECIES_GENERA_NUMBER = 134

        indeces = np.concatenate((test_seen_loc, test_unseen_loc), axis=1)
        # indeces.shape is (1, |indeces|), so we extract the whole list using [0]
        indeces = indeces[0]

        data_mat = scipy.io.loadmat("data/INSECTS/data.mat")
        self.embeddings_img = torch.from_numpy(data_mat["embeddings_img"][indeces]).float()
        self.embeddings_dna = torch.from_numpy(data_mat["embeddings_dna"][indeces]).float()

        # Remap seen species in [0, 651]
        seen_species = data_mat["labels"][train_loc][0]
        seen_species_mapping = {label: i for i, label in enumerate(np.unique(seen_species))}

        # Remap unseen species during validation in [652, 796]
        unseen_species_validation = data_mat["labels"][val_unseen_loc][0]
        unseen_species_validation_mapping = {label: i + TRAINING_SPECIES_NUMBER for i, label in enumerate(np.unique(unseen_species_validation))}

        # Remap unseen species during test in [797, 1039]
        unseen_species_test = data_mat["labels"][test_unseen_loc][0]
        unseen_species_test_mapping = {label: i + TRAINING_VALIDATION_SPECIES for i, label in enumerate(np.unique(unseen_species_test))}

        assert reduce(np.intersect1d, (seen_species, unseen_species_validation, unseen_species_test)).size == 0

        # Union of the two mappings, allows to full remap all the labels
        labels_mapping = seen_species_mapping | unseen_species_validation_mapping | unseen_species_test_mapping
        assert len(labels_mapping) == NUMBER_OF_SPECIES

        species = data_mat["labels"][indeces]
        remapped_labels = np.array([labels_mapping[label.item()] for label in species])
        self.remapped_labels = torch.from_numpy(remapped_labels).long()

        # data_mat['G'] returns a ndarray of type uint16, therefore we convert into int16 before invoking from_numpy
        self.G = torch.from_numpy(data_mat["G"].astype(np.int16)).long()
        self.genera = torch.empty(species.shape).long()
        for i in range(indeces.size):
            self.genera[i][0] = self.G[species[i][0] - 1][0] - 1041

        # Compute genera of unseen species
        unseen_species_genera = []
        for i in test_unseen_loc[0]:
            unseen_species_genera.append(data_mat["G"][data_mat["labels"][i][0] - 1][0] - 1041)

        self.unseen_species_genera = np.array(unseen_species_genera)
        assert len(np.unique(self.unseen_species_genera)) == TEST_UNSEEN_SPECIES_GENERA_NUMBER

        # Compute seen species
        seen_species = []
        for i in test_seen_loc[0]:
            seen_species.append(labels_mapping[data_mat["labels"][i].item()])
        self.seen_species = np.array(seen_species)
        assert len(np.unique(self.seen_species)) == TEST_SEEN_SPECIES_NUMBER

        self.species_name = data_mat["species"][indeces]
        self.ids = data_mat["ids"][indeces]

    def __len__(self):
        return len(self.embeddings_dna)

    def __getitem__(self, idx):
        embedding_img = self.embeddings_img[idx]
        embedding_dna = self.embeddings_dna[idx]
        label = self.remapped_labels[idx].item()
        genera = self.genera[idx].item()

        return embedding_img.view(1, -1), embedding_dna.view(1, -1), label, genera

In [280]:
class ImageDNATrainValidationDataset(Dataset):
    def __init__(self, train=True):
        splits_mat = scipy.io.loadmat("data/INSECTS/splits.mat")
        train_loc = splits_mat["train_loc"]-1
        trainval_loc = splits_mat["trainval_loc"]-1
        test_seen_loc = splits_mat["test_seen_loc"]-1
        test_unseen_loc = splits_mat["test_unseen_loc"]-1
        val_seen_loc = splits_mat["val_seen_loc"]-1
        val_unseen_loc = splits_mat["val_unseen_loc"]-1

        TRAINING_SPECIES_NUMBER = 652
        TRAINING_VALIDATION_SPECIES = 797
        NUMBER_OF_SPECIES = 1040

        indeces = trainval_loc
        # indeces.shape is (1, |indeces|), so we extract the whole list using [0]
        indeces = indeces[0]

        data_mat = scipy.io.loadmat("data/INSECTS/data.mat")
        self.embeddings_img = torch.from_numpy(data_mat["embeddings_img"][indeces]).float()
        self.embeddings_dna = torch.from_numpy(data_mat["embeddings_dna"][indeces]).float()

        # Remap seen species in [0, 651]
        seen_species = data_mat["labels"][train_loc][0]
        seen_species_mapping = {label: i for i, label in enumerate(np.unique(seen_species))}

        # Remap unseen species during validation in [652, 796]
        unseen_species_validation = data_mat["labels"][val_unseen_loc][0]
        unseen_species_validation_mapping = {label: i + TRAINING_SPECIES_NUMBER for i, label in enumerate(np.unique(unseen_species_validation))}

        # Remap unseen species during test in [797, 1039]
        unseen_species_test = data_mat["labels"][test_unseen_loc][0]
        unseen_species_test_mapping = {label: i + TRAINING_VALIDATION_SPECIES for i, label in enumerate(np.unique(unseen_species_test))}

        assert reduce(np.intersect1d, (seen_species, unseen_species_validation, unseen_species_test)).size == 0

        # Union of the two mappings, allows to full remap all the labels
        labels_mapping = seen_species_mapping | unseen_species_validation_mapping | unseen_species_test_mapping
        assert len(labels_mapping) == NUMBER_OF_SPECIES

        species = data_mat["labels"][indeces]  # Consider only train
        remapped_labels = np.array([labels_mapping[label.item()] for label in species])
        self.remapped_labels = torch.from_numpy(remapped_labels).long()

        assert len(torch.unique(self.remapped_labels)) == TRAINING_VALIDATION_SPECIES

        # data_mat['G'] returns a ndarray of type uint16, therefore we convert into int16 before invoking from_numpy
        self.G = torch.from_numpy(data_mat["G"].astype(np.int16)).long()
        self.genera = torch.empty(species.shape).long()
        for i in range(indeces.size):
            self.genera[i][0] = self.G[species[i][0] - 1][0] - 1041

        self.species = data_mat["species"][indeces]
        self.ids = data_mat["ids"][indeces]

    def __len__(self):
        return len(self.embeddings_dna)

    def __getitem__(self, idx):
        embedding_img = self.embeddings_img[idx]
        embedding_dna = self.embeddings_dna[idx]
        label = self.remapped_labels[idx].item()
        genera = self.genera[idx].item()

        return embedding_img.view(1, -1), embedding_dna.view(1, -1), label, genera


# Model definition

In [281]:
class AttentionNet(nn.Module):
        def __init__(self, num_seen_species, num_genera, embed_dim_1, num_heads_1, embed_dim_2, num_heads_2, embed_dim_3, num_heads_3):
                super(AttentionNet, self).__init__()

                self.img_fc1 = nn.Linear(2048, 1024)
                self.img_fc2 = nn.Linear(1024, 500)

                self.inception_encoders_block_1 = InceptionEncodersBlock(embed_dim_1, num_heads_1, embed_dim_2, num_heads_2, embed_dim_3, num_heads_3)

                self.fc_species_1 = nn.Linear(4000, 1000)
                self.fc_species_2 = nn.Linear(1000, num_seen_species)

                self.fc_genera_1 = nn.Linear(4000, 500)
                self.fc_genera_2 = nn.Linear(500, num_genera)

                self.dropout = nn.Dropout(0.5)


        def forward(self, x_img, x_dna):
                x_img = F.relu(self.img_fc1(x_img))
                x_img = self.img_fc2(x_img)

                x_img, x_dna = self.inception_encoders_block_1(x_img, x_dna)

                x = torch.cat((x_img, x_dna), axis=1)

                x_species = x.clone()
                x_genera = x.clone()

                x_species = self.dropout(F.relu(self.fc_species_1(x_species)))
                x_species = self.fc_species_2(x_species)

                x_genera = self.dropout(F.relu(self.fc_genera_1(x_genera)))
                x_genera = self.fc_genera_2(x_genera)

                return x_species, x_genera

class ImageDNAEncoder(nn.Module):
        def __init__(self, embed_dim, linear_dim, num_heads):
                super(ImageDNAEncoder, self).__init__()
                self.multi_head_img_1 = nn.MultiheadAttention(embed_dim, num_heads)
                self.multi_head_dna_1 = nn.MultiheadAttention(embed_dim, num_heads)

                self.norm_img_1 = nn.LayerNorm(embed_dim)
                self.norm_dna_1 = nn.LayerNorm(embed_dim)

                self.linear_img_1 = nn.Linear(embed_dim, linear_dim)
                self.dropout_img = nn.Dropout(0.5)
                self.linear_img_2 = nn.Linear(linear_dim, embed_dim)

                self.linear_dna_1 = nn.Linear(embed_dim, linear_dim)
                self.dropout_dna = nn.Dropout(0.5)
                self.linear_dna_2 = nn.Linear(linear_dim, embed_dim)


        def forward(self, x_img, x_dna):
                identity = x_img
                x_img, _ = self.multi_head_img_1(x_img, x_dna, x_dna)
                x_img = self.norm_img_1(x_img + identity)
                x_img = self.feed_forward_img(x_img)

                identity = x_dna
                x_dna, _ = self.multi_head_dna_1(x_dna, x_img, x_img)
                x_dna = self.norm_dna_1(x_dna + identity)
                x_dna = self.feed_forward_dna(x_dna)

                return x_img, x_dna

        def feed_forward_img(self, x):
                return self.linear_img_2(self.dropout_img(F.relu(self.linear_img_1(x))))

        def feed_forward_dna(self, x):
                return self.linear_dna_2(self.dropout_dna(F.relu(self.linear_dna_1(x))))

class InceptionEncodersBlock(nn.Module):
        def __init__(self, embed_dim_1, num_heads_1, embed_dim_2, num_heads_2, embed_dim_3, num_heads_3):
                super(InceptionEncodersBlock, self).__init__()
                self.img_dna_encoder_1 = ImageDNAEncoder(embed_dim_1, 4 * embed_dim_1, num_heads_1)
                self.img_dna_encoder_2 = ImageDNAEncoder(embed_dim_2, 4 * embed_dim_2, num_heads_2)
                self.img_dna_encoder_3 = ImageDNAEncoder(embed_dim_3, 4 * embed_dim_3, num_heads_3)

                self.INPUT_SIZE = 500
                self.embed_dim_1 = embed_dim_1
                self.num_heads_1 = num_heads_1
                self.embed_dim_2 = embed_dim_2
                self.num_heads_2 = num_heads_2
                self.embed_dim_3 = embed_dim_3
                self.num_heads_3 = num_heads_3

        def forward(self, x_img, x_dna):

                img_identity = torch.reshape(x_img.clone(), (-1, self.INPUT_SIZE ))
                dna_identity = torch.reshape(x_dna.clone(), (-1, self.INPUT_SIZE ))
                x_img_1 = torch.reshape(x_img.clone(), (-1, self.INPUT_SIZE // self.embed_dim_1, self.embed_dim_1))
                x_dna_1 = torch.reshape(x_dna.clone(), (-1, self.INPUT_SIZE // self.embed_dim_1, self.embed_dim_1))
                x_img_2 = torch.reshape(x_img.clone(), (-1, self.INPUT_SIZE // self.embed_dim_2, self.embed_dim_2))
                x_dna_2 = torch.reshape(x_dna.clone(), (-1, self.INPUT_SIZE // self.embed_dim_2, self.embed_dim_2))
                x_img_3 = torch.reshape(x_img.clone(), (-1, self.INPUT_SIZE // self.embed_dim_3, self.embed_dim_3))
                x_dna_3 = torch.reshape(x_dna.clone(), (-1, self.INPUT_SIZE // self.embed_dim_3, self.embed_dim_3))

                x_img_1, x_dna_1 = self.img_dna_encoder_1(x_img_1, x_dna_1)
                x_img_2, x_dna_2 = self.img_dna_encoder_2(x_img_2, x_dna_2)
                x_img_3, x_dna_3 = self.img_dna_encoder_3(x_img_3, x_dna_3)

                x_img_1 = torch.reshape(x_img_1, (-1, self.INPUT_SIZE))
                x_dna_1 = torch.reshape(x_dna_1, (-1, self.INPUT_SIZE))
                x_img_2 = torch.reshape(x_img_2, (-1, self.INPUT_SIZE))
                x_dna_2 = torch.reshape(x_dna_2, (-1, self.INPUT_SIZE))
                x_img_3 = torch.reshape(x_img_3, (-1, self.INPUT_SIZE))
                x_dna_3 = torch.reshape(x_dna_3, (-1, self.INPUT_SIZE))

                x_img = img_identity + x_img_1 + x_img_2 + x_img_3
                x_dna = dna_identity + x_dna_1 + x_dna_2 + x_dna_3

                x_img = torch.concat((img_identity, x_img_1, x_img_2, x_img_3), axis=1)
                x_dna = torch.concat((dna_identity, x_dna_1, x_dna_2, x_dna_3), axis=1)

                return x_img, x_dna


# Creating datasets

In [282]:
training_set = ImageDNATrainDataset()
validation_set = ImageDNAValidationDataset()
test_set = ImageDNATestDataset()
training_validation_set = ImageDNATrainValidationDataset()

Defining methods for training, validating and testing the model.

In [283]:
def train(model, lr, momentum, num_epochs, batch_size, train_val=False, print_losses=False, print_step = 200):
    model.train()
    criterion_species = torch.nn.CrossEntropyLoss()
    criterion_genera = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    if train_val:
        loader = torch.utils.data.DataLoader(training_validation_set, batch_size=batch_size, shuffle=True)
    else:
        loader = torch.utils.data.DataLoader(training_set, batch_size=batch_size, shuffle=True)

    for epoch in range(num_epochs):

        running_labels_loss = 0.0
        running_genera_loss = 0.0

        for i, data in enumerate(loader, 0):
            inputs_img, inputs_dna, species, genera = data
            inputs_img, inputs_dna, species, genera = inputs_img.to(device), inputs_dna.to(device), species.to(device), genera.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            labels_outputs, genera_outputs = model(inputs_img, inputs_dna)
            labels_loss = criterion_species(labels_outputs, species)
            genera_loss = criterion_genera(genera_outputs, genera)
            total_loss = labels_loss + genera_loss
            total_loss.backward()
            optimizer.step()

            # Print losses
            if print_losses:
                running_labels_loss += labels_loss.item()
                running_genera_loss += genera_loss.item()
                if i % print_step == print_step - 1:
                    print(f"[{epoch + 1}, {i + 1:5d}] Species loss: {running_labels_loss / print_step:.3f}; Genera loss: {running_genera_loss / print_step:.3f}")
                    running_labels_loss = 0.0
                    running_genera_loss = 0.0

In [284]:
def validate(model, threshold, batch_size):

    validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=batch_size, shuffle=False)
    model.eval()

    with torch.no_grad():
        correct_predictions_per_labels = defaultdict(int)
        total_samples_per_labels = defaultdict(int)
        correct_predictions_per_genera = defaultdict(int)
        total_samples_per_genera = defaultdict(int)

        for data in validation_loader:
            inputs_img, inputs_dna, species, genera = data
            inputs_img, inputs_dna, species, genera = inputs_img.to(device), inputs_dna.to(device), species.to(device), genera.to(device)

            labels_outputs, genera_outputs = model(inputs_img, inputs_dna)

            labels_outputs = nn.Softmax(dim=1)(labels_outputs)
            genera_outputs = nn.Softmax(dim=1)(genera_outputs)

            predicted_labels_values, predicted_labels = torch.topk(labels_outputs.data, k=2, dim=1)
            _, predicted_genera = torch.max(genera_outputs.data, 1)

            differences = predicted_labels_values[:, 0] - predicted_labels_values[:, 1]
            genera_mask = differences <= threshold
            labels_mask = ~genera_mask

            # Update relative frequencies
            for idx in range(len(genera)):
                total_samples_per_labels[species[idx].item()] += 1

                if labels_mask[idx] and predicted_labels[idx, 0] == species[idx]:
                    correct_predictions_per_labels[species[idx].item()] += 1

                # if the sample is of one undescribed species
                if species[idx].item() not in np.unique(validation_set.seen_species):
                    assert genera[idx].item() in np.unique(validation_set.unseen_species_genera)
                    total_samples_per_genera[genera[idx].item()] += 1

                    if genera_mask[idx] and predicted_genera[idx] == genera[idx]:
                        correct_predictions_per_genera[genera[idx].item()] += 1

        accuracy_per_label = {label: (correct_predictions_per_labels[label] / total_samples_per_labels[label]) if total_samples_per_labels[label] > 0 else 0 for label in total_samples_per_labels}
        accuracy_per_genera = {genera: (correct_predictions_per_genera[genera] / total_samples_per_genera[genera]) if total_samples_per_genera[genera] > 0 else 0 for genera in total_samples_per_genera}

        test_described_species_accuracy = 0
        for label in np.unique(validation_set.seen_species):
            test_described_species_accuracy += accuracy_per_label[label]

        test_undescribed_species_accuracy = 0
        for genera in np.unique(validation_set.unseen_species_genera):
            test_undescribed_species_accuracy += accuracy_per_genera[genera]

        normalized_test_described_species_accuracy = test_described_species_accuracy / 629
        normalized_test_undescribed_species_accuracy = test_undescribed_species_accuracy / 97

        return normalized_test_described_species_accuracy, normalized_test_undescribed_species_accuracy

In [285]:
def test(model, threshold, batch_size):
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)
    model.eval()

    with torch.no_grad():
        correct_predictions_per_labels = defaultdict(int)
        total_samples_per_labels = defaultdict(int)
        correct_predictions_per_genera = defaultdict(int)
        total_samples_per_genera = defaultdict(int)

        for data in test_loader:
            inputs_img, inputs_dna, species, genera = data
            inputs_img, inputs_dna, species, genera = inputs_img.to(device), inputs_dna.to(device), species.to(device), genera.to(device)

            labels_outputs, genera_outputs = model(inputs_img, inputs_dna)

            labels_outputs = nn.Softmax(dim=1)(labels_outputs)
            genera_outputs = nn.Softmax(dim=1)(genera_outputs)

            predicted_labels_values, predicted_labels = torch.topk(labels_outputs.data, k=2, dim=1)
            _, predicted_genera = torch.max(genera_outputs.data, 1)

            differences = predicted_labels_values[:, 0] - predicted_labels_values[:, 1]
            genera_mask = differences <= threshold
            labels_mask = ~genera_mask

            # Update relative frequencies
            for idx in range(len(genera)):
                total_samples_per_labels[species[idx].item()] += 1

                if labels_mask[idx] and predicted_labels[idx, 0] == species[idx]:
                    correct_predictions_per_labels[species[idx].item()] += 1

                # if the sample is of one undescribed species
                if species[idx].item() not in np.unique(test_set.seen_species):
                    assert genera[idx].item() in np.unique(test_set.unseen_species_genera)
                    total_samples_per_genera[genera[idx].item()] += 1

                    if genera_mask[idx] and predicted_genera[idx] == genera[idx]:
                        correct_predictions_per_genera[genera[idx].item()] += 1

        accuracy_per_label = {label: (correct_predictions_per_labels[label] / total_samples_per_labels[label]) if total_samples_per_labels[label] > 0 else 0 for label in total_samples_per_labels}
        accuracy_per_genera = {genera: (correct_predictions_per_genera[genera] / total_samples_per_genera[genera]) if total_samples_per_genera[genera] > 0 else 0 for genera in total_samples_per_genera}

        test_described_species_accuracy = 0
        for label in np.unique(test_set.seen_species):
            test_described_species_accuracy += accuracy_per_label[label]

        test_undescribed_species_accuracy = 0
        for genera in np.unique(test_set.unseen_species_genera):
            test_undescribed_species_accuracy += accuracy_per_genera[genera]

        normalized_test_described_species_accuracy = test_described_species_accuracy / 770
        normalized_test_undescribed_species_accuracy = test_undescribed_species_accuracy / 134

        return normalized_test_described_species_accuracy, normalized_test_undescribed_species_accuracy

In [291]:
lr_values = [0.001]
momentum_values = [0.9]
batch_size_values = [16, 32]
num_epochs_values = [10]
threshold_values = np.linspace(0.7, 1, 2)**0.25
model_parameters = [
    (50, 5, 100, 5, 125, 5)
]

best_validation_accuracy = 0
best_parameters = {}

for lr in lr_values:
    for momentum in momentum_values:
        for num_epochs in num_epochs_values:
            for batch_size in batch_size_values:
                for embed_dim_1, num_heads_1, embed_dim_2, num_heads_2, embed_dim_3, num_heads_3 in model_parameters:
                    model = AttentionNet(652, 368, embed_dim_1, num_heads_1, embed_dim_2, num_heads_2, embed_dim_3, num_heads_3)
                    model.to(device)
                    train(model, lr, momentum, num_epochs, batch_size)
                    for threshold in threshold_values:
                        validation_species_accuracy, validation_genera_accuracy = validate(model, threshold, batch_size)
                        validation_accuracy = validation_species_accuracy + validation_genera_accuracy
                        print((f"Validation accuracy: {validation_accuracy}. Parameters: lr={lr}, momentum={momentum}, num_epochs={num_epochs}, batch_size={batch_size}, threshold={threshold}, "
                                f"num_heads_1={num_heads_1}, "
                                f"embed_dim_2={embed_dim_2}, "
                                f"num_heads_2={num_heads_2}, "
                                f"embed_dim_3={embed_dim_3}, "
                                f"num_heads_3={num_heads_3}"))

                        if validation_accuracy > best_validation_accuracy:
                            best_validation_accuracy = validation_accuracy
                            best_parameters = {
                                'learning_rate': lr,
                                'momentum': momentum,
                                'num_epochs': num_epochs,
                                'batch_size': batch_size,
                                'threshold': threshold,
                                'embed_dim_1': embed_dim_1,
                                'num_heads_1': num_heads_1,
                                'embed_dim_2': embed_dim_2,
                                'num_heads_2': num_heads_2,
                                'embed_dim_3': embed_dim_3,
                                'num_heads_3': num_heads_3
                            }

print("Best parameters:", best_parameters)

Validation accuracy: 1.637361336963266. Parameters: lr=0.001, momentum=0.9, num_epochs=10, batch_size=16, threshold=0.9146912192286945, num_heads_1=5, embed_dim_2=100, num_heads_2=5, embed_dim_3=125, num_heads_3=5
Validation accuracy: 0.717665167937111. Parameters: lr=0.001, momentum=0.9, num_epochs=10, batch_size=16, threshold=1.0, num_heads_1=5, embed_dim_2=100, num_heads_2=5, embed_dim_3=125, num_heads_3=5
Validation accuracy: 0.9250467024979424. Parameters: lr=0.001, momentum=0.9, num_epochs=10, batch_size=32, threshold=0.9146912192286945, num_heads_1=5, embed_dim_2=100, num_heads_2=5, embed_dim_3=125, num_heads_3=5
Validation accuracy: 0.7182467117014119. Parameters: lr=0.001, momentum=0.9, num_epochs=10, batch_size=32, threshold=1.0, num_heads_1=5, embed_dim_2=100, num_heads_2=5, embed_dim_3=125, num_heads_3=5
Best parameters: {'learning_rate': 0.001, 'momentum': 0.9, 'num_epochs': 10, 'batch_size': 16, 'threshold': 0.9146912192286945, 'embed_dim_1': 50, 'num_heads_1': 5, 'embed_

In [292]:
model = AttentionNet(797, 368, best_parameters['embed_dim_1'], best_parameters['num_heads_1'], best_parameters['embed_dim_2'], best_parameters['num_heads_2'], best_parameters['embed_dim_3'], best_parameters['num_heads_3'])
model.to(device)
train(model, best_parameters['learning_rate'], best_parameters['momentum'], best_parameters['num_epochs'], best_parameters['batch_size'], train_val=True, print_losses=True)
species_accuracy, genera_accuracy = test(model, best_parameters['threshold'], best_parameters['batch_size'])

print("-------------------------------------------------------------------------------")
print(f"Final model described species accuracy: ", species_accuracy)
print(f"Final model undescribed species accuracy: ", genera_accuracy)
print("-------------------------------------------------------------------------------")

[1,   200] Species loss: 6.200; Genera loss: 5.002
[1,   400] Species loss: 5.209; Genera loss: 3.632
[1,   600] Species loss: 4.533; Genera loss: 2.949
[1,   800] Species loss: 3.959; Genera loss: 2.419
[1,  1000] Species loss: 3.512; Genera loss: 2.087
[1,  1200] Species loss: 3.164; Genera loss: 1.900
[2,   200] Species loss: 2.683; Genera loss: 1.570
[2,   400] Species loss: 2.480; Genera loss: 1.426
[2,   600] Species loss: 2.172; Genera loss: 1.251
[2,   800] Species loss: 1.920; Genera loss: 1.108
[2,  1000] Species loss: 1.727; Genera loss: 0.999
[2,  1200] Species loss: 1.580; Genera loss: 0.910
[3,   200] Species loss: 1.303; Genera loss: 0.749
[3,   400] Species loss: 1.221; Genera loss: 0.706
[3,   600] Species loss: 1.104; Genera loss: 0.672
[3,   800] Species loss: 0.938; Genera loss: 0.573
[3,  1000] Species loss: 0.861; Genera loss: 0.541
[3,  1200] Species loss: 0.744; Genera loss: 0.469
[4,   200] Species loss: 0.559; Genera loss: 0.385
[4,   400] Species loss: 0.527;