In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sklearn.preprocessing import OneHotEncoder

class CustomDataset(Dataset):
    def __init__(self, csv_file, method='endpoint'):
        self.data = pd.read_csv(csv_file)
        self.labels = self.data['class'].values
        self.method = method

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        file_path = self.data.iloc[idx]['path']
        label = 0 if self.labels[idx] == 'NoClass' else 1

        raw_embedding = torch.load(file_path)
        if self.method == 'endpoint':
            embedding = torch.cat((raw_embedding[:, 0, :], raw_embedding[:, -1, :]), dim=1)
        elif self.method == 'diff-sum':
            embedding = torch.cat(
                (
                    raw_embedding[:, 0, :] + raw_embedding[:, -1, :],
                    raw_embedding[:, 0, :] - raw_embedding[:, -1, :]
                ),
                dim=1
            )
        elif self.method == 'coherent':
            embedding = torch.cat(
                (
                    raw_embedding[:, 0, :360],
                    raw_embedding[:, -1, 360:720],
                    torch.dot(
                        raw_embedding[:, 0, :].squeeze()[720:],
                        raw_embedding[:, -1, :].squeeze()[720:]
                    ).unsqueeze(0).unsqueeze(0)
                ),
                dim=1
            )
        elif self.method == 'maxpool':
            embedding = torch.max(raw_embedding, dim=1)[0]
        elif self.method == 'avgpool':
            embedding = torch.mean(raw_embedding, dim=1)

        return embedding, torch.FloatTensor([label])

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [2]:
import torch.nn.functional as F
import torch.nn as nn

class CustomDeepClassifier(nn.Module):
    def __init__(self, input_dim=1536, hidden_dim=2048):
        super(CustomDeepClassifier, self).__init__()

        layers = []
        layers.append(nn.Linear(input_dim, hidden_dim))
        layers.append(nn.ReLU())
        layers.append(nn.BatchNorm1d(1))

        layers.append(nn.Linear(hidden_dim, hidden_dim))
        layers.append(nn.ReLU())
        layers.append(nn.BatchNorm1d(1))

        layers.append(nn.Linear(hidden_dim, 1))

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        x = self.layers(x)
        x = F.sigmoid(x)
        return x

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import precision_recall_fscore_support, classification_report, accuracy_score
import matplotlib.pyplot as plt
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train_validate_model(model, train_loader, val_loader, num_epochs=10,
                         learning_rate=0.1, step_report=5, stepslr=10, gamma=0.9,
                         sgd=False):
    model.to(device)
    criterion = nn.BCELoss()
    if sgd:
        optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
    else:
        optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer, stepslr, gamma=gamma)

    train_losses = []
    val_losses = []
    # precision_scores = []
    # recall_scores = []
    # f1_scores = []
    # accuracy_scores = []

    for epoch in tqdm(range(num_epochs), total=num_epochs):
        model.train()
        running_train_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs.float())
            loss = criterion(outputs.squeeze(-1), labels)
            loss.backward()
            optimizer.step()
            running_train_loss += loss.item()
        
        train_loss = running_train_loss / len(train_loader)
        train_losses.append(train_loss)

        model.eval()
        running_val_loss = 0.0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs.float())
                loss = criterion(outputs.squeeze(-1), labels)
                running_val_loss += loss.item()

                predicted = outputs.squeeze(-1)
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        val_loss = running_val_loss / len(val_loader)
        val_losses.append(val_loss)

        # precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')
        # precision_scores.append(precision)
        # recall_scores.append(recall)
        # f1_scores.append(f1)

        # accuracy = accuracy_score(all_labels, all_preds)
        # accuracy_scores.append(accuracy)

        scheduler.step()

        # print(f"Epoch [{epoch + 1}/{num_epochs}] "
        #     f"Train Loss: {train_loss:.4f} "
        #     f"Val Loss: {val_loss:.4f} "
        #     f"Precision: {precision:.4f} "
        #     f"Recall: {recall:.4f} "
        #     f"Accuracy: {accuracy:.4f} "
        #     f"F1 Score: {f1:.4f}\n")
        if (epoch + 1) % step_report == 0:
            print(f"Epoch [{epoch + 1}/{num_epochs}] "
                f"Train Loss: {train_loss:.4f} "
                f"Val Loss: {val_loss:.4f} ")
            print(classification_report(all_labels, [1 if x > 0.5 else 0 for x in all_preds]))

    return train_losses, val_losses#, precision_scores, recall_scores, f1_scores

In [10]:
train_csv_path = 'train.csv'
test_csv_path = 'dev.csv'

method = 'diff-sum'

train_dataset = CustomDataset(train_csv_path, method=method)
test_dataset = CustomDataset(test_csv_path, method=method)

train_batch_size = 256
test_batch_size = 256

train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

model = CustomDeepClassifier(input_dim=train_dataset[0][0].shape[1])
tl, vl = train_validate_model(model,
                              train_data_loader,
                              test_data_loader,
                              num_epochs=100,
                              learning_rate=3e-2,
                              stepslr=10,
                              gamma=0.5,
                              sgd=True)

torch.save(model.state_dict(), f'binary_model_{method}.pt')

  0%|          | 0/100 [00:00<?, ?it/s]

  5%|▌         | 5/100 [04:29<1:24:53, 53.61s/it]

Epoch [5/100] Train Loss: 0.2806 Val Loss: 0.3338 
              precision    recall  f1-score   support

         0.0       0.93      0.80      0.86      1801
         1.0       0.82      0.94      0.88      1750

    accuracy                           0.87      3551
   macro avg       0.88      0.87      0.87      3551
weighted avg       0.88      0.87      0.87      3551



 10%|█         | 10/100 [08:47<1:17:35, 51.72s/it]

Epoch [10/100] Train Loss: 0.1884 Val Loss: 0.2660 
              precision    recall  f1-score   support

         0.0       0.92      0.86      0.89      1801
         1.0       0.86      0.92      0.89      1750

    accuracy                           0.89      3551
   macro avg       0.89      0.89      0.89      3551
weighted avg       0.89      0.89      0.89      3551



 15%|█▌        | 15/100 [12:58<1:11:40, 50.59s/it]

Epoch [15/100] Train Loss: 0.2466 Val Loss: 0.3018 
              precision    recall  f1-score   support

         0.0       0.93      0.84      0.89      1801
         1.0       0.85      0.94      0.89      1750

    accuracy                           0.89      3551
   macro avg       0.89      0.89      0.89      3551
weighted avg       0.89      0.89      0.89      3551



 20%|██        | 20/100 [17:06<1:05:39, 49.24s/it]

Epoch [20/100] Train Loss: 0.0836 Val Loss: 0.5621 
              precision    recall  f1-score   support

         0.0       0.81      0.96      0.88      1801
         1.0       0.94      0.77      0.85      1750

    accuracy                           0.87      3551
   macro avg       0.88      0.86      0.86      3551
weighted avg       0.88      0.87      0.86      3551



 25%|██▌       | 25/100 [21:07<1:00:22, 48.30s/it]

Epoch [25/100] Train Loss: 0.0046 Val Loss: 0.6578 
              precision    recall  f1-score   support

         0.0       0.88      0.92      0.90      1801
         1.0       0.91      0.88      0.89      1750

    accuracy                           0.90      3551
   macro avg       0.90      0.90      0.90      3551
weighted avg       0.90      0.90      0.90      3551



 30%|███       | 30/100 [25:07<56:07, 48.10s/it]  

Epoch [30/100] Train Loss: 0.0031 Val Loss: 0.7011 
              precision    recall  f1-score   support

         0.0       0.90      0.89      0.89      1801
         1.0       0.88      0.90      0.89      1750

    accuracy                           0.89      3551
   macro avg       0.89      0.89      0.89      3551
weighted avg       0.89      0.89      0.89      3551



 35%|███▌      | 35/100 [29:07<52:02, 48.04s/it]

Epoch [35/100] Train Loss: 0.0017 Val Loss: 0.7270 
              precision    recall  f1-score   support

         0.0       0.89      0.91      0.90      1801
         1.0       0.90      0.88      0.89      1750

    accuracy                           0.89      3551
   macro avg       0.89      0.89      0.89      3551
weighted avg       0.89      0.89      0.89      3551



 40%|████      | 40/100 [33:08<48:06, 48.10s/it]

Epoch [40/100] Train Loss: 0.0015 Val Loss: 0.7503 
              precision    recall  f1-score   support

         0.0       0.89      0.90      0.90      1801
         1.0       0.90      0.89      0.89      1750

    accuracy                           0.90      3551
   macro avg       0.90      0.90      0.90      3551
weighted avg       0.90      0.90      0.90      3551



 45%|████▌     | 45/100 [37:08<44:05, 48.10s/it]

Epoch [45/100] Train Loss: 0.0013 Val Loss: 0.7440 
              precision    recall  f1-score   support

         0.0       0.88      0.91      0.90      1801
         1.0       0.91      0.87      0.89      1750

    accuracy                           0.89      3551
   macro avg       0.89      0.89      0.89      3551
weighted avg       0.89      0.89      0.89      3551



 50%|█████     | 50/100 [41:09<40:06, 48.12s/it]

Epoch [50/100] Train Loss: 0.0014 Val Loss: 0.7509 
              precision    recall  f1-score   support

         0.0       0.89      0.91      0.90      1801
         1.0       0.90      0.88      0.89      1750

    accuracy                           0.89      3551
   macro avg       0.89      0.89      0.89      3551
weighted avg       0.89      0.89      0.89      3551



 55%|█████▌    | 55/100 [45:10<36:06, 48.14s/it]

Epoch [55/100] Train Loss: 0.0012 Val Loss: 0.7749 
              precision    recall  f1-score   support

         0.0       0.89      0.90      0.90      1801
         1.0       0.90      0.88      0.89      1750

    accuracy                           0.89      3551
   macro avg       0.89      0.89      0.89      3551
weighted avg       0.89      0.89      0.89      3551



 60%|██████    | 60/100 [49:10<32:04, 48.11s/it]

Epoch [60/100] Train Loss: 0.0012 Val Loss: 0.7680 
              precision    recall  f1-score   support

         0.0       0.89      0.91      0.90      1801
         1.0       0.90      0.88      0.89      1750

    accuracy                           0.89      3551
   macro avg       0.89      0.89      0.89      3551
weighted avg       0.89      0.89      0.89      3551



 65%|██████▌   | 65/100 [53:11<28:05, 48.15s/it]

Epoch [65/100] Train Loss: 0.0011 Val Loss: 0.7874 
              precision    recall  f1-score   support

         0.0       0.89      0.91      0.90      1801
         1.0       0.90      0.88      0.89      1750

    accuracy                           0.89      3551
   macro avg       0.90      0.89      0.89      3551
weighted avg       0.90      0.89      0.89      3551



 70%|███████   | 70/100 [57:11<24:02, 48.07s/it]

Epoch [70/100] Train Loss: 0.0011 Val Loss: 0.7857 
              precision    recall  f1-score   support

         0.0       0.89      0.91      0.90      1801
         1.0       0.90      0.88      0.89      1750

    accuracy                           0.89      3551
   macro avg       0.89      0.89      0.89      3551
weighted avg       0.89      0.89      0.89      3551



 75%|███████▌  | 75/100 [1:01:14<20:10, 48.42s/it]

Epoch [75/100] Train Loss: 0.0011 Val Loss: 0.7871 
              precision    recall  f1-score   support

         0.0       0.89      0.91      0.90      1801
         1.0       0.90      0.88      0.89      1750

    accuracy                           0.89      3551
   macro avg       0.89      0.89      0.89      3551
weighted avg       0.89      0.89      0.89      3551



 80%|████████  | 80/100 [1:05:17<16:07, 48.39s/it]

Epoch [80/100] Train Loss: 0.0011 Val Loss: 0.7783 
              precision    recall  f1-score   support

         0.0       0.89      0.91      0.90      1801
         1.0       0.90      0.88      0.89      1750

    accuracy                           0.89      3551
   macro avg       0.89      0.89      0.89      3551
weighted avg       0.89      0.89      0.89      3551



 85%|████████▌ | 85/100 [1:09:21<12:14, 48.98s/it]

Epoch [85/100] Train Loss: 0.0011 Val Loss: 0.7902 
              precision    recall  f1-score   support

         0.0       0.89      0.91      0.90      1801
         1.0       0.90      0.88      0.89      1750

    accuracy                           0.89      3551
   macro avg       0.89      0.89      0.89      3551
weighted avg       0.89      0.89      0.89      3551



 90%|█████████ | 90/100 [1:13:46<08:44, 52.41s/it]

Epoch [90/100] Train Loss: 0.0011 Val Loss: 0.7767 
              precision    recall  f1-score   support

         0.0       0.89      0.91      0.90      1801
         1.0       0.90      0.88      0.89      1750

    accuracy                           0.89      3551
   macro avg       0.89      0.89      0.89      3551
weighted avg       0.89      0.89      0.89      3551



 95%|█████████▌| 95/100 [1:17:57<04:12, 50.43s/it]

Epoch [95/100] Train Loss: 0.0011 Val Loss: 0.7833 
              precision    recall  f1-score   support

         0.0       0.89      0.91      0.90      1801
         1.0       0.90      0.88      0.89      1750

    accuracy                           0.89      3551
   macro avg       0.89      0.89      0.89      3551
weighted avg       0.89      0.89      0.89      3551



100%|██████████| 100/100 [1:22:09<00:00, 49.29s/it]

Epoch [100/100] Train Loss: 0.0010 Val Loss: 0.7929 
              precision    recall  f1-score   support

         0.0       0.89      0.91      0.90      1801
         1.0       0.90      0.88      0.89      1750

    accuracy                           0.89      3551
   macro avg       0.89      0.89      0.89      3551
weighted avg       0.89      0.89      0.89      3551






In [12]:
train_csv_path = 'train.csv'
test_csv_path = 'dev.csv'

method = 'diff-sum'

train_dataset = CustomDataset(train_csv_path, method=method)
test_dataset = CustomDataset(test_csv_path, method=method)

train_batch_size = 256
test_batch_size = 256

train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

model = CustomDeepClassifier(input_dim=train_dataset[0][0].shape[1])
tl, vl = train_validate_model(model,
                              train_data_loader,
                              test_data_loader,
                              num_epochs=30,
                              learning_rate=3e-4,
                              stepslr=1,
                              gamma=0.8,
                              sgd=True,
                              step_report=1)

torch.save(model.state_dict(), f'binary_model_{method}_1.pt')

  0%|          | 0/30 [00:00<?, ?it/s]

  3%|▎         | 1/30 [00:54<26:17, 54.40s/it]

Epoch [1/30] Train Loss: 0.5824 Val Loss: 0.4848 
              precision    recall  f1-score   support

         0.0       0.77      0.84      0.80      1801
         1.0       0.82      0.74      0.77      1750

    accuracy                           0.79      3551
   macro avg       0.79      0.79      0.79      3551
weighted avg       0.79      0.79      0.79      3551



  7%|▋         | 2/30 [01:48<25:24, 54.44s/it]

Epoch [2/30] Train Loss: 0.4410 Val Loss: 0.4113 
              precision    recall  f1-score   support

         0.0       0.83      0.82      0.82      1801
         1.0       0.82      0.83      0.82      1750

    accuracy                           0.82      3551
   macro avg       0.82      0.82      0.82      3551
weighted avg       0.82      0.82      0.82      3551



 10%|█         | 3/30 [02:43<24:32, 54.54s/it]

Epoch [3/30] Train Loss: 0.3953 Val Loss: 0.3845 
              precision    recall  f1-score   support

         0.0       0.86      0.81      0.84      1801
         1.0       0.82      0.87      0.84      1750

    accuracy                           0.84      3551
   macro avg       0.84      0.84      0.84      3551
weighted avg       0.84      0.84      0.84      3551



 13%|█▎        | 4/30 [03:36<23:22, 53.92s/it]

Epoch [4/30] Train Loss: 0.3629 Val Loss: 0.3681 
              precision    recall  f1-score   support

         0.0       0.86      0.83      0.85      1801
         1.0       0.83      0.86      0.85      1750

    accuracy                           0.85      3551
   macro avg       0.85      0.85      0.85      3551
weighted avg       0.85      0.85      0.85      3551



 17%|█▋        | 5/30 [04:27<21:59, 52.77s/it]

Epoch [5/30] Train Loss: 0.3462 Val Loss: 0.3593 
              precision    recall  f1-score   support

         0.0       0.84      0.86      0.85      1801
         1.0       0.85      0.83      0.84      1750

    accuracy                           0.85      3551
   macro avg       0.85      0.84      0.85      3551
weighted avg       0.85      0.85      0.85      3551



 20%|██        | 6/30 [05:17<20:49, 52.08s/it]

Epoch [6/30] Train Loss: 0.3304 Val Loss: 0.3511 
              precision    recall  f1-score   support

         0.0       0.85      0.86      0.85      1801
         1.0       0.85      0.84      0.85      1750

    accuracy                           0.85      3551
   macro avg       0.85      0.85      0.85      3551
weighted avg       0.85      0.85      0.85      3551



 23%|██▎       | 7/30 [06:09<19:50, 51.78s/it]

Epoch [7/30] Train Loss: 0.3195 Val Loss: 0.3452 
              precision    recall  f1-score   support

         0.0       0.85      0.86      0.86      1801
         1.0       0.85      0.85      0.85      1750

    accuracy                           0.85      3551
   macro avg       0.85      0.85      0.85      3551
weighted avg       0.85      0.85      0.85      3551



 27%|██▋       | 8/30 [06:58<18:45, 51.14s/it]

Epoch [8/30] Train Loss: 0.3120 Val Loss: 0.3402 
              precision    recall  f1-score   support

         0.0       0.87      0.84      0.85      1801
         1.0       0.84      0.87      0.86      1750

    accuracy                           0.86      3551
   macro avg       0.86      0.86      0.86      3551
weighted avg       0.86      0.86      0.86      3551



 30%|███       | 9/30 [07:48<17:44, 50.70s/it]

Epoch [9/30] Train Loss: 0.3064 Val Loss: 0.3389 
              precision    recall  f1-score   support

         0.0       0.85      0.87      0.86      1801
         1.0       0.86      0.84      0.85      1750

    accuracy                           0.86      3551
   macro avg       0.86      0.86      0.86      3551
weighted avg       0.86      0.86      0.86      3551



 33%|███▎      | 10/30 [08:38<16:48, 50.41s/it]

Epoch [10/30] Train Loss: 0.3022 Val Loss: 0.3363 
              precision    recall  f1-score   support

         0.0       0.85      0.87      0.86      1801
         1.0       0.87      0.84      0.85      1750

    accuracy                           0.86      3551
   macro avg       0.86      0.86      0.86      3551
weighted avg       0.86      0.86      0.86      3551



 37%|███▋      | 11/30 [09:28<15:54, 50.23s/it]

Epoch [11/30] Train Loss: 0.2981 Val Loss: 0.3363 
              precision    recall  f1-score   support

         0.0       0.84      0.88      0.86      1801
         1.0       0.87      0.83      0.85      1750

    accuracy                           0.85      3551
   macro avg       0.85      0.85      0.85      3551
weighted avg       0.85      0.85      0.85      3551



 40%|████      | 12/30 [10:17<15:01, 50.09s/it]

Epoch [12/30] Train Loss: 0.2963 Val Loss: 0.3319 
              precision    recall  f1-score   support

         0.0       0.86      0.87      0.86      1801
         1.0       0.86      0.85      0.86      1750

    accuracy                           0.86      3551
   macro avg       0.86      0.86      0.86      3551
weighted avg       0.86      0.86      0.86      3551



 43%|████▎     | 13/30 [11:07<14:10, 50.01s/it]

Epoch [13/30] Train Loss: 0.2929 Val Loss: 0.3310 
              precision    recall  f1-score   support

         0.0       0.86      0.87      0.86      1801
         1.0       0.86      0.85      0.86      1750

    accuracy                           0.86      3551
   macro avg       0.86      0.86      0.86      3551
weighted avg       0.86      0.86      0.86      3551



 47%|████▋     | 14/30 [11:57<13:18, 49.94s/it]

Epoch [14/30] Train Loss: 0.2920 Val Loss: 0.3301 
              precision    recall  f1-score   support

         0.0       0.86      0.87      0.86      1801
         1.0       0.87      0.85      0.86      1750

    accuracy                           0.86      3551
   macro avg       0.86      0.86      0.86      3551
weighted avg       0.86      0.86      0.86      3551



 50%|█████     | 15/30 [12:47<12:28, 49.89s/it]

Epoch [15/30] Train Loss: 0.2891 Val Loss: 0.3293 
              precision    recall  f1-score   support

         0.0       0.86      0.87      0.87      1801
         1.0       0.87      0.85      0.86      1750

    accuracy                           0.86      3551
   macro avg       0.86      0.86      0.86      3551
weighted avg       0.86      0.86      0.86      3551



 53%|█████▎    | 16/30 [13:37<11:37, 49.85s/it]

Epoch [16/30] Train Loss: 0.2880 Val Loss: 0.3293 
              precision    recall  f1-score   support

         0.0       0.85      0.88      0.86      1801
         1.0       0.87      0.85      0.86      1750

    accuracy                           0.86      3551
   macro avg       0.86      0.86      0.86      3551
weighted avg       0.86      0.86      0.86      3551



 57%|█████▋    | 17/30 [14:27<10:48, 49.88s/it]

Epoch [17/30] Train Loss: 0.2873 Val Loss: 0.3284 
              precision    recall  f1-score   support

         0.0       0.86      0.87      0.87      1801
         1.0       0.87      0.85      0.86      1750

    accuracy                           0.86      3551
   macro avg       0.86      0.86      0.86      3551
weighted avg       0.86      0.86      0.86      3551



 60%|██████    | 18/30 [15:18<10:02, 50.21s/it]

Epoch [18/30] Train Loss: 0.2871 Val Loss: 0.3280 
              precision    recall  f1-score   support

         0.0       0.86      0.87      0.87      1801
         1.0       0.87      0.85      0.86      1750

    accuracy                           0.86      3551
   macro avg       0.86      0.86      0.86      3551
weighted avg       0.86      0.86      0.86      3551



 63%|██████▎   | 19/30 [16:07<09:11, 50.10s/it]

Epoch [19/30] Train Loss: 0.2869 Val Loss: 0.3280 
              precision    recall  f1-score   support

         0.0       0.86      0.87      0.86      1801
         1.0       0.87      0.85      0.86      1750

    accuracy                           0.86      3551
   macro avg       0.86      0.86      0.86      3551
weighted avg       0.86      0.86      0.86      3551



 67%|██████▋   | 20/30 [16:57<08:19, 50.00s/it]

Epoch [20/30] Train Loss: 0.2860 Val Loss: 0.3280 
              precision    recall  f1-score   support

         0.0       0.86      0.88      0.87      1801
         1.0       0.87      0.85      0.86      1750

    accuracy                           0.86      3551
   macro avg       0.86      0.86      0.86      3551
weighted avg       0.86      0.86      0.86      3551



 70%|███████   | 21/30 [17:47<07:29, 49.93s/it]

Epoch [21/30] Train Loss: 0.2871 Val Loss: 0.3277 
              precision    recall  f1-score   support

         0.0       0.86      0.88      0.87      1801
         1.0       0.87      0.85      0.86      1750

    accuracy                           0.86      3551
   macro avg       0.86      0.86      0.86      3551
weighted avg       0.86      0.86      0.86      3551



 73%|███████▎  | 22/30 [18:37<06:39, 49.88s/it]

Epoch [22/30] Train Loss: 0.2855 Val Loss: 0.3273 
              precision    recall  f1-score   support

         0.0       0.86      0.87      0.87      1801
         1.0       0.87      0.85      0.86      1750

    accuracy                           0.86      3551
   macro avg       0.86      0.86      0.86      3551
weighted avg       0.86      0.86      0.86      3551



 77%|███████▋  | 23/30 [19:26<05:48, 49.83s/it]

Epoch [23/30] Train Loss: 0.2857 Val Loss: 0.3274 
              precision    recall  f1-score   support

         0.0       0.86      0.87      0.87      1801
         1.0       0.87      0.85      0.86      1750

    accuracy                           0.86      3551
   macro avg       0.86      0.86      0.86      3551
weighted avg       0.86      0.86      0.86      3551



 80%|████████  | 24/30 [20:16<04:59, 49.84s/it]

Epoch [24/30] Train Loss: 0.2850 Val Loss: 0.3273 
              precision    recall  f1-score   support

         0.0       0.86      0.87      0.87      1801
         1.0       0.87      0.85      0.86      1750

    accuracy                           0.86      3551
   macro avg       0.86      0.86      0.86      3551
weighted avg       0.86      0.86      0.86      3551



 83%|████████▎ | 25/30 [21:06<04:09, 49.84s/it]

Epoch [25/30] Train Loss: 0.2852 Val Loss: 0.3271 
              precision    recall  f1-score   support

         0.0       0.86      0.87      0.87      1801
         1.0       0.87      0.85      0.86      1750

    accuracy                           0.86      3551
   macro avg       0.86      0.86      0.86      3551
weighted avg       0.86      0.86      0.86      3551



 87%|████████▋ | 26/30 [21:56<03:19, 49.89s/it]

Epoch [26/30] Train Loss: 0.2848 Val Loss: 0.3270 
              precision    recall  f1-score   support

         0.0       0.86      0.87      0.87      1801
         1.0       0.87      0.85      0.86      1750

    accuracy                           0.86      3551
   macro avg       0.86      0.86      0.86      3551
weighted avg       0.86      0.86      0.86      3551



 90%|█████████ | 27/30 [22:46<02:29, 49.89s/it]

Epoch [27/30] Train Loss: 0.2842 Val Loss: 0.3269 
              precision    recall  f1-score   support

         0.0       0.86      0.87      0.87      1801
         1.0       0.87      0.85      0.86      1750

    accuracy                           0.86      3551
   macro avg       0.86      0.86      0.86      3551
weighted avg       0.86      0.86      0.86      3551



 93%|█████████▎| 28/30 [23:36<01:39, 49.88s/it]

Epoch [28/30] Train Loss: 0.2847 Val Loss: 0.3270 
              precision    recall  f1-score   support

         0.0       0.86      0.87      0.87      1801
         1.0       0.87      0.85      0.86      1750

    accuracy                           0.86      3551
   macro avg       0.86      0.86      0.86      3551
weighted avg       0.86      0.86      0.86      3551



 97%|█████████▋| 29/30 [24:26<00:49, 49.87s/it]

Epoch [29/30] Train Loss: 0.2855 Val Loss: 0.3270 
              precision    recall  f1-score   support

         0.0       0.86      0.87      0.87      1801
         1.0       0.87      0.85      0.86      1750

    accuracy                           0.86      3551
   macro avg       0.86      0.86      0.86      3551
weighted avg       0.86      0.86      0.86      3551



100%|██████████| 30/30 [25:16<00:00, 50.53s/it]

Epoch [30/30] Train Loss: 0.2862 Val Loss: 0.3270 
              precision    recall  f1-score   support

         0.0       0.86      0.87      0.87      1801
         1.0       0.87      0.85      0.86      1750

    accuracy                           0.86      3551
   macro avg       0.86      0.86      0.86      3551
weighted avg       0.86      0.86      0.86      3551






In [17]:
train_csv_path = 'train.csv'
test_csv_path = 'dev.csv'

method = 'endpoint'

train_dataset = CustomDataset(train_csv_path, method=method)
test_dataset = CustomDataset(test_csv_path, method=method)

train_batch_size = 256
test_batch_size = 256

train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

model = CustomDeepClassifier(input_dim=train_dataset[0][0].shape[1])
tl, vl = train_validate_model(model,
                              train_data_loader,
                              test_data_loader,
                              num_epochs=300,
                              learning_rate=3e-4,
                              stepslr=30,
                              gamma=0.75)

torch.save(model.state_dict(), f'binary_model_{method}.pt')

 10%|█         | 30/300 [19:26<2:57:19, 39.40s/it]

[array([0.9992539], dtype=float32), array([0.9992539], dtype=float32), array([0.99271566], dtype=float32), array([0.999933], dtype=float32), array([0.9996593], dtype=float32), array([0.99999094], dtype=float32), array([0.9979373], dtype=float32), array([0.90290123], dtype=float32), array([1.], dtype=float32), array([0.9867425], dtype=float32), array([0.9999107], dtype=float32), array([0.99969494], dtype=float32), array([0.9986078], dtype=float32), array([0.9992507], dtype=float32), array([0.9998908], dtype=float32), array([0.999966], dtype=float32), array([0.9977102], dtype=float32), array([0.9999993], dtype=float32), array([0.88852036], dtype=float32), array([0.9999989], dtype=float32), array([0.9997731], dtype=float32), array([0.9984864], dtype=float32), array([0.1293112], dtype=float32), array([0.3460961], dtype=float32), array([0.9999759], dtype=float32), array([0.9999982], dtype=float32), array([0.81751966], dtype=float32), array([0.99999976], dtype=float32), array([0.9998522], dt

 20%|██        | 60/300 [38:28<2:41:49, 40.46s/it]

[array([0.99991345], dtype=float32), array([0.99991345], dtype=float32), array([0.99555284], dtype=float32), array([0.99999046], dtype=float32), array([0.9999566], dtype=float32), array([0.9999982], dtype=float32), array([0.99947923], dtype=float32), array([0.88321626], dtype=float32), array([1.], dtype=float32), array([0.9951108], dtype=float32), array([0.99999356], dtype=float32), array([0.9998776], dtype=float32), array([0.99958235], dtype=float32), array([0.9997868], dtype=float32), array([0.99997973], dtype=float32), array([0.99999785], dtype=float32), array([0.99977857], dtype=float32), array([1.], dtype=float32), array([0.9613548], dtype=float32), array([0.9999999], dtype=float32), array([0.9999833], dtype=float32), array([0.9997476], dtype=float32), array([0.16654086], dtype=float32), array([0.69695824], dtype=float32), array([0.9999887], dtype=float32), array([0.9999995], dtype=float32), array([0.90840244], dtype=float32), array([1.], dtype=float32), array([0.9996253], dtype=f

 30%|███       | 90/300 [57:26<2:17:34, 39.31s/it]

[array([0.9999838], dtype=float32), array([0.9999838], dtype=float32), array([0.9981437], dtype=float32), array([0.999985], dtype=float32), array([0.99988043], dtype=float32), array([0.99999535], dtype=float32), array([0.9997733], dtype=float32), array([0.6402869], dtype=float32), array([0.9999995], dtype=float32), array([0.8535758], dtype=float32), array([0.99980766], dtype=float32), array([0.9992637], dtype=float32), array([0.9990113], dtype=float32), array([0.9999068], dtype=float32), array([0.9997992], dtype=float32), array([0.99999416], dtype=float32), array([0.99913174], dtype=float32), array([0.99999976], dtype=float32), array([0.32265997], dtype=float32), array([0.999995], dtype=float32), array([0.99983466], dtype=float32), array([0.9994234], dtype=float32), array([0.5022407], dtype=float32), array([0.08420923], dtype=float32), array([0.99514765], dtype=float32), array([0.9999188], dtype=float32), array([0.7046157], dtype=float32), array([0.9999999], dtype=float32), array([0.99

 40%|████      | 120/300 [1:16:34<1:58:42, 39.57s/it]

[array([0.9999968], dtype=float32), array([0.9999968], dtype=float32), array([0.9986332], dtype=float32), array([0.99999726], dtype=float32), array([0.9999691], dtype=float32), array([0.9999989], dtype=float32), array([0.9999132], dtype=float32), array([0.68161076], dtype=float32), array([0.9999999], dtype=float32), array([0.8270377], dtype=float32), array([0.9999467], dtype=float32), array([0.99961793], dtype=float32), array([0.9995146], dtype=float32), array([0.9999566], dtype=float32), array([0.9999256], dtype=float32), array([0.99999917], dtype=float32), array([0.999699], dtype=float32), array([1.], dtype=float32), array([0.30312285], dtype=float32), array([0.99999917], dtype=float32), array([0.99996185], dtype=float32), array([0.9998306], dtype=float32), array([0.4820967], dtype=float32), array([0.05587598], dtype=float32), array([0.9992518], dtype=float32), array([0.99997735], dtype=float32), array([0.7671334], dtype=float32), array([1.], dtype=float32), array([0.9992236], dtype=

 50%|█████     | 150/300 [1:35:57<1:41:53, 40.75s/it]

[array([0.99999535], dtype=float32), array([0.99999535], dtype=float32), array([0.2559833], dtype=float32), array([0.99995136], dtype=float32), array([0.9998661], dtype=float32), array([1.], dtype=float32), array([0.9967573], dtype=float32), array([0.06712987], dtype=float32), array([0.9999964], dtype=float32), array([0.47117904], dtype=float32), array([0.9995876], dtype=float32), array([0.78328484], dtype=float32), array([0.9830289], dtype=float32), array([0.99479574], dtype=float32), array([0.99990225], dtype=float32), array([1.], dtype=float32), array([0.9996209], dtype=float32), array([1.], dtype=float32), array([0.989248], dtype=float32), array([0.9999819], dtype=float32), array([0.99991107], dtype=float32), array([0.99966383], dtype=float32), array([0.36820266], dtype=float32), array([0.81778526], dtype=float32), array([0.99999094], dtype=float32), array([0.9999087], dtype=float32), array([0.97605], dtype=float32), array([0.99999964], dtype=float32), array([0.99990165], dtype=flo

 60%|██████    | 180/300 [1:55:14<1:18:13, 39.11s/it]

[array([0.99999964], dtype=float32), array([0.99999964], dtype=float32), array([0.395095], dtype=float32), array([0.99999785], dtype=float32), array([0.9987488], dtype=float32), array([1.], dtype=float32), array([0.99986327], dtype=float32), array([0.9269275], dtype=float32), array([0.9999995], dtype=float32), array([0.12073614], dtype=float32), array([0.9998733], dtype=float32), array([0.8747747], dtype=float32), array([0.98402804], dtype=float32), array([0.99992776], dtype=float32), array([0.99998796], dtype=float32), array([1.], dtype=float32), array([0.99968946], dtype=float32), array([1.], dtype=float32), array([0.39298534], dtype=float32), array([0.9999995], dtype=float32), array([0.99998534], dtype=float32), array([0.9997433], dtype=float32), array([0.01303871], dtype=float32), array([0.02146838], dtype=float32), array([0.9999999], dtype=float32), array([0.999995], dtype=float32), array([0.9891004], dtype=float32), array([1.], dtype=float32), array([0.99999905], dtype=float32), 

 70%|███████   | 210/300 [2:14:11<59:14, 39.50s/it]  

[array([0.9999999], dtype=float32), array([0.9999999], dtype=float32), array([0.5844889], dtype=float32), array([0.9999994], dtype=float32), array([0.99945897], dtype=float32), array([1.], dtype=float32), array([0.9999509], dtype=float32), array([0.89551246], dtype=float32), array([0.9999999], dtype=float32), array([0.08243697], dtype=float32), array([0.9999156], dtype=float32), array([0.94901204], dtype=float32), array([0.9945385], dtype=float32), array([0.99995935], dtype=float32), array([0.9999956], dtype=float32), array([1.], dtype=float32), array([0.9998659], dtype=float32), array([1.], dtype=float32), array([0.45775673], dtype=float32), array([0.9999999], dtype=float32), array([0.99999595], dtype=float32), array([0.99985874], dtype=float32), array([0.02096878], dtype=float32), array([0.03534953], dtype=float32), array([0.9999999], dtype=float32), array([0.9999981], dtype=float32), array([0.9892113], dtype=float32), array([1.], dtype=float32), array([0.99999964], dtype=float32), a

 80%|████████  | 240/300 [2:33:13<39:50, 39.85s/it]

[array([1.], dtype=float32), array([1.], dtype=float32), array([0.7840277], dtype=float32), array([0.99999976], dtype=float32), array([0.99977], dtype=float32), array([1.], dtype=float32), array([0.99997795], dtype=float32), array([0.9035791], dtype=float32), array([0.9999999], dtype=float32), array([0.07053714], dtype=float32), array([0.99995136], dtype=float32), array([0.95482], dtype=float32), array([0.9965995], dtype=float32), array([0.9999716], dtype=float32), array([0.9999976], dtype=float32), array([1.], dtype=float32), array([0.9999212], dtype=float32), array([1.], dtype=float32), array([0.54129416], dtype=float32), array([0.9999999], dtype=float32), array([0.99999845], dtype=float32), array([0.9999237], dtype=float32), array([0.03742378], dtype=float32), array([0.04207102], dtype=float32), array([1.], dtype=float32), array([0.99999905], dtype=float32), array([0.99196106], dtype=float32), array([1.], dtype=float32), array([0.99999976], dtype=float32), array([0.99999917], dtype=

 90%|█████████ | 270/300 [2:52:19<19:47, 39.57s/it]

[array([0.9999999], dtype=float32), array([0.9999999], dtype=float32), array([0.9432296], dtype=float32), array([0.9999995], dtype=float32), array([0.9991844], dtype=float32), array([1.], dtype=float32), array([0.99954057], dtype=float32), array([0.991433], dtype=float32), array([1.], dtype=float32), array([0.90663564], dtype=float32), array([0.9999974], dtype=float32), array([0.9996773], dtype=float32), array([0.99999464], dtype=float32), array([0.9998135], dtype=float32), array([0.99986744], dtype=float32), array([1.], dtype=float32), array([0.99947315], dtype=float32), array([1.], dtype=float32), array([0.14409569], dtype=float32), array([1.], dtype=float32), array([0.999995], dtype=float32), array([0.9999925], dtype=float32), array([0.41795278], dtype=float32), array([0.67281675], dtype=float32), array([0.9999999], dtype=float32), array([0.9999999], dtype=float32), array([0.7721881], dtype=float32), array([1.], dtype=float32), array([0.9999527], dtype=float32), array([0.99988925], 

100%|██████████| 300/300 [3:11:21<00:00, 38.27s/it]

[array([1.], dtype=float32), array([1.], dtype=float32), array([0.94016486], dtype=float32), array([1.], dtype=float32), array([0.99973327], dtype=float32), array([1.], dtype=float32), array([0.99987185], dtype=float32), array([0.99076754], dtype=float32), array([1.], dtype=float32), array([0.9075767], dtype=float32), array([0.9999994], dtype=float32), array([0.9998242], dtype=float32), array([0.99999833], dtype=float32), array([0.99994576], dtype=float32), array([0.99994683], dtype=float32), array([1.], dtype=float32), array([0.9999318], dtype=float32), array([1.], dtype=float32), array([0.14032556], dtype=float32), array([1.], dtype=float32), array([0.9999993], dtype=float32), array([0.99999857], dtype=float32), array([0.41030252], dtype=float32), array([0.7116566], dtype=float32), array([1.], dtype=float32), array([1.], dtype=float32), array([0.8812498], dtype=float32), array([1.], dtype=float32), array([0.9999975], dtype=float32), array([0.999974], dtype=float32), array([1.], dtype




In [22]:
train_csv_path = 'train.csv'
test_csv_path = 'dev.csv'

method = 'endpoint'

train_dataset = CustomDataset(train_csv_path, method=method)
test_dataset = CustomDataset(test_csv_path, method=method)

train_batch_size = 256
test_batch_size = 256

train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

model = CustomDeepClassifier(input_dim=train_dataset[0][0].shape[1])
tl, vl = train_validate_model(model,
                              train_data_loader,
                              test_data_loader,
                              num_epochs=20,
                              learning_rate=3e-4,
                              stepslr=1,
                              gamma=0.95)

torch.save(model.state_dict(), f'binary_model_{method}_1.pt')

  5%|▌         | 1/20 [00:46<14:52, 46.97s/it]

Epoch [1/20] Train Loss: 0.4504 Val Loss: 0.3041 
              precision    recall  f1-score   support

         0.0       0.85      0.91      0.88      1801
         1.0       0.90      0.83      0.87      1750

    accuracy                           0.87      3551
   macro avg       0.88      0.87      0.87      3551
weighted avg       0.88      0.87      0.87      3551



 10%|█         | 2/20 [01:33<13:56, 46.47s/it]

Epoch [2/20] Train Loss: 0.1971 Val Loss: 0.2698 
              precision    recall  f1-score   support

         0.0       0.92      0.86      0.89      1801
         1.0       0.86      0.93      0.89      1750

    accuracy                           0.89      3551
   macro avg       0.89      0.89      0.89      3551
weighted avg       0.89      0.89      0.89      3551



 15%|█▌        | 3/20 [02:19<13:10, 46.51s/it]

Epoch [3/20] Train Loss: 0.0912 Val Loss: 0.3030 
              precision    recall  f1-score   support

         0.0       0.87      0.92      0.89      1801
         1.0       0.91      0.86      0.89      1750

    accuracy                           0.89      3551
   macro avg       0.89      0.89      0.89      3551
weighted avg       0.89      0.89      0.89      3551



 20%|██        | 4/20 [03:05<12:20, 46.31s/it]

Epoch [4/20] Train Loss: 0.0292 Val Loss: 0.3365 
              precision    recall  f1-score   support

         0.0       0.92      0.88      0.90      1801
         1.0       0.88      0.92      0.90      1750

    accuracy                           0.90      3551
   macro avg       0.90      0.90      0.90      3551
weighted avg       0.90      0.90      0.90      3551



 25%|██▌       | 5/20 [03:51<11:32, 46.14s/it]

Epoch [5/20] Train Loss: 0.0174 Val Loss: 0.3416 
              precision    recall  f1-score   support

         0.0       0.88      0.92      0.90      1801
         1.0       0.91      0.87      0.89      1750

    accuracy                           0.90      3551
   macro avg       0.90      0.90      0.90      3551
weighted avg       0.90      0.90      0.90      3551



 30%|███       | 6/20 [04:38<10:48, 46.30s/it]

Epoch [6/20] Train Loss: 0.0102 Val Loss: 0.3583 
              precision    recall  f1-score   support

         0.0       0.91      0.88      0.90      1801
         1.0       0.88      0.91      0.90      1750

    accuracy                           0.90      3551
   macro avg       0.90      0.90      0.90      3551
weighted avg       0.90      0.90      0.90      3551



 35%|███▌      | 7/20 [05:24<10:02, 46.35s/it]

Epoch [7/20] Train Loss: 0.0061 Val Loss: 0.3616 
              precision    recall  f1-score   support

         0.0       0.90      0.89      0.90      1801
         1.0       0.89      0.90      0.89      1750

    accuracy                           0.89      3551
   macro avg       0.89      0.89      0.89      3551
weighted avg       0.90      0.89      0.89      3551



 40%|████      | 8/20 [06:10<09:14, 46.21s/it]

Epoch [8/20] Train Loss: 0.0046 Val Loss: 0.3636 
              precision    recall  f1-score   support

         0.0       0.89      0.90      0.90      1801
         1.0       0.90      0.89      0.89      1750

    accuracy                           0.90      3551
   macro avg       0.90      0.90      0.90      3551
weighted avg       0.90      0.90      0.90      3551



 45%|████▌     | 9/20 [06:56<08:27, 46.15s/it]

Epoch [9/20] Train Loss: 0.0039 Val Loss: 0.3674 
              precision    recall  f1-score   support

         0.0       0.90      0.90      0.90      1801
         1.0       0.90      0.90      0.90      1750

    accuracy                           0.90      3551
   macro avg       0.90      0.90      0.90      3551
weighted avg       0.90      0.90      0.90      3551



 50%|█████     | 10/20 [07:43<07:42, 46.28s/it]

Epoch [10/20] Train Loss: 0.0027 Val Loss: 0.3802 
              precision    recall  f1-score   support

         0.0       0.89      0.91      0.90      1801
         1.0       0.90      0.89      0.89      1750

    accuracy                           0.90      3551
   macro avg       0.90      0.90      0.90      3551
weighted avg       0.90      0.90      0.90      3551



 55%|█████▌    | 11/20 [08:29<06:55, 46.18s/it]

Epoch [11/20] Train Loss: 0.0026 Val Loss: 0.3829 
              precision    recall  f1-score   support

         0.0       0.89      0.90      0.90      1801
         1.0       0.90      0.89      0.89      1750

    accuracy                           0.90      3551
   macro avg       0.90      0.90      0.90      3551
weighted avg       0.90      0.90      0.90      3551



 60%|██████    | 12/20 [09:16<06:12, 46.58s/it]

Epoch [12/20] Train Loss: 0.0031 Val Loss: 0.3980 
              precision    recall  f1-score   support

         0.0       0.88      0.91      0.90      1801
         1.0       0.91      0.87      0.89      1750

    accuracy                           0.89      3551
   macro avg       0.90      0.89      0.89      3551
weighted avg       0.89      0.89      0.89      3551



 65%|██████▌   | 13/20 [10:03<05:27, 46.80s/it]

Epoch [13/20] Train Loss: 0.0029 Val Loss: 0.3884 
              precision    recall  f1-score   support

         0.0       0.90      0.90      0.90      1801
         1.0       0.90      0.89      0.90      1750

    accuracy                           0.90      3551
   macro avg       0.90      0.90      0.90      3551
weighted avg       0.90      0.90      0.90      3551



 70%|███████   | 14/20 [10:50<04:41, 46.90s/it]

Epoch [14/20] Train Loss: 0.0022 Val Loss: 0.3953 
              precision    recall  f1-score   support

         0.0       0.90      0.90      0.90      1801
         1.0       0.90      0.90      0.90      1750

    accuracy                           0.90      3551
   macro avg       0.90      0.90      0.90      3551
weighted avg       0.90      0.90      0.90      3551



 75%|███████▌  | 15/20 [11:37<03:54, 46.85s/it]

Epoch [15/20] Train Loss: 0.0019 Val Loss: 0.3965 
              precision    recall  f1-score   support

         0.0       0.90      0.90      0.90      1801
         1.0       0.89      0.90      0.90      1750

    accuracy                           0.90      3551
   macro avg       0.90      0.90      0.90      3551
weighted avg       0.90      0.90      0.90      3551



 80%|████████  | 16/20 [12:25<03:08, 47.16s/it]

Epoch [16/20] Train Loss: 0.0020 Val Loss: 0.3992 
              precision    recall  f1-score   support

         0.0       0.90      0.90      0.90      1801
         1.0       0.89      0.90      0.90      1750

    accuracy                           0.90      3551
   macro avg       0.90      0.90      0.90      3551
weighted avg       0.90      0.90      0.90      3551



 85%|████████▌ | 17/20 [13:11<02:20, 46.67s/it]

Epoch [17/20] Train Loss: 0.0019 Val Loss: 0.4053 
              precision    recall  f1-score   support

         0.0       0.89      0.91      0.90      1801
         1.0       0.90      0.89      0.89      1750

    accuracy                           0.90      3551
   macro avg       0.90      0.90      0.90      3551
weighted avg       0.90      0.90      0.90      3551



 90%|█████████ | 18/20 [13:58<01:33, 46.88s/it]

Epoch [18/20] Train Loss: 0.0017 Val Loss: 0.4117 
              precision    recall  f1-score   support

         0.0       0.89      0.91      0.90      1801
         1.0       0.90      0.88      0.89      1750

    accuracy                           0.90      3551
   macro avg       0.90      0.90      0.90      3551
weighted avg       0.90      0.90      0.90      3551



 95%|█████████▌| 19/20 [14:47<00:47, 47.40s/it]

Epoch [19/20] Train Loss: 0.0018 Val Loss: 0.4328 
              precision    recall  f1-score   support

         0.0       0.90      0.90      0.90      1801
         1.0       0.89      0.90      0.90      1750

    accuracy                           0.90      3551
   macro avg       0.90      0.90      0.90      3551
weighted avg       0.90      0.90      0.90      3551



100%|██████████| 20/20 [15:34<00:00, 46.73s/it]

Epoch [20/20] Train Loss: 0.0018 Val Loss: 0.4410 
              precision    recall  f1-score   support

         0.0       0.89      0.91      0.90      1801
         1.0       0.90      0.89      0.89      1750

    accuracy                           0.90      3551
   macro avg       0.90      0.90      0.90      3551
weighted avg       0.90      0.90      0.90      3551






In [25]:
train_csv_path = 'train.csv'
test_csv_path = 'dev.csv'

method = 'diff-sum'

train_dataset = CustomDataset(train_csv_path, method=method)
test_dataset = CustomDataset(test_csv_path, method=method)

train_batch_size = 256
test_batch_size = 256

train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

model = CustomDeepClassifier(input_dim=train_dataset[0][0].shape[1])
tl, vl = train_validate_model(model,
                              train_data_loader,
                              test_data_loader,
                              num_epochs=100,
                              learning_rate=3e-4,
                              stepslr=10,
                              gamma=0.5)

torch.save(model.state_dict(), f'binary_model_{method}.pt')

 10%|█         | 10/100 [06:48<1:02:56, 41.97s/it]

Epoch [10/100] Train Loss: 0.0021 Val Loss: 0.3933 
              precision    recall  f1-score   support

         0.0       0.91      0.91      0.91      1801
         1.0       0.91      0.90      0.90      1750

    accuracy                           0.91      3551
   macro avg       0.91      0.91      0.91      3551
weighted avg       0.91      0.91      0.91      3551



 20%|██        | 20/100 [13:35<55:43, 41.80s/it]  

Epoch [20/100] Train Loss: 0.0013 Val Loss: 0.4417 
              precision    recall  f1-score   support

         0.0       0.91      0.91      0.91      1801
         1.0       0.90      0.91      0.91      1750

    accuracy                           0.91      3551
   macro avg       0.91      0.91      0.91      3551
weighted avg       0.91      0.91      0.91      3551



 30%|███       | 30/100 [20:11<47:09, 40.43s/it]

Epoch [30/100] Train Loss: 0.0011 Val Loss: 0.4739 
              precision    recall  f1-score   support

         0.0       0.91      0.91      0.91      1801
         1.0       0.91      0.91      0.91      1750

    accuracy                           0.91      3551
   macro avg       0.91      0.91      0.91      3551
weighted avg       0.91      0.91      0.91      3551



 40%|████      | 40/100 [26:43<40:14, 40.24s/it]

Epoch [40/100] Train Loss: 0.0011 Val Loss: 0.4950 
              precision    recall  f1-score   support

         0.0       0.91      0.91      0.91      1801
         1.0       0.91      0.90      0.91      1750

    accuracy                           0.91      3551
   macro avg       0.91      0.91      0.91      3551
weighted avg       0.91      0.91      0.91      3551



 50%|█████     | 50/100 [33:15<33:36, 40.32s/it]

Epoch [50/100] Train Loss: 0.0010 Val Loss: 0.5316 
              precision    recall  f1-score   support

         0.0       0.91      0.91      0.91      1801
         1.0       0.91      0.90      0.90      1750

    accuracy                           0.91      3551
   macro avg       0.91      0.91      0.91      3551
weighted avg       0.91      0.91      0.91      3551



 60%|██████    | 60/100 [39:49<26:48, 40.22s/it]

Epoch [60/100] Train Loss: 0.0010 Val Loss: 0.6407 
              precision    recall  f1-score   support

         0.0       0.91      0.91      0.91      1801
         1.0       0.91      0.90      0.91      1750

    accuracy                           0.91      3551
   macro avg       0.91      0.91      0.91      3551
weighted avg       0.91      0.91      0.91      3551



 70%|███████   | 70/100 [46:10<19:29, 38.98s/it]

Epoch [70/100] Train Loss: 0.0010 Val Loss: 0.7015 
              precision    recall  f1-score   support

         0.0       0.91      0.91      0.91      1801
         1.0       0.90      0.91      0.91      1750

    accuracy                           0.91      3551
   macro avg       0.91      0.91      0.91      3551
weighted avg       0.91      0.91      0.91      3551



 80%|████████  | 80/100 [52:23<12:47, 38.38s/it]

Epoch [80/100] Train Loss: 0.0010 Val Loss: 0.7320 
              precision    recall  f1-score   support

         0.0       0.91      0.91      0.91      1801
         1.0       0.91      0.90      0.90      1750

    accuracy                           0.91      3551
   macro avg       0.91      0.91      0.91      3551
weighted avg       0.91      0.91      0.91      3551



 90%|█████████ | 90/100 [58:34<06:22, 38.20s/it]

Epoch [90/100] Train Loss: 0.0017 Val Loss: 0.4955 
              precision    recall  f1-score   support

         0.0       0.89      0.91      0.90      1801
         1.0       0.90      0.89      0.90      1750

    accuracy                           0.90      3551
   macro avg       0.90      0.90      0.90      3551
weighted avg       0.90      0.90      0.90      3551



100%|██████████| 100/100 [1:04:49<00:00, 38.90s/it]

Epoch [100/100] Train Loss: 0.0011 Val Loss: 0.5265 
              precision    recall  f1-score   support

         0.0       0.89      0.92      0.90      1801
         1.0       0.91      0.88      0.90      1750

    accuracy                           0.90      3551
   macro avg       0.90      0.90      0.90      3551
weighted avg       0.90      0.90      0.90      3551






In [26]:
train_csv_path = 'train.csv'
test_csv_path = 'dev.csv'

method = 'coherent'

train_dataset = CustomDataset(train_csv_path, method=method)
test_dataset = CustomDataset(test_csv_path, method=method)

train_batch_size = 256
test_batch_size = 256

train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

model = CustomDeepClassifier(input_dim=train_dataset[0][0].shape[1])
tl, vl = train_validate_model(model,
                              train_data_loader,
                              test_data_loader,
                              num_epochs=100,
                              learning_rate=3e-4,
                              stepslr=10,
                              gamma=0.5)

torch.save(model.state_dict(), f'binary_model_{method}.pt')

 10%|█         | 10/100 [05:42<53:00, 35.34s/it]

Epoch [10/100] Train Loss: 0.0117 Val Loss: 0.5348 
              precision    recall  f1-score   support

         0.0       0.87      0.88      0.88      1801
         1.0       0.88      0.87      0.87      1750

    accuracy                           0.87      3551
   macro avg       0.88      0.87      0.87      3551
weighted avg       0.88      0.87      0.87      3551



 20%|██        | 20/100 [11:23<47:05, 35.32s/it]

Epoch [20/100] Train Loss: 0.0087 Val Loss: 0.6511 
              precision    recall  f1-score   support

         0.0       0.89      0.87      0.88      1801
         1.0       0.87      0.89      0.88      1750

    accuracy                           0.88      3551
   macro avg       0.88      0.88      0.88      3551
weighted avg       0.88      0.88      0.88      3551



 30%|███       | 30/100 [16:59<40:11, 34.44s/it]

Epoch [30/100] Train Loss: 0.0085 Val Loss: 0.7013 
              precision    recall  f1-score   support

         0.0       0.87      0.88      0.88      1801
         1.0       0.88      0.87      0.87      1750

    accuracy                           0.87      3551
   macro avg       0.87      0.87      0.87      3551
weighted avg       0.87      0.87      0.87      3551



 40%|████      | 40/100 [22:44<37:17, 37.28s/it]

Epoch [40/100] Train Loss: 0.0080 Val Loss: 0.7685 
              precision    recall  f1-score   support

         0.0       0.86      0.89      0.88      1801
         1.0       0.89      0.86      0.87      1750

    accuracy                           0.88      3551
   macro avg       0.88      0.87      0.88      3551
weighted avg       0.88      0.88      0.88      3551



 50%|█████     | 50/100 [28:57<31:46, 38.14s/it]

Epoch [50/100] Train Loss: 0.0242 Val Loss: 0.7666 
              precision    recall  f1-score   support

         0.0       0.85      0.90      0.87      1801
         1.0       0.89      0.83      0.86      1750

    accuracy                           0.87      3551
   macro avg       0.87      0.86      0.87      3551
weighted avg       0.87      0.87      0.87      3551



 60%|██████    | 60/100 [35:07<25:19, 37.98s/it]

Epoch [60/100] Train Loss: 0.0011 Val Loss: 0.9011 
              precision    recall  f1-score   support

         0.0       0.86      0.89      0.87      1801
         1.0       0.88      0.85      0.87      1750

    accuracy                           0.87      3551
   macro avg       0.87      0.87      0.87      3551
weighted avg       0.87      0.87      0.87      3551



 70%|███████   | 70/100 [41:15<18:45, 37.51s/it]

Epoch [70/100] Train Loss: 0.0011 Val Loss: 0.9244 
              precision    recall  f1-score   support

         0.0       0.86      0.89      0.87      1801
         1.0       0.88      0.85      0.87      1750

    accuracy                           0.87      3551
   macro avg       0.87      0.87      0.87      3551
weighted avg       0.87      0.87      0.87      3551



 80%|████████  | 80/100 [47:17<12:24, 37.22s/it]

Epoch [80/100] Train Loss: 0.0012 Val Loss: 1.0238 
              precision    recall  f1-score   support

         0.0       0.86      0.89      0.87      1801
         1.0       0.88      0.85      0.87      1750

    accuracy                           0.87      3551
   macro avg       0.87      0.87      0.87      3551
weighted avg       0.87      0.87      0.87      3551



 90%|█████████ | 90/100 [53:28<06:18, 37.83s/it]

Epoch [90/100] Train Loss: 0.0010 Val Loss: 1.0348 
              precision    recall  f1-score   support

         0.0       0.86      0.89      0.87      1801
         1.0       0.89      0.85      0.87      1750

    accuracy                           0.87      3551
   macro avg       0.87      0.87      0.87      3551
weighted avg       0.87      0.87      0.87      3551



100%|██████████| 100/100 [59:33<00:00, 35.74s/it]

Epoch [100/100] Train Loss: 0.0011 Val Loss: 1.1077 
              precision    recall  f1-score   support

         0.0       0.86      0.89      0.87      1801
         1.0       0.88      0.85      0.87      1750

    accuracy                           0.87      3551
   macro avg       0.87      0.87      0.87      3551
weighted avg       0.87      0.87      0.87      3551






In [27]:
train_csv_path = 'train.csv'
test_csv_path = 'dev.csv'

method = 'maxpool'

train_dataset = CustomDataset(train_csv_path, method=method)
test_dataset = CustomDataset(test_csv_path, method=method)

train_batch_size = 256
test_batch_size = 256

train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

model = CustomDeepClassifier(input_dim=train_dataset[0][0].shape[1])
tl, vl = train_validate_model(model,
                              train_data_loader,
                              test_data_loader,
                              num_epochs=100,
                              learning_rate=3e-4,
                              stepslr=10,
                              gamma=0.5)

torch.save(model.state_dict(), f'binary_model_{method}.pt')

 10%|█         | 10/100 [07:41<1:11:07, 47.42s/it]

Epoch [10/100] Train Loss: 0.2929 Val Loss: 0.7236 
              precision    recall  f1-score   support

         0.0       0.64      0.96      0.77      1801
         1.0       0.91      0.44      0.59      1750

    accuracy                           0.70      3551
   macro avg       0.77      0.70      0.68      3551
weighted avg       0.77      0.70      0.68      3551



 20%|██        | 20/100 [15:38<1:04:50, 48.63s/it]

Epoch [20/100] Train Loss: 0.1366 Val Loss: 0.8831 
              precision    recall  f1-score   support

         0.0       0.68      0.90      0.78      1801
         1.0       0.85      0.57      0.68      1750

    accuracy                           0.74      3551
   macro avg       0.77      0.74      0.73      3551
weighted avg       0.76      0.74      0.73      3551



 30%|███       | 30/100 [23:31<56:57, 48.82s/it]  

Epoch [30/100] Train Loss: 0.1241 Val Loss: 1.2442 
              precision    recall  f1-score   support

         0.0       0.77      0.70      0.74      1801
         1.0       0.72      0.79      0.75      1750

    accuracy                           0.74      3551
   macro avg       0.75      0.75      0.74      3551
weighted avg       0.75      0.74      0.74      3551



 40%|████      | 40/100 [31:24<48:40, 48.68s/it]

Epoch [40/100] Train Loss: 0.0720 Val Loss: 1.8861 
              precision    recall  f1-score   support

         0.0       0.69      0.86      0.77      1801
         1.0       0.81      0.61      0.70      1750

    accuracy                           0.74      3551
   macro avg       0.75      0.74      0.73      3551
weighted avg       0.75      0.74      0.73      3551



 50%|█████     | 50/100 [39:19<40:27, 48.55s/it]

Epoch [50/100] Train Loss: 0.0519 Val Loss: 1.8301 
              precision    recall  f1-score   support

         0.0       0.73      0.77      0.75      1801
         1.0       0.75      0.71      0.73      1750

    accuracy                           0.74      3551
   macro avg       0.74      0.74      0.74      3551
weighted avg       0.74      0.74      0.74      3551



 60%|██████    | 60/100 [47:11<32:23, 48.60s/it]

Epoch [60/100] Train Loss: 0.0318 Val Loss: 2.6037 
              precision    recall  f1-score   support

         0.0       0.75      0.75      0.75      1801
         1.0       0.75      0.74      0.74      1750

    accuracy                           0.75      3551
   macro avg       0.75      0.75      0.75      3551
weighted avg       0.75      0.75      0.75      3551



 70%|███████   | 70/100 [55:06<24:23, 48.79s/it]

Epoch [70/100] Train Loss: 0.0454 Val Loss: 2.6719 
              precision    recall  f1-score   support

         0.0       0.71      0.80      0.76      1801
         1.0       0.77      0.67      0.71      1750

    accuracy                           0.74      3551
   macro avg       0.74      0.73      0.73      3551
weighted avg       0.74      0.74      0.73      3551



 80%|████████  | 80/100 [1:03:06<16:18, 48.93s/it]

Epoch [80/100] Train Loss: 0.0282 Val Loss: 2.7582 
              precision    recall  f1-score   support

         0.0       0.74      0.74      0.74      1801
         1.0       0.73      0.73      0.73      1750

    accuracy                           0.73      3551
   macro avg       0.73      0.73      0.73      3551
weighted avg       0.73      0.73      0.73      3551



 90%|█████████ | 90/100 [1:11:01<08:09, 48.97s/it]

Epoch [90/100] Train Loss: 0.0418 Val Loss: 2.4277 
              precision    recall  f1-score   support

         0.0       0.70      0.85      0.77      1801
         1.0       0.80      0.63      0.70      1750

    accuracy                           0.74      3551
   macro avg       0.75      0.74      0.73      3551
weighted avg       0.75      0.74      0.73      3551



100%|██████████| 100/100 [1:18:56<00:00, 47.36s/it]

Epoch [100/100] Train Loss: 0.0277 Val Loss: 2.8455 
              precision    recall  f1-score   support

         0.0       0.74      0.76      0.75      1801
         1.0       0.74      0.72      0.73      1750

    accuracy                           0.74      3551
   macro avg       0.74      0.74      0.74      3551
weighted avg       0.74      0.74      0.74      3551






In [28]:
train_csv_path = 'train.csv'
test_csv_path = 'dev.csv'

method = 'avgpool'

train_dataset = CustomDataset(train_csv_path, method=method)
test_dataset = CustomDataset(test_csv_path, method=method)

train_batch_size = 256
test_batch_size = 256

train_data_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)
test_data_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)

model = CustomDeepClassifier(input_dim=train_dataset[0][0].shape[1])
tl, vl = train_validate_model(model,
                              train_data_loader,
                              test_data_loader,
                              num_epochs=100,
                              learning_rate=3e-4,
                              stepslr=10,
                              gamma=0.5)

torch.save(model.state_dict(), f'binary_model_{method}.pt')

 10%|█         | 10/100 [06:15<58:00, 38.67s/it] 

Epoch [10/100] Train Loss: 0.0636 Val Loss: 0.8119 
              precision    recall  f1-score   support

         0.0       0.75      0.88      0.81      1801
         1.0       0.85      0.70      0.77      1750

    accuracy                           0.79      3551
   macro avg       0.80      0.79      0.79      3551
weighted avg       0.80      0.79      0.79      3551



 20%|██        | 20/100 [12:31<51:32, 38.66s/it]

Epoch [20/100] Train Loss: 0.0230 Val Loss: 0.9662 
              precision    recall  f1-score   support

         0.0       0.82      0.80      0.81      1801
         1.0       0.80      0.82      0.81      1750

    accuracy                           0.81      3551
   macro avg       0.81      0.81      0.81      3551
weighted avg       0.81      0.81      0.81      3551



 30%|███       | 30/100 [18:46<45:11, 38.74s/it]

Epoch [30/100] Train Loss: 0.0224 Val Loss: 1.0777 
              precision    recall  f1-score   support

         0.0       0.80      0.81      0.80      1801
         1.0       0.80      0.79      0.80      1750

    accuracy                           0.80      3551
   macro avg       0.80      0.80      0.80      3551
weighted avg       0.80      0.80      0.80      3551



 40%|████      | 40/100 [25:02<38:44, 38.75s/it]

Epoch [40/100] Train Loss: 0.0099 Val Loss: 1.3312 
              precision    recall  f1-score   support

         0.0       0.76      0.86      0.81      1801
         1.0       0.83      0.72      0.77      1750

    accuracy                           0.79      3551
   macro avg       0.80      0.79      0.79      3551
weighted avg       0.80      0.79      0.79      3551



 50%|█████     | 50/100 [31:17<32:15, 38.70s/it]

Epoch [50/100] Train Loss: 0.0178 Val Loss: 1.6152 
              precision    recall  f1-score   support

         0.0       0.81      0.78      0.79      1801
         1.0       0.78      0.80      0.79      1750

    accuracy                           0.79      3551
   macro avg       0.79      0.79      0.79      3551
weighted avg       0.79      0.79      0.79      3551



 60%|██████    | 60/100 [37:38<26:51, 40.28s/it]

Epoch [60/100] Train Loss: 0.0027 Val Loss: 1.8226 
              precision    recall  f1-score   support

         0.0       0.78      0.82      0.80      1801
         1.0       0.81      0.76      0.78      1750

    accuracy                           0.79      3551
   macro avg       0.79      0.79      0.79      3551
weighted avg       0.79      0.79      0.79      3551



 70%|███████   | 70/100 [43:55<19:22, 38.76s/it]

Epoch [70/100] Train Loss: 0.0701 Val Loss: 1.3451 
              precision    recall  f1-score   support

         0.0       0.76      0.84      0.80      1801
         1.0       0.82      0.73      0.77      1750

    accuracy                           0.78      3551
   macro avg       0.79      0.78      0.78      3551
weighted avg       0.79      0.78      0.78      3551



 80%|████████  | 80/100 [50:11<12:54, 38.75s/it]

Epoch [80/100] Train Loss: 0.0038 Val Loss: 1.7775 
              precision    recall  f1-score   support

         0.0       0.79      0.80      0.80      1801
         1.0       0.79      0.78      0.79      1750

    accuracy                           0.79      3551
   macro avg       0.79      0.79      0.79      3551
weighted avg       0.79      0.79      0.79      3551



 90%|█████████ | 90/100 [56:28<06:27, 38.75s/it]

Epoch [90/100] Train Loss: 0.0031 Val Loss: 1.8941 
              precision    recall  f1-score   support

         0.0       0.75      0.87      0.80      1801
         1.0       0.84      0.70      0.76      1750

    accuracy                           0.78      3551
   macro avg       0.79      0.78      0.78      3551
weighted avg       0.79      0.78      0.78      3551



100%|██████████| 100/100 [1:02:43<00:00, 37.63s/it]

Epoch [100/100] Train Loss: 0.0068 Val Loss: 1.9716 
              precision    recall  f1-score   support

         0.0       0.81      0.79      0.80      1801
         1.0       0.79      0.81      0.80      1750

    accuracy                           0.80      3551
   macro avg       0.80      0.80      0.80      3551
weighted avg       0.80      0.80      0.80      3551




