In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import OneCycleLR
import tqdm.notebook as tqdm
from torch.utils.data import DataLoader
from transformers import GPT2TokenizerFast, GPT2Model
from transformers import AdamW
from utils.datasets import TextDataset
from models.gpt2 import GPT2ForClassification

In [2]:
MODEL = "pierreguillou/gpt2-small-portuguese"

# Loading tweet dataset

In [3]:
train_df = pd.read_csv("/home/kenzo/datasets/cleaned_tweetsentbr/train.tsv", sep="\t", names=["id", "label", "alfa", "text"], index_col=0)
test_df = pd.read_csv("/home/kenzo/datasets/cleaned_tweetsentbr/test.tsv", sep="\t", names=["id", "label", "alfa", "text"], index_col=0)

In [4]:
tokenizer = GPT2TokenizerFast.from_pretrained(MODEL)

In [5]:
train_ds = TextDataset.from_df(train_df, tokenizer, max_seq_len=64)
test_ds = TextDataset.from_df(test_df, tokenizer, max_seq_len=64)

# Preparing model

In [6]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from functools import partial

metrics = {
    "accuracy": accuracy_score,
    "precision": partial(precision_score, average="macro"),
    "recall": partial(recall_score, average="macro"),
    "f1": partial(f1_score, average="macro"),
}

In [7]:
gpu = torch.device("cuda:1")

In [8]:
gpt_model = GPT2Model.from_pretrained(MODEL)

In [9]:
model = GPT2ForClassification(gpt_model, 64, 3, metrics).cuda(gpu)

In [10]:
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=64, shuffle=True)

In [11]:
batches = len(train_dl)
epochs = 3
optimizer = AdamW(model.parameters(), lr=1e-5) 
scheduler = OneCycleLR(optimizer, max_lr=1e-5, steps_per_epoch=batches, epochs=epochs)
criterion = nn.CrossEntropyLoss()

In [12]:
train, test = model.fit(epochs, train_dl, test_dl, criterion, optimizer, scheduler=scheduler, cuda=True, device=gpu)

- Remaining batches: 100%|██████████| 154/154 [04:21<00:00,  1.70s/it]
Epoch: 1
- Remaining batches:   0%|          | 0/154 [00:00<?, ?it/s]	train_loss: 1.0228747721616325 // test_loss: 0.8990437984466553// metrics: {'accuracy': 0.5735174654752234, 'precision': 0.5398938962755699, 'recall': 0.5288365584829621, 'f1': 0.5292728976278424}

- Remaining batches: 100%|██████████| 154/154 [04:21<00:00,  1.70s/it]
Epoch: 2
- Remaining batches:   0%|          | 0/154 [00:00<?, ?it/s]	train_loss: 0.826910300301267 // test_loss: 0.8286859928033291// metrics: {'accuracy': 0.6214459788789602, 'precision': 0.5919524792585519, 'recall': 0.5810038573422652, 'f1': 0.5781046059050848}

- Remaining batches: 100%|██████████| 154/154 [04:21<00:00,  1.70s/it]
Epoch: 3
	train_loss: 0.7604062886207135 // test_loss: 0.8199584407684131// metrics: {'accuracy': 0.6267262388302194, 'precision': 0.6002453969914275, 'recall': 0.5902444582887446, 'f1': 0.5895114958316555}



In [13]:
# torch.save(model.bert.state_dict(), "data/checkpoints/bert_tweetsent_br.ckpt")