In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import pandas as pd
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, f1_score, accuracy_score, precision_score, recall_score
from tqdm.notebook import tqdm


class FakeNewsDataset(Dataset):
    def __init__(self, mode, tokenizer, path):
        assert mode in ['train', 'val']
        self.mode = mode
        self.df = pd.read_csv(path + mode + '.tsv', sep='\t').fillna('')
        self.len = len(self.df)
        self.tokenizer = tokenizer
        
    def __getitem__(self, idx):
        statement, label = self.df.iloc[idx, :].values
        label_tensor = torch.tensor(label)
        
        # DistilBERT does not require special tokens <s> or </s>
        tokens = self.tokenizer.encode_plus(
            statement,
            add_special_tokens=True,
            return_tensors='pt'
        )
        
        tokens_tensor = tokens['input_ids'].squeeze()
        attention_tensor = tokens['attention_mask'].squeeze()
        
        return (tokens_tensor, attention_tensor, label_tensor)
        
    def __len__(self):
        return self.len

def create_mini_batch(samples):
    tokens_tensors = [s[0] for s in samples]
    attention_tensors = [s[1] for s in samples]
    
    if samples[0][2] is not None:
        label_ids = torch.stack([s[2] for s in samples])
    else:
        label_ids = None
        
    # zero padding
    tokens_tensors = pad_sequence(tokens_tensors, batch_first=True)
    attention_tensors = pad_sequence(attention_tensors, batch_first=True)
    
    return tokens_tensors, attention_tensors, label_ids

BATCH_SIZE = 16
MODEL_NAME = 'distilbert-base-uncased'
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)

NUM_LABELS = 2
model = DistilBertForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=NUM_LABELS)

# Move model to device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

def evaluate(valloader):
    true = []
    predictions = []
    with torch.no_grad():
        model.eval()
        for data in valloader:
            if next(model.parameters()).is_cuda:
                data = [t.to(device) for t in data if t is not None]

            tokens_tensors, attention_tensors = data[:2]
            val_outputs = model(input_ids=tokens_tensors,
                                attention_mask=attention_tensors)

            logits = val_outputs.logits  # DistilBERT uses 'logits' attribute
            _, pred = torch.max(logits.data, 1)

            labels = data[2]
            true.extend(labels.cpu().tolist())
            predictions.extend(pred.cpu().tolist())

    cm = confusion_matrix(true, predictions, labels=[1, 0], normalize='true')
    print(cm)

    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Real', 'Fake'])
    disp.plot()

    accuracy = accuracy_score(true, predictions)
    precision = precision_score(true, predictions)
    recall = recall_score(true, predictions)
    f1 = f1_score(true, predictions)
    
    print('\nAccuracy:', accuracy)
    print('Precision:', precision)
    print('Recall:', recall)
    print('F1 Score:', f1)
    
    return accuracy, precision, recall, f1

trainset = FakeNewsDataset('train', tokenizer=tokenizer, path='./')
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, collate_fn=create_mini_batch)

# fine-tuning
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device:', device)
model.to(device)

model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
NUM_EPOCHS = 10

for epoch in range(NUM_EPOCHS):
    train_loss = 0.0
    train_acc = 0.0
    
    loop = tqdm(trainloader)
    for batch_idx, data in enumerate(loop):
        tokens_tensors, attention_tensors, labels = [t.to(device) for t in data]
        
        optimizer.zero_grad()
        
        outputs = model(input_ids=tokens_tensors,
                        attention_mask=attention_tensors,
                        labels=labels)
        
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        
        logits = outputs.logits
        _, pred = torch.max(logits.data, 1)
        train_acc = accuracy_score(pred.cpu().tolist(), labels.cpu().tolist())
        
        train_loss += loss.item()
        
        loop.set_description(f'Epoch [{epoch + 1}/{NUM_EPOCHS}]')
        loop.set_postfix(acc=train_acc, loss=train_loss/(batch_idx+1))


valset_distilbert = FakeNewsDataset('val', tokenizer=tokenizer, path='./')
print('DistilBERT valset size:', valset_distilbert.__len__())
valloader_distilbert = DataLoader(valset_distilbert, batch_size=BATCH_SIZE, collate_fn=create_mini_batch)

print('DistilBERT:')
print('Confusion Matrix:')
evaluate(valloader_distilbert)
