# 1. Imports and Constants
Core dependencies and configuration constants for the narrative classification system.

In [1]:
import pandas as pd
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss
import os
import tqdm
import zipfile
from conllu import parse
from torch.utils.data.dataset import Dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import hamming_loss, f1_score, classification_report, precision_recall_fscore_support
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, WeightedRandomSampler
from torch.optim import AdamW
import numpy as np

# Configuration constants
TARGET_LANG = ['EN', 'PT', 'RU']
RAW_DATASET_PATH = '../data/raw/target_4_December_release'
PREPROCESSED_DATASET_PATH = '../data/preprocessed/preprocessed_target_4_December_release'
LABELS_PATH = [os.path.join(RAW_DATASET_PATH, lang, 'subtask-2-annotations.txt') for lang in TARGET_LANG]
INPUTS_PATH = [os.path.join(PREPROCESSED_DATASET_PATH, lang) for lang in TARGET_LANG]

  from .autonotebook import tqdm as notebook_tqdm


# 2. Data Loading and Preprocessing
Functions for loading and preprocessing the narrative classification dataset.
Includes file extraction, label mapping, and text processing.

In [2]:
def extract_datasets():
    """Extract datasets from zip files if not already extracted."""
    if not os.path.exists(RAW_DATASET_PATH):
        with zipfile.ZipFile(RAW_DATASET_PATH + '.zip', 'r') as zip_ref:
            zip_ref.extractall(RAW_DATASET_PATH, pwd=b'narratives5202trainTHREE')
    
    if not os.path.exists(PREPROCESSED_DATASET_PATH):
        with zipfile.ZipFile(PREPROCESSED_DATASET_PATH + '.zip', 'r') as zip_ref:
            zip_ref.extractall(PREPROCESSED_DATASET_PATH)

def load_and_map_labels(label_file_paths):
    """Load and process narrative labels from files."""
    all_labels = []
    all_narratives_set = set()
    all_subnarratives_set = set()
    
    for label_file_path in label_file_paths:
        labels_df = pd.read_csv(
            label_file_path, 
            sep="\t", 
            header=None, 
            names=["article_id", "narratives", "subnarratives"]
        )
        
        for _, row in labels_df.iterrows():
            narratives = row["narratives"].split(";") if pd.notna(row["narratives"]) else []
            subnarratives = row["subnarratives"].split(";") if pd.notna(row["subnarratives"]) else []
            
            all_narratives_set.update(narratives)
            all_subnarratives_set.update(subnarratives)
            
            all_labels.append({
                "article_id": row["article_id"],
                "narratives": narratives,
                "subnarratives": subnarratives
            })
    
    # Remove empty strings and sort
    all_narratives = sorted(list(all_narratives_set - {''} if '' in all_narratives_set else all_narratives_set))
    all_subnarratives = sorted(list(all_subnarratives_set - {''} if '' in all_subnarratives_set else all_subnarratives_set))
    
    return pd.DataFrame(all_labels), all_narratives, all_subnarratives

def parse_conllu_file(file_path):
    """Parse a CoNLL-U format file and concatenate tokens."""
    with open(file_path, "r", encoding="utf-8") as f:
        data = f.read()
    token_lists = parse(data)
    return " ".join(token["form"] for token_list in token_lists for token in token_list)

def map_input_to_label_with_lang(articles_paths, article_ids, labels):
    """Map articles to labels and add language information."""
    labels = labels.set_index("article_id")
    articles_data = []
    
    for articles_path in articles_paths:
        lang = articles_path.split('/')[-1]
        
        for article_id in article_ids:
            file_path = os.path.join(articles_path, f"{article_id.replace('.txt', '.conllu')}")
            if os.path.exists(file_path) and article_id in labels.index:
                article_text = parse_conllu_file(file_path)
                article_labels = labels.loc[article_id]
                articles_data.append({
                    "article_id": article_id,
                    "text": article_text,
                    "narratives": article_labels["narratives"],
                    "subnarratives": article_labels["subnarratives"],
                    "language": lang
                })
    return pd.DataFrame(articles_data)

# 3. Dataset and Model Components
Core components for the narrative classification model including custom dataset,
loss function, and evaluation metrics.

In [3]:
class NarrativeDataset(Dataset):
    """Custom dataset for narrative classification."""
    def __init__(self, articles, tokenizer, max_len, task_type='narrative'):
        self.articles = articles
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.task_type = task_type

    def __len__(self):
        return len(self.articles)

    def __getitem__(self, idx):
        article = self.articles.iloc[idx]
        inputs = self.tokenizer(
            article["text"],
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        labels = article[f"{self.task_type}_labels"]
        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "labels": torch.tensor(labels, dtype=torch.float32)
        }

class CustomBCEWithLogitsLoss(nn.Module):
    """Custom BCE loss with label-dependent weighting."""
    def __init__(self, pos_weight=None, reduction='mean'):
        super().__init__()
        self.pos_weight = pos_weight
        self.reduction = reduction
        
    def forward(self, logits, target):
        batch_pos_counts = torch.sum(target, dim=0)
        batch_neg_counts = target.size(0) - batch_pos_counts
        
        eps = 1e-7
        batch_weights = (batch_neg_counts + eps) / (batch_pos_counts + eps)
        
        if self.pos_weight is not None:
            batch_weights = batch_weights * self.pos_weight
            
        loss_fn = BCEWithLogitsLoss(pos_weight=batch_weights, reduction=self.reduction)
        return loss_fn(logits, target)


# 4. Data Preparation and Analysis
Functions for preparing data splits, analyzing label distributions,
and computing class weights.

In [4]:
def process_label_data(df):
    """Prepare data with improved label handling and class weights."""
    mlb_narrative = MultiLabelBinarizer()
    mlb_subnarrative = MultiLabelBinarizer()
    
    narrative_labels = mlb_narrative.fit_transform(df["narratives"])
    subnarrative_labels = mlb_subnarrative.fit_transform(df["subnarratives"])
    
    narrative_classes = list(mlb_narrative.classes_)
    subnarrative_classes = list(mlb_subnarrative.classes_)
    
    narrative_weights = compute_class_weights(narrative_labels)
    subnarrative_weights = compute_class_weights(subnarrative_labels)
    
    df = df.copy()
    df["narrative_labels"] = list(narrative_labels)
    df["subnarrative_labels"] = list(subnarrative_labels)
    
    return df, narrative_weights, subnarrative_weights, narrative_classes, subnarrative_classes

def prepare_data_for_split(df, narrative_classes, subnarrative_classes):
    """
    Prepare validation or test split using the classes from training data.
    """
    # Encode narrative labels using training classes
    narrative_labels = df["narratives"].apply(
        lambda x: [1 if n in x else 0 for n in narrative_classes]
    ).tolist()
    
    # Encode subnarrative labels using training classes
    subnarrative_labels = df["subnarratives"].apply(
        lambda x: [1 if sn in x else 0 for sn in subnarrative_classes]
    ).tolist()
    
    # Add encoded labels to dataframe
    df = df.copy()
    df["narrative_labels"] = narrative_labels
    df["subnarrative_labels"] = subnarrative_labels
      
    return df


def create_language_specific_splits(df, test_size=0.1, val_size=0.1, train_target_ratio=0.6, random_state=42, target_lang='EN'):
    """
    Create train/val/test splits with controlled amount of target language in training.
    """
    
    target_lang_df = df[df['language'] == target_lang]
    other_langs_df = df[df['language'] != target_lang]
    
    total_target = len(target_lang_df)
    test_samples = int(total_target * test_size)
    val_samples = int(total_target * val_size)
    train_samples = int(total_target * train_target_ratio)
    
    remaining_target_df, test_df = train_test_split(
        target_lang_df,
        test_size=test_samples,
        random_state=random_state
    )
    
    train_target_df, val_df = train_test_split(
        remaining_target_df,
        test_size=val_samples,
        random_state=random_state
    )
    
    # If we want more training samples, take them from what's left
    if len(train_target_df) > train_samples:
        train_target_df = train_target_df.sample(n=train_samples, random_state=random_state)
    
    # Combine target language training data with other languages
    train_df = pd.concat([other_langs_df, train_target_df])
    
    train_df = train_df.sample(frac=1, random_state=random_state).reset_index(drop=True)
    
    # Print detailed statistics
    print("\nData split sizes:")
    print(f"Total samples: {len(train_df)}")
    print("Language distribution:")
    print(train_df['language'].value_counts())
    print(f"Validation set ({target_lang} only):")
    print(f"Total samples: {len(val_df)}")
    print(f"Test set ({target_lang} only):")
    print(f"Total samples: {len(test_df)}")
    
    return train_df, val_df, test_df

def compute_class_weights(labels):
    """Compute balanced class weights with clipping."""
    pos_counts = np.sum(labels, axis=0)
    neg_counts = len(labels) - pos_counts
    weights = neg_counts / (pos_counts + 1e-7)
    return torch.FloatTensor(np.clip(weights, 0.1, 10.0))

def create_weighted_sampler(labels):
    """Create weighted sampler for handling class imbalance."""
    label_counts = np.sum(labels, axis=0)
    weights = 1.0 / label_counts
    weights = np.nan_to_num(weights, nan=1.0, posinf=1.0)
    sample_weights = np.sum(labels * weights, axis=1)
    return WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )

def analyze_rare_labels(df):
    """Analyze and report rare labels in the dataset."""
    narrative_counts = {}
    subnarrative_counts = {}
    
    for narratives in df["narratives"]:
        for n in narratives:
            narrative_counts[n] = narrative_counts.get(n, 0) + 1
    
    for subnarratives in df["subnarratives"]:
        for sn in subnarratives:
            subnarrative_counts[sn] = subnarrative_counts.get(sn, 0) + 1
    
    rare_narratives = {k: v for k, v in narrative_counts.items() if v <= 2}
    rare_subnarratives = {k: v for k, v in subnarrative_counts.items() if v <= 2}
    
    print("\nRare Label Analysis:")
    print(f"Total unique narratives: {len(narrative_counts)}")
    print(f"Rare narratives (<=2 occurrences): {len(rare_narratives)}")
    print(f"Total unique subnarratives: {len(subnarrative_counts)}")
    print(f"Rare subnarratives (<=2 occurrences): {len(rare_subnarratives)}")
    
    return rare_narratives, rare_subnarratives

# 5. Model Training and Evaluation
Functions for model training, threshold optimization, and performance evaluation.

In [5]:
def train_model_improved(model, train_loader, val_loader, device, task_type, narrative_classes, subnarrative_classes, class_weights, num_epochs=5):
    """Train model with improved handling of multilabel classification."""
    optimizer = AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)
    criterion = CustomBCEWithLogitsLoss(pos_weight=class_weights.to(device))
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)
    
    best_f1 = 0
    patience = 5
    patience_counter = 0
    best_threshold = 0.5
    thresholds = np.arange(0.1, 0.9, 0.1)
    
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        progress_bar = tqdm.tqdm(train_loader, desc=f"Training {task_type} model - Epoch {epoch + 1}/{num_epochs}")
        
        for batch in progress_bar:
            optimizer.zero_grad()
            inputs = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**inputs)
            loss = criterion(outputs.logits, inputs['labels'])
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            
            epoch_loss += loss.item()
            progress_bar.set_postfix({"Loss": f"{loss.item():.4f}"})
        
        best_threshold = optimize_threshold(model, val_loader, device, thresholds)
        y_pred, y_true = get_predictions(model, val_loader, device, threshold=best_threshold)
        class_labels = narrative_classes if task_type == 'narrative' else subnarrative_classes
        val_metrics = evaluate_model(y_pred, y_true, class_labels, print_report=True)
        
        scheduler.step(val_metrics['Macro F1'])
        
        if val_metrics['Macro F1'] > best_f1:
            best_f1 = val_metrics['Macro F1']
            patience_counter = 0
            torch.save({
                'model_state_dict': model.state_dict(),
                'best_threshold': best_threshold,
                'best_f1': best_f1
            }, f'best_{task_type}_model.pt')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered after epoch {epoch + 1}")
                break
    
    return model, best_threshold

def get_predictions(model, data_loader, device, threshold=0.3):
    """Generate predictions from the model."""
    model.eval()
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for batch in data_loader:
            inputs = {k: v.to(device) for k, v in batch.items() if k != "labels"}
            labels = batch["labels"].to(device)
            
            outputs = model(**inputs)
            logits = outputs.logits
            probs = torch.sigmoid(logits)
            preds = (probs > threshold).int()
            
            all_predictions.append(preds.cpu())
            all_labels.append(labels.cpu())
    
    return torch.cat(all_predictions, dim=0).numpy(), torch.cat(all_labels, dim=0).numpy()

def evaluate_model(y_pred, y_true, class_labels, print_report=False):
    """Evaluate model performance using multiple metrics."""
    metrics = {
        "Hamming Loss": hamming_loss(y_true, y_pred),
        "Macro F1": f1_score(y_true, y_pred, average='macro', zero_division=0),
        "Micro F1": f1_score(y_true, y_pred, average='micro', zero_division=0),
        "Subset Accuracy": (y_true == y_pred).all(axis=1).mean()
    }
    
    if print_report:
        active_classes = np.where(y_true.sum(axis=0) > 0)[0]
        active_labels = [class_labels[i] for i in active_classes]
        
        y_true_filtered = y_true[:, active_classes]
        y_pred_filtered = y_pred[:, active_classes]
        
        print("\nClassification Report (Active Classes Only):")
        print(classification_report(
            y_true_filtered,
            y_pred_filtered,
            target_names=active_labels,
            digits=2,
            zero_division=0
        ))
    
    return metrics

def optimize_threshold(model, val_loader, device, thresholds):
    """Find optimal classification threshold."""
    model.eval()
    best_f1 = 0
    best_threshold = 0.5
    
    all_logits = []
    all_labels = []
    
    with torch.no_grad():
        for batch in val_loader:
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
            labels = batch['labels'].to(device)
            outputs = model(**inputs)
            all_logits.append(outputs.logits)
            all_labels.append(labels)
    
    logits = torch.cat(all_logits, dim=0)
    labels = torch.cat(all_labels, dim=0)
    probs = torch.sigmoid(logits)
    
    for threshold in thresholds:
        preds = (probs > threshold).float()
        f1 = f1_score(labels.cpu().numpy(), preds.cpu().numpy(), average='macro')
        
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold
    
    return best_threshold

def evaluate_global_performance(narrative_model, subnarrative_model, val_narrative_loader, val_subnarrative_loader, device, narrative_threshold=0.3, subnarrative_threshold=0.3):
    """Evaluate both models together and compute global metrics."""
    narrative_preds, narrative_true = get_predictions(
        narrative_model, val_narrative_loader, device, narrative_threshold)
    subnarrative_preds, subnarrative_true = get_predictions(
        subnarrative_model, val_subnarrative_loader, device, subnarrative_threshold)
    
    return compute_global_score(
        narrative_preds, narrative_true,
        subnarrative_preds, subnarrative_true
    )

def compute_global_score(narrative_preds, narrative_true, subnarrative_preds, subnarrative_true, narrative_weight=0.4, subnarrative_weight=0.6):
    """
    Compute weighted global performance metrics including precision and recall.
    Returns detailed metrics for both narrative and subnarrative levels.
    """

    def compute_metrics(preds, true):
        # Calculate precision, recall, f1-score for both macro and micro averages
        macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(
            true, preds, average='macro', zero_division=0
        )
        micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support(
            true, preds, average='micro', zero_division=0
        )
        
        return {
            "hamming": hamming_loss(true, preds),
            "macro_precision": macro_precision,
            "macro_recall": macro_recall,
            "macro_f1": macro_f1,
            "micro_precision": micro_precision,
            "micro_recall": micro_recall,
            "micro_f1": micro_f1,
            "subset_acc": (true == preds).all(axis=1).mean()
        }
    
    narrative_metrics = compute_metrics(narrative_preds, narrative_true)
    subnarrative_metrics = compute_metrics(subnarrative_preds, subnarrative_true)
    
    # Calculate weighted global metrics
    global_metrics = {}
    for metric in narrative_metrics.keys():
        global_metrics[f"global_{metric}"] = (
            narrative_weight * narrative_metrics[metric] +
            subnarrative_weight * subnarrative_metrics[metric]
        )
    
    hierarchical_accuracy = (
        (narrative_true == narrative_preds).all(axis=1) &
        (subnarrative_true == subnarrative_preds).all(axis=1)
    ).mean()
    
    return {
        "narrative_metrics": narrative_metrics,
        "subnarrative_metrics": subnarrative_metrics,
        "global_metrics": global_metrics,
        "hierarchical_accuracy": hierarchical_accuracy
    }

def print_all_metrics(metrics_dict):
    """Print all metrics in a formatted way"""
    for level, metrics in metrics_dict.items():
        if level != "hierarchical_accuracy":
            print(f"\n{level.replace('_', ' ').title()}:")
            for metric, value in metrics.items():
                print(f"{metric.replace('_', ' ').title()}: {value:.4f}")
    
    print(f"\nHierarchical Accuracy: {metrics_dict['hierarchical_accuracy']:.4f}")

# 6. Main Execution Flow

In [6]:

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

# Extract and load data
extract_datasets()
labels_df, all_narratives, all_subnarratives = load_and_map_labels(LABELS_PATH)

# Process and prepare data
article_ids = labels_df["article_id"]
df = map_input_to_label_with_lang(INPUTS_PATH, article_ids, labels_df)

# Remove excess "Other" labels
other_df = df[
    df["narratives"].apply(lambda x: any("Other" in item for item in x)) & 
    df["subnarratives"].apply(lambda x: any("Other" in item for item in x))
].sample(frac=0.7, random_state=42)
df = df.drop(other_df.index)

# Analyze labels and create splits
rare_narratives, rare_subnarratives = analyze_rare_labels(df)
train_df, val_df, test_df = create_language_specific_splits(
    df, test_size=0.1, val_size=0.1, train_target_ratio=0.6,
    random_state=42, target_lang='EN'
)

# Prepare data and create datasets
train_data, narrative_weights, subnarrative_weights, narrative_classes, subnarrative_classes = process_label_data(
    train_df)
val_data = prepare_data_for_split(val_df, narrative_classes, subnarrative_classes)
test_data = prepare_data_for_split(test_df, narrative_classes, subnarrative_classes)

# Initialize models
narrative_model = BertForSequenceClassification.from_pretrained(
    "bert-base-multilingual-cased",
    num_labels=len(narrative_classes),
    problem_type="multi_label_classification"
).to(device)

subnarrative_model = BertForSequenceClassification.from_pretrained(
    "bert-base-multilingual-cased",
    num_labels=len(subnarrative_classes),
    problem_type="multi_label_classification"
).to(device)

narrative_test_loader = None
subnarrative_test_loader = None
narrative_threshold = None
subnarrative_threshold = None

# Create data loaders and train models
for task_type, model, weights, classes in [
    ('narrative', narrative_model, narrative_weights, narrative_classes),
    ('subnarrative', subnarrative_model, subnarrative_weights, subnarrative_classes)
]:
    print(f"\nTraining {task_type} model...")
    train_dataset = NarrativeDataset(train_data, tokenizer, max_len=512, task_type=task_type)
    val_dataset = NarrativeDataset(val_data, tokenizer, max_len=512, task_type=task_type)
    test_dataset = NarrativeDataset(test_data, tokenizer, max_len=512, task_type=task_type)
    
    sampler = create_weighted_sampler(train_data[f"{task_type}_labels"].tolist())
    train_loader = DataLoader(train_dataset, batch_size=16, sampler=sampler, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, pin_memory=True)
    
    if task_type == 'narrative':
        narrative_test_loader = test_loader
    else:
        subnarrative_test_loader = test_loader
    
    model, threshold = train_model_improved(
        model, train_loader, val_loader, device,
        task_type, narrative_classes, subnarrative_classes,
        weights, num_epochs=10
    )
    
    if task_type == 'narrative':
        narrative_threshold = threshold
    else:
        subnarrative_threshold = threshold

print("\nGenerating final predictions and computing metrics...")
narrative_preds, narrative_true = get_predictions(
    narrative_model, narrative_test_loader, device, narrative_threshold)
subnarrative_preds, subnarrative_true = get_predictions(
    subnarrative_model, subnarrative_test_loader, device, subnarrative_threshold)

scores = compute_global_score(
    narrative_preds, narrative_true,
    subnarrative_preds, subnarrative_true
)

print_all_metrics(scores)


Rare Label Analysis:
Total unique narratives: 22
Rare narratives (<=2 occurrences): 0
Total unique subnarratives: 92
Rare subnarratives (<=2 occurrences): 12

Data split sizes:
Total samples: 683
Language distribution:
language
PT    384
EN    166
RU    133
Name: count, dtype: int64
Validation set (EN only):
Total samples: 27
Test set (EN only):
Total samples: 27


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


RuntimeError: CUDA error: CUDA-capable device(s) is/are busy or unavailable
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
