In [104]:
#!wget https://dataworks.indianapolis.iu.edu/bitstream/handle/11243/41/data.zip
#!unzip data.zip
#!rm data.zip

In [105]:
import scipy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from functools import reduce

torch.manual_seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


# Dataset import

In [106]:
class ImageDNATrainDataset():
    def __init__(self, train=True):
        splits_mat = scipy.io.loadmat("data/INSECTS/splits.mat")
        train_loc = splits_mat["train_loc"]-1

        TRAINING_SAMPLES_NUMBER = 12481
        TRAINING_LABELS_NUMBER = 652
        
        assert len(train_loc[0]) == TRAINING_SAMPLES_NUMBER

        indeces = train_loc
        # indeces.shape is (1, |indeces|), so we extract the whole list using [0]
        indeces = indeces[0]

        data_mat = scipy.io.loadmat("data/INSECTS/data.mat")
        self.embeddings_img = torch.from_numpy(data_mat["embeddings_img"][indeces]).float()
        self.embeddings_dna = torch.from_numpy(data_mat["embeddings_dna"][indeces]).float()

        # Remap seen species in [0, 651]
        seen_species = data_mat["labels"][train_loc][0]
        seen_species_mapping = {label: i for i, label in enumerate(np.unique(seen_species))}

        species_mapping = seen_species_mapping
        assert len(species_mapping) == TRAINING_LABELS_NUMBER

        species = data_mat["labels"][indeces]
        remapped_species = np.array([species_mapping[label.item()] for label in species])
        self.remapped_species = torch.from_numpy(remapped_species).long()
        self.species = torch.from_numpy(species).long()

        assert len(torch.unique(self.remapped_species)) == TRAINING_LABELS_NUMBER

        # data_mat['G'] returns a ndarray of type uint16, therefore we convert into int16 before invoking from_numpy
        self.G = torch.from_numpy(data_mat["G"].astype(np.int16)).long()
        self.genera = torch.empty(self.species.shape).long()
        for i in range(indeces.size):
            self.genera[i][0] = self.G[self.species[i][0] - 1][0] - 1041

        assert len(self.genera) == TRAINING_SAMPLES_NUMBER

        self.species_names = data_mat["species"][indeces]
        self.ids = data_mat["ids"][indeces]

    def __len__(self):
        return len(self.embeddings_dna)

    def __getitem__(self, idx):
        embedding_img = self.embeddings_img[idx]
        embedding_dna = self.embeddings_dna[idx]
        label = self.remapped_species[idx].item()
        genera = self.genera[idx].item()

        return embedding_img.view(1, -1), embedding_dna.view(1, -1), label, genera
    

In [107]:
class ImageDNAValidationDataset():
    def __init__(self, train=True):
        splits_mat = scipy.io.loadmat("data/INSECTS/splits.mat")
        train_loc = splits_mat["train_loc"]-1
        val_seen_loc = splits_mat["val_seen_loc"]-1
        val_unseen_loc = splits_mat["val_unseen_loc"]-1

        VALIDATION_SAMPLES_NUMBER = 6939
        VALIDATION_SPECIES_NUMBER = 774
        TRAIN_VALIDATION_SPECIES_NUMBER = 797

        indeces = np.concatenate((val_seen_loc, val_unseen_loc), axis=1)
        # indeces.shape is (1, |indeces|), so we extract the whole list using [0]
        indeces = indeces[0]
        assert len(indeces) == VALIDATION_SAMPLES_NUMBER

        data_mat = scipy.io.loadmat("data/INSECTS/data.mat")
        self.embeddings_img = torch.from_numpy(data_mat["embeddings_img"][indeces]).float()
        self.embeddings_dna = torch.from_numpy(data_mat["embeddings_dna"][indeces]).float()

        # Remap seen species in [0, 651]
        seen_species = data_mat["labels"][train_loc][0]
        seen_species_mapping = {label: i for i, label in enumerate(np.unique(seen_species))}

        # Remap unseen species during validation in [652, 796]
        unseen_species = data_mat["labels"][val_unseen_loc][0]
        unseen_species_mapping = {label: i + 652 for i, label in enumerate(np.unique(unseen_species))}

        # Union of the two mappings, allows to full remap all the labels
        species_mapping = seen_species_mapping | unseen_species_mapping
        assert len(species_mapping) == TRAIN_VALIDATION_SPECIES_NUMBER

        species = data_mat["labels"][indeces]
        remapped_species = np.array([species_mapping[label.item()] for label in species])
        self.remapped_species = torch.from_numpy(remapped_species).long()
        self.species = torch.from_numpy(species).long()

        assert len(torch.unique(self.remapped_species)) == VALIDATION_SPECIES_NUMBER

        # data_mat['G'] returns a ndarray of type uint16, therefore we convert into int16 before invoking from_numpy
        self.G = torch.from_numpy(data_mat["G"].astype(np.int16)).long()
        self.genera = torch.empty(self.species.shape).long()
        for i in range(indeces.size):
            self.genera[i][0] = self.G[self.species[i][0] - 1][0] - 1041

        # Compute genera of unseen species in the validation set
        unseen_species_genera = []
        for i in val_unseen_loc[0]:
            unseen_species_genera.append(data_mat["G"][data_mat["labels"][i][0] - 1][0] - 1041)
        self.unseen_species_genera = np.array(unseen_species_genera)
        assert len(np.unique(self.unseen_species_genera)) == 97

        # Compute seen species number in the validation set
        seen_species = []
        for i in val_seen_loc[0]:
            seen_species.append(species_mapping[data_mat["labels"][i].item()])
        self.seen_species = np.array(seen_species)
        assert len(np.unique(self.seen_species)) == 629


        self.species_names = data_mat["species"][indeces]
        self.ids = data_mat["ids"][indeces]

    def __len__(self):
        return len(self.embeddings_dna)

    def __getitem__(self, idx):
        embedding_img = self.embeddings_img[idx]
        embedding_dna = self.embeddings_dna[idx]
        label = self.remapped_species[idx].item()
        genera = self.genera[idx].item()

        return embedding_img.view(1, -1), embedding_dna.view(1, -1), label, genera

In [108]:
class ImageDNATestDataset(Dataset):
    def __init__(self, train=True):
        splits_mat = scipy.io.loadmat("data/INSECTS/splits.mat")
        train_loc = splits_mat["train_loc"]-1
        trainval_loc = splits_mat["trainval_loc"]-1
        test_seen_loc = splits_mat["test_seen_loc"]-1
        test_unseen_loc = splits_mat["test_unseen_loc"]-1
        val_seen_loc = splits_mat["val_seen_loc"]-1
        val_unseen_loc = splits_mat["val_unseen_loc"]-1

        assert len(trainval_loc[0] == 19420)
        assert len(test_seen_loc[0] == 4965)
        assert len(test_unseen_loc[0] == 8463)

        indeces = np.concatenate((test_seen_loc, test_unseen_loc), axis=1)
        # indeces.shape is (1, |indeces|), so we extract the whole list using [0]
        indeces = indeces[0]

        data_mat = scipy.io.loadmat("data/INSECTS/data.mat")
        self.embeddings_img = torch.from_numpy(data_mat["embeddings_img"][indeces]).float()
        self.embeddings_dna = torch.from_numpy(data_mat["embeddings_dna"][indeces]).float()

        # Remap seen species in [0, 651]
        seen_species = data_mat["labels"][train_loc][0]
        seen_species_mapping = {label: i for i, label in enumerate(np.unique(seen_species))}

        # Remap unseen species during validation in [652, 796]
        unseen_species_validation = data_mat["labels"][val_unseen_loc][0]
        unseen_species_validation_mapping = {label: i + 652 for i, label in enumerate(np.unique(unseen_species_validation))}

        # Remap unseen species during test in [797, 1039]
        unseen_species_test = data_mat["labels"][test_unseen_loc][0]
        unseen_species_test_mapping = {label: i + 797 for i, label in enumerate(np.unique(unseen_species_test))}

        assert reduce(np.intersect1d, (seen_species, unseen_species_validation, unseen_species_test)).size == 0

        # Union of the two mappings, allows to full remap all the labels
        labels_mapping = seen_species_mapping | unseen_species_validation_mapping | unseen_species_test_mapping
        assert len(labels_mapping) == 1040

        species = data_mat["labels"][indeces]  # Consider only train/test labels
        remapped_labels = np.array([labels_mapping[label.item()] for label in species])
        self.remapped_labels = torch.from_numpy(remapped_labels).long()
        self.species = torch.from_numpy(species).long()

        assert len(torch.unique(self.remapped_labels)) == 1013

        # data_mat['G'] returns a ndarray of type uint16, therefore we convert into int16 before invoking from_numpy
        self.G = torch.from_numpy(data_mat["G"].astype(np.int16)).long()
        self.genera = torch.empty(self.species.shape).long()
        for i in range(indeces.size):
            self.genera[i][0] = self.G[self.species[i][0] - 1][0] - 1041

        assert len(self.genera) == 13428

        # Compute genera of unseen species
        unseen_species_genera = []
        for i in test_unseen_loc[0]:
            unseen_species_genera.append(data_mat["G"][data_mat["labels"][i][0] - 1][0] - 1041)

        self.unseen_species_genera = np.array(unseen_species_genera)
        assert len(np.unique(self.unseen_species_genera)) == 134

        # Compute seen species
        seen_species = []
        for i in test_seen_loc[0]:
            seen_species.append(labels_mapping[data_mat["labels"][i].item()])
        self.seen_species = np.array(seen_species)
        assert len(np.unique(self.seen_species)) == 770

        self.species_name = data_mat["species"][indeces]
        self.ids = data_mat["ids"][indeces]

    def __len__(self):
        return len(self.embeddings_dna)

    def __getitem__(self, idx):
        embedding_img = self.embeddings_img[idx]
        embedding_dna = self.embeddings_dna[idx]
        label = self.remapped_labels[idx].item()
        genera = self.genera[idx].item()

        return embedding_img.view(1, -1), embedding_dna.view(1, -1), label, genera

In [109]:
class ImageDNATrainValidationDataset(Dataset):
    def __init__(self, train=True):
        splits_mat = scipy.io.loadmat("data/INSECTS/splits.mat")
        train_loc = splits_mat["train_loc"]-1
        trainval_loc = splits_mat["trainval_loc"]-1
        test_seen_loc = splits_mat["test_seen_loc"]-1
        test_unseen_loc = splits_mat["test_unseen_loc"]-1
        val_seen_loc = splits_mat["val_seen_loc"]-1
        val_unseen_loc = splits_mat["val_unseen_loc"]-1

        assert len(trainval_loc[0] == 19420)
        assert len(test_seen_loc[0] == 4965)
        assert len(test_unseen_loc[0] == 8463)

        indeces = trainval_loc
        # indeces.shape is (1, |indeces|), so we extract the whole list using [0]
        indeces = indeces[0]

        data_mat = scipy.io.loadmat("data/INSECTS/data.mat")
        self.embeddings_img = torch.from_numpy(data_mat["embeddings_img"][indeces]).float()
        self.embeddings_dna = torch.from_numpy(data_mat["embeddings_dna"][indeces]).float()

        # Remap seen species in [0, 651]
        seen_species = data_mat["labels"][train_loc][0]
        seen_species_mapping = {label: i for i, label in enumerate(np.unique(seen_species))}

        # Remap unseen species during validation in [652, 796]
        unseen_species_validation = data_mat["labels"][val_unseen_loc][0]
        unseen_species_validation_mapping = {label: i + 652 for i, label in enumerate(np.unique(unseen_species_validation))}

        # Remap unseen species during test in [797, 1039]
        unseen_species_test = data_mat["labels"][test_unseen_loc][0]
        unseen_species_test_mapping = {label: i + 797 for i, label in enumerate(np.unique(unseen_species_test))}

        assert reduce(np.intersect1d, (seen_species, unseen_species_validation, unseen_species_test)).size == 0
        
        # Union of the two mappings, allows to full remap all the labels
        labels_mapping = seen_species_mapping | unseen_species_validation_mapping | unseen_species_test_mapping
        assert len(labels_mapping) == 1040

        labels = data_mat["labels"][indeces]  # Consider only train
        remapped_labels = np.array([labels_mapping[label.item()] for label in labels])
        self.remapped_labels = torch.from_numpy(remapped_labels).long()
        self.labels = torch.from_numpy(labels).long()

        assert len(torch.unique(self.remapped_labels)) == 797

        # data_mat['G'] returns a ndarray of type uint16, therefore we convert into int16 before invoking from_numpy
        self.G = torch.from_numpy(data_mat["G"].astype(np.int16)).long()
        self.genera = torch.empty(self.labels.shape).long()
        for i in range(indeces.size):
            self.genera[i][0] = self.G[self.labels[i][0] - 1][0] - 1041
        
        assert len(self.genera) == 19420

        self.species = data_mat["species"][indeces]
        self.ids = data_mat["ids"][indeces]

    def __len__(self):
        return len(self.embeddings_dna)

    def __getitem__(self, idx):
        embedding_img = self.embeddings_img[idx]
        embedding_dna = self.embeddings_dna[idx]
        label = self.remapped_labels[idx].item()
        genera = self.genera[idx].item()

        return embedding_img.view(1, -1), embedding_dna.view(1, -1), label, genera
    
class ImageDNATrainDataset():
    def __init__(self, train=True):
        splits_mat = scipy.io.loadmat("data/INSECTS/splits.mat")
        train_loc = splits_mat["train_loc"]-1

        TRAINING_SAMPLES_NUMBER = 12481
        TRAINING_LABELS_NUMBER = 652
        
        assert len(train_loc[0]) == TRAINING_SAMPLES_NUMBER

        indeces = train_loc
        # indeces.shape is (1, |indeces|), so we extract the whole list using [0]
        indeces = indeces[0]

        data_mat = scipy.io.loadmat("data/INSECTS/data.mat")
        self.embeddings_img = torch.from_numpy(
            data_mat["embeddings_img"][indeces]
        ).float()
        self.embeddings_dna = torch.from_numpy(
            data_mat["embeddings_dna"][indeces]
        ).float()

        # Remap seen species in [0, 651]
        seen_species = data_mat["labels"][train_loc][0]
        seen_species_mapping = {label: i for i, label in enumerate(np.unique(seen_species))}

        species_mapping = seen_species_mapping
        assert len(species_mapping) == TRAINING_LABELS_NUMBER

        species = data_mat["labels"][indeces]
        remapped_species = np.array([species_mapping[label.item()] for label in species])
        self.remapped_species = torch.from_numpy(remapped_species).long()
        self.species = torch.from_numpy(species).long()

        assert len(torch.unique(self.remapped_species)) == TRAINING_LABELS_NUMBER

        # data_mat['G'] returns a ndarray of type uint16, therefore we convert into int16 before invoking from_numpy
        self.G = torch.from_numpy(data_mat["G"].astype(np.int16)).long()
        self.genera = torch.empty(self.species.shape).long()
        for i in range(indeces.size):
            self.genera[i][0] = self.G[self.species[i][0] - 1][0] - 1041

        assert len(self.genera) == TRAINING_SAMPLES_NUMBER

        self.species_names = data_mat["species"][indeces]
        self.ids = data_mat["ids"][indeces]

    def __len__(self):
        return len(self.embeddings_dna)

    def __getitem__(self, idx):
        embedding_img = self.embeddings_img[idx]
        embedding_dna = self.embeddings_dna[idx]
        label = self.remapped_species[idx].item()
        genera = self.genera[idx].item()

        return embedding_img.view(1, -1), embedding_dna.view(1, -1), label, genera
    

# Model definition

In [110]:
class CrossNet(nn.Module):
    def __init__(self, num_seen_species, num_genera):
        super(CrossNet, self).__init__()
        # Pre-core network
        # Image embedding dimensionality reduction
        self.img_fc1 = nn.Linear(2048, 1024)
        self.img_fc2 = nn.Linear(1024, 500)

        # Separate processing pipelines
        self.img_resblock1 = ResidualBlock1d(1, 4)
        self.img_resblock2 = ResidualBlock1d(4, 4)

        self.dna_resblock1 = ResidualBlock1d(1, 4)
        self.dna_resblock2 = ResidualBlock1d(4, 4)

        self.resblock1 = ResidualBlock1d(4, 4)
        self.resblock2 = ResidualBlock1d(4, 4)
        self.resblock3 = ResidualBlock1d(4, 4)
        self.resblock4 = ResidualBlock1d(4, 4)

        # Fully connected layers for classification
        self.fc_species_1 = nn.Linear(4*2548, 2048)
        self.fc_species_2 = nn.Linear(2048, num_seen_species)

        self.fc_genera_1 = nn.Linear(4*2548, 2048)
        self.fc_genera_2 = nn.Linear(2048, num_genera)

        # Dropout layers for regularization
        self.conv_dropout = nn.Dropout(0.2)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x_img, x_dna):
        # Reduce dimensionality of image embeddings
        #x_img = F.relu(self.img_fc1(x_img))
        #x_img = F.relu(self.img_fc2(x_img))

        x_img = self.img_resblock1(x_img)
        x_img = self.img_resblock2(x_img)

        x_dna = self.dna_resblock1(x_dna)
        x_dna = self.dna_resblock2(x_dna)

        x = torch.cat((x_img, x_dna), axis=2)

        # CrossNet core
        x = F.relu(self.resblock1(x))
        x = self.conv_dropout(F.relu(self.resblock2(x)))
        x = F.relu(self.resblock3(x))
        x = self.conv_dropout(F.relu(self.resblock4(x)))

        x = x.view(x.shape[0], 4*2548)
        #x = self.dropout(F.relu(self.fc1(x)))

        x_species = x.clone()
        x_genera = x.clone()

        # Dropout for regularization
        x_species = self.dropout(F.relu(self.fc_species_1(x_species)))
        x_species = self.fc_species_2(x_species)

        x_genera = self.dropout(F.relu(self.fc_genera_1(x_genera)))
        x_genera = self.fc_genera_2(x_genera)

        return x_species, x_genera

class ResidualBlock1d(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock1d, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm1d(out_channels)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        return out

# Creating datasets and dataloaders

In [111]:
model = CrossNet(797, 368)
model.to(device)
training_set = ImageDNATrainDataset()
validation_set = ImageDNAValidationDataset()
test_set = ImageDNATestDataset()
training_validation_set = ImageDNATrainValidationDataset()

batch_size = 32
loader = torch.utils.data.DataLoader(
    training_set, batch_size=batch_size, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=batch_size, shuffle=False
)

Print datasets statistics.

In [112]:
print(f"Training set has {len(training_set)} instances.")
print(f"Test set has {len(test_set)} instances.")

inputs_img, inputs_dna, species, genera = next(iter(loader))
print(f"Training input batch: {inputs_img.shape}, {inputs_dna.shape}")
print(f"Training label batch: {species.shape}")
print(f"Training genera batch: {genera.shape}")

Training set has 12481 instances.
Test set has 13428 instances.
Training input batch: torch.Size([32, 1, 2048]), torch.Size([32, 1, 500])
Training label batch: torch.Size([32])
Training genera batch: torch.Size([32])


# Training

In [113]:
def train(model, lr, momentum, num_epochs, batch_size, train_val=False):
    model.train()
    criterion_species = torch.nn.CrossEntropyLoss()
    criterion_genera = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    if train_val:
        loader = torch.utils.data.DataLoader(training_validation_set, batch_size=batch_size, shuffle=True)
    else:
        loader = torch.utils.data.DataLoader(training_set, batch_size=batch_size, shuffle=True)

    for epoch in range(num_epochs):

        running_labels_loss = 0.0
        running_genera_loss = 0.0

        for i, data in enumerate(loader, 0):
            inputs_img, inputs_dna, species, genera = data
            inputs_img, inputs_dna, species, genera = inputs_img.to(device), inputs_dna.to(device), species.to(device), genera.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            labels_outputs, genera_outputs = model(inputs_img, inputs_dna)
            labels_loss = criterion_species(labels_outputs, species)
            genera_loss = criterion_genera(genera_outputs, genera)
            total_loss = labels_loss + genera_loss
            total_loss.backward()
            optimizer.step()

            # print statistics
            running_labels_loss += labels_loss.item()
            running_genera_loss += genera_loss.item()
            print_step = 200
            if i % print_step == print_step - 1:
                print(f"[{epoch + 1}, {i + 1:5d}] Species loss: {running_labels_loss / print_step:.3f}; Genera loss: {running_genera_loss / print_step:.3f}")
                running_labels_loss = 0.0
                running_genera_loss = 0.0
        
print("Finished Training")



Finished Training


In [114]:
def validate(model, threshold):

    validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=batch_size, shuffle=False)
    model.eval()

    with torch.no_grad():
        correct_predictions_per_labels = defaultdict(int)
        total_samples_per_labels = defaultdict(int)
        correct_predictions_per_genera = defaultdict(int)
        total_samples_per_genera = defaultdict(int)

        for data in validation_loader:
            inputs_img, inputs_dna, species_name, genera = data
            inputs_img, inputs_dna, species_name, genera = inputs_img.to(device), inputs_dna.to(device), species_name.to(device), genera.to(device)

            labels_outputs, genera_outputs = model(inputs_img, inputs_dna)

            labels_outputs = nn.Softmax(dim=1)(labels_outputs)
            genera_outputs = nn.Softmax(dim=1)(genera_outputs)

            predicted_labels_values, predicted_labels = torch.topk(labels_outputs.data, k=2, dim=1)
            _, predicted_genera = torch.max(genera_outputs.data, 1)

            differences = predicted_labels_values[:, 0] - predicted_labels_values[:, 1]
            genera_mask = differences <= threshold
            labels_mask = ~genera_mask

            for idx in range(len(genera)):
                total_samples_per_genera[genera[idx].item()] += 1
                total_samples_per_labels[species_name[idx].item()] += 1
                if genera_mask[idx]:
                    if predicted_genera[idx] == genera[idx]:
                        correct_predictions_per_genera[genera[idx].item()] += 1
                if labels_mask[idx]:
                    if predicted_labels[idx, 0] == species_name[idx]:
                        correct_predictions_per_labels[species_name[idx].item()] += 1

        accuracy_per_label = {label: (correct_predictions_per_labels[label] / total_samples_per_labels[label]) if total_samples_per_labels[label] > 0 else 0 for label in total_samples_per_labels}
        accuracy_per_genera = {genera: (correct_predictions_per_genera[genera] / total_samples_per_genera[genera]) if total_samples_per_genera[genera] > 0 else 0 for genera in total_samples_per_genera}

        test_described_species_accuracy = 0
        for label in np.unique(validation_set.seen_species):
            test_described_species_accuracy += accuracy_per_label[label]

        test_undescribed_species_accuracy = 0
        for genera in np.unique(validation_set.unseen_species_genera):
            test_undescribed_species_accuracy += accuracy_per_genera[genera]

        normalized_test_described_species_accuracy = test_described_species_accuracy / 629
        normalized_test_undescribed_species_accuracy = test_undescribed_species_accuracy / 97

        print("-------------------------------------------------------------------------------")
        print(f"threshold: {threshold}")
        print(f"Validation described species accuracy: {normalized_test_described_species_accuracy}")
        print(f"Validation undescribed species accuracy: {normalized_test_undescribed_species_accuracy}")
        print("-------------------------------------------------------------------------------")

        return normalized_test_described_species_accuracy + normalized_test_undescribed_species_accuracy

In [115]:
def test(model, threshold):
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)
    model.eval()

    with torch.no_grad():
        correct_predictions_per_labels = defaultdict(int)
        total_samples_per_labels = defaultdict(int)
        correct_predictions_per_genera = defaultdict(int)
        total_samples_per_genera = defaultdict(int)

        for data in test_loader:
            inputs_img, inputs_dna, species_name, genera = data
            inputs_img, inputs_dna, species_name, genera = inputs_img.to(device), inputs_dna.to(device), species_name.to(device), genera.to(device)

            labels_outputs, genera_outputs = model(inputs_img, inputs_dna)

            labels_outputs = nn.Softmax(dim=1)(labels_outputs)
            genera_outputs = nn.Softmax(dim=1)(genera_outputs)

            predicted_labels_values, predicted_labels = torch.topk(labels_outputs.data, k=2, dim=1)
            _, predicted_genera = torch.max(genera_outputs.data, 1)

            differences = predicted_labels_values[:, 0] - predicted_labels_values[:, 1]
            genera_mask = differences <= threshold
            labels_mask = ~genera_mask

            for idx in range(len(genera)):
                total_samples_per_genera[genera[idx].item()] += 1
                total_samples_per_labels[species_name[idx].item()] += 1
                if genera_mask[idx]:
                    if predicted_genera[idx] == genera[idx]:
                        correct_predictions_per_genera[genera[idx].item()] += 1
                if labels_mask[idx]:
                    if predicted_labels[idx, 0] == species_name[idx]:
                        correct_predictions_per_labels[species_name[idx].item()] += 1

        accuracy_per_label = {label: (correct_predictions_per_labels[label] / total_samples_per_labels[label]) if total_samples_per_labels[label] > 0 else 0 for label in total_samples_per_labels}
        accuracy_per_genera = {genera: (correct_predictions_per_genera[genera] / total_samples_per_genera[genera]) if total_samples_per_genera[genera] > 0 else 0 for genera in total_samples_per_genera}

        test_described_species_accuracy = 0
        for label in np.unique(test_set.seen_species):
            test_described_species_accuracy += accuracy_per_label[label]

        test_undescribed_species_accuracy = 0
        for genera in np.unique(test_set.unseen_species_genera):
            test_undescribed_species_accuracy += accuracy_per_genera[genera]

        normalized_test_described_species_accuracy = test_described_species_accuracy / 770
        normalized_test_undescribed_species_accuracy = test_undescribed_species_accuracy / 134

        print("-------------------------------------------------------------------------------")
        print(f"threshold: {threshold}")
        print(f"Test described species accuracy: {normalized_test_described_species_accuracy}")
        print(f"Test undescribed species accuracy: {normalized_test_undescribed_species_accuracy}")
        print("-------------------------------------------------------------------------------")

        return normalized_test_described_species_accuracy, normalized_test_undescribed_species_accuracy

In [116]:
lr_values = [0.001]
momentum_values = [0.9]
batch_size_values = [16, 32]
num_epochs_values = [5]
threshold_values = np.linspace(0.7, 1, 2)**0.25

best_validation_accuracy = 0
best_parameters = {}

for lr in lr_values:
    for momentum in momentum_values:
        for num_epochs in num_epochs_values:
            for batch_size in batch_size_values:
                model = CrossNet(652, 368)
                model.to(device)
                train(model, lr, momentum, num_epochs, batch_size)
                for threshold in threshold_values:
                    validation_accuracy = validate(model, threshold)
                    print(f"Validation accuracy for lr={lr}, momentum={momentum}, num_epochs={num_epochs}, batch_size={batch_size}, threshold={threshold}: {validation_accuracy}")

                    if validation_accuracy > best_validation_accuracy:
                        best_validation_accuracy = validation_accuracy
                        best_parameters = {'learning_rate': lr, 'momentum': momentum, 'num_epochs': num_epochs, 'batch_size': batch_size, 'threshold': threshold}

print("Best parameters:", best_parameters)

[1,   200] Species loss: 4.562; Genera loss: 3.506
[1,   400] Species loss: 1.697; Genera loss: 1.120
[1,   600] Species loss: 0.559; Genera loss: 0.350
[2,   200] Species loss: 0.098; Genera loss: 0.147
[2,   400] Species loss: 0.066; Genera loss: 0.053
[2,   600] Species loss: 0.042; Genera loss: 0.034
[3,   200] Species loss: 0.011; Genera loss: 0.013
[3,   400] Species loss: 0.010; Genera loss: 0.010
[3,   600] Species loss: 0.011; Genera loss: 0.010
[4,   200] Species loss: 0.004; Genera loss: 0.003
[4,   400] Species loss: 0.015; Genera loss: 0.011
[4,   600] Species loss: 0.009; Genera loss: 0.008
[5,   200] Species loss: 0.004; Genera loss: 0.006
[5,   400] Species loss: 0.004; Genera loss: 0.003
[5,   600] Species loss: 0.005; Genera loss: 0.003
-------------------------------------------------------------------------------
threshold: 0.9146912192286945
Validation described species accuracy: 0.9832655535596713
Validation undescribed species accuracy: 0.3376109088925106
-------

In [117]:
model = CrossNet(797, 368)
model.to(device)
train(model, best_parameters['learning_rate'], best_parameters['momentum'], best_parameters['num_epochs'], best_parameters['batch_size'], train_val=True)
species_accuracy, genera_accuracy = test(model, best_parameters['threshold'])

print("-------------------------------------------------------------------------------")
print(f"Final model described species accuracy: ", species_accuracy)
print(f"Final model undescribed species accuracy: ", genera_accuracy)
print("-------------------------------------------------------------------------------")

[1,   200] Species loss: 4.168; Genera loss: 2.671
[1,   400] Species loss: 0.916; Genera loss: 0.437
[1,   600] Species loss: 0.164; Genera loss: 0.092
[2,   200] Species loss: 0.047; Genera loss: 0.029
[2,   400] Species loss: 0.030; Genera loss: 0.017
[2,   600] Species loss: 0.038; Genera loss: 0.029
[3,   200] Species loss: 0.010; Genera loss: 0.008
[3,   400] Species loss: 0.006; Genera loss: 0.010
[3,   600] Species loss: 0.007; Genera loss: 0.010
[4,   200] Species loss: 0.006; Genera loss: 0.005
[4,   400] Species loss: 0.006; Genera loss: 0.004
[4,   600] Species loss: 0.007; Genera loss: 0.005
[5,   200] Species loss: 0.004; Genera loss: 0.003
[5,   400] Species loss: 0.004; Genera loss: 0.003
[5,   600] Species loss: 0.003; Genera loss: 0.004
-------------------------------------------------------------------------------
threshold: 0.9146912192286945
Test described species accuracy: 0.9843164834205329
Test undescribed species accuracy: 0.3728011226537804
-------------------