In [1]:
import sys

module_path = "../src"

if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
# Load dataset
from dataset import get_dataset

dataset = get_dataset()

In [3]:
# Load libraries
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoConfig
import torch
import wandb

model_name = "gpt2"
epochs = 3
num_labels = 4
batch_size = 32

run = wandb.init(project=model_name)

[34m[1mwandb[0m: Currently logged in as: [33mkpierzynski[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
# Define tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=256)

# check which one is required
tokenizer.pad_token = tokenizer.eos_token


def tokenize(input):
    return tokenizer(input["text"], padding="max_length", truncation=True)


# Tokenize dataset
tokenized_dataset = dataset.map(tokenize, batched=True)

# Shuffle and pick subset from dataset
train_dataset = tokenized_dataset["train"].shuffle(seed=442333 + 424714).select(range(5000))
eval_dataset = tokenized_dataset["test"].shuffle(seed=442333 + 424714).select(range(1000))

In [5]:
import evaluate
import numpy as np

metric = evaluate.load("accuracy")


# Prepare evaluation callback, metric = accuracy
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [6]:
# Define custom classification head, 2 layers deep


class CustomClassificationHead(torch.nn.Module):
    def __init__(self, input_dim, num_labels):
        super().__init__()
        self.dense = torch.nn.Linear(input_dim, 512)
        self.dense2nd = torch.nn.Linear(512, 256)
        self.dropout = torch.nn.Dropout(0.2)
        self.out_proj = torch.nn.Linear(256, num_labels)

    def forward(self, x):
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dense2nd(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x

In [7]:
# Create model
config = AutoConfig.from_pretrained(model_name, num_labels=num_labels)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)

model.config.pad_token_id = model.config.eos_token_id
# Connect new classfier to model
model.transformer.add_module("classifier", CustomClassificationHead(config.n_embd, config.num_labels))

Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
# Freeze odd layers

i = 0
for param in model.parameters():
    if i % 2 == 0:
        continue
    param.requires_grad = False
    i += 1

In [9]:
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=epochs,
    per_device_train_batch_size=batch_size,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    report_to="wandb",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

In [10]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.33468,0.886
2,No log,0.289173,0.91
3,No log,0.301305,0.911


TrainOutput(global_step=471, training_loss=0.40511615210531116, metrics={'train_runtime': 179.7601, 'train_samples_per_second': 83.445, 'train_steps_per_second': 2.62, 'total_flos': 1971881994240000.0, 'train_loss': 0.40511615210531116, 'epoch': 3.0})

In [None]:
print(model)

In [12]:
trainer.save_model(model_name + "_model")

run.save(f"{model_name}_model/*")
wandb.finish()



VBox(children=(Label(value='2.665 MB of 476.743 MB uploaded\r'), FloatProgress(value=0.005591002636754955, max…

0,1
eval/accuracy,▁██
eval/loss,█▁▃
eval/runtime,▁█▃
eval/samples_per_second,█▁▆
eval/steps_per_second,█▁▆
train/epoch,▁▅██
train/global_step,▁▄██
train/total_flos,▁
train/train_loss,▁
train/train_runtime,▁

0,1
eval/accuracy,0.911
eval/loss,0.30131
eval/runtime,4.1105
eval/samples_per_second,243.28
eval/steps_per_second,30.41
train/epoch,3.0
train/global_step,471.0
train/total_flos,1971881994240000.0
train/train_loss,0.40512
train/train_runtime,179.7601
