In [None]:
import json
import torch
torch.manual_seed(123)
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# data 
X = torch.load('hidden_states.pt', map_location=torch.device('cpu'))
Y = torch.tensor(json.load(open('labels.json')), dtype=torch.long)
Y.sum()

In [None]:
X_balanced = torch.empty_like(X)
Y_balanced = torch.ones_like(Y) + 2

count_false = 0
count_true = 0

for i in range(X.shape[0]):  
    
    if Y[i] == 0 and count_false < 450:
        X_balanced[i, :] = X[i]
        Y_balanced[i] = 0
        count_false += 1

    elif Y[i] == 1 and count_true < 450:
        X_balanced[i, :] = X[i]
        Y_balanced[i] = 1
        count_true += 1

mask = Y_balanced == 3

X_balanced = X_balanced[~mask].float()
Y_balanced = Y_balanced[~mask].float()
Y_balanced.shape

In [None]:
n = int(0.9 * X_balanced.shape[0])
X_train = X_balanced[:n, :]
Y_train = Y_balanced[:n]
X_test = X_balanced[n:, :]
Y_test = Y_balanced[n:]

# X = torch.randn(100, 4096)
# X_train = torch.cat([X, X + 1], dim=0)
# Y_train = torch.cat([torch.zeros(100), torch.ones(100)],  dim=0)

rand_perm = torch.randperm(X_train.shape[0])
X_train = X_train[rand_perm]
Y_train = Y_train[rand_perm]
Y_train.shape[0] / 4

## Binary CLassifier

### Model

In [None]:
class BinaryClassifier(nn.Module):
  def __init__(self, input_dim=4096, hidden_dim=2048, output_dim=1, n_layers=4):
    super().__init__()
    
    # input layers
    layers = [
        nn.Linear(input_dim, hidden_dim),
        nn.ReLU()
    ]
    
    # hidden
    for _ in range(n_layers - 1):
      layers.append(nn.Linear(hidden_dim, hidden_dim))
      layers.append(nn.ReLU())
      
    # output
    layers.extend([nn.Linear(hidden_dim, output_dim), nn.Sigmoid()])

    self.model = nn.Sequential(*layers)

  def forward(self, x):
    return self.model(x)

In [None]:
class Trainer:
  def __init__(self, X, Y, eval_X, eval_Y, model, batch_size=32, lr=1e-5, epochs=10):
    self.X = X
    self.Y = Y
    self.eval_X = eval_X
    self.eval_Y = eval_Y
    self.model = model
    self.batch_size = batch_size
    self.lr = lr
    self.epochs = epochs
    self.criterion = nn.BCELoss()
    self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
    self.training_losses = []
    self.eval_losses = []


  def train(self):  
    for epoch in range(self.epochs):      
      epoch_loss = []
      num_steps = self.X.shape[0] // self.batch_size      
      for i in range(num_steps):
        rand_idx = torch.randint(self.X.shape[0], (self.batch_size, ))
        # print(rand_idx)
        x = self.X[rand_idx].float()        
        y_label = self.Y[rand_idx].float().unsqueeze(-1)
        y_pred = self.model(x)
        # print(y_pred.shape, y_label.shape)     
        loss = self.criterion(y_pred, y_label)
        # print(y_pred, y_label)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        epoch_loss.append(loss.item())
        # print(f"Step {i} loss: {loss.item()}")
        
      print(f"Epoch loss {epoch}: {torch.mean(torch.tensor(epoch_loss))}")
      self.training_losses.append(torch.mean(torch.tensor(epoch_loss)))
      self.eval(eval_X=self.eval_X, eval_Y=self.eval_Y)
      
  def eval(self, eval_X, eval_Y):
    with torch.no_grad():
      y_pred = self.model(eval_X.float())
      loss = self.criterion(y_pred, eval_Y.float().unsqueeze(-1))
      print(f"Eval Loss: {loss.mean()}")
      self.eval_losses.append(loss.mean())
      
    

In [None]:
model = BinaryClassifier()
trainer = Trainer(X_train, Y_train, X_test, Y_test, model)
trainer.train()

In [None]:
[trainer.training_losses.numpy()

In [None]:
# import matplotlib.pyplot as plt
# [loss.numpy() for loss in trainer.training_losses.numpy()]
[loss.detach().item() for loss in trainer.eval_losses]

In [None]:
[loss.numpy() for loss in trainer.eval_losses.numpy()]

In [None]:
 trainer.eval_losses