# Predicting the veracity of claims in LIAR 

In [None]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset


import transformers as ppb

In [None]:
df = pd.read_csv('train.tsv', delimiter='\t', header = None)

df.rename({0: 'id', 1: 'label', 2: 'statement', 3: 'subject', 4: 'speaker', 5: 'job-title',
           6: 'state_info', 7: 'party_affiliation', 8: 'barely_true_counts', 9: 'false_counts',
           10: 'half_true_counts', 11: 'mostly_true_counts', 12: 'pants_on_fire_counts', 13: 'context'
          }, axis = 1, inplace = True)

In [None]:
mapping = {'false': 0, 'half-true': 1, 'mostly-true': 2, 'true': 3, 'barely-true': 4,
       'pants-fire': 5}

df['label'] = df['label'].replace(mapping)

In [None]:
#training on small subset for now
batch_1 = df[:1000]

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print('cuda')
else:
    device = torch.device("cpu")

In [None]:
from transformers import BertTokenizer, BertForSequenceClassification

In [None]:
tuned_bert = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels = 6).to(device)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Tokenize
train_texts = list(batch_1['statement'])
train_labels = batch_1['label']

encoded_data = tokenizer(train_texts, padding=True, 
                         truncation=True, return_tensors='pt', 
                         return_attention_mask=True).to(device)



# DataLoader
train_dataset = TensorDataset(encoded_data.input_ids.to(device),
                              encoded_data.attention_mask.to(device), torch.tensor(train_labels).to(device))
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)


# Dropout
dropout_rate = 0.6
tuned_bert.classifier.dropout = nn.Dropout(p=dropout_rate).to(device)


# Fine-tuning
optimizer = torch.optim.AdamW(tuned_bert.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()
epochs = 5

for epoch in range(epochs):
    tuned_bert.train()
    total_loss = 0.0
    for batch_input_ids, batch_attention_mask, batch_labels in train_loader:
        optimizer.zero_grad()
        outputs = tuned_bert(input_ids=batch_input_ids, attention_mask=batch_attention_mask)
        loss = criterion(outputs.logits, batch_labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    average_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch + 1}/{epochs}] Loss: {average_loss:.4f}")

print("Training finished")

In [None]:
predicted_labels = []


for index, row in batch_1.iterrows():
    text = row['statement']

    # Encode input text
    input_ids = tokenizer.encode(text, add_special_tokens=True, padding=True, truncation=True, return_tensors='pt')
    attention_mask = (input_ids != tokenizer.pad_token_id).float()

    with torch.no_grad():
        tuned_bert.eval()
        logits = tuned_bert(input_ids.to(device), attention_mask=attention_mask.to(device)).logits
        probabilities = torch.softmax(logits, dim=1)
        predicted_class = torch.argmax(probabilities, dim=1).item()
        

    predicted_labels.append(predicted_class)

In [None]:
#around 0.9
(predicted_labels == batch_1['label']).mean()

In [None]:
batch_2 = df[1000:2000]

predicted_labels_2 = []

for index, row in batch_2.iterrows():
    text = row['statement']

    # Encode input text
    input_ids = tokenizer.encode(text, add_special_tokens=True, padding=True, truncation=True, return_tensors='pt')
    attention_mask = (input_ids != tokenizer.pad_token_id).float()

    with torch.no_grad():
        tuned_bert.eval()
        logits = tuned_bert(input_ids.to(device), attention_mask=attention_mask.to(device)).logits
        probabilities = torch.softmax(logits, dim=1)
        predicted_class = torch.argmax(probabilities, dim=1).item()
        

    predicted_labels_2.append(predicted_class)


In [None]:
#around 0.22, need to revisit and tune!
(predicted_labels_2 == batch_2['label']).mean()