In [None]:
from torch.utils.data import DataLoader, Dataset
from transformers import AdamW, BertConfig, RobertaForSequenceClassification

from media_frame_transformer.dataset import PrimaryFrameDataset
from transformers import AdamW
from torch.nn import functional as F
from config import ISSUES

BATCH_SIZE = 5
NUM_DATALOADER_WORKER = 2

In [None]:
train_set = PrimaryFrameDataset(ISSUES, "train")
test_set = PrimaryFrameDataset(ISSUES, "test")
train_loader = DataLoader(
    train_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_DATALOADER_WORKER,
)
test_loader = DataLoader(
    test_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_DATALOADER_WORKER,
)

In [None]:
len(train_loader), len(test_loader)

In [None]:
len(train_set), len(test_set)

In [None]:
model = RobertaForSequenceClassification.from_pretrained(
    "roberta-base",
    num_labels=15,
    output_attentions=False,  # Whether the model returns attentions weights.
    output_hidden_states=False,  # Whether the model returns all hidden-states.
)
model = model.cuda()

In [None]:
optimizer = AdamW(model.parameters(),  lr = 1e-5)

for i, batch in enumerate(train_loader):
    optimizer.zero_grad()
    model.train()

    x = batch['x'].cuda()
    y = batch['y'].cuda()
    # print(x.shape, y.shape)

    outputs = model(x)
    loss = F.cross_entropy(outputs.logits, y)
    loss.backward()
    optimizer.step()
    
    print(i, loss)

In [None]:
from tqdm import tqdm
import torch  

num_correct = 0

with torch.no_grad():
    for i, batch in enumerate(tqdm(train_loader)):
        model.eval()
        x = batch['x'].cuda()
        y = batch['y'].cuda()
        outputs = model(x)
        preds = torch.argmax(outputs.logits, dim=-1)
        correct = (preds == y)
        num_correct += correct.sum()

print(num_correct / len(train_set))

In [None]:
from tqdm import tqdm
import torch  

num_correct = 0

with torch.no_grad():
    for i, batch in enumerate(tqdm(test_loader)):
        model.eval()
        x = batch['x'].cuda()
        y = batch['y'].cuda()
        outputs = model(x)
        preds = torch.argmax(outputs.logits, dim=-1)
        correct = (preds == y)
        num_correct += correct.sum()

print(num_correct / len(test_set))