In [11]:
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import (
    DistilBertTokenizer,
    DistilBertModel,
    get_linear_schedule_with_warmup,
    DataCollatorWithPadding,
)

In [12]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

In [13]:
labels_to_num = {
    "0_not_relevant": 0,
    "1_not_happening": 1,
    "2_not_human": 2,
    "3_not_bad": 3,
    "4_solutions_harmful_unnecessary": 4,
    "5_science_unreliable": 5,
    "6_proponents_biased": 6,
    "7_fossil_fuels_needed": 7,
}


def label_to_class(label: str):
    return labels_to_num[label]

In [14]:
dataset = load_dataset("quotaclimat/frugalaichallenge-text-train")

In [15]:
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")


def tokenize(datum):
    tokens = tokenizer(
        datum["quote"], truncation=True, padding="max_length", max_length=256
    )
    tokens["label"] = labels_to_num[datum["label"]]
    return tokens


tokenized = dataset.map(tokenize)

In [16]:
tokenized = tokenized.rename_column("label", "labels")
tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

In [17]:
class DistilBERTClassifier(nn.Module):
    def __init__(self, num_labels=8):
        super().__init__()
        self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
        self.softmax = nn.Softmax(1)

    def forward(self, input_ids, attention_mask, labels=None):
        output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = output.last_hidden_state[:, 0]  # [CLS] token
        x = self.dropout(pooled_output)
        logits = self.classifier(x)

        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)

        return logits, loss

In [18]:
data_collator = DataCollatorWithPadding(
    tokenizer=tokenizer, return_tensors="pt", padding=True
)

In [25]:
model = DistilBERTClassifier(num_labels=8).to(device)

train_loader = DataLoader(
    tokenized["train"], batch_size=16, shuffle=True, collate_fn=data_collator
)
eval_loader = DataLoader(tokenized["test"], batch_size=32, collate_fn=data_collator)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

num_epochs = 10
total_steps = len(train_loader) * num_epochs

scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps
)

In [26]:
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch in train_loader:
        b_input_ids = batch["input_ids"].to(device)
        b_attention_mask = batch["attention_mask"].to(device)
        b_labels = batch["labels"].to(device)

        optimizer.zero_grad()
        logits, loss = model(b_input_ids, b_attention_mask, labels=b_labels)
        loss.backward()
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1} - Train Loss: {avg_loss:.3f}")


Epoch 1 - Train Loss: 1.608
Epoch 2 - Train Loss: 0.907
Epoch 3 - Train Loss: 0.555
Epoch 4 - Train Loss: 0.317
Epoch 5 - Train Loss: 0.152
Epoch 6 - Train Loss: 0.081
Epoch 7 - Train Loss: 0.045
Epoch 8 - Train Loss: 0.032
Epoch 9 - Train Loss: 0.025
Epoch 10 - Train Loss: 0.019


In [27]:
torch.save(model.state_dict(), "distilbert_climate_classifier.pt")

In [None]:
model = DistilBERTClassifier()
state_dict = torch.load(("./distilbert_climate_classifier.pt"))
model.load_state_dict(state_dict)
model.to(device)

DistilBERTClassifier(
  (bert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): L

In [28]:
from sklearn.metrics import classification_report

labels = [
    "0_not_relevant",
    "1_not_happening",
    "2_not_human",
    "3_not_bad",
    "4_solutions_harmful_unnecessary",
    "5_science_unreliable",
    "6_proponents_biased",
    "7_fossil_fuels_needed",
]

model.eval()
all_preds = []
all_labels = []

with torch.inference_mode():
    for batch in eval_loader:
        b_input_ids = batch["input_ids"].to(device)
        b_attention_mask = batch["attention_mask"].to(device)
        b_labels = batch["labels"].to(device)

        logits, _ = model(b_input_ids, b_attention_mask)
        preds = torch.argmax(logits, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(b_labels.cpu().numpy())

print(classification_report(all_labels, all_preds, target_names=labels))


                                 precision    recall  f1-score   support

                 0_not_relevant       0.75      0.76      0.76       307
                1_not_happening       0.71      0.78      0.75       154
                    2_not_human       0.66      0.65      0.65       137
                      3_not_bad       0.70      0.62      0.66        97
4_solutions_harmful_unnecessary       0.69      0.71      0.70       160
           5_science_unreliable       0.60      0.66      0.63       160
            6_proponents_biased       0.67      0.57      0.61       139
          7_fossil_fuels_needed       0.61      0.55      0.58        65

                       accuracy                           0.69      1219
                      macro avg       0.67      0.66      0.67      1219
                   weighted avg       0.69      0.69      0.69      1219

