In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [8]:
!pip install fair-esm



In [9]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from Bio import SeqIO
import esm
import numpy as np
from sklearn.metrics import f1_score
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import Counter
import pandas as pd
import random

# import os.path
# os.chdir("/kaggle/working")

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [5]:
class ProteinSequenceDataset(Dataset):
    def __init__(self, fasta_file, arg_dict, alphabet, max_samples_per_class=1000):
        self.alphabet = alphabet
        self.sequences, self.labels, self.ids= self.load_sequences(fasta_file, arg_dict, max_samples_per_class)

    def encode_sequence(self, sequence, alphabet):
        tokens = alphabet.encode(sequence)
        return torch.tensor(tokens, dtype=torch.long)
    
    def load_sequences(self, fasta_file, arg_dict, max_samples_per_class):
        temp_sequences = {label: [] for label in arg_dict.values()}  # Initialize with integers
        sequences, labels, ids = [], [], []
        
        for record in SeqIO.parse(fasta_file, "fasta"):
            if 'FEATURES' in record.description:
                label = record.description.split('|')[3]
                label_idx = arg_dict.get(label, arg_dict['nonarg'])
            else:
                label_idx = arg_dict['nonarg']

            seq_encoded = self.encode_sequence(str(record.seq), self.alphabet)
            temp_sequences[label_idx].append((seq_encoded, label_idx, record.id))
        
        for label_idx, items in temp_sequences.items():
            if max_samples_per_class is not None and len(items) > max_samples_per_class:
                items = random.sample(items, max_samples_per_class)
            for seq_encoded, label_idx, record_id in items:
                sequences.append(seq_encoded)
                labels.append(label_idx)
                ids.append(record_id)

        return sequences, labels, ids
    
    def one_hot_encode(self, sequence):
        aa_to_int = {aa: i for i, aa in enumerate('ARNDCEQGHILKMFPSTWYV')}
        one_hot = torch.zeros(len(sequence), 20)  # 20 amino acids
        for i, aa in enumerate(sequence):
            if aa in aa_to_int:
                one_hot[i, aa_to_int[aa]] = 1.0
        return one_hot
    
    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return self.sequences[idx], self.labels[idx], self.ids[idx]

arg_dict = {
    'aminoglycoside': 0, 'macrolide-lincosamide-streptogramin': 1, 'polymyxin': 2,
    'fosfomycin': 3, 'trimethoprim': 4, 'bacitracin': 5, 'quinolone': 6, 'multidrug': 7,
    'chloramphenicol': 8, 'tetracycline': 9, 'rifampin': 10, 'beta_lactam': 11,
    'sulfonamide': 12, 'glycopeptide': 13, 'nonarg': 14
}

from torch.nn.utils.rnn import pad_sequence
# dealing with sequences of different lengths
def collate_batch(batch):
    sequences, labels, ids = zip(*batch)
    # pad sequences to have the same length
    sequences_padded = pad_sequence(sequences, batch_first=True, padding_value=0)
    labels = torch.LongTensor(labels)
    return sequences_padded, labels, ids


In [None]:
class ESMClassifier(nn.Module):
    def __init__(self, esm_model, num_classes):
        super().__init__()
        self.esm_model = esm_model
        self.classifier = nn.Linear(320, num_classes)

    def forward(self, x):
        results = self.esm_model(x, repr_layers=[6])
        representations = results["representations"][6]
        pooled_representations = representations.mean(dim=1)
        return self.classifier(pooled_representations)

In [None]:
def train_model(classifier, train_loader, val_loader, criterion, optimizer, num_epochs):
    best_val_f1 = -1
    for epoch in tqdm(range(num_epochs), desc="Epochs"):
        classifier.train()
        total_loss = 0
        for sequences, labels, _ in train_loader:

            sequences, labels = sequences.to(device), labels.to(device)

            predictions = classifier(sequences)
            loss = criterion(predictions, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        val_f1 = evaluate_model(classifier, val_loader)
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            torch.save(classifier.state_dict(), 'model.pth')
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader)}, Val F1: {val_f1}')


def evaluate_model(classifier, data_loader):
    classifier.eval()
    all_predictions = []
    all_targets = []
    with torch.no_grad():
        for sequences, labels, _ in data_loader:
            sequences, labels = sequences.to(device), labels.to(device)
            outputs = classifier(sequences)
            _, predicted = torch.max(outputs.data, 1)
            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(labels.cpu().numpy())

    f1 = f1_score(all_targets, all_predictions, average='macro')
#     print('F1 Score:', f1)
    return f1

def generate_predictions_and_save(classifier, test_loader, file_path='submission.csv'):
    classifier.eval()
    sequence_ids = []
    predictions = []
    with torch.no_grad():
        for sequences, labels, ids in test_loader:
            print(len(sequences))
            sequences = sequences.to(device)
            outputs = classifier(sequences)
            _, predicted = torch.max(outputs.data, 1)
            predictions.extend(predicted.cpu().numpy())
            sequence_ids.extend(ids)

    df = pd.DataFrame({'id': sequence_ids, 'label': predictions})
    df.to_csv(file_path, index=False)

In [None]:
# model_path = "/kaggle/input/model-pth/model.pth"
model_path = "/kaggle/input/model/model.pth"
esm_path = "/kaggle/input/esm2-pretrained/esm2_t6_8M_UR50D.pt"
data_path = "/kaggle/input/aist4010-spring2024-a2/data"
num_epochs = 15
lr=0.0000005
batch_size = 4

model, alphabet = esm.pretrained.load_model_and_alphabet(esm_path)
model.to(device)
model.eval()

train_dataset = ProteinSequenceDataset(data_path+'/train.fasta', arg_dict, alphabet)
val_dataset = ProteinSequenceDataset(data_path+'/val.fasta', arg_dict, alphabet, max_samples_per_class = None)
test_dataset = ProteinSequenceDataset(data_path+'/test.fasta', arg_dict, alphabet, max_samples_per_class = None)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch)

classifier = ESMClassifier(model, len(arg_dict))
classifier = classifier.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=lr)
classifier.load_state_dict(torch.load(model_path))



In [None]:

# train_model(classifier, train_loader, val_loader, criterion, optimizer, num_epochs)

In [None]:
# classifier.load_state_dict(torch.load("/kaggle/working/model.pth"))
# generate_predictions_and_save(classifier, test_loader)

# LORA (Microsoft)

In [11]:
!pip install loralib



In [20]:
import loralib as lora

class ESMClassifier(nn.Module):
    def __init__(self, esm_model, num_classes, lora_config=None):
        super().__init__()
        self.esm_model = esm_model
        if lora_config:
            self.classifier = lora.Linear(320, # esm2_t6_8M_UR50D Embedding Dim: 320 
                                          num_classes, 
                                          r=lora_config['r'], 
                                          lora_alpha=lora_config['lora_alpha'], 
                                          lora_dropout=lora_config['lora_dropout'])
        else:
            self.classifier = nn.Linear(esm_model.classifier.in_features, num_classes)


    def forward(self, x):
        results = self.esm_model(x, repr_layers=[6]) #esm2_t6_8M_UR50D layers = 6
        representations = results["representations"][6]
        pooled_representations = representations.mean(dim=1)
        return self.classifier(pooled_representations)
    

# model_path = "/kaggle/input/model-pth/model.pth"
model_path = "/kaggle/input/model/model.pth"
model_path = "/kaggle/working/model_base_lora.pth"
esm_path = "/kaggle/input/esm2-pretrained/esm2_t6_8M_UR50D.pt"
data_path = "/kaggle/input/aist4010-spring2024-a2/data"

batch_size = 10

model, alphabet = esm.pretrained.load_model_and_alphabet(esm_path)
model.to(device)
model.eval()

train_dataset = ProteinSequenceDataset(data_path+'/train.fasta', arg_dict, alphabet)
val_dataset = ProteinSequenceDataset(data_path+'/val.fasta', arg_dict, alphabet, max_samples_per_class = None)
test_dataset = ProteinSequenceDataset(data_path+'/test.fasta', arg_dict, alphabet, max_samples_per_class = None)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch)


In [24]:
model_path = "/kaggle/input/model/model.pth"
model_path = "/kaggle/working/model_base_lora.pth"

lora_config = {
    'r': 8,
    'lora_alpha': 24,
    'lora_dropout': 0.5,
}
num_epochs = 100
lr=0.0001



classifier = ESMClassifier(model, len(arg_dict), lora_config=lora_config)
classifier = classifier.to(device)
classifier.load_state_dict(torch.load(model_path), strict=False)

lora.mark_only_lora_as_trainable(classifier)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=lr)



def train_model(classifier, train_loader, val_loader, criterion, optimizer, num_epochs):
    best_val_f1 = -1
    for epoch in tqdm(range(num_epochs), desc="Epochs"):
        classifier.train()
        total_loss = 0
        for sequences, labels, _ in train_loader:

            sequences, labels = sequences.to(device), labels.to(device)

            predictions = classifier(sequences)
            loss = criterion(predictions, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        val_f1 = evaluate_model(classifier, val_loader)
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            torch.save(classifier.state_dict(), 'model_base_lora.pth')
            torch.save(lora.lora_state_dict(classifier), 'lora_model.pth')

        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader)}, Val F1: {val_f1}')


def evaluate_model(classifier, data_loader):
    classifier.eval()
    all_predictions = []
    all_targets = []
    with torch.no_grad():
        for sequences, labels, _ in data_loader:
            sequences, labels = sequences.to(device), labels.to(device)
            outputs = classifier(sequences)
            _, predicted = torch.max(outputs.data, 1)
            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(labels.cpu().numpy())

    f1 = f1_score(all_targets, all_predictions, average='macro')
#     print('F1 Score:', f1)
    return f1

def generate_predictions_and_save(classifier, test_loader, file_path='submission.csv'):
    classifier.eval()
    sequence_ids = []
    predictions = []
    with torch.no_grad():
        for sequences, labels, ids in test_loader:
            outputs = classifier(sequences)            sequences = sequences.to(device)

            _, predicted = torch.max(outputs.data, 1)
            predictions.extend(predicted.cpu().numpy())
            sequence_ids.extend(ids)

    df = pd.DataFrame({'id': sequence_ids, 'label': predictions})
    df.to_csv(file_path, index=False)

train_model(classifier, train_loader, val_loader, criterion, optimizer, num_epochs)


Epochs:   1%|          | 1/100 [01:42<2:49:29, 102.73s/it]

Epoch 1/100, Loss: 0.012277000624962658, Val F1: 0.9221905176510666


Epochs:   2%|▏         | 2/100 [03:24<2:47:03, 102.28s/it]

Epoch 2/100, Loss: 0.012562273693557422, Val F1: 0.9221304016881108


Epochs:   3%|▎         | 3/100 [05:06<2:45:12, 102.19s/it]

Epoch 3/100, Loss: 0.01289036996037418, Val F1: 0.9221905176510666


Epochs:   4%|▍         | 4/100 [06:50<2:44:18, 102.70s/it]

Epoch 4/100, Loss: 0.012414641327045122, Val F1: 0.9219109907984725


Epochs:   5%|▌         | 5/100 [08:32<2:42:35, 102.69s/it]

Epoch 5/100, Loss: 0.013434155087072677, Val F1: 0.9222272579327081


Epochs:   6%|▌         | 6/100 [10:15<2:40:37, 102.52s/it]

Epoch 6/100, Loss: 0.013200434092977598, Val F1: 0.9215113958036952


Epochs:   7%|▋         | 7/100 [11:57<2:38:47, 102.44s/it]

Epoch 7/100, Loss: 0.01207172692654672, Val F1: 0.9215113958036952


Epochs:   8%|▊         | 8/100 [13:40<2:37:14, 102.55s/it]

Epoch 8/100, Loss: 0.012812371591265286, Val F1: 0.9215113958036952


Epochs:   9%|▉         | 9/100 [15:22<2:35:13, 102.34s/it]

Epoch 9/100, Loss: 0.013284417463232792, Val F1: 0.9215113958036952


Epochs:  10%|█         | 10/100 [17:05<2:33:47, 102.52s/it]

Epoch 10/100, Loss: 0.012097589907341524, Val F1: 0.9215113958036952


Epochs:  11%|█         | 11/100 [18:47<2:31:53, 102.40s/it]

Epoch 11/100, Loss: 0.011896885841086283, Val F1: 0.9215113958036952


Epochs:  12%|█▏        | 12/100 [20:29<2:30:13, 102.43s/it]

Epoch 12/100, Loss: 0.011879707687608774, Val F1: 0.9219338078801218


Epochs:  13%|█▎        | 13/100 [22:12<2:28:38, 102.51s/it]

Epoch 13/100, Loss: 0.011263226108590721, Val F1: 0.9216776570449633


Epochs:  14%|█▍        | 14/100 [23:54<2:26:41, 102.34s/it]

Epoch 14/100, Loss: 0.011651009210867586, Val F1: 0.9216941092837148


Epochs:  15%|█▌        | 15/100 [25:35<2:24:39, 102.11s/it]

Epoch 15/100, Loss: 0.01179824252917268, Val F1: 0.9214379642672147


Epochs:  16%|█▌        | 16/100 [27:18<2:23:13, 102.30s/it]

Epoch 16/100, Loss: 0.0124485856200416, Val F1: 0.9214379642672147


Epochs:  17%|█▋        | 17/100 [29:00<2:21:10, 102.06s/it]

Epoch 17/100, Loss: 0.011252952741382004, Val F1: 0.921657369002073


Epochs:  18%|█▊        | 18/100 [30:42<2:19:34, 102.13s/it]

Epoch 18/100, Loss: 0.012499950084877285, Val F1: 0.9208645164866308


Epochs:  19%|█▉        | 19/100 [32:25<2:18:08, 102.33s/it]

Epoch 19/100, Loss: 0.011343638738625511, Val F1: 0.9208645164866308


Epochs:  20%|██        | 20/100 [34:07<2:16:36, 102.46s/it]

Epoch 20/100, Loss: 0.012664104955688479, Val F1: 0.9208645164866308


Epochs:  21%|██        | 21/100 [35:50<2:14:57, 102.49s/it]

Epoch 21/100, Loss: 0.011194266792172412, Val F1: 0.9208645164866308


Epochs:  22%|██▏       | 22/100 [37:33<2:13:19, 102.56s/it]

Epoch 22/100, Loss: 0.011402367388734498, Val F1: 0.9208645164866308


Epochs:  23%|██▎       | 23/100 [39:15<2:11:40, 102.60s/it]

Epoch 23/100, Loss: 0.010914732113678261, Val F1: 0.9208645164866308


Epochs:  24%|██▍       | 24/100 [40:57<2:09:38, 102.34s/it]

Epoch 24/100, Loss: 0.010730891258286423, Val F1: 0.9208645164866308


Epochs:  25%|██▌       | 25/100 [42:40<2:07:56, 102.35s/it]

Epoch 25/100, Loss: 0.011022173255042324, Val F1: 0.9214379642672147


Epochs:  26%|██▌       | 26/100 [44:22<2:06:08, 102.28s/it]

Epoch 26/100, Loss: 0.010910104393340394, Val F1: 0.9214379642672147


Epochs:  27%|██▋       | 27/100 [46:04<2:04:28, 102.30s/it]

Epoch 27/100, Loss: 0.01066626511181258, Val F1: 0.9214379642672147


Epochs:  28%|██▊       | 28/100 [47:46<2:02:39, 102.22s/it]

Epoch 28/100, Loss: 0.011257845715759848, Val F1: 0.9213759390033595


Epochs:  29%|██▉       | 29/100 [49:29<2:01:21, 102.55s/it]

Epoch 29/100, Loss: 0.010336302508403861, Val F1: 0.9216216516002709


Epochs:  30%|███       | 30/100 [51:12<1:59:41, 102.59s/it]

Epoch 30/100, Loss: 0.010975906499656579, Val F1: 0.9213778421494788


Epochs:  31%|███       | 31/100 [52:54<1:57:52, 102.49s/it]

Epoch 31/100, Loss: 0.01073151461438173, Val F1: 0.9213778421494788


Epochs:  32%|███▏      | 32/100 [54:37<1:56:22, 102.68s/it]

Epoch 32/100, Loss: 0.011143673089633174, Val F1: 0.9213158189660291


Epochs:  33%|███▎      | 33/100 [56:20<1:54:43, 102.74s/it]

Epoch 33/100, Loss: 0.010526082381251623, Val F1: 0.9213158189660291


Epochs:  34%|███▍      | 34/100 [58:02<1:52:44, 102.50s/it]

Epoch 34/100, Loss: 0.010792300185835975, Val F1: 0.9213759390033595


Epochs:  35%|███▌      | 35/100 [59:44<1:50:48, 102.29s/it]

Epoch 35/100, Loss: 0.010503509177883239, Val F1: 0.9213759390033595


Epochs:  36%|███▌      | 36/100 [1:01:26<1:49:01, 102.21s/it]

Epoch 36/100, Loss: 0.010109354796800219, Val F1: 0.9213759390033595


Epochs:  37%|███▋      | 37/100 [1:03:09<1:47:31, 102.41s/it]

Epoch 37/100, Loss: 0.010394542480649537, Val F1: 0.9213759390033595


Epochs:  38%|███▊      | 38/100 [1:04:51<1:45:38, 102.23s/it]

Epoch 38/100, Loss: 0.010607455388614687, Val F1: 0.9215596226007002


Epochs:  39%|███▉      | 39/100 [1:06:32<1:43:45, 102.05s/it]

Epoch 39/100, Loss: 0.010193992775802874, Val F1: 0.9215596226007002


Epochs:  40%|████      | 40/100 [1:08:14<1:41:54, 101.90s/it]

Epoch 40/100, Loss: 0.010242078234931295, Val F1: 0.9217557050772704


Epochs:  41%|████      | 41/100 [1:09:56<1:40:14, 101.95s/it]

Epoch 41/100, Loss: 0.009993003764336756, Val F1: 0.9214395560743632


Epochs:  42%|████▏     | 42/100 [1:11:38<1:38:41, 102.09s/it]

Epoch 42/100, Loss: 0.010483753648309183, Val F1: 0.9213976170998804


Epochs:  43%|████▎     | 43/100 [1:13:21<1:37:11, 102.31s/it]

Epoch 43/100, Loss: 0.010395845744430513, Val F1: 0.9213158189660291


Epochs:  44%|████▍     | 44/100 [1:15:04<1:35:29, 102.31s/it]

Epoch 44/100, Loss: 0.008700421391835763, Val F1: 0.9211181731632493


Epochs:  45%|████▌     | 45/100 [1:16:46<1:33:47, 102.32s/it]

Epoch 45/100, Loss: 0.009813660850813031, Val F1: 0.9213759390033595


Epochs:  46%|████▌     | 46/100 [1:18:28<1:31:59, 102.22s/it]

Epoch 46/100, Loss: 0.01065016960820802, Val F1: 0.9213158189660291


Epochs:  47%|████▋     | 47/100 [1:20:11<1:30:33, 102.51s/it]

Epoch 47/100, Loss: 0.01047438340718681, Val F1: 0.9208533958029937


Epochs:  48%|████▊     | 48/100 [1:21:53<1:28:38, 102.27s/it]

Epoch 48/100, Loss: 0.009656888948490679, Val F1: 0.9210586304243146


Epochs:  49%|████▉     | 49/100 [1:23:35<1:26:59, 102.34s/it]

Epoch 49/100, Loss: 0.010688907426829016, Val F1: 0.9205447257261787


Epochs:  50%|█████     | 50/100 [1:25:17<1:25:03, 102.08s/it]

Epoch 50/100, Loss: 0.009760272996180363, Val F1: 0.9209985040571131


Epochs:  51%|█████     | 51/100 [1:27:00<1:23:40, 102.46s/it]

Epoch 51/100, Loss: 0.009010712064935379, Val F1: 0.9207038840805549


Epochs:  52%|█████▏    | 52/100 [1:28:43<1:22:01, 102.54s/it]

Epoch 52/100, Loss: 0.010337941995127753, Val F1: 0.9206364691200898


Epochs:  53%|█████▎    | 53/100 [1:30:25<1:20:11, 102.37s/it]

Epoch 53/100, Loss: 0.010209814292160732, Val F1: 0.9212099352131929


Epochs:  54%|█████▍    | 54/100 [1:32:07<1:18:24, 102.28s/it]

Epoch 54/100, Loss: 0.010102340115868424, Val F1: 0.9208927335892819


Epochs:  55%|█████▌    | 55/100 [1:33:49<1:16:38, 102.18s/it]

Epoch 55/100, Loss: 0.009510556539579352, Val F1: 0.9212865327111349


Epochs:  56%|█████▌    | 56/100 [1:35:31<1:14:55, 102.16s/it]

Epoch 56/100, Loss: 0.010645029074990953, Val F1: 0.9207130479195695


Epochs:  57%|█████▋    | 57/100 [1:37:13<1:13:13, 102.17s/it]

Epoch 57/100, Loss: 0.009025746468111306, Val F1: 0.9215303412911395


Epochs:  58%|█████▊    | 58/100 [1:38:55<1:11:33, 102.22s/it]

Epoch 58/100, Loss: 0.010230999599861628, Val F1: 0.9215262041855203


Epochs:  59%|█████▉    | 59/100 [1:40:38<1:09:58, 102.41s/it]

Epoch 59/100, Loss: 0.009587266726557624, Val F1: 0.9215262041855203


Epochs:  60%|██████    | 60/100 [1:42:20<1:08:12, 102.31s/it]

Epoch 60/100, Loss: 0.01033689231404521, Val F1: 0.9216010119991281


Epochs:  61%|██████    | 61/100 [1:44:03<1:06:33, 102.40s/it]

Epoch 61/100, Loss: 0.010112100492111437, Val F1: 0.9213463642588614


Epochs:  62%|██████▏   | 62/100 [1:45:46<1:05:02, 102.70s/it]

Epoch 62/100, Loss: 0.009103906190870224, Val F1: 0.9215862663840375


Epochs:  63%|██████▎   | 63/100 [1:47:28<1:03:10, 102.43s/it]

Epoch 63/100, Loss: 0.00915074746590313, Val F1: 0.921488760456003


Epochs:  64%|██████▍   | 64/100 [1:49:10<1:01:22, 102.28s/it]

Epoch 64/100, Loss: 0.010656748462413853, Val F1: 0.9213463642588614


Epochs:  65%|██████▌   | 65/100 [1:50:53<59:44, 102.42s/it]  

Epoch 65/100, Loss: 0.008919025158774543, Val F1: 0.9215069870805218


Epochs:  66%|██████▌   | 66/100 [1:52:35<58:00, 102.35s/it]

Epoch 66/100, Loss: 0.008821180070437846, Val F1: 0.9217297971483653


Epochs:  67%|██████▋   | 67/100 [1:54:17<56:13, 102.22s/it]

Epoch 67/100, Loss: 0.008965617937198371, Val F1: 0.9215303412911395


Epochs:  68%|██████▊   | 68/100 [1:55:59<54:28, 102.15s/it]

Epoch 68/100, Loss: 0.008509320152970835, Val F1: 0.9215832336585225


Epochs:  69%|██████▉   | 69/100 [1:57:42<52:53, 102.37s/it]

Epoch 69/100, Loss: 0.009708903609687859, Val F1: 0.9215901730157892


Epochs:  70%|███████   | 70/100 [1:59:24<51:08, 102.27s/it]

Epoch 70/100, Loss: 0.008816805760043164, Val F1: 0.9215832336585225


Epochs:  71%|███████   | 71/100 [2:01:07<49:31, 102.46s/it]

Epoch 71/100, Loss: 0.009902316231327422, Val F1: 0.922044322076916


Epochs:  72%|███████▏  | 72/100 [2:02:49<47:47, 102.42s/it]

Epoch 72/100, Loss: 0.009872863902620172, Val F1: 0.9215303412911395


Epochs:  73%|███████▎  | 73/100 [2:04:31<46:04, 102.38s/it]

Epoch 73/100, Loss: 0.00909840039093594, Val F1: 0.9217208479476617


Epochs:  74%|███████▍  | 74/100 [2:06:13<44:18, 102.27s/it]

Epoch 74/100, Loss: 0.008881969958584847, Val F1: 0.9224190641530805


Epochs:  75%|███████▌  | 75/100 [2:07:56<42:39, 102.39s/it]

Epoch 75/100, Loss: 0.00872860574760787, Val F1: 0.9228484677113008


Epochs:  76%|███████▌  | 76/100 [2:09:38<40:53, 102.21s/it]

Epoch 76/100, Loss: 0.00927003448247204, Val F1: 0.9227285421276685


Epochs:  77%|███████▋  | 77/100 [2:11:21<39:14, 102.38s/it]

Epoch 77/100, Loss: 0.009005781797513636, Val F1: 0.922314373699883


Epochs:  78%|███████▊  | 78/100 [2:13:03<37:34, 102.48s/it]

Epoch 78/100, Loss: 0.009470969152296775, Val F1: 0.922451308283155


Epochs:  79%|███████▉  | 79/100 [2:14:45<35:49, 102.36s/it]

Epoch 79/100, Loss: 0.009939633148990155, Val F1: 0.922228248177037


Epochs:  80%|████████  | 80/100 [2:16:27<34:03, 102.17s/it]

Epoch 80/100, Loss: 0.009079301700255773, Val F1: 0.9229441168130516


Epochs:  81%|████████  | 81/100 [2:18:09<32:20, 102.12s/it]

Epoch 81/100, Loss: 0.00832529472549261, Val F1: 0.9218448265805209


Epochs:  82%|████████▏ | 82/100 [2:19:52<30:39, 102.20s/it]

Epoch 82/100, Loss: 0.009717161127998529, Val F1: 0.9218448265805209


Epochs:  83%|████████▎ | 83/100 [2:21:34<28:57, 102.21s/it]

Epoch 83/100, Loss: 0.009185779332435652, Val F1: 0.921961963784206


Epochs:  84%|████████▍ | 84/100 [2:23:16<27:15, 102.24s/it]

Epoch 84/100, Loss: 0.0076057093554710315, Val F1: 0.9226894747168463


Epochs:  85%|████████▌ | 85/100 [2:24:58<25:34, 102.28s/it]

Epoch 85/100, Loss: 0.008516726821867944, Val F1: 0.9220217955088558


Epochs:  86%|████████▌ | 86/100 [2:26:42<23:55, 102.53s/it]

Epoch 86/100, Loss: 0.010977349707810636, Val F1: 0.9217159044488796


Epochs:  87%|████████▋ | 87/100 [2:28:24<22:13, 102.60s/it]

Epoch 87/100, Loss: 0.009259428311785736, Val F1: 0.922228248177037


Epochs:  88%|████████▊ | 88/100 [2:30:07<20:30, 102.54s/it]

Epoch 88/100, Loss: 0.009072073400002333, Val F1: 0.921922356764824


Epochs:  89%|████████▉ | 89/100 [2:31:49<18:46, 102.38s/it]

Epoch 89/100, Loss: 0.008699435903325418, Val F1: 0.9222127489760574


Epochs:  90%|█████████ | 90/100 [2:33:31<17:02, 102.25s/it]

Epoch 90/100, Loss: 0.008951067612206738, Val F1: 0.9226755820173608


Epochs:  91%|█████████ | 91/100 [2:35:14<15:23, 102.60s/it]

Epoch 91/100, Loss: 0.008549899638170792, Val F1: 0.9224009144300017


Epochs:  92%|█████████▏| 92/100 [2:36:57<13:40, 102.62s/it]

Epoch 92/100, Loss: 0.008108125143857243, Val F1: 0.9210460969804968


Epochs:  93%|█████████▎| 93/100 [2:38:39<11:56, 102.37s/it]

Epoch 93/100, Loss: 0.00903947359158558, Val F1: 0.9214963584056058


Epochs:  94%|█████████▍| 94/100 [2:40:21<10:13, 102.29s/it]

Epoch 94/100, Loss: 0.008478455372858736, Val F1: 0.9220356206830033


Epochs:  95%|█████████▌| 95/100 [2:42:03<08:31, 102.36s/it]

Epoch 95/100, Loss: 0.00881819783419998, Val F1: 0.9226073739700162


Epochs:  96%|█████████▌| 96/100 [2:43:46<06:49, 102.48s/it]

Epoch 96/100, Loss: 0.007937883095963062, Val F1: 0.9231350702802531


Epochs:  97%|█████████▋| 97/100 [2:45:29<05:07, 102.52s/it]

Epoch 97/100, Loss: 0.007779528056546763, Val F1: 0.9226052980070071


Epochs:  98%|█████████▊| 98/100 [2:47:11<03:25, 102.52s/it]

Epoch 98/100, Loss: 0.009874432329544248, Val F1: 0.9233211491619868


Epochs:  99%|█████████▉| 99/100 [2:48:54<01:42, 102.56s/it]

Epoch 99/100, Loss: 0.008118097492718248, Val F1: 0.9226273980529996


Epochs: 100%|██████████| 100/100 [2:50:35<00:00, 102.36s/it]

Epoch 100/100, Loss: 0.0079377332543038, Val F1: 0.9217151394802078





In [None]:

classifier.load_state_dict(torch.load('model_base_lora.pth'), strict=False)
classifier.load_state_dict(torch.load('lora_model.pth'), strict=False)

generate_predictions_and_save(classifier, test_loader)