In [1]:
import pandas as pd
import torch
import os
import tqdm
import zipfile
from conllu import parse
from torch.utils.data.dataset import Dataset
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer
from transformers import BertForSequenceClassification
from torch.utils.data import DataLoader
from torch.optim import AdamW
from sklearn.metrics import hamming_loss, f1_score, classification_report

  from .autonotebook import tqdm as notebook_tqdm


# Defining constants and paths
Here, we specify the target language and the paths to the raw and preprocessed data.
We also unzip the raw dataset.

In [2]:
TARGET_LANG = ['EN', 'BG', 'PT', 'RU']

RAW_DATASET_PATH = '../data/raw/target_4_December_release'
PREPROCESSED_DATASET_PATH = '../data/preprocessed/preprocessed_target_4_December_release'

LABELS_PATH = [os.path.join(RAW_DATASET_PATH, lang, 'subtask-2-annotations.txt') for lang in TARGET_LANG]
INPUTS_PATH = [os.path.join(PREPROCESSED_DATASET_PATH, lang) for lang in TARGET_LANG]

with zipfile.ZipFile(RAW_DATASET_PATH + '.zip', 'r') as zip_ref:
    zip_ref.extractall(RAW_DATASET_PATH, pwd=b'narratives5202trainTHREE')
    

with zipfile.ZipFile(PREPROCESSED_DATASET_PATH + '.zip', 'r') as zip_ref:
    zip_ref.extractall(PREPROCESSED_DATASET_PATH)

# Loading and mapping narrative labels
load the label file and map article IDs to their respective narratives and subnarratives.

In [3]:
def load_and_map_labels(label_file_paths: list[str]):
    all_labels = []
    
    for label_file_path in label_file_paths:
        labels_df = pd.read_csv(
            label_file_path, 
            sep="\t", 
            header=None, 
            names=["article_id", "narratives", "subnarratives"]
        )
        
        labels_df.head()
        
        for _, row in labels_df.iterrows():
            all_labels.append({
                "article_id": row["article_id"],
                "narratives": row["narratives"].split(";") if pd.notna(row["narratives"]) else [],
                "subnarratives": row["subnarratives"].split(";") if pd.notna(row["subnarratives"]) else []
            })
    
    labels_mapping = pd.DataFrame(all_labels)
    return labels_mapping

labels = load_and_map_labels(LABELS_PATH)
labels.head()

Unnamed: 0,article_id,narratives,subnarratives
0,EN_CC_100013.txt,[CC: Criticism of climate movement],[CC: Criticism of climate movement: Ad hominem...
1,EN_UA_300009.txt,[Other],[Other]
2,EN_UA_300017.txt,[Other],[Other]
3,EN_CC_100021.txt,[Other],[Other]
4,EN_UA_300041.txt,[Other],[Other]


# Mapping articles to labels
Takes the article IDs and corresponding labels, reads the associated `.conllu` file, and creates a dataframe that maps the text and labels together.

In [4]:
def parse_conllu_file(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        data = f.read()
    token_lists = parse(data)
    all_tokens = [token["form"] for token_list in token_lists for token in token_list]
    return " ".join(all_tokens)

def map_input_to_label(articles_paths: list[str], article_ids: list[str], labels: pd.DataFrame):
    labels = labels.set_index("article_id")
    
    articles_data = []
    for articles_path in articles_paths:
        for article_id in article_ids:
            file_path = os.path.join(articles_path, f"{article_id.replace('.txt', '.conllu')}")
            if os.path.exists(file_path) and article_id in labels.index:
                article_text = parse_conllu_file(file_path)
                article_labels = labels.loc[article_id]
                articles_data.append({
                    "article_id": article_id,
                    "text": article_text,
                    "narratives": article_labels["narratives"],
                    "subnarratives": article_labels["subnarratives"]
                })
    return pd.DataFrame(articles_data)

article_ids = labels["article_id"]
article_ids.head()
df = map_input_to_label(INPUTS_PATH, article_ids, labels)
df.head()

Unnamed: 0,article_id,text,narratives,subnarratives
0,EN_CC_100013.txt,bill gates says solution climate change ok fou...,[CC: Criticism of climate movement],[CC: Criticism of climate movement: Ad hominem...
1,EN_UA_300009.txt,russia clashes erupt bashkortostan rights acti...,[Other],[Other]
2,EN_UA_300017.txt,mcdonalds exit russia sell business country am...,[Other],[Other]
3,EN_CC_100021.txt,collaborative plans innovation keys circular r...,[Other],[Other]
4,EN_UA_300041.txt,russia intends supply light mountain tanks inf...,[Other],[Other]


In [5]:
# Remove 50% of the articles with the "Other" label
other_df = df[
    df["narratives"].apply(lambda x: any("Other" in item for item in x)) & 
    df["subnarratives"].apply(lambda x: any("Other" in item for item in x))
].sample(frac=0.7, random_state=42)

print(f"Removing {len(other_df)} articles with the 'Other' label")

df = df.drop(other_df.index)

df.info()

Removing 137 articles with the 'Other' label
<class 'pandas.core.frame.DataFrame'>
Index: 795 entries, 0 to 931
Data columns (total 4 columns):
 #   Column         Non-Null Count  Dtype 
---  ------         --------------  ----- 
 0   article_id     795 non-null    object
 1   text           795 non-null    object
 2   narratives     795 non-null    object
 3   subnarratives  795 non-null    object
dtypes: object(4)
memory usage: 31.1+ KB


# Initializing the BERT tokenizer
Loads a multilingual BERT tokenizer for processing the article text.


In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

# Generating label vocabularies
We gather all unique narratives and subnarratives from the dataset, and create mapping dictionaries to convert them into numerical indices.

In [None]:
all_narratives = df["narratives"].explode().unique()
all_subnarratives = df["subnarratives"].explode().unique()

narrative_to_index = {n: i for i, n in enumerate(all_narratives)}
subnarrative_to_index = {sn: i for i, sn in enumerate(all_subnarratives)}

def encode_labels(narratives, subnarratives):
    narrative_vector = [0] * len(all_narratives)
    subnarrative_vector = [0] * len(all_subnarratives)

    for n in narratives:
        narrative_vector[narrative_to_index[n]] = 1
    for sn in subnarratives:
        subnarrative_vector[subnarrative_to_index[sn]] = 1

    return narrative_vector + subnarrative_vector

df["labels"] = df.apply(lambda x: encode_labels(x["narratives"], x["subnarratives"]), axis=1)

df = pd.DataFrame(df)
df.head()

Unnamed: 0,article_id,text,narratives,subnarratives,labels
0,EN_CC_100013.txt,bill gates says solution climate change ok fou...,[CC: Criticism of climate movement],[CC: Criticism of climate movement: Ad hominem...,"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,EN_UA_300009.txt,russia clashes erupt bashkortostan rights acti...,[Other],[Other],"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,EN_UA_300017.txt,mcdonalds exit russia sell business country am...,[Other],[Other],"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,EN_CC_100021.txt,collaborative plans innovation keys circular r...,[Other],[Other],"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,EN_UA_300041.txt,russia intends supply light mountain tanks inf...,[Other],[Other],"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


# Creating a PyTorch Dataset class
`NarrativeDataset` is a custom dataset that tokenizes text on the fly and prepares inputs and labels for training.


In [None]:
class NarrativeDataset(Dataset):
    def __init__(self, articles, tokenizer, max_len):
        self.articles = articles
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.articles)

    def __getitem__(self, idx):
        article = self.articles.iloc[idx]
        inputs = self.tokenizer(
            article["text"],
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "labels": torch.tensor(article["labels"], dtype=torch.float32)
        }


# Defining a prediction helper function
This function generates predictions from the model and applies a threshold to determine class membership. It returns the predicted and true labels.


In [None]:
def get_predictions(model, data_loader, device, threshold=0.2):
    model.eval()
    all_predictions = []
    all_labels = []
    with torch.no_grad():
        for batch in data_loader:
            # Move data to device
            inputs = {key: val.to(device) for key, val in batch.items() if key != "labels"}
            labels = batch["labels"].to(device)
            
            # Forward pass
            outputs = model(**inputs)
            logits = outputs.logits
            probs = torch.sigmoid(logits)  # Convert logits to probabilities
            preds = (probs > threshold).int()  # Apply threshold 
            
            all_predictions.append(preds.cpu())
            all_labels.append(labels.cpu())
    
    # Concatenate predictions and labels
    all_predictions = torch.cat(all_predictions, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    return all_predictions.numpy(), all_labels.numpy()

# Evaluating the model
A utility function to compute Hamming loss, Macro/Micro F1, and Subset Accuracy. It can also print a detailed classification report.


In [None]:
def evaluate_model(y_pred, y_true, class_labels, print_report=False):
    hamming = hamming_loss(y_true, y_pred)

    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    micro_f1 = f1_score(y_true, y_pred, average='micro', zero_division=0)

    subset_accuracy = (y_true == y_pred).all(axis=1).mean()

    if print_report:
        report = classification_report(
            y_true, y_pred, target_names=class_labels, digits=2, zero_division=0
        )
        print("\nClassification Report:\n")
        print(report)

    return {"Hamming Loss": hamming, "Macro F1": macro_f1, "Micro F1": micro_f1, "Subset Accuracy": subset_accuracy}


# Setting up the model and device
Loads a pre-trained multilingual BERT and sets it up for multi-label classification on our dataset. Moves model to GPU if available.


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_labels = len(all_narratives) + len(all_subnarratives)
model = BertForSequenceClassification.from_pretrained(
    "bert-base-multilingual-cased", 
    num_labels=num_labels
)
model = model.to(device)
model.train()

# Preparing data loaders for training and validation
Splits the data into train and validation sets, and creates `DataLoader` instances for each.


In [None]:
train_data, val_data = train_test_split(df, test_size=0.2, random_state=42)
train_dataset = NarrativeDataset(train_data, tokenizer, max_len=512)
val_dataset = NarrativeDataset(val_data, tokenizer, max_len=512)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, pin_memory=True)

# Training loop
Trains the model for a specified number of epochs, reports batch loss, and evaluates on the validation set after each epoch.


In [None]:
optimizer = AdamW(model.parameters(), lr=5e-5)
num_epochs = 3

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    model.train()
    epoch_loss = 0
    progress_bar = tqdm.tqdm(train_loader, desc="Processing Batches", leave=True)
    
    for batch in progress_bar:
        print(f"Batch labels shape: {batch['labels'].shape}")

        optimizer.zero_grad()

        # Prepare inputs
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        # Forward pass
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        # Backward pass
        loss.backward()
        optimizer.step()

        # Track loss
        epoch_loss += loss.item()
        progress_bar.set_postfix({"Batch Loss": f"{loss.item():.4f}"})

    # Validation
    y_pred, y_true = get_predictions(model, val_loader, device, threshold=0.3)
    val_metrics = evaluate_model(y_pred, y_true, all_narratives.tolist() + all_subnarratives.tolist())
    print(f"- Validation Micro F1: {val_metrics['Micro F1']:.4f} | Macro F1: {val_metrics['Macro F1']:.4f} | Hamming Loss: {val_metrics['Hamming Loss']:.4f}")

  return torch._C._cuda_getDeviceCount() > 0
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/3


Processing Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  10%|█         | 1/10 [00:23<03:34, 23.83s/it, Batch Loss=0.7114]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  20%|██        | 2/10 [00:50<03:23, 25.39s/it, Batch Loss=0.6805]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  30%|███       | 3/10 [01:15<02:56, 25.14s/it, Batch Loss=0.6478]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  40%|████      | 4/10 [01:39<02:28, 24.70s/it, Batch Loss=0.6142]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  50%|█████     | 5/10 [02:03<02:02, 24.48s/it, Batch Loss=0.5979]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  60%|██████    | 6/10 [02:27<01:37, 24.31s/it, Batch Loss=0.5801]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  70%|███████   | 7/10 [02:51<01:12, 24.18s/it, Batch Loss=0.5632]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  80%|████████  | 8/10 [03:15<00:48, 24.17s/it, Batch Loss=0.5556]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  90%|█████████ | 9/10 [03:39<00:24, 24.09s/it, Batch Loss=0.5352]

Batch labels shape: torch.Size([31, 106])


Processing Batches: 100%|██████████| 10/10 [04:02<00:00, 24.25s/it, Batch Loss=0.5210]


- Validation F1: 0.0690 | Hamming Loss: 0.9290
Epoch 2/3


Processing Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  10%|█         | 1/10 [00:23<03:33, 23.69s/it, Batch Loss=0.5106]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  20%|██        | 2/10 [00:47<03:10, 23.85s/it, Batch Loss=0.4946]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  30%|███       | 3/10 [01:11<02:47, 23.89s/it, Batch Loss=0.4871]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  40%|████      | 4/10 [01:35<02:23, 23.94s/it, Batch Loss=0.4737]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  50%|█████     | 5/10 [01:59<01:59, 23.93s/it, Batch Loss=0.4638]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  60%|██████    | 6/10 [02:23<01:35, 23.94s/it, Batch Loss=0.4514]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  70%|███████   | 7/10 [02:47<01:11, 23.89s/it, Batch Loss=0.4403]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  80%|████████  | 8/10 [03:11<00:47, 23.88s/it, Batch Loss=0.4346]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  90%|█████████ | 9/10 [03:35<00:23, 24.00s/it, Batch Loss=0.4162]

Batch labels shape: torch.Size([31, 106])


Processing Batches: 100%|██████████| 10/10 [03:58<00:00, 23.86s/it, Batch Loss=0.3980]


- Validation F1: 0.0914 | Hamming Loss: 0.4805
Epoch 3/3


Processing Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  10%|█         | 1/10 [00:24<03:37, 24.16s/it, Batch Loss=0.4011]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  20%|██        | 2/10 [00:48<03:12, 24.08s/it, Batch Loss=0.3829]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  30%|███       | 3/10 [01:12<02:48, 24.01s/it, Batch Loss=0.3761]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  40%|████      | 4/10 [01:36<02:23, 23.97s/it, Batch Loss=0.3675]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  50%|█████     | 5/10 [02:00<02:00, 24.06s/it, Batch Loss=0.3616]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  60%|██████    | 6/10 [02:23<01:35, 23.95s/it, Batch Loss=0.3476]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  70%|███████   | 7/10 [02:47<01:11, 23.88s/it, Batch Loss=0.3451]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  80%|████████  | 8/10 [03:11<00:47, 23.88s/it, Batch Loss=0.3331]

Batch labels shape: torch.Size([32, 106])


Processing Batches:  90%|█████████ | 9/10 [03:35<00:23, 23.94s/it, Batch Loss=0.3227]

Batch labels shape: torch.Size([31, 106])


Processing Batches: 100%|██████████| 10/10 [03:58<00:00, 23.87s/it, Batch Loss=0.3177]


- Validation F1: 0.1593 | Hamming Loss: 0.1170


# Final evaluation on the validation set
After training, we run the evaluation once more and print out a full classification report.

In [None]:
print("Evaluating on Validation Set...")
y_pred, y_true = get_predictions(model, val_loader, device, threshold=0.3)
val_metrics = evaluate_model(y_pred, y_true, all_narratives.tolist() + all_subnarratives.tolist(), print_report=True)

Evaluating on Validation Set...

Classification Report:

                                                                                                                   precision    recall  f1-score   support

                                                                                CC: Criticism of climate movement       0.00      0.00      0.00         8
                                                                                                            Other       0.47      1.00      0.64        38
                                                                     CC: Questioning the measurements and science       0.00      0.00      0.00         2
                                                                                    URW: Speculating war outcomes       0.00      0.00      0.00         6
                                                                                            URW: Praise of Russia       0.00      0.00      0.00         6
            