In [None]:
%pip install kagglehub
%pip install transformers
%pip install torch torchvision torchaudio
%pip install transformers datasets torch

import kagglehub
import os
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DistilBertForSequenceClassification

download_folder = '.'
path = kagglehub.dataset_download("surajkarakulath/labelled-corpus-political-bias-hugging-face")

data_dir = "/root/.cache/kagglehub/datasets/surajkarakulath/labelled-corpus-political-bias-hugging-face/versions/1"
label_mapping = {"Left Data": 0, "Center Data": 1, "Right Data": 2}

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

class BiasDataset(Dataset):
    def __init__(self, data_dir, tokenizer):
        self.data = []
        self.tokenizer = tokenizer

        for bias_label in ["Left Data", "Center Data", "Right Data"]:
            bias_path = os.path.join(data_dir, bias_label, bias_label)
            label = label_mapping[bias_label]
            
            for txt_file in os.listdir(bias_path):
                if txt_file.endswith(".txt"):
                    file_path = os.path.join(bias_path, txt_file)
                    with open(file_path, "r", encoding="utf-8") as f:
                        text = f.read()
                    self.data.append((text, label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text, label = self.data[idx]
        encoding = self.tokenizer(
            text, return_tensors="pt", truncation=True, padding="max_length", max_length=512
        )
        encoding.pop("token_type_ids", None)  
        return {key: val.squeeze(0) for key, val in encoding.items()}, torch.tensor(label)

dataset = BiasDataset(data_dir, tokenizer)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)

teacher_model = AutoModelForSequenceClassification.from_pretrained("bucketresearch/politicalBiasBERT")
student_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-cased", num_labels=3)

def distillation_loss(student_logits, teacher_logits, temperature=4.0):
    teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
    student_probs = F.log_softmax(student_logits / temperature, dim=-1)
    loss = F.kl_div(student_probs, teacher_probs, reduction="batchmean") * (temperature ** 2)
    return loss

optimizer = optim.AdamW(student_model.parameters(), lr=2e-5)

epochs = 3
device = "cuda" if torch.cuda.is_available() else "cpu"

teacher_model.to(device)
student_model.to(device)

if device == "cuda":
    torch.cuda.empty_cache()
    torch.backends.cudnn.benchmark = True

for epoch in range(epochs):
    student_model.train()
    total_loss = 0
    num_batches = len(train_dataloader)

    for batch_idx, batch in enumerate(train_dataloader):
        inputs, labels = batch
        inputs = {key: val.to(device) for key, val in inputs.items()}
        labels = labels.to(device)
        
        with torch.no_grad():
            teacher_outputs = teacher_model(**inputs)
            teacher_logits = teacher_outputs.logits

        student_outputs = student_model(**inputs)
        student_logits = student_outputs.logits

        loss_ce = F.cross_entropy(student_logits, labels)
        loss_kd = distillation_loss(student_logits, teacher_logits)
        loss = 0.5 * loss_ce + 0.5 * loss_kd

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()

    if device == "cuda":
        torch.cuda.empty_cache()

def evaluate_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in dataloader:
            inputs, labels = batch
            inputs = {key: val.to(device) for key, val in inputs.items()}
            labels = labels.to(device)

            outputs = model(**inputs)
            logits = outputs.logits
            predictions = torch.argmax(logits, dim=-1)

            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    return correct / total

student_accuracy = evaluate_model(student_model, test_dataloader)

student_model.save_pretrained("distilled_biasBERT")
tokenizer.save_pretrained("distilled_biasBERT")
