In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sklearn.preprocessing import OneHotEncoder

onehot_encoder = OneHotEncoder(sparse_output=False)
data = pd.read_csv('train_top5.csv')['class'].values.reshape(-1, 1)
onehot_encoder.fit(data)

class CustomDataset(Dataset):
    def __init__(self, csv_file, onehot_encoder, method='endpoint'):
        self.data = pd.read_csv(csv_file)
        self.labels = onehot_encoder.transform(self.data['class'].values.reshape(-1, 1))
        self.method = method

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        file_path = self.data.iloc[idx]['path']
        label = self.labels[idx]

        raw_embedding = torch.load(file_path)
        if self.method == 'endpoint':
            embedding = torch.cat((raw_embedding[:, 0, :], raw_embedding[:, -1, :]), dim=1)
        elif self.method == 'diff-sum':
            embedding = torch.cat(
                (
                    raw_embedding[:, 0, :] + raw_embedding[:, -1, :],
                    raw_embedding[:, 0, :] - raw_embedding[:, -1, :]
                ),
                dim=1
            )
        elif self.method == 'coherent':
            embedding = torch.cat(
                (
                    raw_embedding[:, 0, :360],
                    raw_embedding[:, -1, 360:720],
                    torch.dot(
                        raw_embedding[:, 0, :].squeeze()[720:],
                        raw_embedding[:, -1, :].squeeze()[720:]
                    ).unsqueeze(0).unsqueeze(0)
                ),
                dim=1
            )
        elif self.method == 'maxpool':
            embedding = torch.max(raw_embedding, dim=1)[0]
        elif self.method == 'avgpool':
            embedding = torch.mean(raw_embedding, dim=1)

        return embedding, torch.tensor(label)

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [2]:
import torch.nn.functional as F
import torch.nn as nn

class CustomDeepClassifier(nn.Module):
    def __init__(self, input_dim=1536, num_classes=20, hidden_dim=2048):
        super(CustomDeepClassifier, self).__init__()

        layers = []
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(nn.ReLU())
        layers.append(nn.BatchNorm1d(1))

        layers.append(nn.Linear(hidden_dim, hidden_dim))
        layers.append(nn.ReLU())
        layers.append(nn.BatchNorm1d(1))

        layers.append(nn.Linear(hidden_dim, num_classes))

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        x = self.layers(x)
        x = F.softmax(x, dim=0)
        return x

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import precision_recall_fscore_support, classification_report
import matplotlib.pyplot as plt
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train_validate_model(model, train_loader, val_loader, num_epochs=10, learning_rate=0.1, step_report=40, stepslr=10, gamma=0.9):
    model.to(device)
    criterion = nn.CrossEntropyLoss(weight=torch.tensor([1./183, 1./119, 1./548, 1./285, 1./1801, 1./150]).to(device))
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, stepslr, gamma=gamma)

    train_losses = []
    val_losses = []
    precision_scores = []
    recall_scores = []
    f1_scores = []

    for epoch in tqdm(range(num_epochs), total=num_epochs):
        model.train()
        running_train_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs.float())
            loss = criterion(outputs.squeeze(), labels)
            loss.backward()
            optimizer.step()
            running_train_loss += loss.item()
        
        train_loss = running_train_loss / len(train_loader)
        train_losses.append(train_loss)

        model.eval()
        running_val_loss = 0.0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs.float())
                loss = criterion(outputs.squeeze(), labels)
                running_val_loss += loss.item()

                predicted = torch.argmax(outputs.squeeze(1), 1)
                labels = torch.argmax(labels, 1)
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        val_loss = running_val_loss / len(val_loader)
        val_losses.append(val_loss)

        precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='macro')
        precision_scores.append(precision)
        recall_scores.append(recall)
        f1_scores.append(f1)

        scheduler.step()

        if (epoch + 1) % step_report == 0:
            print(f"Epoch [{epoch + 1}/{num_epochs}] "
                f"Train Loss: {train_loss:.4f} "
                f"Val Loss: {val_loss:.4f} "
                f"Precision: {precision:.4f} "
                f"Recall: {recall:.4f} "
                f"F1 Score: {f1:.4f}\n")
            print(classification_report(all_labels, all_preds))

    return train_losses, val_losses, precision_scores, recall_scores, f1_scores

  return torch._C._cuda_getDeviceCount() > 0


In [4]:
train_csv_path = 'train_top5.csv'
test_csv_path = 'dev_top5.csv'

method = 'endpoint'

train_dataset = CustomDataset(train_csv_path, onehot_encoder, method=method)
test_dataset = CustomDataset(test_csv_path, onehot_encoder, method=method)

train_batch_size = 256
test_batch_size = 256

train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

model = CustomDeepClassifier(input_dim=train_dataset[0][0].shape[1], num_classes=6)
tl, vl, p, r, f1 = train_validate_model(model,
                                        train_data_loader,
                                        test_data_loader,
                                        num_epochs=400,
                                        learning_rate=3e-3,
                                        stepslr=40,
                                        gamma=0.5)

torch.save(model.state_dict(), f'best_model_{method}.pt')

 10%|█         | 40/400 [28:28<4:14:37, 42.44s/it]

Epoch [40/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3748 Recall: 0.2873 F1 Score: 0.2938

              precision    recall  f1-score   support

           0       0.57      0.14      0.22       183
           1       0.25      0.08      0.12       119
           2       0.32      0.46      0.37       548
           3       0.28      0.16      0.21       285
           4       0.77      0.73      0.74      1801
           5       0.07      0.16      0.10       150

    accuracy                           0.54      3086
   macro avg       0.37      0.29      0.29      3086
weighted avg       0.58      0.54      0.54      3086



 20%|██        | 80/400 [56:09<3:41:42, 41.57s/it]

Epoch [80/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3741 Recall: 0.3848 F1 Score: 0.3436

              precision    recall  f1-score   support

           0       0.38      0.43      0.41       183
           1       0.15      0.14      0.15       119
           2       0.47      0.42      0.44       548
           3       0.26      0.55      0.36       285
           4       0.91      0.46      0.61      1801
           5       0.06      0.31      0.10       150

    accuracy                           0.44      3086
   macro avg       0.37      0.38      0.34      3086
weighted avg       0.67      0.44      0.50      3086



 30%|███       | 120/400 [1:23:54<3:14:41, 41.72s/it]

Epoch [120/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3709 Recall: 0.4292 F1 Score: 0.3693

              precision    recall  f1-score   support

           0       0.31      0.59      0.41       183
           1       0.14      0.25      0.18       119
           2       0.49      0.43      0.46       548
           3       0.30      0.53      0.38       285
           4       0.92      0.54      0.68      1801
           5       0.07      0.23      0.11       150

    accuracy                           0.50      3086
   macro avg       0.37      0.43      0.37      3086
weighted avg       0.68      0.50      0.55      3086



 40%|████      | 160/400 [1:51:57<2:47:29, 41.87s/it]

Epoch [160/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3761 Recall: 0.4364 F1 Score: 0.3697

              precision    recall  f1-score   support

           0       0.33      0.58      0.42       183
           1       0.11      0.37      0.17       119
           2       0.51      0.42      0.46       548
           3       0.30      0.52      0.38       285
           4       0.92      0.53      0.67      1801
           5       0.07      0.19      0.10       150

    accuracy                           0.49      3086
   macro avg       0.38      0.44      0.37      3086
weighted avg       0.69      0.49      0.55      3086



 50%|█████     | 200/400 [2:19:57<2:20:22, 42.11s/it]

Epoch [200/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3798 Recall: 0.4474 F1 Score: 0.3817

              precision    recall  f1-score   support

           0       0.34      0.57      0.43       183
           1       0.12      0.42      0.19       119
           2       0.51      0.45      0.48       548
           3       0.30      0.52      0.38       285
           4       0.92      0.58      0.71      1801
           5       0.09      0.15      0.11       150

    accuracy                           0.52      3086
   macro avg       0.38      0.45      0.38      3086
weighted avg       0.69      0.52      0.57      3086



 60%|██████    | 240/400 [2:47:54<1:51:49, 41.93s/it]

Epoch [240/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3781 Recall: 0.4496 F1 Score: 0.3883

              precision    recall  f1-score   support

           0       0.36      0.58      0.45       183
           1       0.14      0.38      0.20       119
           2       0.47      0.47      0.47       548
           3       0.30      0.54      0.38       285
           4       0.92      0.60      0.73      1801
           5       0.09      0.12      0.10       150

    accuracy                           0.54      3086
   macro avg       0.38      0.45      0.39      3086
weighted avg       0.68      0.54      0.58      3086



 70%|███████   | 280/400 [3:16:07<1:24:11, 42.10s/it]

Epoch [280/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3809 Recall: 0.4484 F1 Score: 0.3919

              precision    recall  f1-score   support

           0       0.37      0.57      0.45       183
           1       0.14      0.34      0.19       119
           2       0.47      0.48      0.47       548
           3       0.30      0.56      0.39       285
           4       0.92      0.62      0.74      1801
           5       0.09      0.12      0.10       150

    accuracy                           0.55      3086
   macro avg       0.38      0.45      0.39      3086
weighted avg       0.68      0.55      0.59      3086



 80%|████████  | 320/400 [3:44:19<56:07, 42.09s/it]  

Epoch [320/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3815 Recall: 0.4497 F1 Score: 0.3930

              precision    recall  f1-score   support

           0       0.37      0.57      0.45       183
           1       0.14      0.34      0.20       119
           2       0.47      0.48      0.48       548
           3       0.30      0.56      0.39       285
           4       0.92      0.62      0.74      1801
           5       0.09      0.12      0.10       150

    accuracy                           0.55      3086
   macro avg       0.38      0.45      0.39      3086
weighted avg       0.68      0.55      0.59      3086



 90%|█████████ | 360/400 [4:12:25<28:02, 42.07s/it]

Epoch [360/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3819 Recall: 0.4517 F1 Score: 0.3942

              precision    recall  f1-score   support

           0       0.36      0.57      0.44       183
           1       0.15      0.36      0.21       119
           2       0.47      0.49      0.48       548
           3       0.29      0.55      0.38       285
           4       0.92      0.62      0.74      1801
           5       0.10      0.12      0.11       150

    accuracy                           0.55      3086
   macro avg       0.38      0.45      0.39      3086
weighted avg       0.68      0.55      0.59      3086



100%|██████████| 400/400 [4:40:38<00:00, 42.10s/it]

Epoch [400/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3814 Recall: 0.4495 F1 Score: 0.3936

              precision    recall  f1-score   support

           0       0.36      0.57      0.44       183
           1       0.15      0.34      0.21       119
           2       0.47      0.49      0.48       548
           3       0.29      0.55      0.38       285
           4       0.92      0.63      0.75      1801
           5       0.10      0.12      0.11       150

    accuracy                           0.56      3086
   macro avg       0.38      0.45      0.39      3086
weighted avg       0.68      0.56      0.59      3086






In [5]:
train_csv_path = 'train_top5.csv'
test_csv_path = 'dev_top5.csv'

method = 'diff-sum'

train_dataset = CustomDataset(train_csv_path, onehot_encoder, method=method)
test_dataset = CustomDataset(test_csv_path, onehot_encoder, method=method)

train_batch_size = 256
test_batch_size = 256

train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

model = CustomDeepClassifier(input_dim=train_dataset[0][0].shape[1], num_classes=6)
tl, vl, p, r, f1 = train_validate_model(model,
                                        train_data_loader,
                                        test_data_loader,
                                        num_epochs=400,
                                        learning_rate=3e-3,
                                        stepslr=40,
                                        gamma=0.5)

torch.save(model.state_dict(), f'best_model_{method}.pt')

 10%|█         | 40/400 [27:41<4:09:23, 41.57s/it]

Epoch [40/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3564 Recall: 0.2297 F1 Score: 0.2334

              precision    recall  f1-score   support

           0       0.47      0.08      0.13       183
           1       0.23      0.05      0.08       119
           2       0.57      0.15      0.23       548
           3       0.24      0.17      0.20       285
           4       0.63      0.93      0.75      1801
           5       0.00      0.00      0.00       150

    accuracy                           0.59      3086
   macro avg       0.36      0.23      0.23      3086
weighted avg       0.53      0.59      0.51      3086



 20%|██        | 80/400 [55:25<3:41:59, 41.62s/it]

Epoch [80/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.4048 Recall: 0.3353 F1 Score: 0.3300

              precision    recall  f1-score   support

           0       0.53      0.17      0.26       183
           1       0.39      0.11      0.17       119
           2       0.35      0.63      0.45       548
           3       0.23      0.36      0.28       285
           4       0.86      0.71      0.78      1801
           5       0.07      0.03      0.05       150

    accuracy                           0.58      3086
   macro avg       0.40      0.34      0.33      3086
weighted avg       0.63      0.58      0.58      3086



 30%|███       | 120/400 [1:23:23<3:16:45, 42.16s/it]

Epoch [120/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3928 Recall: 0.3635 F1 Score: 0.3566

              precision    recall  f1-score   support

           0       0.49      0.27      0.35       183
           1       0.27      0.13      0.18       119
           2       0.44      0.53      0.48       548
           3       0.21      0.41      0.27       285
           4       0.89      0.66      0.76      1801
           5       0.08      0.18      0.11       150

    accuracy                           0.55      3086
   macro avg       0.39      0.36      0.36      3086
weighted avg       0.66      0.55      0.58      3086



 40%|████      | 160/400 [1:51:15<2:47:17, 41.82s/it]

Epoch [160/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.4164 Recall: 0.3730 F1 Score: 0.3509

              precision    recall  f1-score   support

           0       0.45      0.28      0.35       183
           1       0.26      0.13      0.17       119
           2       0.48      0.44      0.46       548
           3       0.31      0.32      0.31       285
           4       0.92      0.54      0.68      1801
           5       0.07      0.53      0.13       150

    accuracy                           0.47      3086
   macro avg       0.42      0.37      0.35      3086
weighted avg       0.69      0.47      0.54      3086



 50%|█████     | 200/400 [2:19:01<2:18:24, 41.52s/it]

Epoch [200/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3917 Recall: 0.3827 F1 Score: 0.3509

              precision    recall  f1-score   support

           0       0.36      0.42      0.39       183
           1       0.16      0.14      0.15       119
           2       0.49      0.43      0.46       548
           3       0.33      0.31      0.32       285
           4       0.93      0.51      0.66      1801
           5       0.07      0.48      0.12       150

    accuracy                           0.46      3086
   macro avg       0.39      0.38      0.35      3086
weighted avg       0.69      0.46      0.53      3086



 60%|██████    | 240/400 [2:46:56<1:50:55, 41.59s/it]

Epoch [240/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3841 Recall: 0.3738 F1 Score: 0.3395

              precision    recall  f1-score   support

           0       0.36      0.44      0.40       183
           1       0.11      0.15      0.13       119
           2       0.50      0.42      0.46       548
           3       0.33      0.27      0.30       285
           4       0.94      0.49      0.64      1801
           5       0.07      0.47      0.12       150

    accuracy                           0.44      3086
   macro avg       0.38      0.37      0.34      3086
weighted avg       0.70      0.44      0.52      3086



 70%|███████   | 280/400 [3:14:38<1:23:11, 41.60s/it]

Epoch [280/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3848 Recall: 0.3797 F1 Score: 0.3452

              precision    recall  f1-score   support

           0       0.36      0.45      0.40       183
           1       0.11      0.16      0.13       119
           2       0.50      0.44      0.47       548
           3       0.33      0.28      0.30       285
           4       0.94      0.50      0.65      1801
           5       0.07      0.45      0.12       150

    accuracy                           0.45      3086
   macro avg       0.38      0.38      0.35      3086
weighted avg       0.70      0.45      0.53      3086



 80%|████████  | 320/400 [3:42:34<56:44, 42.56s/it]  

Epoch [320/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3852 Recall: 0.3746 F1 Score: 0.3408

              precision    recall  f1-score   support

           0       0.36      0.44      0.40       183
           1       0.11      0.16      0.13       119
           2       0.50      0.43      0.46       548
           3       0.34      0.27      0.30       285
           4       0.94      0.49      0.64      1801
           5       0.07      0.46      0.11       150

    accuracy                           0.44      3086
   macro avg       0.39      0.37      0.34      3086
weighted avg       0.70      0.44      0.52      3086



 90%|█████████ | 360/400 [4:10:38<27:51, 41.78s/it]

Epoch [360/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3848 Recall: 0.3738 F1 Score: 0.3388

              precision    recall  f1-score   support

           0       0.36      0.44      0.40       183
           1       0.11      0.16      0.13       119
           2       0.50      0.42      0.46       548
           3       0.34      0.27      0.30       285
           4       0.94      0.48      0.64      1801
           5       0.07      0.47      0.12       150

    accuracy                           0.44      3086
   macro avg       0.38      0.37      0.34      3086
weighted avg       0.70      0.44      0.51      3086



100%|██████████| 400/400 [4:38:26<00:00, 41.77s/it]

Epoch [400/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3863 Recall: 0.3726 F1 Score: 0.3385

              precision    recall  f1-score   support

           0       0.36      0.43      0.39       183
           1       0.11      0.16      0.13       119
           2       0.50      0.43      0.46       548
           3       0.35      0.27      0.30       285
           4       0.94      0.48      0.63      1801
           5       0.07      0.47      0.12       150

    accuracy                           0.43      3086
   macro avg       0.39      0.37      0.34      3086
weighted avg       0.70      0.43      0.51      3086






In [6]:
train_csv_path = 'train_top5.csv'
test_csv_path = 'dev_top5.csv'

method = 'coherent'

train_dataset = CustomDataset(train_csv_path, onehot_encoder, method=method)
test_dataset = CustomDataset(test_csv_path, onehot_encoder, method=method)

train_batch_size = 256
test_batch_size = 256

train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

model = CustomDeepClassifier(input_dim=train_dataset[0][0].shape[1], num_classes=6)
tl, vl, p, r, f1 = train_validate_model(model,
                                        train_data_loader,
                                        test_data_loader,
                                        num_epochs=400,
                                        learning_rate=3e-3,
                                        stepslr=40,
                                        gamma=0.5)

torch.save(model.state_dict(), f'best_model_{method}.pt')

 10%|█         | 40/400 [29:42<4:26:46, 44.46s/it]

Epoch [40/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3848 Recall: 0.2263 F1 Score: 0.2286

              precision    recall  f1-score   support

           0       0.44      0.06      0.11       183
           1       0.17      0.03      0.06       119
           2       0.46      0.16      0.23       548
           3       0.47      0.06      0.10       285
           4       0.63      0.95      0.76      1801
           5       0.14      0.11      0.12       150

    accuracy                           0.59      3086
   macro avg       0.38      0.23      0.23      3086
weighted avg       0.53      0.59      0.51      3086



 20%|██        | 80/400 [59:18<3:56:55, 44.42s/it]

Epoch [80/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.4114 Recall: 0.3040 F1 Score: 0.3093

              precision    recall  f1-score   support

           0       0.57      0.16      0.25       183
           1       0.21      0.08      0.11       119
           2       0.35      0.55      0.42       548
           3       0.48      0.11      0.18       285
           4       0.76      0.76      0.76      1801
           5       0.10      0.16      0.12       150

    accuracy                           0.57      3086
   macro avg       0.41      0.30      0.31      3086
weighted avg       0.60      0.57      0.56      3086



 30%|███       | 120/400 [1:29:35<3:31:38, 45.35s/it]

Epoch [120/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3641 Recall: 0.3605 F1 Score: 0.3356

              precision    recall  f1-score   support

           0       0.38      0.30      0.33       183
           1       0.14      0.11      0.12       119
           2       0.41      0.43      0.42       548
           3       0.28      0.27      0.28       285
           4       0.89      0.60      0.71      1801
           5       0.09      0.45      0.14       150

    accuracy                           0.49      3086
   macro avg       0.36      0.36      0.34      3086
weighted avg       0.65      0.49      0.55      3086



 40%|████      | 160/400 [1:59:46<3:00:12, 45.05s/it]

Epoch [160/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3566 Recall: 0.3803 F1 Score: 0.3354

              precision    recall  f1-score   support

           0       0.27      0.44      0.34       183
           1       0.16      0.17      0.16       119
           2       0.46      0.34      0.39       548
           3       0.25      0.35      0.29       285
           4       0.92      0.55      0.69      1801
           5       0.08      0.43      0.14       150

    accuracy                           0.47      3086
   macro avg       0.36      0.38      0.34      3086
weighted avg       0.67      0.47      0.53      3086



 50%|█████     | 200/400 [2:30:16<2:35:51, 46.76s/it]

Epoch [200/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3553 Recall: 0.3866 F1 Score: 0.3389

              precision    recall  f1-score   support

           0       0.27      0.46      0.34       183
           1       0.16      0.18      0.17       119
           2       0.47      0.34      0.39       548
           3       0.22      0.43      0.29       285
           4       0.92      0.56      0.70      1801
           5       0.09      0.36      0.14       150

    accuracy                           0.48      3086
   macro avg       0.36      0.39      0.34      3086
weighted avg       0.67      0.48      0.54      3086



 60%|██████    | 240/400 [3:00:34<2:00:47, 45.30s/it]

Epoch [240/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3546 Recall: 0.3958 F1 Score: 0.3383

              precision    recall  f1-score   support

           0       0.26      0.51      0.34       183
           1       0.16      0.17      0.16       119
           2       0.47      0.34      0.40       548
           3       0.22      0.50      0.31       285
           4       0.92      0.53      0.68      1801
           5       0.09      0.33      0.15       150

    accuracy                           0.47      3086
   macro avg       0.35      0.40      0.34      3086
weighted avg       0.67      0.47      0.53      3086



 70%|███████   | 280/400 [3:28:48<1:24:14, 42.12s/it]

Epoch [280/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3537 Recall: 0.3949 F1 Score: 0.3374

              precision    recall  f1-score   support

           0       0.26      0.51      0.34       183
           1       0.16      0.18      0.17       119
           2       0.47      0.33      0.39       548
           3       0.22      0.50      0.31       285
           4       0.92      0.53      0.68      1801
           5       0.09      0.32      0.14       150

    accuracy                           0.47      3086
   macro avg       0.35      0.39      0.34      3086
weighted avg       0.67      0.47      0.53      3086



 80%|████████  | 320/400 [3:57:05<58:44, 44.05s/it]  

Epoch [320/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3537 Recall: 0.3942 F1 Score: 0.3371

              precision    recall  f1-score   support

           0       0.25      0.50      0.34       183
           1       0.16      0.18      0.17       119
           2       0.48      0.34      0.40       548
           3       0.22      0.50      0.30       285
           4       0.92      0.53      0.68      1801
           5       0.09      0.31      0.14       150

    accuracy                           0.47      3086
   macro avg       0.35      0.39      0.34      3086
weighted avg       0.67      0.47      0.53      3086



 90%|█████████ | 360/400 [4:25:33<28:15, 42.38s/it]

Epoch [360/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3519 Recall: 0.3942 F1 Score: 0.3353

              precision    recall  f1-score   support

           0       0.25      0.51      0.33       183
           1       0.15      0.18      0.16       119
           2       0.47      0.33      0.39       548
           3       0.22      0.51      0.31       285
           4       0.92      0.53      0.68      1801
           5       0.09      0.31      0.15       150

    accuracy                           0.47      3086
   macro avg       0.35      0.39      0.34      3086
weighted avg       0.67      0.47      0.53      3086



100%|██████████| 400/400 [4:54:27<00:00, 44.17s/it]

Epoch [400/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3524 Recall: 0.3945 F1 Score: 0.3357

              precision    recall  f1-score   support

           0       0.25      0.51      0.33       183
           1       0.15      0.18      0.16       119
           2       0.48      0.33      0.39       548
           3       0.22      0.51      0.31       285
           4       0.93      0.54      0.68      1801
           5       0.09      0.31      0.15       150

    accuracy                           0.47      3086
   macro avg       0.35      0.39      0.34      3086
weighted avg       0.67      0.47      0.53      3086






In [7]:
train_csv_path = 'train_top5.csv'
test_csv_path = 'dev_top5.csv'

method = 'maxpool'

train_dataset = CustomDataset(train_csv_path, onehot_encoder, method=method)
test_dataset = CustomDataset(test_csv_path, onehot_encoder, method=method)

train_batch_size = 256
test_batch_size = 256

train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

model = CustomDeepClassifier(input_dim=train_dataset[0][0].shape[1], num_classes=6)
tl, vl, p, r, f1 = train_validate_model(model,
                                        train_data_loader,
                                        test_data_loader,
                                        num_epochs=400,
                                        learning_rate=3e-3,
                                        stepslr=40,
                                        gamma=0.5)

torch.save(model.state_dict(), f'best_model_{method}.pt')

 10%|█         | 40/400 [36:12<5:26:46, 54.46s/it]

Epoch [40/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3261 Recall: 0.3306 F1 Score: 0.3089

              precision    recall  f1-score   support

           0       0.29      0.35      0.32       183
           1       0.12      0.26      0.16       119
           2       0.49      0.21      0.29       548
           3       0.27      0.41      0.32       285
           4       0.68      0.70      0.69      1801
           5       0.12      0.05      0.07       150

    accuracy                           0.52      3086
   macro avg       0.33      0.33      0.31      3086
weighted avg       0.53      0.52      0.51      3086



 20%|██        | 80/400 [1:12:47<4:53:08, 54.96s/it]

Epoch [80/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3440 Recall: 0.3364 F1 Score: 0.2922

              precision    recall  f1-score   support

           0       0.29      0.33      0.31       183
           1       0.13      0.13      0.13       119
           2       0.55      0.20      0.30       548
           3       0.25      0.51      0.34       285
           4       0.78      0.45      0.57      1801
           5       0.06      0.39      0.11       150

    accuracy                           0.39      3086
   macro avg       0.34      0.34      0.29      3086
weighted avg       0.60      0.39      0.44      3086



 30%|███       | 120/400 [1:49:19<4:14:52, 54.62s/it]

Epoch [120/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3375 Recall: 0.3409 F1 Score: 0.2712

              precision    recall  f1-score   support

           0       0.26      0.41      0.32       183
           1       0.12      0.20      0.15       119
           2       0.52      0.16      0.25       548
           3       0.27      0.37      0.31       285
           4       0.79      0.35      0.48      1801
           5       0.07      0.55      0.12       150

    accuracy                           0.32      3086
   macro avg       0.34      0.34      0.27      3086
weighted avg       0.60      0.32      0.38      3086



 40%|████      | 160/400 [2:26:10<3:42:15, 55.57s/it]

Epoch [160/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3503 Recall: 0.3533 F1 Score: 0.2879

              precision    recall  f1-score   support

           0       0.25      0.34      0.29       183
           1       0.11      0.24      0.16       119
           2       0.54      0.18      0.27       548
           3       0.33      0.35      0.34       285
           4       0.80      0.41      0.54      1801
           5       0.08      0.59      0.13       150

    accuracy                           0.36      3086
   macro avg       0.35      0.35      0.29      3086
weighted avg       0.61      0.36      0.42      3086



 50%|█████     | 200/400 [3:02:49<3:03:29, 55.05s/it]

Epoch [200/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3366 Recall: 0.3610 F1 Score: 0.3002

              precision    recall  f1-score   support

           0       0.24      0.35      0.29       183
           1       0.10      0.26      0.15       119
           2       0.47      0.22      0.30       548
           3       0.34      0.37      0.36       285
           4       0.78      0.44      0.56      1801
           5       0.08      0.52      0.14       150

    accuracy                           0.39      3086
   macro avg       0.34      0.36      0.30      3086
weighted avg       0.59      0.39      0.45      3086



 60%|██████    | 240/400 [3:39:44<2:29:51, 56.20s/it]

Epoch [240/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3333 Recall: 0.3539 F1 Score: 0.2974

              precision    recall  f1-score   support

           0       0.24      0.33      0.28       183
           1       0.10      0.24      0.14       119
           2       0.45      0.24      0.31       548
           3       0.34      0.36      0.35       285
           4       0.78      0.44      0.56      1801
           5       0.08      0.51      0.14       150

    accuracy                           0.38      3086
   macro avg       0.33      0.35      0.30      3086
weighted avg       0.59      0.38      0.44      3086



 70%|███████   | 280/400 [4:16:38<1:50:00, 55.00s/it]

Epoch [280/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3351 Recall: 0.3556 F1 Score: 0.3002

              precision    recall  f1-score   support

           0       0.24      0.33      0.28       183
           1       0.10      0.24      0.14       119
           2       0.44      0.24      0.31       548
           3       0.36      0.35      0.36       285
           4       0.78      0.44      0.56      1801
           5       0.08      0.53      0.14       150

    accuracy                           0.39      3086
   macro avg       0.34      0.36      0.30      3086
weighted avg       0.59      0.39      0.45      3086



 80%|████████  | 320/400 [4:53:21<1:12:53, 54.67s/it]

Epoch [320/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3379 Recall: 0.3576 F1 Score: 0.3030

              precision    recall  f1-score   support

           0       0.25      0.33      0.28       183
           1       0.11      0.24      0.15       119
           2       0.45      0.24      0.32       548
           3       0.37      0.35      0.36       285
           4       0.78      0.45      0.57      1801
           5       0.08      0.53      0.14       150

    accuracy                           0.39      3086
   macro avg       0.34      0.36      0.30      3086
weighted avg       0.59      0.39      0.45      3086



 90%|█████████ | 360/400 [5:30:10<36:33, 54.83s/it]  

Epoch [360/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3401 Recall: 0.3592 F1 Score: 0.3052

              precision    recall  f1-score   support

           0       0.25      0.33      0.28       183
           1       0.11      0.24      0.15       119
           2       0.45      0.25      0.32       548
           3       0.37      0.35      0.36       285
           4       0.78      0.46      0.58      1801
           5       0.08      0.53      0.14       150

    accuracy                           0.40      3086
   macro avg       0.34      0.36      0.31      3086
weighted avg       0.59      0.40      0.46      3086



100%|██████████| 400/400 [6:06:50<00:00, 55.03s/it]

Epoch [400/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3411 Recall: 0.3584 F1 Score: 0.3048

              precision    recall  f1-score   support

           0       0.25      0.33      0.29       183
           1       0.11      0.24      0.15       119
           2       0.46      0.24      0.31       548
           3       0.37      0.35      0.36       285
           4       0.78      0.46      0.58      1801
           5       0.08      0.53      0.14       150

    accuracy                           0.40      3086
   macro avg       0.34      0.36      0.30      3086
weighted avg       0.59      0.40      0.46      3086






In [8]:
train_csv_path = 'train_top5.csv'
test_csv_path = 'dev_top5.csv'

method = 'avgpool'

train_dataset = CustomDataset(train_csv_path, onehot_encoder, method=method)
test_dataset = CustomDataset(test_csv_path, onehot_encoder, method=method)

train_batch_size = 256
test_batch_size = 256

train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

model = CustomDeepClassifier(input_dim=train_dataset[0][0].shape[1], num_classes=6)
tl, vl, p, r, f1 = train_validate_model(model,
                                        train_data_loader,
                                        test_data_loader,
                                        num_epochs=400,
                                        learning_rate=3e-3,
                                        stepslr=40,
                                        gamma=0.5)

torch.save(model.state_dict(), f'best_model_{method}.pt')

 10%|█         | 40/400 [31:21<4:44:30, 47.42s/it]

Epoch [40/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3011 Recall: 0.2465 F1 Score: 0.2512

              precision    recall  f1-score   support

           0       0.12      0.04      0.06       183
           1       0.04      0.02      0.02       119
           2       0.44      0.38      0.41       548
           3       0.44      0.13      0.20       285
           4       0.66      0.87      0.75      1801
           5       0.10      0.05      0.07       150

    accuracy                           0.59      3086
   macro avg       0.30      0.25      0.25      3086
weighted avg       0.52      0.59      0.54      3086



 20%|██        | 80/400 [1:02:25<4:07:40, 46.44s/it]

Epoch [80/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3080 Recall: 0.2777 F1 Score: 0.2675

              precision    recall  f1-score   support

           0       0.12      0.15      0.13       183
           1       0.10      0.04      0.06       119
           2       0.35      0.62      0.45       548
           3       0.50      0.15      0.23       285
           4       0.75      0.67      0.71      1801
           5       0.04      0.03      0.03       150

    accuracy                           0.53      3086
   macro avg       0.31      0.28      0.27      3086
weighted avg       0.56      0.53      0.53      3086



 30%|███       | 120/400 [1:33:47<3:41:55, 47.55s/it]

Epoch [120/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3251 Recall: 0.3097 F1 Score: 0.3083

              precision    recall  f1-score   support

           0       0.12      0.19      0.15       183
           1       0.11      0.07      0.08       119
           2       0.42      0.56      0.48       548
           3       0.51      0.29      0.37       285
           4       0.75      0.73      0.74      1801
           5       0.04      0.03      0.03       150

    accuracy                           0.56      3086
   macro avg       0.33      0.31      0.31      3086
weighted avg       0.57      0.56      0.56      3086



 40%|████      | 160/400 [2:05:11<3:07:18, 46.83s/it]

Epoch [160/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3107 Recall: 0.3279 F1 Score: 0.3133

              precision    recall  f1-score   support

           0       0.14      0.20      0.16       183
           1       0.10      0.11      0.10       119
           2       0.51      0.42      0.46       548
           3       0.33      0.53      0.41       285
           4       0.76      0.67      0.71      1801
           5       0.03      0.03      0.03       150

    accuracy                           0.53      3086
   macro avg       0.31      0.33      0.31      3086
weighted avg       0.58      0.53      0.55      3086



 50%|█████     | 200/400 [2:36:47<2:37:50, 47.35s/it]

Epoch [200/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3206 Recall: 0.3425 F1 Score: 0.3177

              precision    recall  f1-score   support

           0       0.15      0.23      0.19       183
           1       0.10      0.21      0.13       119
           2       0.55      0.38      0.45       548
           3       0.31      0.54      0.40       285
           4       0.77      0.63      0.70      1801
           5       0.04      0.05      0.05       150

    accuracy                           0.51      3086
   macro avg       0.32      0.34      0.32      3086
weighted avg       0.59      0.51      0.54      3086



 60%|██████    | 240/400 [3:08:21<2:06:30, 47.44s/it]

Epoch [240/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3257 Recall: 0.3449 F1 Score: 0.3152

              precision    recall  f1-score   support

           0       0.16      0.22      0.19       183
           1       0.09      0.26      0.14       119
           2       0.57      0.33      0.42       548
           3       0.32      0.56      0.40       285
           4       0.77      0.62      0.69      1801
           5       0.05      0.07      0.06       150

    accuracy                           0.50      3086
   macro avg       0.33      0.34      0.32      3086
weighted avg       0.60      0.50      0.53      3086



 70%|███████   | 280/400 [3:39:50<1:34:36, 47.31s/it]

Epoch [280/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3276 Recall: 0.3451 F1 Score: 0.3124

              precision    recall  f1-score   support

           0       0.16      0.21      0.18       183
           1       0.09      0.30      0.14       119
           2       0.58      0.31      0.40       548
           3       0.31      0.56      0.40       285
           4       0.77      0.63      0.69      1801
           5       0.05      0.06      0.05       150

    accuracy                           0.50      3086
   macro avg       0.33      0.35      0.31      3086
weighted avg       0.60      0.50      0.53      3086



 80%|████████  | 320/400 [4:11:06<1:02:39, 46.99s/it]

Epoch [320/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3331 Recall: 0.3528 F1 Score: 0.3142

              precision    recall  f1-score   support

           0       0.16      0.21      0.18       183
           1       0.10      0.35      0.15       119
           2       0.60      0.31      0.41       548
           3       0.31      0.57      0.40       285
           4       0.78      0.62      0.69      1801
           5       0.05      0.06      0.05       150

    accuracy                           0.50      3086
   macro avg       0.33      0.35      0.31      3086
weighted avg       0.61      0.50      0.53      3086



 90%|█████████ | 360/400 [4:42:36<31:44, 47.62s/it]  

Epoch [360/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3328 Recall: 0.3504 F1 Score: 0.3121

              precision    recall  f1-score   support

           0       0.16      0.21      0.18       183
           1       0.09      0.35      0.14       119
           2       0.60      0.30      0.40       548
           3       0.32      0.57      0.41       285
           4       0.78      0.62      0.69      1801
           5       0.05      0.06      0.05       150

    accuracy                           0.49      3086
   macro avg       0.33      0.35      0.31      3086
weighted avg       0.61      0.49      0.53      3086



100%|██████████| 400/400 [5:14:16<00:00, 47.14s/it]

Epoch [400/400] Train Loss: 0.0035 Val Loss: 0.0033 Precision: 0.3331 Recall: 0.3502 F1 Score: 0.3114

              precision    recall  f1-score   support

           0       0.16      0.20      0.18       183
           1       0.09      0.36      0.14       119
           2       0.60      0.30      0.40       548
           3       0.32      0.57      0.41       285
           4       0.78      0.61      0.69      1801
           5       0.05      0.06      0.05       150

    accuracy                           0.49      3086
   macro avg       0.33      0.35      0.31      3086
weighted avg       0.61      0.49      0.53      3086




