In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler
from datasets import load_dataset
import tqdm
from transformers import GPT2TokenizerFast, BertTokenizerFast, Trainer, TrainingArguments, EvalPrediction
import re
import tiktoken
import numpy as np
from bus_nGPT import Decoder, TransformerLayer, AttentionHead, Rotary, LMHead
import tensorboard as tb
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from importlib import reload

2024-12-02 22:14:04.371491: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1733199244.388617 1067838 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1733199244.394517 1067838 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-02 22:14:04.413579: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
baseline = torch.load('bus_model.pt')

  baseline = torch.load('bus_model.pt')


In [3]:
snli = load_dataset('stanfordnlp/snli')
mnli = load_dataset('nyu-mll/multi_nli')

In [4]:
tokenizer = BertTokenizerFast.from_pretrained('google-bert/bert-base-uncased')

In [5]:
def tokenize_fn(examples):
    return tokenizer(examples['premise'] + " " + examples['hypothesis'], 
                     padding='max_length', 
                     truncation=True,
                     )

In [6]:
def compute_accuracy(eval: EvalPrediction):
    x = np.argmax(eval.predictions[0], axis=-1)
    z = np.average(x == eval.predictions[1])
    return {
        'accuracy': z    
        }

In [7]:
def abs_func(example):
    return {'label': example['label']} if example['label'] >= 0 else {'label':-example['label']}

In [8]:
snli = snli.map(abs_func)

In [9]:
# snli_tokenized = snli.map(tokenize_fn)
mnli_tokenized = mnli.map(tokenize_fn)

Map:   0%|          | 0/392702 [00:00<?, ? examples/s]

Map:   0%|          | 0/9815 [00:00<?, ? examples/s]

Map:   0%|          | 0/9832 [00:00<?, ? examples/s]

In [10]:
bl = LMHead(baseline, 512, 3)
bl.to('cuda')

LMHead(
  (ffn): Sequential(
    (0): Linear(in_features=512, out_features=512, bias=True)
    (1): GELU(approximate='none')
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): GELU(approximate='none')
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=512, out_features=3, bias=True)
  )
  (model): Decoder(
    (blocks): ModuleList(
      (0-7): 8 x TransformerLayer(
        (heads): ModuleList(
          (0-7): 8 x AttentionHead(
            (rope): Rotary()
          )
        )
        (silu): SiLU()
      )
    )
    (embeddings): Embedding(30523, 384)
  )
)

In [12]:
mnli_tokenized

DatasetDict({
    train: Dataset({
        features: ['promptID', 'pairID', 'premise', 'premise_binary_parse', 'premise_parse', 'hypothesis', 'hypothesis_binary_parse', 'hypothesis_parse', 'genre', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 392702
    })
    validation_matched: Dataset({
        features: ['promptID', 'pairID', 'premise', 'premise_binary_parse', 'premise_parse', 'hypothesis', 'hypothesis_binary_parse', 'hypothesis_parse', 'genre', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 9815
    })
    validation_mismatched: Dataset({
        features: ['promptID', 'pairID', 'premise', 'premise_binary_parse', 'premise_parse', 'hypothesis', 'hypothesis_binary_parse', 'hypothesis_parse', 'genre', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 9832
    })
})

In [13]:

training_args = TrainingArguments(
	output_dir="output", 
	eval_strategy="steps", 
	num_train_epochs = 2,
	warmup_steps = 0,
	logging_steps = 100,
	save_steps = 100,
	load_best_model_at_end = True,
	learning_rate = 1e-3,
    per_device_train_batch_size=8,
    label_names=['labels'],
    gradient_accumulation_steps=200,
	)

trainer = Trainer(
    model=bl,
    args=training_args,
    train_dataset=mnli_tokenized['train'],
    eval_dataset=mnli_tokenized['validation_mismatched'],
	tokenizer=tokenizer,
    compute_metrics=compute_accuracy

)



  trainer = Trainer(


In [14]:
trainer.train()

  0%|          | 0/490 [00:00<?, ?it/s]

{'loss': 1.0989, 'grad_norm': 13.102109909057617, 'learning_rate': 0.0007959183673469387, 'epoch': 0.41}


  0%|          | 0/1229 [00:00<?, ?it/s]

{'eval_loss': 1.097913146018982, 'eval_accuracy': 0.3347233523189585, 'eval_runtime': 69.6883, 'eval_samples_per_second': 141.085, 'eval_steps_per_second': 17.636, 'epoch': 0.41}
{'loss': 1.097, 'grad_norm': 15.252888679504395, 'learning_rate': 0.0005918367346938776, 'epoch': 0.81}


  0%|          | 0/1229 [00:00<?, ?it/s]

{'eval_loss': 1.0950745344161987, 'eval_accuracy': 0.35262408462164363, 'eval_runtime': 58.5094, 'eval_samples_per_second': 168.041, 'eval_steps_per_second': 21.005, 'epoch': 0.81}
{'loss': 1.0965, 'grad_norm': 28.049867630004883, 'learning_rate': 0.0003877551020408163, 'epoch': 1.22}


  0%|          | 0/1229 [00:00<?, ?it/s]

{'eval_loss': 1.0963783264160156, 'eval_accuracy': 0.350793327908869, 'eval_runtime': 60.083, 'eval_samples_per_second': 163.64, 'eval_steps_per_second': 20.455, 'epoch': 1.22}
{'loss': 1.0943, 'grad_norm': 11.529874801635742, 'learning_rate': 0.00018367346938775512, 'epoch': 1.63}


  0%|          | 0/1229 [00:00<?, ?it/s]

{'eval_loss': 1.091947317123413, 'eval_accuracy': 0.36696501220504474, 'eval_runtime': 57.4565, 'eval_samples_per_second': 171.121, 'eval_steps_per_second': 21.39, 'epoch': 1.63}
{'train_runtime': 4939.3435, 'train_samples_per_second': 159.01, 'train_steps_per_second': 0.099, 'train_loss': 1.095941940619021, 'epoch': 2.0}


TrainOutput(global_step=490, training_loss=1.095941940619021, metrics={'train_runtime': 4939.3435, 'train_samples_per_second': 159.01, 'train_steps_per_second': 0.099, 'total_flos': 0.0, 'train_loss': 1.095941940619021, 'epoch': 1.996414602346806})

In [15]:
trainer.evaluate(mnli_tokenized['validation_matched'])

  0%|          | 0/1227 [00:00<?, ?it/s]

{'eval_loss': 1.0934261083602905,
 'eval_accuracy': 0.3681100356597045,
 'eval_runtime': 61.3461,
 'eval_samples_per_second': 159.994,
 'eval_steps_per_second': 20.001,
 'epoch': 1.996414602346806}

In [None]:
trainer.evaluate(snli_tokenized['test'])

  0%|          | 0/1250 [00:00<?, ?it/s]

{'eval_loss': 1.079468011856079,
 'eval_accuracy': 0.4098,
 'eval_runtime': 61.0158,
 'eval_samples_per_second': 163.892,
 'eval_steps_per_second': 20.487,
 'epoch': 1.9953321991013393}