In [97]:
import re
import numpy as np
import pandas as pd
import math
from time import time
from copy import deepcopy

import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.utils.tensorboard import SummaryWriter
from tensorflow.keras.preprocessing.sequence import pad_sequences as pad
from sklearn.preprocessing import LabelEncoder

from Models import FullyConnected, RNN, RNNEmbed, LSTM, BiLSTM, Transformer

from sklearn.metrics import f1_score
from torchmetrics.functional.classification import f1_score

In [68]:
class PfamDataset(Dataset):
    
    # Initiating the label encoder
    label_encoder = LabelEncoder()
    
    def __init__(self, title, num_classes, k, mapping, max_seq_len, oversampling):
        self.title = title
        self.k = k
        self.num_classes = num_classes
        self.oversampling = oversampling
        self.mapping = mapping
        self.max_seq_len = max_seq_len
        self.data = None
        self.X = None
        self.y = None
        self.weights = None
        self.len = None
        self.__initiate__()
        
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        X = self.X[idx]
        y = self.y[idx]
        return X, y
    
    def __initiate__(self):
        ''' Loads the datasets and generate X and Y'''
        path_post = "" if self.k == "all" else "_" + str(self.k)
        self.data = pd.read_csv("data/" + self.title + path_post + ".csv")
        self.X = self.get_inputs()
        self.y = self.get_labels()
        self.weights = self.get_class_weights()
        
    def replace_uncommon_amino_acids(self):
        ''' Replaces uncommon amino acids in sequences with a collective "X" '''
        replace_minority_aa = lambda data: [re.sub(r"[XUBOZ]", "X", sequence) \
                                            for sequence in data["sequence"]]
        self.data["sequence"] = replace_minority_aa(self.data)
        
    def get_inputs(self, padding="pre"):
        ''' Returns the sequences in a format suitable for our model '''
        self.replace_uncommon_amino_acids()
        encode = lambda sequence: [self.mapping[aa] for aa in sequence[:self.max_seq_len - 2]]
        sequence_list = [encode(sequence) for sequence in self.data["sequence"]]
        sequence_list = [[22] + sequence + [23] for sequence in sequence_list]
        return torch.tensor(pad(sequence_list, maxlen=self.max_seq_len, 
                                padding=padding, truncating="post")).float()
        
    def get_labels(self):
        ''' Return the labels in a format suitable for the model'''
        if self.title == "train" or self.title == "test":
            encoded_labels = self.label_encoder.fit_transform(self.data["family_accession"])
        else:
            encoded_labels = self.label_encoder.transform(self.data["family_accession"])
        return F.one_hot(torch.tensor(encoded_labels).long(), 
                         num_classes=self.num_classes).float()
    
    def get_class_weights(self):
        ''' Returns the calss weights for use in WeightedRandomSampler'''
        class_counts = self.data["family_accession"].value_counts()
        class_counts_dict = (1 / class_counts).to_dict()
        if self.oversampling == "regular":
            self.data["family_weight"] = [class_counts_dict[x] for x in self.data["family_accession"]]
        elif self.oversampling == "sqrt":
            self.data["family_weight"] = [np.sqrt(class_counts_dict[x]) for x in self.data["family_accession"]]
        elif self.oversampling == "beta":
            beta = 0.9
            class_counts_dict = class_counts.to_dict()
            class_counts_dict = {i: 1 / ((1 - beta**c)/(1 - beta)) for (i, c) in class_counts_dict.items()}
            self.data["family_weight"] = [class_counts_dict[x] for x in self.data["family_accession"]]
        else:
            self.data["family_weight"] = [1] * len(self.data["family_accession"])
        self.len = len(self.data["family_weight"])
        return self.data["family_weight"].to_list()
    

In [54]:
def train_evaluate(
    ''' The main loop for training the models '''
    label: str,
    model_parameters: dict,
    oversampling: str = "none",
    epochs: int = 10, 
    num_classes: int = 100,
    class_gap: int = 1,
    max_seq_len: int = 128,
    vocab_len = 24,
    batch_size: int = 64,
    lr: float = 0.001,
):
    
    # Device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Tensorboard writer
    writer = SummaryWriter(log_dir="runs/" + label)
                                    # Model _ Hidden Dim _ Layers _ Dropout
    
    # Mapping
    amino_acids = ["A","C","D","E","F","G","H","I","K","L",
                   "M","N","P","Q","R","S","T","V","W","Y","X"]
    mapping = {aa:i + 1 for i, aa in enumerate(amino_acids)}
    mapping.update({'X': 21, 'U': 21, 'B': 21, 'O': 21, 'Z': 21})
    
    # Creating the datasets
    train_dataset = PfamDataset("train", num_classes, class_gap, 
                                mapping, max_seq_len, oversampling)
    validation_dataset = PfamDataset("validation", num_classes, class_gap, 
                                     mapping, max_seq_len, oversampling="regular")
    
    # Weighted Random Sampler
    train_sampler = WeightedRandomSampler(train_dataset.weights, train_dataset.len)
    validation_sampler = WeightedRandomSampler(validation_dataset.weights, validation_dataset.len)
    
    # Creating the dataloaders
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, 
                                                   num_workers=0, sampler=train_sampler)
    validation_dataloader = torch.utils.data.DataLoader(validation_dataset,
                                                        batch_size=batch_size, 
                                                        num_workers=0, sampler=validation_sampler)
    
    # Model
    if model_parameters["type"] == "fc":
        model = FullyConnected(
            dropout = model_parameters["dropout"],
            hidden_dim = model_parameters["hidden_dim"], 
            num_layers = model_parameters["layers"],
            input_dim = max_seq_len, 
            output_dim = num_classes, 
        )
    elif model_parameters["type"] == "rnn":
        model = RNN(
            dropout = model_parameters["dropout"],
            hidden_dim = model_parameters["hidden_dim"],
            num_layers = model_parameters["num_layers"],
            vocab_len = vocab_len,
            output_dim = num_classes, 
            device = device, 
        )
    elif model_parameters["type"] == "rnn_embed":
        model = RNNEmbed(
            dropout = model_parameters["dropout"],
            hidden_dim = model_parameters["hidden_dim"], 
            embed_size = model_parameters["embed_size"],
            num_layers = model_parameters["num_layers"],
            vocab_len = vocab_len,
            output_dim = num_classes, 
            device = device, 
        )
    elif model_parameters["type"] == "lstm":
        model = LSTM(
            dropout = model_parameters["dropout"],
            hidden_dim = model_parameters["hidden_dim"], 
            embed_size = model_parameters["embed_size"],
            num_layers = model_parameters["num_layers"],
            vocab_len = vocab_len,
            output_dim = num_classes, 
            device = device, 
        )
    elif model_parameters["type"] == "bi-lstm":
        model = BiLSTM(
            dropout = model_parameters["dropout"],
            hidden_dim = model_parameters["hidden_dim"], 
            embed_size = model_parameters["embed_size"],
            num_layers = model_parameters["num_layers"],
            vocab_len = vocab_len,
            output_dim = num_classes, 
            device = device, 
        )
    elif model_parameters["type"] == "transformer":
        model = Transformer(
            dropout_pos = model_parameters["dropout_pos"],
            dropout_transformer = model_parameters["dropout_transformer"],
            dropout_class = model_parameters["dropout_class"],
            feed_forward_dim = model_parameters["feed_forward_dim"], 
            hidden_dim = model_parameters["hidden_dim"], 
            embed_size = model_parameters["embed_size"],
            num_layers = model_parameters["num_layers"],
            num_heads = model_parameters["num_heads"],
            vocab_len = vocab_len,
            output_dim = num_classes,
        )
    model = model.to(device)
        
    '''
    print("Model Design (TF Format)")
    summary(model, input_size=(1, max_seq_len))
    print("\n\n")
    '''
    
    # Loss function and optimizer
    loss_function = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
    
    best_model_val_acc = 0.0

    for epoch in range(1, epochs+1):
        # --- TRAIN AND EVALUATE ON TRAINING SET 
        start = time()
        model.train()

        train_loss = 0.0
        num_train_correct = 0
        num_train_examples = 0

        for X, y in train_dataloader:
            X = X.to(device)
            y = y.to(device)
            
            y_pred = model(X)

            optimizer.zero_grad()
            
            loss = loss_function(y_pred, y)
        
            loss.backward()
            optimizer.step()

            train_loss += loss.data.item() 
            num_train_correct += (y_pred.argmax(axis=1) == y.argmax(axis=1)).sum().item()
            num_train_examples += X.shape[0]
            
        train_acc = num_train_correct / num_train_examples
        train_loss = train_loss / len(train_dataloader.dataset)

        # --- EVALUATE ON VALIDATION SET -------------------------------------
        model.eval()
        val_loss = 0.0
        num_val_correct = 0
        num_val_examples = 0

        with torch.no_grad():
            for X, y in validation_dataloader:
                X = X.to(device)
                y = y.to(device)
                y_pred = model(X)
                loss = loss_function(y_pred, y)

                val_loss += loss.data.item()
                num_val_correct += (y_pred.argmax(axis=1) == y.argmax(axis=1)).sum().item()
                num_val_examples += X.shape[0]

            val_acc = num_val_correct / num_val_examples
            val_loss = val_loss / len(validation_dataloader.dataset)

        if epoch == 1 or epoch % 5 == 0:
          print('Epoch %3d/%3d, train loss: %3.2f, train acc: %3.2f, val loss: %3.2f, val acc: %3.2f, duration: %3.1fs'% \
                (epoch, epochs, train_loss, train_acc, val_loss, val_acc, time() - start))

        writer.add_scalar("Loss/Train", train_loss, epoch)
        writer.add_scalar("Accuracy/Train", train_acc, epoch)
        writer.add_scalar("Loss/Eval", val_loss, epoch)
        writer.add_scalar("Accuracy/Eval", val_acc, epoch)
        
        if val_acc > best_model_val_acc:
            best_model = deepcopy(model)
    
    writer.close()
    
    torch.save(best_model, "model/" + label + ".pt")

In [56]:
def evaluate(
    ''' Evaluates the models against the test set returning the loss and accuracy '''
    label: str,
    oversampling: str = "none",
    num_classes: int = 100,
    class_gap: int = 1,
    max_seq_len: int = 128,
    vocab_len = 24,
    batch_size: int = 64,
    lr: float = 0.001,
):
    
    # Device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    amino_acids = ["A","C","D","E","F","G","H","I","K","L",
                   "M","N","P","Q","R","S","T","V","W","Y","X"]
    mapping = {aa:i + 1 for i, aa in enumerate(amino_acids)}
    mapping.update({'X': 21, 'U': 21, 'B': 21, 'O': 21, 'Z': 21})
    
    # Creating the datasets
    test_dataset = PfamDataset("test", num_classes, class_gap, 
                               mapping, max_seq_len, oversampling)
    
    # Weighted Random Sampler
    sampler = WeightedRandomSampler(test_dataset.weights, test_dataset.len)
    
    # Creating the dataloaders
    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=batch_size, 
                                                  num_workers=0, sampler=sampler)
    
    # Model
    model = torch.load("model/" + label + ".pt")
    model = model.to(device)
    
    # Loss function
    loss_function = nn.CrossEntropyLoss()

    model.eval()
    val_loss = 0.0
    num_val_correct = 0
    num_val_examples = 0

    with torch.no_grad():
        for X, y in test_dataloader:
            X = X.to(device)
            y = y.to(device)
            y_pred = model(X)
            loss = loss_function(y_pred, y)

            val_loss += loss.data.item()
            num_val_correct += (y_pred.argmax(axis=1) == y.argmax(axis=1)).sum().item()
            num_val_examples += X.shape[0]

        val_acc = num_val_correct / num_val_examples
        val_loss = val_loss / len(test_dataloader.dataset)

    return val_loss, val_acc

# Architecture Experiment Training

In [None]:
# Training the Transformer on Iter1
train_evaluate(
    label = "transformer_100",
    num_classes = 100,
    model_parameters = {
        "type": "transformer",
        "vocab_len": 24, 
        "embed_size": 256,
        "hidden_dim": 128, 
        "feed_forward_dim": 256,
        "dropout_pos": 0.2, 
        "dropout_transformer": 0.2,
        "dropout_class": 0.2, 
        "num_heads": 4, 
        "num_layers": 4,
    },
    epochs=10
)
print("\n\n")

# Training the Bi-LSTM on Iter1
train_evaluate(
    label = "bi-lstm_100", 
    num_classes = 100,
    model_parameters = {
        "type": "bi-lstm",
        "hidden_dim": 128, 
        "dropout": 0.2,
        "num_layers": 4,
        "embed_size": 64,
        },
    epochs=10
)
print("\n\n")

# Training the LSTM on Iter1
train_evaluate(
    label = "lstm_100", 
    num_classes = 100,
    model_parameters = {
        "type": "lstm",
        "hidden_dim": 256, 
        "dropout": 0.2,
        "num_layers": 4,
        "embed_size": 64,
        },
    epochs=10
)
print("\n\n")


# Training the RNN on Iter1
train_evaluate(
    label = "rnn_embed_100", 
    num_classes = 100,
    model_parameters = {
        "type": "rnn_embed",
        "hidden_dim": 256, 
        "dropout": 0.2,
        "num_layers": 4,
        "embed_size": 64,
        },
    epochs=10
)
print("\n\n")


# Training the RNN (No Embed) on Iter1
train_evaluate(
    label = "rnn_100", 
    num_classes = 100,
    model_parameters = {
        "type": "rnn",
        "hidden_dim": 256, 
        "dropout": 0.2,
        "num_layers": 4,
        "embed_size": 64,
        },
    epochs=10
)
print("\n\n")


# Training the FC on Iter1
train_evaluate(
    label = "fc_100", 
    num_classes = 100,
    model_parameters = {
        "type": "fc",
        "dropout": 0.1,
        "hidden_dim": 128, 
        "layers": 1,
        },
    epochs=10
)
print("\n\n")

# Architecture Experiment Testing

In [157]:
def model_testing(models):
    ''' Returns the test loss and accuracy for the model dictionary input '''
    test_loss = dict()
    test_acc = dict()
    for name, label in models.items():
        loss, acc = evaluate(label)
        test_loss[name] = loss
        test_acc[name] = acc
    return test_loss, test_acc

In [158]:
models = {
    "fc": "fc_100",
    "rnn": "rnn_100",
    "rnn_embed": "rnn_embed_100",
    "lstm": "lstm_100",
    "bi-lstm": "bi-lstm_100",
    "transformer": "transformer_100",
}

In [159]:
test_loss, test_acc = model_testing(models)

In [171]:
for name, result in test_loss.items():
    print("Architecture: %15s has a loss of %14.5f" % \
      (name, result))

Architecture:              fc has a loss of        0.03796
Architecture:             rnn has a loss of        0.01714
Architecture:       rnn_embed has a loss of        0.01072
Architecture:            lstm has a loss of        0.00093
Architecture:         bi-lstm has a loss of        0.00071
Architecture:     transformer has a loss of        0.00103


In [170]:
for name, result in test_acc.items():
    result = result * 100
    print("Architecture: %15s has an accuracy of %7.2f" % \
      (name, result))

Architecture:              fc has an accuracy of   50.63
Architecture:             rnn has an accuracy of   75.37
Architecture:       rnn_embed has an accuracy of   83.31
Architecture:            lstm has an accuracy of   98.28
Architecture:         bi-lstm has an accuracy of   98.75
Architecture:     transformer has an accuracy of   98.24


# Oversampling Experiment Training

This experiment is explained in the report.

In [57]:
# Training and validating the various oversampling models
for oversampling in ["none", "regular", "sqrt", "beta"]:
    for class_gap in [1, 3, 10, 30, 100]:
        print(f"Training gap {class_gap} with oversampling {oversampling}.")
        train_evaluate(
            label = "transformer_oversampling-" + oversampling + "_" + str(class_gap),
            class_gap = class_gap,
            oversampling = oversampling,
            model_parameters = {
                "type": "transformer",
                "vocab_len": 24, 
                "embed_size": 64,
                "hidden_dim": 128, 
                "feed_forward_dim": 128,
                "dropout_pos": 0.2, 
                "dropout_transformer": 0.2,
                "dropout_class": 0.2, 
                "num_heads": 4, 
                "num_layers": 2,
            },
            epochs=20
        )

Training gap 1 with oversampling none.
Epoch   1/ 20, train loss: 0.05, train acc: 0.24, val loss: 0.04, val acc: 0.41, duration: 16.8s
Epoch   5/ 20, train loss: 0.01, train acc: 0.84, val loss: 0.01, val acc: 0.86, duration: 16.3s
Epoch  10/ 20, train loss: 0.00, train acc: 0.92, val loss: 0.00, val acc: 0.94, duration: 18.6s
Epoch  15/ 20, train loss: 0.00, train acc: 0.95, val loss: 0.00, val acc: 0.97, duration: 17.9s
Epoch  20/ 20, train loss: 0.00, train acc: 0.97, val loss: 0.00, val acc: 0.98, duration: 17.1s
Training gap 3 with oversampling none.
Epoch   1/ 20, train loss: 0.05, train acc: 0.21, val loss: 0.04, val acc: 0.31, duration: 12.8s
Epoch   5/ 20, train loss: 0.01, train acc: 0.81, val loss: 0.01, val acc: 0.82, duration: 13.2s
Epoch  10/ 20, train loss: 0.01, train acc: 0.91, val loss: 0.00, val acc: 0.92, duration: 13.1s
Epoch  15/ 20, train loss: 0.00, train acc: 0.94, val loss: 0.00, val acc: 0.96, duration: 13.1s
Epoch  20/ 20, train loss: 0.00, train acc: 0.96,

Epoch  20/ 20, train loss: 0.00, train acc: 0.97, val loss: 0.00, val acc: 0.98, duration: 17.3s
Training gap 3 with oversampling beta.
Epoch   1/ 20, train loss: 0.05, train acc: 0.20, val loss: 0.04, val acc: 0.36, duration: 13.3s
Epoch   5/ 20, train loss: 0.01, train acc: 0.81, val loss: 0.01, val acc: 0.83, duration: 13.4s
Epoch  10/ 20, train loss: 0.01, train acc: 0.91, val loss: 0.00, val acc: 0.92, duration: 14.1s
Epoch  15/ 20, train loss: 0.00, train acc: 0.94, val loss: 0.00, val acc: 0.95, duration: 13.4s
Epoch  20/ 20, train loss: 0.00, train acc: 0.96, val loss: 0.00, val acc: 0.97, duration: 13.3s
Training gap 10 with oversampling beta.
Epoch   1/ 20, train loss: 0.05, train acc: 0.15, val loss: 0.05, val acc: 0.11, duration: 8.5s
Epoch   5/ 20, train loss: 0.02, train acc: 0.72, val loss: 0.02, val acc: 0.58, duration: 8.7s
Epoch  10/ 20, train loss: 0.01, train acc: 0.84, val loss: 0.01, val acc: 0.83, duration: 8.6s
Epoch  15/ 20, train loss: 0.01, train acc: 0.89, v

# Oversampling Experiment Testing

This experiment is explained in the report.

In [65]:
models = {
    "none": "oversampling-none",
    "regular": "oversampling-regular",
    "sqrt": "oversampling-sqrt",
    "beta": "oversampling-effective",
}

In [77]:
def oversampling_testing(models):
    ''' returns the test loss and accuracy for the model dictionary input '''
    test_loss = dict()
    test_acc = dict()
    for name, title in models.items():
        class_gap_loss = dict()
        class_gap_acc = dict()
        for class_gap in [1, 3, 10, 30, 100]:
            label = "transformer_" + title + "_" + str(class_gap)
            loss, acc = evaluate(label, oversampling="regular", class_gap=class_gap)
            class_gap_loss[class_gap] = loss
            class_gap_acc[class_gap] = acc
        test_loss[name] = class_gap_loss
        test_acc[name] = class_gap_acc
    return test_loss, test_acc

In [78]:
test_loss, test_acc = oversampling_testing(models)

In [150]:
for oversampling, tests in test_loss.items():
    for class_gap, result in tests.items():
        print("Oversampling method: %8s with class gap %5s has a loss of %9.5f" % \
             (oversampling, class_gap, result))
    print("")

Oversampling method:     none with class gap     1 has a loss of   0.00099
Oversampling method:     none with class gap     3 has a loss of   0.00172
Oversampling method:     none with class gap    10 has a loss of   0.00425
Oversampling method:     none with class gap    30 has a loss of   0.01478
Oversampling method:     none with class gap   100 has a loss of   0.02576

Oversampling method:  regular with class gap     1 has a loss of   0.00129
Oversampling method:  regular with class gap     3 has a loss of   0.00186
Oversampling method:  regular with class gap    10 has a loss of   0.00495
Oversampling method:  regular with class gap    30 has a loss of   0.01225
Oversampling method:  regular with class gap   100 has a loss of   0.01958

Oversampling method:     sqrt with class gap     1 has a loss of   0.00143
Oversampling method:     sqrt with class gap     3 has a loss of   0.00135
Oversampling method:     sqrt with class gap    10 has a loss of   0.00437
Oversampling method:   

In [156]:
for oversampling, tests in test_acc.items():
    for class_gap, result in tests.items():
        result = result * 100
        print("Oversampling method: %8s with class gap %5s has an accuracy of %8.2f" % \
             (oversampling, class_gap, result))
    print("")

Oversampling method:     none with class gap     1 has an accuracy of    98.19
Oversampling method:     none with class gap     3 has an accuracy of    96.64
Oversampling method:     none with class gap    10 has an accuracy of    91.91
Oversampling method:     none with class gap    30 has an accuracy of    71.95
Oversampling method:     none with class gap   100 has an accuracy of    54.50

Oversampling method:  regular with class gap     1 has an accuracy of    97.50
Oversampling method:  regular with class gap     3 has an accuracy of    96.42
Oversampling method:  regular with class gap    10 has an accuracy of    91.51
Oversampling method:  regular with class gap    30 has an accuracy of    78.82
Oversampling method:  regular with class gap   100 has an accuracy of    64.00

Oversampling method:     sqrt with class gap     1 has an accuracy of    97.29
Oversampling method:     sqrt with class gap     3 has an accuracy of    97.56
Oversampling method:     sqrt with class gap    10

## Testing Best Model

We have found that when it comes to dealing with an imbalanced distribution (we will be testing on the Iter100 dataset), the transformer model with the square root oversampling method performed best. Let's plot a confusion matrix to get a better look at how it performs.

In [134]:
def evaluation_f1(
    ''' Evaluates the models against the test set returning the f1 score '''
    label: str,
    oversampling: str = "none",
    num_classes: int = 100,
    class_gap: int = 100,
    max_seq_len: int = 128,
    vocab_len = 24,
    batch_size: int = 64,
    lr: float = 0.001,
):
    
    # Device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    amino_acids = ["A","C","D","E","F","G","H","I","K","L",
                   "M","N","P","Q","R","S","T","V","W","Y","X"]
    mapping = {aa:i + 1 for i, aa in enumerate(amino_acids)}
    mapping.update({'X': 21, 'U': 21, 'B': 21, 'O': 21, 'Z': 21})
    
    # Creating the datasets
    test_dataset = PfamDataset("test", num_classes, class_gap, 
                               mapping, max_seq_len, oversampling)
    
    # Weighted Random Sampler
    sampler = WeightedRandomSampler(test_dataset.weights, test_dataset.len)
    
    # Creating the dataloaders
    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=batch_size, 
                                                  num_workers=0, sampler=sampler)
    
    # Model
    model = torch.load("model/" + label + ".pt")
    model = model.to(device)
    
    # Loss function and optimizer
    loss_function = nn.CrossEntropyLoss()

    model.eval()
    val_loss = 0.0
    num_val_correct = 0
    num_val_examples = 0
    
    all_y = []
    all_preds = []

    with torch.no_grad():
        for X, y in test_dataloader:
            X = X.to(device)
            y = y.to(device)
            y_pred = model(X)
            loss = loss_function(y_pred, y)
            
            all_y.append(y)
            all_preds.append(y_pred)

            val_loss += loss.data.item()
            num_val_correct += (y_pred.argmax(axis=1) == y.argmax(axis=1)).sum().item()
            num_val_examples += X.shape[0]

        val_acc = num_val_correct / num_val_examples
        val_loss = val_loss / len(test_dataloader.dataset)

    # Calculate f1 score
    all_preds = torch.cat(all_preds, dim=0).cpu().argmax(dim=1)
    all_y = torch.cat(all_y, dim=0).cpu().argmax(dim=1)
    
    return f1_score(all_y, all_preds, average='macro')

In [136]:
f1 = evaluation_f1("transformer_oversampling-sqrt_100")

0.6719625422412902
