In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from utils import load, dump, distill, categories, load_scraped_data, load_diff_data
from evaluate import eval, Evaluator
from os import getcwd
from tqdm import tqdm
import torch

from datasets import Dataset, DatasetDict
from torch.utils.data import DataLoader

root = f'{getcwd()}/..'
path = lambda p: f'{root}/../out/{p}'

In [2]:
datafile = f'{root}/../scraped_diffs.pickle'

In [3]:
best_nl = path('distilbert-base-uncased_scraped_5e-04_thaw1.pt')
best_diff = path('CodeBERTa-small-v1_diffs_pp_dis_solid.pt')

In [4]:
eval_nl = Evaluator(best_nl)
eval_diff = Evaluator(best_diff)

In [5]:
nl_in, nl_labels = load_scraped_data(datafile)
di_in, di_labels = load_diff_data(datafile)

assert(nl_labels == di_labels)

inputs = [*zip(nl_in, di_in)]
labels = nl_labels

In [6]:
split_percent = 0.8
split=round(len(inputs)*split_percent)

#inputs = inputs[split:]
#labels = labels[split:]

assert(len(inputs) == len(labels))

In [24]:
metaset = {'inputs': [], 'labels': []}

# Stacked outputs from both classifiers
# M1 (nl) = [n1, n2, n3, n4]
# M2 (diff) = [d1, d2, d3, d4]
# Metaset = [[n1, n2, n3, n4, d1, d2, d3, d4]]
with torch.no_grad():
    for (nl, diff), label in tqdm([*zip(inputs, labels)]):
        m1 = eval_nl.predict(nl, raw=True)
        m2 = eval_diff.predict(diff, raw=True)
        stacked = torch.cat((m1, m2), -1)
        metaset['inputs'].append(stacked)
        metaset['labels'].append(label)

dump(metaset, 'metaset2.pickle')

 51%|█████     | 19674/38550 [2:46:23<2:33:31,  2.05it/s] IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 81%|████████  | 31239/38550 [4:25:54<56:11,  2.17it/s]  IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 38550/38550 [5:27:53<00:00,  1.96it/s]  


In [7]:
metaset = load('metaset2.pickle')

In [8]:
split_percent = 0.8
split=round(len(metaset['inputs'])*split_percent)

train_inputs = metaset['inputs'][:split]
train_labels = metaset['labels'][:split]
valid_inputs = metaset['inputs'][split:]
valid_labels = metaset['labels'][split:]

datasets = DatasetDict({
    'train': Dataset.from_dict({
      'inputs': train_inputs,
      'labels': [categories.index(c) for c in train_labels]
    }),
    'valid': Dataset.from_dict({
      'inputs': valid_inputs,
      'labels': [categories.index(c) for c in valid_labels]
    }),
})

datasets.set_format("torch")

In [9]:
class LogisticRegression(torch.nn.Module):

    def __init__(self, inp, out):

        super(LogisticRegression, self).__init__()

        self.linear = torch.nn.Linear(inp, out)

    def forward(self, x):

        outputs = self.linear(x)

        return outputs

In [10]:
batch_size = 2
epochs = 10
lr = 1e-3
model = LogisticRegression(8, 4)
lossFunc = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
seed=42

# Create dataloaders
train_dataloader = DataLoader(datasets["train"].shuffle(seed=seed), shuffle=True, batch_size=batch_size)
eval_dataloader = DataLoader(datasets["valid"].shuffle(seed=seed), batch_size=batch_size)

num_training_steps = epochs * len(train_dataloader)

PATH = f'{root}/../out/meta_classifier2.pt'
metrics = []
best_state = None

In [11]:
from datasets import load_metric
import numpy as np
import time

device = torch.device('cuda')
model.to(device)
model.train()

def get_target(like, labels):
    target = torch.zeros_like(like).to(device)

    for i, l in enumerate(labels):
        target[i][l] = 1
        
    return target


for epoch in range(epochs):
    
    train_acc = load_metric('accuracy')
    train_f1 = load_metric('f1')
    val_acc = load_metric('accuracy')
    val_f1 = load_metric('f1')
    start_time = time.time()
    train_loss = 0
    val_loss = 0
    
    for batch in tqdm(train_dataloader):
        inputs = batch['inputs'][0].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        target = get_target(outputs, labels)

        loss = lossFunc(outputs, target)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        
        preds = torch.argmax(outputs, dim=-1)
        train_acc.add_batch(predictions=preds, references=labels)
        train_f1.add_batch(predictions=preds, references=labels)
    
    with torch.no_grad():
      for batch in eval_dataloader:
        inputs = batch['inputs'][0].to(device)
        labels = batch['labels'].to(device)
        outputs = model(inputs)
        loss = lossFunc(outputs, target)
        val_loss += loss.item()
        
        preds = torch.argmax(outputs, dim=-1)
        val_acc.add_batch(predictions=preds, references=labels)
        val_f1.add_batch(predictions=preds, references=labels)

      train_acc = train_acc.compute()
      train_f1 = train_f1.compute(average='macro')
      val_acc = val_acc.compute()
      val_f1 = val_f1.compute(average='macro')

      vpoint = {
          'train_acc': train_acc,
          'train_f1': train_f1,
          'train_loss': train_loss / len(train_dataloader),
          'val_acc': val_acc,
          'val_f1': val_f1,
          'val_loss': val_loss / len(eval_dataloader)
      }

      metrics.append(vpoint)

      print(f'train_loss: {vpoint["train_loss"]:.3f}, val_loss: {vpoint["val_loss"]:.3f}')
    
    print(f"Epoch time: {((time.time() - start_time) / 60):.3f} minutes")
    
    curr_state = {
        **vpoint,
        'model_state_dict': model.state_dict()
    }
    
    if not best_state or best_state['val_loss'] > curr_state['val_loss']:
        best_state = curr_state
      
    torch.save({
      'epoch': epoch,
      'model_state_dict': model.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'best_state': best_state,
      'metrics': metrics,
    }, PATH)

print(f'{epochs}/{epochs} done')

100%|██████████| 15420/15420 [04:44<00:00, 54.11it/s]


train_loss: 0.611, val_loss: 2.155
Epoch time: 5.147 minutes


100%|██████████| 15420/15420 [04:46<00:00, 53.75it/s]


train_loss: 0.549, val_loss: 2.279
Epoch time: 5.182 minutes


100%|██████████| 15420/15420 [04:46<00:00, 53.77it/s]


train_loss: 0.544, val_loss: 1.583
Epoch time: 5.182 minutes


100%|██████████| 15420/15420 [04:46<00:00, 53.81it/s]


train_loss: 0.542, val_loss: 3.289
Epoch time: 5.178 minutes


100%|██████████| 15420/15420 [04:47<00:00, 53.68it/s]


train_loss: 0.541, val_loss: 1.497
Epoch time: 5.190 minutes


100%|██████████| 15420/15420 [04:48<00:00, 53.50it/s]


train_loss: 0.540, val_loss: 3.201
Epoch time: 5.206 minutes


100%|██████████| 15420/15420 [04:56<00:00, 52.03it/s]


train_loss: 0.540, val_loss: 2.484
Epoch time: 5.344 minutes


100%|██████████| 15420/15420 [04:50<00:00, 53.17it/s]


train_loss: 0.539, val_loss: 1.552
Epoch time: 5.235 minutes


100%|██████████| 15420/15420 [04:47<00:00, 53.61it/s]


train_loss: 0.539, val_loss: 2.425
Epoch time: 5.195 minutes


100%|██████████| 15420/15420 [04:47<00:00, 53.67it/s]


train_loss: 0.539, val_loss: 1.546
Epoch time: 5.191 minutes
10/10 done
