In [1]:
import pandas as pd
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss
import os
import tqdm
import zipfile
from conllu import parse
from torch.utils.data.dataset import Dataset
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, WeightedRandomSampler
from torch.optim import AdamW
from sklearn.metrics import hamming_loss, f1_score, classification_report
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Keep existing constants and data loading code...
TARGET_LANG = ['EN', 'BG', 'PT', 'RU']
RAW_DATASET_PATH = '../data/raw/target_4_December_release'
PREPROCESSED_DATASET_PATH = '../data/preprocessed/preprocessed_target_4_December_release'
LABELS_PATH = [os.path.join(RAW_DATASET_PATH, lang, 'subtask-2-annotations.txt') for lang in TARGET_LANG]
INPUTS_PATH = [os.path.join(PREPROCESSED_DATASET_PATH, lang) for lang in TARGET_LANG]

In [3]:
# Extract datasets if needed
from sklearn.preprocessing import MultiLabelBinarizer


def extract_datasets():
    if not os.path.exists(RAW_DATASET_PATH):
        with zipfile.ZipFile(RAW_DATASET_PATH + '.zip', 'r') as zip_ref:
            zip_ref.extractall(RAW_DATASET_PATH, pwd=b'narratives5202trainTHREE')
    
    if not os.path.exists(PREPROCESSED_DATASET_PATH):
        with zipfile.ZipFile(PREPROCESSED_DATASET_PATH + '.zip', 'r') as zip_ref:
            zip_ref.extractall(PREPROCESSED_DATASET_PATH)

def load_and_map_labels(label_file_paths: list[str]):
    """Load and map narrative labels from files."""
    all_labels = []
    all_narratives_set = set()
    all_subnarratives_set = set()
    
    for label_file_path in label_file_paths:
        labels_df = pd.read_csv(
            label_file_path, 
            sep="\t", 
            header=None, 
            names=["article_id", "narratives", "subnarratives"]
        )
        
        for _, row in labels_df.iterrows():
            # Extract narratives and subnarratives
            narratives = row["narratives"].split(";") if pd.notna(row["narratives"]) else []
            subnarratives = row["subnarratives"].split(";") if pd.notna(row["subnarratives"]) else []
            
            # Update sets of unique labels
            all_narratives_set.update(narratives)
            all_subnarratives_set.update(subnarratives)
            
            all_labels.append({
                "article_id": row["article_id"],
                "narratives": narratives,
                "subnarratives": subnarratives
            })
    
    # Convert sets to sorted lists for consistent ordering
    all_narratives = sorted(list(all_narratives_set - {''} if '' in all_narratives_set else all_narratives_set))
    all_subnarratives = sorted(list(all_subnarratives_set - {''} if '' in all_subnarratives_set else all_subnarratives_set))
    
    return pd.DataFrame(all_labels), all_narratives, all_subnarratives

def parse_conllu_file(file_path):
    """Parse a CoNLL-U format file and return concatenated tokens."""
    with open(file_path, "r", encoding="utf-8") as f:
        data = f.read()
    token_lists = parse(data)
    all_tokens = [token["form"] for token_list in token_lists for token in token_list]
    return " ".join(all_tokens)

def map_input_to_label(articles_paths: list[str], article_ids: list[str], labels: pd.DataFrame):
    """Map input articles to their corresponding labels."""
    labels = labels.set_index("article_id")
    
    articles_data = []
    for articles_path in articles_paths:
        for article_id in article_ids:
            file_path = os.path.join(articles_path, f"{article_id.replace('.txt', '.conllu')}")
            if os.path.exists(file_path) and article_id in labels.index:
                article_text = parse_conllu_file(file_path)
                article_labels = labels.loc[article_id]
                articles_data.append({
                    "article_id": article_id,
                    "text": article_text,
                    "narratives": article_labels["narratives"],
                    "subnarratives": article_labels["subnarratives"]
                })
    return pd.DataFrame(articles_data)

class NarrativeDataset(Dataset):
    def __init__(self, articles, tokenizer, max_len, task_type='narrative'):
        self.articles = articles
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.task_type = task_type

    def __len__(self):
        return len(self.articles)

    def __getitem__(self, idx):
        article = self.articles.iloc[idx]
        inputs = self.tokenizer(
            article["text"],
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        if self.task_type == 'narrative':
            labels = article["narrative_labels"]
        else:
            labels = article["subnarrative_labels"]
            
        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "labels": torch.tensor(labels, dtype=torch.float32)
        }

def get_predictions(model, data_loader, device, threshold=0.3):
    """Generate predictions from the model."""
    model.eval()
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for batch in data_loader:
            inputs = {key: val.to(device) for key, val in batch.items() if key != "labels"}
            labels = batch["labels"].to(device)
            
            outputs = model(**inputs)
            logits = outputs.logits
            probs = torch.sigmoid(logits)
            preds = (probs > threshold).int()
            
            all_predictions.append(preds.cpu())
            all_labels.append(labels.cpu())
    
    return torch.cat(all_predictions, dim=0).numpy(), torch.cat(all_labels, dim=0).numpy()

def evaluate_model(y_pred, y_true, class_labels, print_report=False):
    """Evaluate model performance using multiple metrics."""
    hamming = hamming_loss(y_true, y_pred)
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    micro_f1 = f1_score(y_true, y_pred, average='micro', zero_division=0)
    subset_accuracy = (y_true == y_pred).all(axis=1).mean()

    if print_report:
        report = classification_report(
            y_true, y_pred, target_names=class_labels, digits=2, zero_division=0
        )
        print("\nClassification Report:\n")
        print(report)

    return {
        "Hamming Loss": hamming,
        "Macro F1": macro_f1,
        "Micro F1": micro_f1,
        "Subset Accuracy": subset_accuracy
    }

def create_weighted_sampler(labels):
    """Create a weighted sampler to handle class imbalance."""
    label_counts = np.sum(labels, axis=0)
    weights = 1.0 / label_counts
    weights = np.nan_to_num(weights, nan=1.0, posinf=1.0)
    sample_weights = np.sum(labels * weights, axis=1)
    return WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )

def prepare_data(df, all_narratives, all_subnarratives):
    """Prepare separate datasets for narratives and subnarratives."""
    narrative_labels = df["narratives"].apply(
        lambda x: [1 if n in x else 0 for n in all_narratives]
    ).tolist()
    
    subnarrative_labels = df["subnarratives"].apply(
        lambda x: [1 if sn in x else 0 for sn in all_subnarratives]
    ).tolist()
    
    df["narrative_labels"] = narrative_labels
    df["subnarrative_labels"] = subnarrative_labels
    
    return df

class CustomBCEWithLogitsLoss(nn.Module):
    """Custom loss function with label-dependent weighting."""
    def __init__(self, pos_weight=None, reduction='mean'):
        super().__init__()
        self.pos_weight = pos_weight
        self.reduction = reduction
        
    def forward(self, logits, target):
        # Calculate label frequencies in this batch
        batch_pos_counts = torch.sum(target, dim=0)
        batch_neg_counts = target.size(0) - batch_pos_counts
        
        # Avoid division by zero
        eps = 1e-7
        batch_weights = (batch_neg_counts + eps) / (batch_pos_counts + eps)
        
        if self.pos_weight is not None:
            batch_weights = batch_weights * self.pos_weight
            
        loss_fn = BCEWithLogitsLoss(pos_weight=batch_weights, reduction=self.reduction)
        return loss_fn(logits, target)

def analyze_label_distribution(df, task_type='subnarrative'):
    """Analyze and print label distribution statistics."""
    labels = df[f"{task_type}_labels"].tolist()
    label_sums = np.sum(labels, axis=0)
    
    print(f"\n{task_type.capitalize()} Label Distribution:")
    print(f"Total samples: {len(labels)}")
    print(f"Average labels per sample: {np.mean(np.sum(labels, axis=1)):.2f}")
    print(f"Label cardinality: {np.mean(label_sums):.2f}")
    print(f"Label density: {np.mean(label_sums)/len(labels):.4f}")
    
    # Calculate and print label correlations
    label_matrix = np.array(labels)
    correlations = np.corrcoef(label_matrix.T)
    high_corr_pairs = []
    
    for i in range(len(correlations)):
        for j in range(i+1, len(correlations)):
            if abs(correlations[i,j]) > 0.5:  # Threshold for high correlation
                high_corr_pairs.append((i, j, correlations[i,j]))
    
    if high_corr_pairs:
        print("\nHighly correlated label pairs:")
        for i, j, corr in high_corr_pairs[:5]:  # Show top 5
            print(f"Labels {i} and {j}: {corr:.2f}")
    
    return label_sums

def prepare_improved_data(df, all_narratives, all_subnarratives):
    """Prepare data with improved label handling."""
    # Create binary label matrices
    mlb_narrative = MultiLabelBinarizer()
    mlb_subnarrative = MultiLabelBinarizer()
    
    narrative_labels = mlb_narrative.fit_transform(df["narratives"])
    subnarrative_labels = mlb_subnarrative.fit_transform(df["subnarratives"])
    
    # Calculate class weights
    narrative_weights = compute_class_weights(narrative_labels)
    subnarrative_weights = compute_class_weights(subnarrative_labels)
    
    # Add labels to dataframe
    df["narrative_labels"] = list(narrative_labels)
    df["subnarrative_labels"] = list(subnarrative_labels)
    
    return df, narrative_weights, subnarrative_weights

def compute_class_weights(labels):
    """Compute balanced class weights."""
    pos_counts = np.sum(labels, axis=0)
    neg_counts = len(labels) - pos_counts
    
    # Balanced weight calculation
    weights = neg_counts / (pos_counts + 1e-7)
    weights = np.clip(weights, 0.1, 10.0)  # Clip weights to prevent extreme values
    
    return torch.FloatTensor(weights)



def optimize_threshold(model, val_loader, device, thresholds):
    """Find the optimal threshold for classification."""
    model.eval()
    best_f1 = 0
    best_threshold = 0.5
    
    all_logits = []
    all_labels = []
    
    with torch.no_grad():
        for batch in val_loader:
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
            labels = batch['labels'].to(device)
            outputs = model(**inputs)
            all_logits.append(outputs.logits)
            all_labels.append(labels)
    
    logits = torch.cat(all_logits, dim=0)
    labels = torch.cat(all_labels, dim=0)
    probs = torch.sigmoid(logits)
    
    for threshold in thresholds:
        preds = (probs > threshold).float()
        f1 = f1_score(labels.cpu().numpy(), preds.cpu().numpy(), average='macro')
        
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold
    
    return best_threshold


def train_model_improved(model, train_loader, val_loader, device, task_type, all_narratives, 
                        all_subnarratives, class_weights, num_epochs=5):
    """Improved training function with better handling of multilabel classification."""
    optimizer = AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)
    criterion = CustomBCEWithLogitsLoss(pos_weight=class_weights.to(device))
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', 
                                                         factor=0.5, patience=2)
    
    best_f1 = 0
    patience = 5
    patience_counter = 0
    best_threshold = 0.5
    
    # Variables for threshold adjustment
    thresholds = np.arange(0.1, 0.9, 0.1)
    
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        progress_bar = tqdm.tqdm(train_loader, desc="Training")
        
        for batch in progress_bar:
            optimizer.zero_grad()
            
            inputs = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**inputs)
            loss = criterion(outputs.logits, inputs['labels'])
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            
            epoch_loss += loss.item()
            progress_bar.set_postfix({"Loss": f"{loss.item():.4f}"})
        
        # Find best threshold on validation set
        best_threshold = optimize_threshold(model, val_loader, device, thresholds)
        
        y_pred, y_true = get_predictions(model, val_loader, device, threshold=best_threshold)
        class_labels = all_narratives if task_type == 'narrative' else all_subnarratives
        val_metrics = evaluate_model(y_pred, y_true, class_labels, print_report=True)
        
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print(f"Best threshold: {best_threshold:.2f}")
        for metric, value in val_metrics.items():
            print(f"{metric}: {value:.4f}")
        
        # Update learning rate based on F1 score
        scheduler.step(val_metrics['Macro F1'])
        
        current_f1 = val_metrics['Macro F1']
        if current_f1 > best_f1:
            best_f1 = current_f1
            patience_counter = 0
            torch.save({
                'model_state_dict': model.state_dict(),
                'best_threshold': best_threshold,
                'best_f1': best_f1
            }, f'best_{task_type}_model.pt')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered after epoch {epoch + 1}")
                break
    
    return model, best_threshold

In [5]:
# Main execution
if __name__ == "__main__":
    # Extract datasets
    extract_datasets()
    
    # Load labels and get unique narratives and subnarratives
    labels_df, all_narratives, all_subnarratives = load_and_map_labels(LABELS_PATH)
    
    print(f"Found {len(all_narratives)} unique narratives and {len(all_subnarratives)} unique subnarratives")
    print("\nNarratives:", all_narratives)
    print("\nSubnarratives:", all_subnarratives)
    
    # Map inputs to labels
    article_ids = labels_df["article_id"]
    df = map_input_to_label(INPUTS_PATH, article_ids, labels_df)
    
    # Remove excess "Other" labels
    other_df = df[
        df["narratives"].apply(lambda x: any("Other" in item for item in x)) & 
        df["subnarratives"].apply(lambda x: any("Other" in item for item in x))
    ].sample(frac=0.7, random_state=42)
    df = df.drop(other_df.index)
    
    # Prepare data with improved handling
    df, narrative_weights, subnarrative_weights = prepare_improved_data(
        df, all_narratives, all_subnarratives)
    
    # Analyze label distribution
    print("\nAnalyzing label distribution...")
    narrative_label_sums = analyze_label_distribution(df, 'narrative')
    subnarrative_label_sums = analyze_label_distribution(df, 'subnarrative')
    
    # Split data
    train_data, val_data = train_test_split(df, test_size=0.2, random_state=42, stratify=df["narratives"].apply(len))
    
    # Initialize tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
    
    # Create datasets for both tasks
    train_narrative_dataset = NarrativeDataset(train_data, tokenizer, max_len=512, task_type='narrative')
    val_narrative_dataset = NarrativeDataset(val_data, tokenizer, max_len=512, task_type='narrative')
    
    train_subnarrative_dataset = NarrativeDataset(train_data, tokenizer, max_len=512, task_type='subnarrative')
    val_subnarrative_dataset = NarrativeDataset(val_data, tokenizer, max_len=512, task_type='subnarrative')
    
    # Create weighted samplers
    narrative_sampler = create_weighted_sampler(train_data["narrative_labels"].tolist())
    subnarrative_sampler = create_weighted_sampler(train_data["subnarrative_labels"].tolist())
    
    # Create data loaders with weighted sampling
    train_narrative_loader = DataLoader(
        train_narrative_dataset, 
        batch_size=16,  # Reduced batch size
        sampler=narrative_sampler,
        pin_memory=True
    )
    
    train_subnarrative_loader = DataLoader(
        train_subnarrative_dataset,
        batch_size=16,
        sampler=subnarrative_sampler,
        pin_memory=True
    )
    
    val_narrative_loader = DataLoader(val_narrative_dataset, batch_size=16, shuffle=False)
    val_subnarrative_loader = DataLoader(val_subnarrative_dataset, batch_size=16, shuffle=False)
    
    # Initialize models
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    narrative_model = BertForSequenceClassification.from_pretrained(
        "bert-base-multilingual-cased",
        num_labels=len(all_narratives)
    ).to(device)
    
    subnarrative_model = BertForSequenceClassification.from_pretrained(
        "bert-base-multilingual-cased",
        num_labels=len(all_subnarratives)
    ).to(device)
    
    # Train models with improvements
    print("\nTraining Narrative Model...")
    narrative_model, narrative_threshold = train_model_improved(
        narrative_model,
        train_narrative_loader,
        val_narrative_loader,
        device,
        'narrative',
        all_narratives,
        all_subnarratives,
        narrative_weights,
        num_epochs=5
    )
    
    print("\nTraining Subnarrative Model...")
    subnarrative_model, subnarrative_threshold = train_model_improved(
        subnarrative_model,
        train_subnarrative_loader,
        val_subnarrative_loader,
        device,
        'subnarrative',
        all_narratives,
        all_subnarratives,
        subnarrative_weights,
        num_epochs=5
    )

Found 22 unique narratives and 94 unique subnarratives

Narratives: ['CC: Amplifying Climate Fears', 'CC: Climate change is beneficial', 'CC: Controversy about green technologies', 'CC: Criticism of climate movement', 'CC: Criticism of climate policies', 'CC: Criticism of institutions and authorities', 'CC: Downplaying climate change', 'CC: Green policies are geopolitical instruments', 'CC: Hidden plots by secret schemes of powerful groups', 'CC: Questioning the measurements and science', 'Other', 'URW: Amplifying war-related fears', 'URW: Blaming the war on others rather than the invader', 'URW: Discrediting Ukraine', 'URW: Discrediting the West, Diplomacy', 'URW: Distrust towards Media', 'URW: Hidden plots by secret schemes of powerful groups', 'URW: Negative Consequences for the West', 'URW: Overpraising the West', 'URW: Praise of Russia', 'URW: Russia is the Victim', 'URW: Speculating war outcomes']

Subnarratives: ['CC: Amplifying Climate Fears: Amplifying existing fears of global

  return torch._C._cuda_getDeviceCount() > 0
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Training Narrative Model...


Training:   2%|▎         | 1/40 [00:15<09:47, 15.06s/it, Loss=4.5574]


KeyboardInterrupt: 