In [65]:
import torch
import pandas as pd
from typing import Callable, Optional, Tuple
from contextlib import nullcontext
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.optimizer import Optimizer
from adamp import AdamP
from datetime import datetime 
from tqdm import tqdm, trange

from pathlib import Path

In [47]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

device: cpu


In [48]:
model_name = 'distilbert'
dataset_name = 'mbti'
lr = 0.001 # default
betas=(0.9, 0.999) 
weight_decay=1e-2
n_epochs = 200
early_stopping_window = 10

path_data_dir = Path('..') / Path('data')
path_dataset = path_data_dir / Path('split') / Path(f'{dataset_name}.csv')
path_checkpoint_dir = Path('..') / Path('checkpoints')

In [49]:
data = pd.read_csv(path_dataset, header=[0, 1], index_col=0)
data

GROUP,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,CLS,STATS,STATS,STATS,STATS,TARGET,TARGET,TARGET,TARGET
FEATURE,0,1,10,100,101,102,103,104,105,106,...,98,99,NUM_CHARS,NUM_EMOJI,NUM_POSTS,NUM_UPPERCASED,mbtiEXT,mbtiJUD,mbtiSEN,mbtiTHI
AUTHOR,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
-9221022384933360074,0.004908,-0.052056,-0.076416,0.136052,-0.131765,0.015147,-0.403912,-0.041273,-0.104842,-0.108374,...,0.210299,0.208951,82.842105,0.067669,133,7.924812,0.0,0.0,0.0,1.0
-9220031623198266213,0.012287,-0.074256,0.020761,0.190153,-0.158079,0.020834,-0.453263,0.046208,-0.154505,-0.102281,...,0.133708,0.280356,59.416667,0.341667,120,2.958333,0.0,1.0,1.0,1.0
-9219633155989415906,0.038200,-0.042412,0.025798,0.197942,-0.157285,0.081923,-0.349617,0.033938,-0.145553,-0.093792,...,0.216485,0.223513,178.041667,0.000000,48,5.354167,0.0,0.0,0.0,1.0
-9219237589017844173,0.075252,-0.001803,0.024392,0.180314,-0.205115,0.078434,-0.453093,0.042417,-0.187603,-0.015976,...,0.182507,0.178945,77.430380,0.221519,158,3.303797,0.0,0.0,0.0,0.0
-9214568075844254832,0.057628,-0.087571,0.060698,0.196147,-0.166876,0.022728,-0.441546,0.062697,-0.161590,-0.092093,...,0.182287,0.315929,100.617284,0.000000,81,6.716049,0.0,0.0,0.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9220307502816513261,0.054631,-0.040838,0.048029,0.274093,-0.162984,0.100086,-0.424381,0.060659,-0.172998,-0.052123,...,0.206001,0.284869,125.840000,0.000000,50,4.840000,1.0,0.0,0.0,0.0
9220556403022889385,-0.170737,-0.091014,-0.033050,0.133041,-0.146434,0.156010,-0.422089,-0.113840,-0.160334,-0.087500,...,0.261026,0.242038,134.159091,0.000000,44,6.681818,0.0,1.0,0.0,1.0
9221651641191792423,0.035870,-0.075790,0.032501,0.209679,-0.196971,0.039044,-0.391962,0.067448,-0.152983,-0.029710,...,0.217194,0.289530,173.071429,0.000000,42,4.976190,0.0,0.0,0.0,0.0
9222607780732095571,-0.007086,-0.089152,0.079623,0.223943,-0.160965,0.143289,-0.386648,0.021990,-0.167608,0.045628,...,0.167094,0.210510,186.740000,0.000000,50,5.980000,0.0,1.0,0.0,0.0


In [50]:
class PPDataset(Dataset):
  def __init__(self, data:pd.DataFrame):
    self.data_x = torch.Tensor(data.drop(["TARGET"], axis=1).values)
    self.data_y = torch.Tensor(data["TARGET"].values)

  def __len__(self) -> int:
    return self.data_y.shape[0]
  
  def __getitem__(self, index:int) -> torch.Tensor:
    return self.data_x[index], self.data_y[index]

In [51]:
dataset = PPDataset(data)

In [52]:
def split_dataset(dataset: Dataset, train_size: float, test_size: float):
  total_length = len(dataset)
  train_length = int(train_size * total_length)
  test_length = int(test_size * total_length)
  val_length = total_length - (train_length + test_length)
  return random_split(dataset, [train_length, test_length, val_length])

In [53]:
train_ds, test_ds, val_ds = split_dataset(dataset, 0.8, 0.1)

In [54]:
train_args = {
  'batch_size': 1024,
  'num_workers': 1,
  'shuffle': True,
  'pin_memory': False,
  'drop_last': True
}

test_args = {
  'batch_size': 1024,
  'num_workers': 1,
  'shuffle': False,
  'pin_memory': False,
  'drop_last': True
}

In [55]:
train_dl = DataLoader(train_ds, **train_args)
test_dl = DataLoader(test_ds, **test_args)
val_dl = DataLoader(val_ds, **test_args)

In [56]:
class Decoder(nn.Module):
  def __init__(self, hiddens:list[int], dropout_percent=0.5, final='Sigmoid'):
    super(Decoder, self).__init__()
    layers = []
    for i in range(len(hiddens) - 1):
        layers.append(nn.Linear(hiddens[i], hiddens[i + 1]))
        if i < len(hiddens) - 2:
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_percent))
    if final.lower() == 'sigmoid': layers.append(nn.Sigmoid())
    elif final.lower() == 'relu': layers.append(nn.ReLU())
    self.model = nn.Sequential(*layers)

  def forward(self, x:torch.Tensor) -> torch.Tensor:
    return self.model(x)

In [59]:
def handle_epoch(epoch_type:str, dl: DataLoader, model: nn.Module, loss_fn:Callable, optimizer:Optimizer=None) -> Tuple[Optional[float], float]:
    valid_epoch_types = ('train', 'test', 'val', 'validation')
    if epoch_type.lower() not in valid_epoch_types: raise TypeError(f'Argument "epoch_type" must be one of {valid_epoch_types}')
    model.train() if epoch_type == 'train' else model.eval()
    context_manager = torch.no_grad() if epoch_type != 'train' else nullcontext()

    running_loss = 0
    running_corrects = 0
    total_samples = 0

    with context_manager:
      for data_x, data_y in dl:
          data_x = data_x.to(device)
          data_y = data_y.to(device)

          if epoch_type == 'train': optimizer.zero_grad()

          predictions = model(data_x)

          loss = loss_fn(predictions, data_y) if epoch_type != 'test' else None

          if epoch_type == 'train':
            loss.backward()
            optimizer.step()

          preds = (predictions > 0.5).float()
          running_corrects += torch.sum(preds == data_y).item()
          if epoch_type != 'test': running_loss += loss.item() * data_x.size(0)
          total_samples += data_y.numel()

    avg_loss = running_loss / total_samples if epoch_type != 'test' else None
    avg_acc = running_corrects / total_samples
    return avg_loss, avg_acc

In [62]:
def train(model:nn.Module, train:DataLoader, val:DataLoader, test:DataLoader, optimizer:Optimizer, loss_fn:Callable, n_epochs:int, checkpoint_name:str='default', early_stopping_window=5):
  start = datetime.now()
  model_path = path_checkpoint_dir / Path(f'{checkpoint_name}.pth')

  best_epoch = -1
  best_vacc = float('-inf')

  train_loop = trange(n_epochs, desc='Training', leave=True)
  for epoch in train_loop:

    # train
    avg_loss, avg_acc = handle_epoch('train', train, model, loss_fn=loss_fn, optimizer=optimizer) 

    # val
    avg_vloss, avg_vacc = handle_epoch('val', val, model, loss_fn=loss_fn)
    
    train_loop.set_description(f'EPOCH {epoch}: Train loss: {avg_loss:.3f}, Val loss: {avg_vloss:.3f} \t Train acc: {avg_acc:.3f}, Val acc: {avg_vacc:.3f}')

    if avg_vacc > best_vacc:
      best_vacc = avg_vacc
      best_epoch = epoch
      torch.save(model.state_dict(), model_path)
    elif epoch - best_epoch > early_stopping_window:
      tqdm.write(f'Early stopping with best validation accuracy {best_vacc*100:.3f}%')
      break
  
  _, avg_tacc = handle_epoch('test', test, model, loss_fn=loss_fn) 

  end = datetime.now()
  total_time = end - start

  tqdm.write(f'Training finished after {total_time}. Test accuracy {avg_tacc*100:.3f}%')

In [64]:
classes = 4
model = Decoder([data.shape[1]-classes, 2048, 2048, 512, classes])
loss_fn = nn.BCELoss()
optimizer = AdamP(model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
model

Decoder(
  (model): Sequential(
    (0): Linear(in_features=772, out_features=2048, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=2048, out_features=2048, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=2048, out_features=512, bias=True)
    (7): ReLU()
    (8): Dropout(p=0.5, inplace=False)
    (9): Linear(in_features=512, out_features=4, bias=True)
    (10): Sigmoid()
  )
)

In [66]:
train(model, train_dl, val_dl, test_dl, optimizer, loss_fn, n_epochs, checkpoint_name='mbti-test')

TypeError: train() got an unexpected keyword argument 'device'

In [None]:
#  1: Train loss: 0.1488, Val loss: 0.1498. Low dropout
# 2: High dropout: Early stopped training at 20 with best accuracy 66.1376953125%
# 3: 2048: Early stopped training at 15 with best accuracy 66.0889%
# Decoder([data.shape[1]-classes, 4096, 4096, 1024, 256, classes]): EPOCH 15 	 Train loss: 0.152, Val loss: 0.151 	 Train acc: 0.672, Val acc: 0.667
# Decoder([data.shape[1]-classes, 2048, 2048, 1024, 256, classes]) EPOCH 11 	 Train loss: 0.151, Val loss: 0.152 	 Train acc: 0.673, Val acc: 0.665 [01:30<25:56,  8.24s/it]
# EPOCH 13: Train loss: 0.151, Val loss: 0.151 	 Train acc: 0.672, Val acc: 0.669:   6%|▋         | 13/200 [01:30<21:48,  7.00s/it]
# Decoder([data.shape[1]-classes, 2048, 2048, 512, classes]): Training finished after 0:01:31.861364. Test accuracy 66.919%