In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm.notebook as tqdm
from torch.utils.data import DataLoader
from transformers import GPT2TokenizerFast, GPT2Model
from transformers import AdamW

from utils.datasets import TextDataset
from models.gpt2 import GPT2ForClassification

In [2]:
MODEL = "pierreguillou/gpt2-small-portuguese"

# Loading tweet dataset

In [3]:
train_df = pd.read_csv("/home/kenzo/datasets/cleaned_tweetsentbr/train.tsv", sep="\t", names=["id", "label", "alfa", "text"], index_col=0)
test_df = pd.read_csv("/home/kenzo/datasets/cleaned_tweetsentbr/test.tsv", sep="\t", names=["id", "label", "alfa", "text"], index_col=0)

In [4]:
# TODO: Exploração dos dados

In [5]:
tokenizer = GPT2TokenizerFast.from_pretrained(MODEL)

In [6]:
train_ds = TextDataset.from_df(train_df, tokenizer, max_seq_len=64)
test_ds = TextDataset.from_df(test_df, tokenizer, max_seq_len=64)

# Preparing model

In [7]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from functools import partial

metrics = {
    "accuracy": accuracy_score,
    "precision": partial(precision_score, average="macro"),
    "recall": partial(recall_score, average="macro"),
    "f1": partial(f1_score, average="macro"),
}

In [8]:
gpu = torch.device("cuda:1")

In [9]:
gpt_model = GPT2Model.from_pretrained(MODEL).cuda(gpu)

In [10]:
model = GPT2ForClassification(gpt_model, metrics, 64, 3).cuda(gpu)

In [11]:
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=64, shuffle=True)

In [12]:
optimizer = AdamW(model.parameters(), lr=1e-5) 
criterion = nn.CrossEntropyLoss()

In [13]:
train, test = model.fit(3, train_dl, test_dl, criterion, optimizer, cuda=True, device=gpu)

- Remaining batches:   0%|          | 0/154 [00:00<?, ?it/s]Epoch: 1
- Remaining batches: 100%|██████████| 154/154 [04:21<00:00,  1.70s/it]
- Remaining batches:   0%|          | 0/154 [00:00<?, ?it/s]	train_loss: 1.0408172851259059 // test_loss: 0.8960718298569704// metrics: {'accuracy': 0.5771730300568644, 'precision': 0.5415346917411347, 'recall': 0.5216257527838212, 'f1': 0.5088595730777764}
Epoch: 2
- Remaining batches: 100%|██████████| 154/154 [04:21<00:00,  1.70s/it]
- Remaining batches:   0%|          | 0/154 [00:00<?, ?it/s]	train_loss: 0.8334620881390262 // test_loss: 0.8213031674042727// metrics: {'accuracy': 0.6246953696181966, 'precision': 0.5955543232494739, 'recall': 0.5937915997668322, 'f1': 0.593472783635817}
Epoch: 3
- Remaining batches: 100%|██████████| 154/154 [04:21<00:00,  1.70s/it]
	train_loss: 0.758679410079857 // test_loss: 0.8062564111672915// metrics: {'accuracy': 0.632412672623883, 'precision': 0.6024772990194894, 'recall': 0.5982070787886076, 'f1': 0.5964506

In [14]:
# torch.save(model.bert.state_dict(), "data/checkpoints/bert_tweetsent_br.ckpt")