In [60]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sklearn.preprocessing import OneHotEncoder

onehot_encoder = OneHotEncoder(sparse_output=False)
data = pd.read_csv('train_top5.csv')['class'].values.reshape(-1, 1)
onehot_encoder.fit(data)

class CustomDataset(Dataset):
    def __init__(self, csv_file, onehot_encoder, method='endpoint'):
        self.data = pd.read_csv(csv_file)
        self.labels = onehot_encoder.transform(self.data['class'].values.reshape(-1, 1))
        self.method = method

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        file_path = self.data.iloc[idx]['path']
        label = self.labels[idx]

        raw_embedding = torch.load(file_path)
        if self.method == 'endpoint':
            embedding = torch.cat((raw_embedding[:, 0, :], raw_embedding[:, -1, :]), dim=1)
        elif self.method == 'diff-sum':
            embedding = torch.cat(
                (
                    raw_embedding[:, 0, :] + raw_embedding[:, -1, :],
                    raw_embedding[:, 0, :] - raw_embedding[:, -1, :]
                ),
                dim=1
            )
        elif self.method == 'coherent':
            embedding = torch.cat(
                (
                    raw_embedding[:, 0, :360],
                    raw_embedding[:, -1, 360:720],
                    torch.dot(
                        raw_embedding[:, 0, :].squeeze()[720:],
                        raw_embedding[:, -1, :].squeeze()[720:]
                    ).unsqueeze(0).unsqueeze(0)
                ),
                dim=1
            )
        elif self.method == 'maxpool':
            embedding = torch.max(raw_embedding, dim=1)[0]
        elif self.method == 'avgpool':
            embedding = torch.mean(raw_embedding, dim=1)

        return embedding, torch.tensor(label)

In [61]:
import torch.nn.functional as F
import torch.nn as nn

class CustomDeepClassifier(nn.Module):
    def __init__(self, input_dim=1536, num_classes=20, hidden_dim=2048):
        super(CustomDeepClassifier, self).__init__()

        layers = []
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(nn.ReLU())
        layers.append(nn.BatchNorm1d(1))

        layers.append(nn.Linear(hidden_dim, hidden_dim))
        layers.append(nn.ReLU())
        layers.append(nn.BatchNorm1d(1))

        layers.append(nn.Linear(hidden_dim, num_classes))

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        x = self.layers(x)
        x = F.softmax(x, dim=0)
        return x

In [62]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import precision_recall_fscore_support, classification_report
import matplotlib.pyplot as plt
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train_validate_model(model, train_loader, val_loader, num_epochs=10, learning_rate=0.1, stepslr=10, gamma=0.9):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, stepslr, gamma=gamma)

    train_losses = []
    val_losses = []
    precision_scores = []
    recall_scores = []
    f1_scores = []

    for epoch in tqdm(range(num_epochs), total=num_epochs):
        model.train()
        running_train_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs.float())
            loss = criterion(outputs.squeeze(), labels)
            loss.backward()
            optimizer.step()
            running_train_loss += loss.item()
        
        train_loss = running_train_loss / len(train_loader)
        train_losses.append(train_loss)

        if (epoch + 1) % stepslr == 0:
            model.eval()
            running_val_loss = 0.0
            all_preds = []
            all_labels = []
            with torch.no_grad():
                for inputs, labels in val_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = model(inputs.float())
                    loss = criterion(outputs.squeeze(), labels)
                    running_val_loss += loss.item()

                    predicted = torch.argmax(outputs.squeeze(1), 1)
                    labels = torch.argmax(labels, 1)
                    all_preds.extend(predicted.cpu().numpy())
                    all_labels.extend(labels.cpu().numpy())

            val_loss = running_val_loss / len(val_loader)
            val_losses.append(val_loss)

            precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='macro')
            precision_scores.append(precision)
            recall_scores.append(recall)
            f1_scores.append(f1)

            scheduler.step()

            print(f"Epoch [{epoch + 1}/{num_epochs}] "
                f"Train Loss: {train_loss:.4f} "
                f"Val Loss: {val_loss:.4f} "
                f"Precision: {precision:.4f} "
                f"Recall: {recall:.4f} "
                f"F1 Score: {f1:.4f}\n")
            print(classification_report(all_labels, all_preds))

    return train_losses, val_losses, precision_scores, recall_scores, f1_scores

In [66]:
train_csv_path = 'train_top5.csv'
test_csv_path = 'dev_top5.csv'

method = 'endpoint'

train_dataset = CustomDataset(train_csv_path, onehot_encoder, method=method)
test_dataset = CustomDataset(test_csv_path, onehot_encoder, method=method)

train_batch_size = 256
test_batch_size = 256

train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

model = CustomDeepClassifier(input_dim=train_dataset[0][0].shape[1], num_classes=6)
tl, vl, p, r, f1 = train_validate_model(model,
                                        train_data_loader,
                                        test_data_loader,
                                        num_epochs=400,
                                        learning_rate=3e-2,
                                        stepslr=40,
                                        gamma=0.9)

torch.save(model.state_dict(), f'best_model_{method}.pt')

 10%|█         | 40/400 [23:32<3:33:28, 35.58s/it]

Epoch [40/400] Train Loss: 1.7778 Val Loss: 1.7872 Precision: 0.3590 Recall: 0.2719 F1 Score: 0.2661

              precision    recall  f1-score   support

           0       0.29      0.16      0.21       183
           1       0.12      0.18      0.15       119
           2       0.62      0.10      0.17       548
           3       0.33      0.22      0.26       285
           4       0.74      0.71      0.73      1801
           5       0.05      0.26      0.08       150

    accuracy                           0.48      3086
   macro avg       0.36      0.27      0.27      3086
weighted avg       0.60      0.48      0.50      3086



 20%|██        | 80/400 [46:22<3:09:35, 35.55s/it]

Epoch [80/400] Train Loss: 1.7832 Val Loss: 1.7910 Precision: 0.3254 Recall: 0.1931 F1 Score: 0.0753

              precision    recall  f1-score   support

           0       0.06      0.97      0.12       183
           1       0.20      0.03      0.04       119
           2       0.39      0.02      0.03       548
           3       0.38      0.06      0.11       285
           4       0.88      0.07      0.14      1801
           5       0.03      0.01      0.01       150

    accuracy                           0.11      3086
   macro avg       0.33      0.19      0.08      3086
weighted avg       0.63      0.11      0.10      3086



 30%|███       | 120/400 [1:09:00<2:41:37, 34.63s/it]

Epoch [120/400] Train Loss: 1.7819 Val Loss: 1.7895 Precision: 0.3515 Recall: 0.2083 F1 Score: 0.1098

              precision    recall  f1-score   support

           0       0.06      0.84      0.12       183
           1       0.17      0.07      0.10       119
           2       0.60      0.06      0.11       548
           3       0.29      0.04      0.07       285
           4       0.93      0.10      0.18      1801
           5       0.06      0.14      0.08       150

    accuracy                           0.13      3086
   macro avg       0.35      0.21      0.11      3086
weighted avg       0.69      0.13      0.14      3086



 40%|████      | 160/400 [1:30:47<2:12:49, 33.21s/it]

Epoch [160/400] Train Loss: 1.7817 Val Loss: 1.7884 Precision: 0.2891 Recall: 0.1972 F1 Score: 0.1210

              precision    recall  f1-score   support

           0       0.06      0.69      0.11       183
           1       0.07      0.05      0.06       119
           2       0.50      0.05      0.09       548
           3       0.21      0.16      0.18       285
           4       0.86      0.13      0.22      1801
           5       0.04      0.11      0.06       150

    accuracy                           0.15      3086
   macro avg       0.29      0.20      0.12      3086
weighted avg       0.62      0.15      0.18      3086



 50%|█████     | 200/400 [1:51:33<1:47:36, 32.28s/it]

Epoch [200/400] Train Loss: 1.7748 Val Loss: 1.7847 Precision: 0.3988 Recall: 0.3854 F1 Score: 0.3642

              precision    recall  f1-score   support

           0       0.33      0.52      0.41       183
           1       0.18      0.18      0.18       119
           2       0.45      0.48      0.46       548
           3       0.49      0.10      0.17       285
           4       0.82      0.80      0.81      1801
           5       0.12      0.23      0.16       150

    accuracy                           0.61      3086
   macro avg       0.40      0.39      0.36      3086
weighted avg       0.64      0.61      0.61      3086



 60%|██████    | 240/400 [2:12:12<1:25:56, 32.23s/it]

Epoch [240/400] Train Loss: 1.7740 Val Loss: 1.7836 Precision: 0.4035 Recall: 0.4230 F1 Score: 0.4015

              precision    recall  f1-score   support

           0       0.40      0.43      0.41       183
           1       0.24      0.17      0.20       119
           2       0.47      0.53      0.50       548
           3       0.29      0.48      0.36       285
           4       0.92      0.70      0.80      1801
           5       0.10      0.22      0.14       150

    accuracy                           0.59      3086
   macro avg       0.40      0.42      0.40      3086
weighted avg       0.68      0.59      0.63      3086



 70%|███████   | 280/400 [2:32:48<1:04:30, 32.26s/it]

Epoch [280/400] Train Loss: 1.7742 Val Loss: 1.7848 Precision: 0.3862 Recall: 0.3662 F1 Score: 0.3062

              precision    recall  f1-score   support

           0       0.17      0.78      0.27       183
           1       0.26      0.08      0.13       119
           2       0.43      0.50      0.46       548
           3       0.44      0.05      0.09       285
           4       0.89      0.71      0.79      1801
           5       0.14      0.07      0.10       150

    accuracy                           0.56      3086
   macro avg       0.39      0.37      0.31      3086
weighted avg       0.66      0.56      0.58      3086



 80%|████████  | 320/400 [2:53:22<42:46, 32.08s/it]  

Epoch [320/400] Train Loss: 1.7736 Val Loss: 1.7841 Precision: 0.3904 Recall: 0.4492 F1 Score: 0.3938

              precision    recall  f1-score   support

           0       0.33      0.67      0.44       183
           1       0.11      0.31      0.16       119
           2       0.54      0.39      0.45       548
           3       0.28      0.46      0.34       285
           4       0.93      0.70      0.80      1801
           5       0.16      0.17      0.17       150

    accuracy                           0.58      3086
   macro avg       0.39      0.45      0.39      3086
weighted avg       0.70      0.58      0.62      3086



 90%|█████████ | 360/400 [3:13:54<21:26, 32.17s/it]

Epoch [360/400] Train Loss: 1.7736 Val Loss: 1.7834 Precision: 0.4127 Recall: 0.4303 F1 Score: 0.4056

              precision    recall  f1-score   support

           0       0.50      0.33      0.40       183
           1       0.15      0.20      0.17       119
           2       0.45      0.58      0.51       548
           3       0.30      0.54      0.38       285
           4       0.94      0.69      0.80      1801
           5       0.14      0.23      0.18       150

    accuracy                           0.60      3086
   macro avg       0.41      0.43      0.41      3086
weighted avg       0.70      0.60      0.63      3086



100%|██████████| 400/400 [3:34:25<00:00, 32.16s/it]

Epoch [400/400] Train Loss: 1.7735 Val Loss: 1.7847 Precision: 0.4039 Recall: 0.4516 F1 Score: 0.4088

              precision    recall  f1-score   support

           0       0.31      0.62      0.41       183
           1       0.18      0.23      0.20       119
           2       0.44      0.57      0.50       548
           3       0.43      0.32      0.37       285
           4       0.92      0.69      0.79      1801
           5       0.14      0.29      0.19       150

    accuracy                           0.59      3086
   macro avg       0.40      0.45      0.41      3086
weighted avg       0.69      0.59      0.62      3086






In [67]:
train_csv_path = 'train_top5.csv'
test_csv_path = 'dev_top5.csv'

method = 'diff-sum'

train_dataset = CustomDataset(train_csv_path, onehot_encoder, method=method)
test_dataset = CustomDataset(test_csv_path, onehot_encoder, method=method)

train_batch_size = 256
test_batch_size = 256

train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

model = CustomDeepClassifier(input_dim=train_dataset[0][0].shape[1], num_classes=6)
tl, vl, p, r, f1 = train_validate_model(model,
                                        train_data_loader,
                                        test_data_loader,
                                        num_epochs=400,
                                        learning_rate=3e-2,
                                        stepslr=40,
                                        gamma=0.9)

torch.save(model.state_dict(), f'best_model_{method}.pt')

 10%|█         | 40/400 [18:55<2:58:15, 29.71s/it]

Epoch [40/400] Train Loss: 1.7846 Val Loss: 1.7930 Precision: 0.2694 Recall: 0.1778 F1 Score: 0.0664

              precision    recall  f1-score   support

           0       0.06      0.84      0.11       183
           1       0.00      0.00      0.00       119
           2       0.38      0.01      0.01       548
           3       0.20      0.02      0.03       285
           4       0.93      0.09      0.17      1801
           5       0.05      0.11      0.07       150

    accuracy                           0.11      3086
   macro avg       0.27      0.18      0.07      3086
weighted avg       0.63      0.11      0.12      3086



 20%|██        | 80/400 [37:51<2:38:21, 29.69s/it]

Epoch [80/400] Train Loss: 1.7829 Val Loss: 1.7885 Precision: 0.2758 Recall: 0.2327 F1 Score: 0.2222

              precision    recall  f1-score   support

           0       0.19      0.05      0.08       183
           1       0.04      0.05      0.05       119
           2       0.27      0.21      0.24       548
           3       0.34      0.14      0.20       285
           4       0.77      0.59      0.67      1801
           5       0.05      0.34      0.09       150

    accuracy                           0.42      3086
   macro avg       0.28      0.23      0.22      3086
weighted avg       0.54      0.42      0.46      3086



 30%|███       | 120/400 [56:45<2:18:21, 29.65s/it]

Epoch [120/400] Train Loss: 1.7833 Val Loss: 1.7929 Precision: 0.3338 Recall: 0.2159 F1 Score: 0.1260

              precision    recall  f1-score   support

           0       0.07      0.90      0.12       183
           1       0.14      0.05      0.07       119
           2       0.48      0.09      0.16       548
           3       0.35      0.04      0.07       285
           4       0.87      0.15      0.25      1801
           5       0.09      0.07      0.08       150

    accuracy                           0.17      3086
   macro avg       0.33      0.22      0.13      3086
weighted avg       0.64      0.17      0.20      3086



 40%|████      | 160/400 [1:15:38<1:58:33, 29.64s/it]

Epoch [160/400] Train Loss: 1.7840 Val Loss: 1.7898 Precision: 0.3300 Recall: 0.2148 F1 Score: 0.1233

              precision    recall  f1-score   support

           0       0.07      0.90      0.12       183
           1       0.12      0.03      0.04       119
           2       0.52      0.04      0.07       548
           3       0.33      0.07      0.12       285
           4       0.82      0.18      0.29      1801
           5       0.11      0.08      0.09       150

    accuracy                           0.17      3086
   macro avg       0.33      0.21      0.12      3086
weighted avg       0.62      0.17      0.21      3086



 50%|█████     | 200/400 [1:34:27<1:38:15, 29.48s/it]

Epoch [200/400] Train Loss: 1.7744 Val Loss: 1.7853 Precision: 0.3840 Recall: 0.3831 F1 Score: 0.3755

              precision    recall  f1-score   support

           0       0.47      0.40      0.43       183
           1       0.19      0.12      0.15       119
           2       0.42      0.58      0.49       548
           3       0.28      0.40      0.33       285
           4       0.85      0.78      0.81      1801
           5       0.09      0.03      0.05       150

    accuracy                           0.62      3086
   macro avg       0.38      0.38      0.38      3086
weighted avg       0.63      0.62      0.62      3086



 60%|██████    | 240/400 [1:53:12<1:18:36, 29.48s/it]

Epoch [240/400] Train Loss: 1.7735 Val Loss: 1.7837 Precision: 0.3905 Recall: 0.4473 F1 Score: 0.3911

              precision    recall  f1-score   support

           0       0.27      0.62      0.38       183
           1       0.20      0.19      0.20       119
           2       0.51      0.45      0.48       548
           3       0.31      0.50      0.39       285
           4       0.93      0.63      0.75      1801
           5       0.11      0.30      0.16       150

    accuracy                           0.55      3086
   macro avg       0.39      0.45      0.39      3086
weighted avg       0.69      0.55      0.60      3086



 70%|███████   | 280/400 [2:11:57<58:52, 29.44s/it]  

Epoch [280/400] Train Loss: 1.7745 Val Loss: 1.7867 Precision: 0.4281 Recall: 0.3462 F1 Score: 0.3520

              precision    recall  f1-score   support

           0       0.53      0.22      0.31       183
           1       0.23      0.15      0.18       119
           2       0.44      0.60      0.51       548
           3       0.47      0.12      0.19       285
           4       0.81      0.81      0.81      1801
           5       0.08      0.18      0.11       150

    accuracy                           0.62      3086
   macro avg       0.43      0.35      0.35      3086
weighted avg       0.64      0.62      0.61      3086



 80%|████████  | 320/400 [2:30:43<39:21, 29.52s/it]

Epoch [320/400] Train Loss: 1.7732 Val Loss: 1.7829 Precision: 0.4438 Recall: 0.4200 F1 Score: 0.3766

              precision    recall  f1-score   support

           0       0.48      0.46      0.47       183
           1       0.12      0.21      0.16       119
           2       0.58      0.33      0.42       548
           3       0.43      0.38      0.41       285
           4       0.97      0.51      0.67      1801
           5       0.08      0.63      0.14       150

    accuracy                           0.46      3086
   macro avg       0.44      0.42      0.38      3086
weighted avg       0.75      0.46      0.54      3086



 90%|█████████ | 360/400 [2:50:30<21:27, 32.19s/it]

Epoch [360/400] Train Loss: 1.7736 Val Loss: 1.7836 Precision: 0.4116 Recall: 0.4417 F1 Score: 0.3941

              precision    recall  f1-score   support

           0       0.27      0.66      0.39       183
           1       0.21      0.19      0.20       119
           2       0.53      0.45      0.49       548
           3       0.44      0.36      0.40       285
           4       0.93      0.63      0.75      1801
           5       0.09      0.35      0.14       150

    accuracy                           0.55      3086
   macro avg       0.41      0.44      0.39      3086
weighted avg       0.71      0.55      0.60      3086



100%|██████████| 400/400 [3:11:04<00:00, 28.66s/it]

Epoch [400/400] Train Loss: 1.7743 Val Loss: 1.7837 Precision: 0.3790 Recall: 0.4093 F1 Score: 0.3335

              precision    recall  f1-score   support

           0       0.25      0.67      0.36       183
           1       0.16      0.09      0.12       119
           2       0.61      0.15      0.24       548
           3       0.27      0.66      0.38       285
           4       0.90      0.68      0.77      1801
           5       0.10      0.21      0.13       150

    accuracy                           0.54      3086
   macro avg       0.38      0.41      0.33      3086
weighted avg       0.68      0.54      0.56      3086






In [68]:
train_csv_path = 'train_top5.csv'
test_csv_path = 'dev_top5.csv'

method = 'coherent'

train_dataset = CustomDataset(train_csv_path, onehot_encoder, method=method)
test_dataset = CustomDataset(test_csv_path, onehot_encoder, method=method)

train_batch_size = 256
test_batch_size = 256

train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

model = CustomDeepClassifier(input_dim=train_dataset[0][0].shape[1], num_classes=6)
tl, vl, p, r, f1 = train_validate_model(model,
                                        train_data_loader,
                                        test_data_loader,
                                        num_epochs=400,
                                        learning_rate=3e-2,
                                        stepslr=40,
                                        gamma=0.9)

torch.save(model.state_dict(), f'best_model_{method}.pt')

 10%|█         | 40/400 [20:22<3:11:10, 31.86s/it]

Epoch [40/400] Train Loss: 1.7749 Val Loss: 1.7855 Precision: 0.3940 Recall: 0.3284 F1 Score: 0.3317

              precision    recall  f1-score   support

           0       0.39      0.17      0.24       183
           1       0.23      0.08      0.12       119
           2       0.39      0.64      0.48       548
           3       0.44      0.15      0.22       285
           4       0.81      0.81      0.81      1801
           5       0.12      0.12      0.12       150

    accuracy                           0.62      3086
   macro avg       0.39      0.33      0.33      3086
weighted avg       0.62      0.62      0.60      3086



 20%|██        | 80/400 [40:44<2:49:41, 31.82s/it]

Epoch [80/400] Train Loss: 1.7741 Val Loss: 1.7853 Precision: 0.3873 Recall: 0.4231 F1 Score: 0.3978

              precision    recall  f1-score   support

           0       0.35      0.43      0.39       183
           1       0.15      0.21      0.17       119
           2       0.47      0.55      0.50       548
           3       0.30      0.48      0.37       285
           4       0.89      0.73      0.80      1801
           5       0.17      0.14      0.15       150

    accuracy                           0.61      3086
   macro avg       0.39      0.42      0.40      3086
weighted avg       0.67      0.61      0.63      3086



 30%|███       | 120/400 [1:01:04<2:28:15, 31.77s/it]

Epoch [120/400] Train Loss: 1.7737 Val Loss: 1.7842 Precision: 0.3807 Recall: 0.4155 F1 Score: 0.3848

              precision    recall  f1-score   support

           0       0.39      0.48      0.43       183
           1       0.14      0.13      0.13       119
           2       0.46      0.52      0.49       548
           3       0.27      0.44      0.34       285
           4       0.91      0.66      0.77      1801
           5       0.11      0.26      0.15       150

    accuracy                           0.57      3086
   macro avg       0.38      0.42      0.38      3086
weighted avg       0.67      0.57      0.60      3086



 40%|████      | 160/400 [1:21:23<2:07:15, 31.82s/it]

Epoch [160/400] Train Loss: 1.7742 Val Loss: 1.7845 Precision: 0.4447 Recall: 0.3748 F1 Score: 0.3407

              precision    recall  f1-score   support

           0       0.49      0.34      0.40       183
           1       0.20      0.09      0.13       119
           2       0.49      0.39      0.43       548
           3       0.46      0.14      0.22       285
           4       0.94      0.57      0.71      1801
           5       0.08      0.71      0.15       150

    accuracy                           0.48      3086
   macro avg       0.44      0.37      0.34      3086
weighted avg       0.72      0.48      0.55      3086



 50%|█████     | 200/400 [1:41:43<1:45:44, 31.72s/it]

Epoch [200/400] Train Loss: 1.7739 Val Loss: 1.7845 Precision: 0.4297 Recall: 0.3746 F1 Score: 0.3458

              precision    recall  f1-score   support

           0       0.46      0.36      0.41       183
           1       0.21      0.08      0.12       119
           2       0.57      0.31      0.40       548
           3       0.31      0.31      0.31       285
           4       0.94      0.56      0.70      1801
           5       0.08      0.63      0.13       150

    accuracy                           0.47      3086
   macro avg       0.43      0.37      0.35      3086
weighted avg       0.72      0.47      0.55      3086



 60%|██████    | 240/400 [2:02:02<1:24:37, 31.73s/it]

Epoch [240/400] Train Loss: 1.7734 Val Loss: 1.7838 Precision: 0.4200 Recall: 0.3842 F1 Score: 0.3424

              precision    recall  f1-score   support

           0       0.34      0.38      0.36       183
           1       0.19      0.12      0.15       119
           2       0.61      0.30      0.40       548
           3       0.33      0.31      0.32       285
           4       0.96      0.54      0.69      1801
           5       0.08      0.67      0.14       150

    accuracy                           0.45      3086
   macro avg       0.42      0.38      0.34      3086
weighted avg       0.73      0.45      0.54      3086



 70%|███████   | 280/400 [2:22:11<1:00:47, 30.39s/it]

Epoch [280/400] Train Loss: 1.7738 Val Loss: 1.7843 Precision: 0.4168 Recall: 0.4211 F1 Score: 0.3792

              precision    recall  f1-score   support

           0       0.40      0.42      0.41       183
           1       0.17      0.31      0.22       119
           2       0.60      0.34      0.43       548
           3       0.30      0.37      0.33       285
           4       0.94      0.60      0.74      1801
           5       0.08      0.49      0.14       150

    accuracy                           0.51      3086
   macro avg       0.42      0.42      0.38      3086
weighted avg       0.72      0.51      0.58      3086



 80%|████████  | 320/400 [2:40:44<37:52, 28.41s/it]  

Epoch [320/400] Train Loss: 1.7735 Val Loss: 1.7840 Precision: 0.4070 Recall: 0.4066 F1 Score: 0.3739

              precision    recall  f1-score   support

           0       0.38      0.31      0.34       183
           1       0.21      0.18      0.20       119
           2       0.52      0.36      0.43       548
           3       0.31      0.39      0.34       285
           4       0.93      0.65      0.77      1801
           5       0.10      0.54      0.17       150

    accuracy                           0.53      3086
   macro avg       0.41      0.41      0.37      3086
weighted avg       0.70      0.53      0.59      3086



 90%|█████████ | 360/400 [2:58:45<18:45, 28.15s/it]

Epoch [360/400] Train Loss: 1.7742 Val Loss: 1.7858 Precision: 0.4143 Recall: 0.3680 F1 Score: 0.2992

              precision    recall  f1-score   support

           0       0.21      0.66      0.32       183
           1       0.26      0.11      0.15       119
           2       0.62      0.10      0.17       548
           3       0.42      0.16      0.23       285
           4       0.89      0.69      0.78      1801
           5       0.09      0.49      0.15       150

    accuracy                           0.50      3086
   macro avg       0.41      0.37      0.30      3086
weighted avg       0.69      0.50      0.54      3086



100%|██████████| 400/400 [3:16:35<00:00, 29.49s/it]

Epoch [400/400] Train Loss: 1.7743 Val Loss: 1.7841 Precision: 0.4697 Recall: 0.3383 F1 Score: 0.2986

              precision    recall  f1-score   support

           0       0.59      0.09      0.16       183
           1       0.18      0.11      0.14       119
           2       0.55      0.28      0.37       548
           3       0.47      0.22      0.30       285
           4       0.96      0.53      0.68      1801
           5       0.08      0.79      0.14       150

    accuracy                           0.43      3086
   macro avg       0.47      0.34      0.30      3086
weighted avg       0.74      0.43      0.51      3086






In [69]:
train_csv_path = 'train_top5.csv'
test_csv_path = 'dev_top5.csv'

method = 'maxpool'

train_dataset = CustomDataset(train_csv_path, onehot_encoder, method=method)
test_dataset = CustomDataset(test_csv_path, onehot_encoder, method=method)

train_batch_size = 256
test_batch_size = 256

train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

model = CustomDeepClassifier(input_dim=train_dataset[0][0].shape[1], num_classes=6)
tl, vl, p, r, f1 = train_validate_model(model,
                                        train_data_loader,
                                        test_data_loader,
                                        num_epochs=400,
                                        learning_rate=3e-2,
                                        stepslr=40,
                                        gamma=0.9)

torch.save(model.state_dict(), f'best_model_{method}.pt')

 10%|█         | 40/400 [24:02<3:46:52, 37.81s/it]

Epoch [40/400] Train Loss: 1.7766 Val Loss: 1.7873 Precision: 0.2940 Recall: 0.2802 F1 Score: 0.2734

              precision    recall  f1-score   support

           0       0.21      0.15      0.17       183
           1       0.10      0.12      0.11       119
           2       0.48      0.20      0.28       548
           3       0.27      0.41      0.33       285
           4       0.68      0.80      0.74      1801
           5       0.02      0.01      0.01       150

    accuracy                           0.55      3086
   macro avg       0.29      0.28      0.27      3086
weighted avg       0.52      0.55      0.52      3086



 20%|██        | 80/400 [48:19<3:21:44, 37.83s/it]

Epoch [80/400] Train Loss: 1.7769 Val Loss: 1.7883 Precision: 0.3872 Recall: 0.2659 F1 Score: 0.2692

              precision    recall  f1-score   support

           0       0.28      0.17      0.21       183
           1       0.16      0.10      0.12       119
           2       0.76      0.06      0.12       548
           3       0.39      0.30      0.34       285
           4       0.63      0.89      0.74      1801
           5       0.10      0.07      0.08       150

    accuracy                           0.58      3086
   macro avg       0.39      0.27      0.27      3086
weighted avg       0.57      0.58      0.51      3086



 30%|███       | 120/400 [1:12:31<2:57:43, 38.08s/it]

Epoch [120/400] Train Loss: 1.7823 Val Loss: 1.7885 Precision: 0.3840 Recall: 0.1823 F1 Score: 0.1147

              precision    recall  f1-score   support

           0       0.05      0.56      0.09       183
           1       0.15      0.07      0.09       119
           2       0.95      0.03      0.07       548
           3       0.29      0.12      0.17       285
           4       0.82      0.11      0.19      1801
           5       0.05      0.20      0.07       150

    accuracy                           0.13      3086
   macro avg       0.38      0.18      0.11      3086
weighted avg       0.68      0.13      0.15      3086



 40%|████      | 160/400 [1:36:51<2:36:23, 39.10s/it]

Epoch [160/400] Train Loss: 1.7756 Val Loss: 1.7865 Precision: 0.3229 Recall: 0.2920 F1 Score: 0.2866

              precision    recall  f1-score   support

           0       0.26      0.30      0.27       183
           1       0.14      0.20      0.17       119
           2       0.43      0.14      0.21       548
           3       0.33      0.28      0.30       285
           4       0.64      0.80      0.71      1801
           5       0.14      0.03      0.05       150

    accuracy                           0.54      3086
   macro avg       0.32      0.29      0.29      3086
weighted avg       0.51      0.54      0.51      3086



 50%|█████     | 200/400 [2:01:20<2:07:22, 38.21s/it]

Epoch [200/400] Train Loss: 1.7747 Val Loss: 1.7870 Precision: 0.3804 Recall: 0.3144 F1 Score: 0.2795

              precision    recall  f1-score   support

           0       0.27      0.15      0.19       183
           1       0.11      0.15      0.13       119
           2       0.75      0.10      0.18       548
           3       0.22      0.67      0.34       285
           4       0.72      0.73      0.72      1801
           5       0.21      0.09      0.12       150

    accuracy                           0.52      3086
   macro avg       0.38      0.31      0.28      3086
weighted avg       0.60      0.52      0.51      3086



 60%|██████    | 240/400 [2:25:31<1:40:51, 37.82s/it]

Epoch [240/400] Train Loss: 1.7743 Val Loss: 1.7872 Precision: 0.3388 Recall: 0.3566 F1 Score: 0.3356

              precision    recall  f1-score   support

           0       0.25      0.38      0.31       183
           1       0.15      0.19      0.17       119
           2       0.41      0.36      0.38       548
           3       0.31      0.46      0.37       285
           4       0.73      0.70      0.72      1801
           5       0.17      0.04      0.06       150

    accuracy                           0.55      3086
   macro avg       0.34      0.36      0.34      3086
weighted avg       0.56      0.55      0.55      3086



 70%|███████   | 280/400 [2:49:41<1:15:21, 37.68s/it]

Epoch [280/400] Train Loss: 1.7747 Val Loss: 1.7867 Precision: 0.3350 Recall: 0.3314 F1 Score: 0.3103

              precision    recall  f1-score   support

           0       0.26      0.34      0.30       183
           1       0.10      0.18      0.13       119
           2       0.50      0.20      0.28       548
           3       0.32      0.48      0.38       285
           4       0.70      0.75      0.73      1801
           5       0.13      0.03      0.04       150

    accuracy                           0.55      3086
   macro avg       0.33      0.33      0.31      3086
weighted avg       0.55      0.55      0.53      3086



 80%|████████  | 320/400 [3:13:54<50:41, 38.02s/it]  

Epoch [320/400] Train Loss: 1.7749 Val Loss: 1.7887 Precision: 0.3475 Recall: 0.2414 F1 Score: 0.2287

              precision    recall  f1-score   support

           0       0.24      0.15      0.19       183
           1       0.09      0.15      0.11       119
           2       0.80      0.04      0.07       548
           3       0.27      0.20      0.23       285
           4       0.64      0.76      0.69      1801
           5       0.06      0.15      0.08       150

    accuracy                           0.49      3086
   macro avg       0.35      0.24      0.23      3086
weighted avg       0.56      0.49      0.46      3086



 90%|█████████ | 360/400 [3:38:06<25:18, 37.96s/it]

Epoch [360/400] Train Loss: 1.7750 Val Loss: 1.7865 Precision: 0.3893 Recall: 0.3004 F1 Score: 0.2549

              precision    recall  f1-score   support

           0       0.31      0.21      0.25       183
           1       0.14      0.10      0.12       119
           2       0.77      0.10      0.18       548
           3       0.15      0.79      0.25       285
           4       0.74      0.50      0.60      1801
           5       0.21      0.10      0.14       150

    accuracy                           0.40      3086
   macro avg       0.39      0.30      0.25      3086
weighted avg       0.62      0.40      0.43      3086



100%|██████████| 400/400 [4:02:20<00:00, 36.35s/it]

Epoch [400/400] Train Loss: 1.7750 Val Loss: 1.7877 Precision: 0.3828 Recall: 0.2943 F1 Score: 0.2861

              precision    recall  f1-score   support

           0       0.35      0.22      0.27       183
           1       0.08      0.24      0.12       119
           2       0.75      0.13      0.22       548
           3       0.23      0.32      0.27       285
           4       0.67      0.77      0.72      1801
           5       0.21      0.08      0.12       150

    accuracy                           0.53      3086
   macro avg       0.38      0.29      0.29      3086
weighted avg       0.58      0.53      0.51      3086






In [70]:
train_csv_path = 'train_top5.csv'
test_csv_path = 'dev_top5.csv'

method = 'avgpool'

train_dataset = CustomDataset(train_csv_path, onehot_encoder, method=method)
test_dataset = CustomDataset(test_csv_path, onehot_encoder, method=method)

train_batch_size = 256
test_batch_size = 256

train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

model = CustomDeepClassifier(input_dim=train_dataset[0][0].shape[1], num_classes=6)
tl, vl, p, r, f1 = train_validate_model(model,
                                        train_data_loader,
                                        test_data_loader,
                                        num_epochs=400,
                                        learning_rate=3e-2,
                                        stepslr=40,
                                        gamma=0.9)

torch.save(model.state_dict(), f'best_model_{method}.pt')

 10%|█         | 40/400 [18:50<2:57:31, 29.59s/it]

Epoch [40/400] Train Loss: 1.7757 Val Loss: 1.7897 Precision: 0.3069 Recall: 0.2858 F1 Score: 0.2853

              precision    recall  f1-score   support

           0       0.13      0.07      0.09       183
           1       0.07      0.11      0.08       119
           2       0.44      0.32      0.37       548
           3       0.37      0.24      0.29       285
           4       0.73      0.76      0.75      1801
           5       0.09      0.22      0.13       150

    accuracy                           0.54      3086
   macro avg       0.31      0.29      0.29      3086
weighted avg       0.56      0.54      0.54      3086



 20%|██        | 80/400 [40:09<3:05:23, 34.76s/it]

Epoch [80/400] Train Loss: 1.7756 Val Loss: 1.7875 Precision: 0.3623 Recall: 0.3383 F1 Score: 0.3357

              precision    recall  f1-score   support

           0       0.17      0.09      0.12       183
           1       0.22      0.07      0.10       119
           2       0.40      0.63      0.49       548
           3       0.44      0.34      0.39       285
           4       0.78      0.74      0.76      1801
           5       0.17      0.16      0.16       150

    accuracy                           0.59      3086
   macro avg       0.36      0.34      0.34      3086
weighted avg       0.59      0.59      0.58      3086



 30%|███       | 120/400 [1:01:35<2:36:17, 33.49s/it]

Epoch [120/400] Train Loss: 1.7753 Val Loss: 1.7893 Precision: 0.3495 Recall: 0.3084 F1 Score: 0.3026

              precision    recall  f1-score   support

           0       0.19      0.06      0.09       183
           1       0.12      0.05      0.07       119
           2       0.39      0.68      0.50       548
           3       0.49      0.20      0.28       285
           4       0.78      0.80      0.79      1801
           5       0.12      0.06      0.08       150

    accuracy                           0.61      3086
   macro avg       0.35      0.31      0.30      3086
weighted avg       0.59      0.61      0.59      3086



 40%|████      | 160/400 [1:23:06<2:14:04, 33.52s/it]

Epoch [160/400] Train Loss: 1.7744 Val Loss: 1.7862 Precision: 0.3416 Recall: 0.3314 F1 Score: 0.3214

              precision    recall  f1-score   support

           0       0.17      0.09      0.12       183
           1       0.19      0.08      0.12       119
           2       0.39      0.66      0.49       548
           3       0.38      0.25      0.30       285
           4       0.80      0.70      0.75      1801
           5       0.12      0.20      0.15       150

    accuracy                           0.57      3086
   macro avg       0.34      0.33      0.32      3086
weighted avg       0.60      0.57      0.57      3086



 50%|█████     | 200/400 [1:44:34<1:51:42, 33.51s/it]

Epoch [200/400] Train Loss: 1.7749 Val Loss: 1.7874 Precision: 0.3363 Recall: 0.3370 F1 Score: 0.3291

              precision    recall  f1-score   support

           0       0.17      0.07      0.10       183
           1       0.11      0.09      0.10       119
           2       0.42      0.60      0.50       548
           3       0.38      0.41      0.40       285
           4       0.78      0.75      0.76      1801
           5       0.15      0.11      0.12       150

    accuracy                           0.59      3086
   macro avg       0.34      0.34      0.33      3086
weighted avg       0.59      0.59      0.59      3086



 60%|██████    | 240/400 [2:06:03<1:29:21, 33.51s/it]

Epoch [240/400] Train Loss: 1.7745 Val Loss: 1.7869 Precision: 0.3430 Recall: 0.3365 F1 Score: 0.3270

              precision    recall  f1-score   support

           0       0.15      0.17      0.16       183
           1       0.17      0.13      0.15       119
           2       0.39      0.59      0.47       548
           3       0.43      0.22      0.29       285
           4       0.80      0.68      0.73      1801
           5       0.13      0.22      0.16       150

    accuracy                           0.55      3086
   macro avg       0.34      0.34      0.33      3086
weighted avg       0.60      0.55      0.56      3086



 70%|███████   | 280/400 [2:27:56<1:08:40, 34.34s/it]

Epoch [280/400] Train Loss: 1.7742 Val Loss: 1.7865 Precision: 0.3433 Recall: 0.3449 F1 Score: 0.3222

              precision    recall  f1-score   support

           0       0.14      0.15      0.14       183
           1       0.18      0.09      0.12       119
           2       0.44      0.52      0.48       548
           3       0.39      0.30      0.34       285
           4       0.82      0.61      0.70      1801
           5       0.10      0.40      0.15       150

    accuracy                           0.51      3086
   macro avg       0.34      0.34      0.32      3086
weighted avg       0.61      0.51      0.54      3086



 80%|████████  | 320/400 [2:48:44<41:08, 30.85s/it]  

Epoch [320/400] Train Loss: 1.7746 Val Loss: 1.7865 Precision: 0.3424 Recall: 0.3560 F1 Score: 0.3375

              precision    recall  f1-score   support

           0       0.17      0.15      0.16       183
           1       0.16      0.12      0.14       119
           2       0.45      0.54      0.49       548
           3       0.37      0.40      0.39       285
           4       0.81      0.63      0.71      1801
           5       0.10      0.29      0.14       150

    accuracy                           0.53      3086
   macro avg       0.34      0.36      0.34      3086
weighted avg       0.61      0.53      0.56      3086



 90%|█████████ | 360/400 [3:07:56<20:22, 30.57s/it]

Epoch [360/400] Train Loss: 1.7741 Val Loss: 1.7852 Precision: 0.3530 Recall: 0.3774 F1 Score: 0.3600

              precision    recall  f1-score   support

           0       0.21      0.20      0.20       183
           1       0.15      0.20      0.17       119
           2       0.43      0.54      0.48       548
           3       0.35      0.37      0.36       285
           4       0.82      0.68      0.74      1801
           5       0.15      0.27      0.20       150

    accuracy                           0.56      3086
   macro avg       0.35      0.38      0.36      3086
weighted avg       0.61      0.56      0.58      3086



100%|██████████| 400/400 [3:27:16<00:00, 31.09s/it]

Epoch [400/400] Train Loss: 1.7740 Val Loss: 1.7867 Precision: 0.3586 Recall: 0.3675 F1 Score: 0.3605

              precision    recall  f1-score   support

           0       0.20      0.20      0.20       183
           1       0.18      0.16      0.17       119
           2       0.49      0.53      0.51       548
           3       0.36      0.35      0.36       285
           4       0.80      0.73      0.76      1801
           5       0.13      0.23      0.17       150

    accuracy                           0.58      3086
   macro avg       0.36      0.37      0.36      3086
weighted avg       0.61      0.58      0.60      3086






---

### Attention pooling

In [9]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.nn as nn
import pandas as pd
from sklearn.preprocessing import OneHotEncoder
import torch.optim as optim
from sklearn.metrics import precision_recall_fscore_support, classification_report
import matplotlib.pyplot as plt
from tqdm import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

onehot_encoder = OneHotEncoder(sparse_output=False)
data = pd.read_csv('train_top5.csv')['class'].values.reshape(-1, 1)
onehot_encoder.fit(data)

class CustomDataset(Dataset):
    def __init__(self, csv_file, onehot_encoder, method='endpoint', padding_size=256):
        self.data = pd.read_csv(csv_file)
        self.labels = onehot_encoder.transform(self.data['class'].values.reshape(-1, 1))
        self.method = method
        self.padding_size = padding_size

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        file_path = self.data.iloc[idx]['path']
        label = self.labels[idx]

        raw_embedding = torch.load(file_path)
        if self.method == 'endpoint':
            embedding = torch.cat((raw_embedding[:, 0, :], raw_embedding[:, -1, :]), dim=1)
        elif self.method == 'diff-sum':
            embedding = torch.cat(
                (
                    raw_embedding[:, 0, :] + raw_embedding[:, -1, :],
                    raw_embedding[:, 0, :] - raw_embedding[:, -1, :]
                ),
                dim=1
            )
        elif self.method == 'coherent':
            embedding = torch.cat(
                (
                    raw_embedding[:, 0, :360],
                    raw_embedding[:, -1, 360:720],
                    torch.dot(
                        raw_embedding[:, 0, :].squeeze()[720:],
                        raw_embedding[:, -1, :].squeeze()[720:]
                    ).unsqueeze(0).unsqueeze(0)
                ),
                dim=1
            )
        elif self.method == 'maxpool':
            embedding = torch.max(raw_embedding, dim=1)[0]
        elif self.method == 'avgpool':
            embedding = torch.mean(raw_embedding, dim=1)
        else:
            embedding = raw_embedding
            padding = torch.zeros(1, self.padding_size - embedding.shape[1], embedding.shape[-1])
            embedding = torch.cat((embedding, padding), dim=1)

        return embedding, torch.tensor(label)


class CustomAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim=1024, proj_dim=256, num_layers=3, num_classes=20):
        super(CustomAttention, self).__init__()

        self.projection = nn.Linear(input_dim, proj_dim)

        self.attention_params = nn.Linear(proj_dim, 1)

        layers = []
        layers.append(nn.Linear(proj_dim, hidden_dim))
        layers.append(nn.ReLU())
        layers.append(nn.BatchNorm1d(1))
        for _ in range(num_layers - 2):
            layers.append(nn.Linear(hidden_dim, hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(1))
        layers.append(nn.Linear(hidden_dim, num_classes))
        self.layers = nn.Sequential(*layers)
        
    def forward(self, embeddings):
        embeddings = self.projection(embeddings)

        attn_logits = self.attention_params(embeddings)
        attention_wts = nn.functional.softmax(attn_logits, dim=2)

        attention_term = torch.sum(attention_wts * embeddings, dim=-1)

        output = self.layers(attention_term)
        
        output = F.softmax(output, dim=0)
        
        return output
    

def train_validate_model(model, train_loader, val_loader, num_epochs=10, learning_rate=0.1, stepslr=10, step_report=10, gamma=0.9):
    model.to(device)
    criterion = nn.CrossEntropyLoss(weight=torch.tensor([1./183, 1./119, 1./548, 1./285, 1./1801, 1./150]).to(device))
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, stepslr, gamma=gamma)

    train_losses = []
    val_losses = []
    precision_scores = []
    recall_scores = []
    f1_scores = []

    for epoch in tqdm(range(num_epochs), total=num_epochs, leave=False):
        model.train()
        running_train_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs.float())
            # print(outputs.device, labels.device)
            loss = criterion(outputs.squeeze(), labels)
            loss.backward()
            optimizer.step()
            running_train_loss += loss.item()
        
        train_loss = running_train_loss / len(train_loader)
        train_losses.append(train_loss)

        model.eval()
        running_val_loss = 0.0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs.float())
                loss = criterion(outputs.squeeze(), labels)
                running_val_loss += loss.item()

                predicted = torch.argmax(outputs.squeeze(1), 1)
                labels = torch.argmax(labels, 1)
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        val_loss = running_val_loss / len(val_loader)
        val_losses.append(val_loss)

        precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='macro')
        precision_scores.append(precision)
        recall_scores.append(recall)
        f1_scores.append(f1)

        scheduler.step()

        print(f"Epoch [{epoch + 1}/{num_epochs}] "
            f"Train Loss: {train_loss:.4f} "
            f"Val Loss: {val_loss:.4f} "
            f"Precision: {precision:.4f} "
            f"Recall: {recall:.4f} "
            f"F1 Score: {f1:.4f}")
        if (epoch + 1) % step_report == 0:
            print(classification_report(all_labels, all_preds))

    return train_losses, val_losses, precision_scores, recall_scores, f1_scores

In [10]:
train_csv_path = 'train_top5.csv'
test_csv_path = 'dev_top5.csv'

method = 'attnpool'

train_dataset = CustomDataset(train_csv_path, onehot_encoder, method=method)
test_dataset = CustomDataset(test_csv_path, onehot_encoder, method=method)

train_batch_size = 256
test_batch_size = 256

train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

model = CustomAttention(input_dim=768, num_classes=6).to(device)
tl, vl, p, r, f1 = train_validate_model(model,
                                        train_data_loader,
                                        test_data_loader,
                                        num_epochs=100,
                                        learning_rate=3e-2,
                                        stepslr=10,
                                        gamma=0.5)

torch.save(model.state_dict(), f'best_model_{method}.pt')

  0%|          | 0/100 [00:00<?, ?it/s]

  1%|          | 1/100 [00:56<1:33:50, 56.87s/it]

Epoch [1/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.1999 Recall: 0.2095 F1 Score: 0.0878


  2%|▏         | 2/100 [01:53<1:32:35, 56.69s/it]

Epoch [2/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2141 Recall: 0.1857 F1 Score: 0.0930


  3%|▎         | 3/100 [02:48<1:30:18, 55.86s/it]

Epoch [3/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2515 Recall: 0.1914 F1 Score: 0.1042


  4%|▍         | 4/100 [03:43<1:28:51, 55.54s/it]

Epoch [4/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2527 Recall: 0.2078 F1 Score: 0.0695


  5%|▌         | 5/100 [04:37<1:27:05, 55.01s/it]

Epoch [5/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2428 Recall: 0.2080 F1 Score: 0.0754


  6%|▌         | 6/100 [05:32<1:26:14, 55.04s/it]

Epoch [6/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2371 Recall: 0.2140 F1 Score: 0.1064


  7%|▋         | 7/100 [06:28<1:25:37, 55.24s/it]

Epoch [7/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2703 Recall: 0.2108 F1 Score: 0.1029


  8%|▊         | 8/100 [07:24<1:25:05, 55.50s/it]

Epoch [8/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2833 Recall: 0.2158 F1 Score: 0.0960


  9%|▉         | 9/100 [08:19<1:24:06, 55.45s/it]

Epoch [9/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2483 Recall: 0.2132 F1 Score: 0.0974


 10%|█         | 10/100 [09:14<1:22:47, 55.20s/it]

Epoch [10/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2302 Recall: 0.2066 F1 Score: 0.1008
              precision    recall  f1-score   support

           0       0.07      0.01      0.01       183
           1       0.04      0.23      0.06       119
           2       0.29      0.20      0.23       548
           3       0.11      0.72      0.19       285
           4       0.77      0.01      0.01      1801
           5       0.11      0.08      0.09       150

    accuracy                           0.12      3086
   macro avg       0.23      0.21      0.10      3086
weighted avg       0.52      0.12      0.07      3086



 11%|█         | 11/100 [10:11<1:22:53, 55.88s/it]

Epoch [11/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2313 Recall: 0.2110 F1 Score: 0.1026


 12%|█▏        | 12/100 [11:05<1:21:10, 55.34s/it]

Epoch [12/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2231 Recall: 0.1977 F1 Score: 0.1021


 13%|█▎        | 13/100 [12:01<1:20:35, 55.58s/it]

Epoch [13/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2461 Recall: 0.2106 F1 Score: 0.1185


 14%|█▍        | 14/100 [12:55<1:18:47, 54.97s/it]

Epoch [14/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2414 Recall: 0.2117 F1 Score: 0.1198


 15%|█▌        | 15/100 [13:52<1:18:41, 55.55s/it]

Epoch [15/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2346 Recall: 0.2033 F1 Score: 0.1156


 16%|█▌        | 16/100 [14:55<1:21:08, 57.96s/it]

Epoch [16/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2251 Recall: 0.2027 F1 Score: 0.1144


 17%|█▋        | 17/100 [15:50<1:18:55, 57.06s/it]

Epoch [17/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2329 Recall: 0.1959 F1 Score: 0.1085


 18%|█▊        | 18/100 [16:45<1:16:59, 56.34s/it]

Epoch [18/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2376 Recall: 0.1958 F1 Score: 0.1083


 19%|█▉        | 19/100 [17:42<1:16:29, 56.65s/it]

Epoch [19/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2420 Recall: 0.2097 F1 Score: 0.1210


 20%|██        | 20/100 [18:38<1:15:01, 56.27s/it]

Epoch [20/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2078 Recall: 0.2106 F1 Score: 0.1112
              precision    recall  f1-score   support

           0       0.00      0.00      0.00       183
           1       0.06      0.16      0.09       119
           2       0.28      0.34      0.31       548
           3       0.10      0.68      0.17       285
           4       0.70      0.00      0.01      1801
           5       0.11      0.08      0.09       150

    accuracy                           0.14      3086
   macro avg       0.21      0.21      0.11      3086
weighted avg       0.47      0.14      0.08      3086



 21%|██        | 21/100 [19:32<1:13:24, 55.75s/it]

Epoch [21/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2356 Recall: 0.1585 F1 Score: 0.0770


 22%|██▏       | 22/100 [20:28<1:12:31, 55.78s/it]

Epoch [22/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2060 Recall: 0.2011 F1 Score: 0.1064


 23%|██▎       | 23/100 [21:30<1:13:55, 57.60s/it]

Epoch [23/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2684 Recall: 0.2016 F1 Score: 0.1063


 24%|██▍       | 24/100 [22:32<1:14:47, 59.04s/it]

Epoch [24/100] Train Loss: 0.0036 Val Loss: 0.0033 Precision: 0.2673 Recall: 0.1893 F1 Score: 0.0943


                                                  

KeyboardInterrupt: 