In [None]:
import pandas as pd
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss
import os
import tqdm
import zipfile
from conllu import parse
from torch.utils.data.dataset import Dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import hamming_loss, f1_score, classification_report
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, WeightedRandomSampler
from torch.optim import AdamW
import numpy as np

In [2]:
# Keep existing constants and data loading code...
TARGET_LANG = ['EN', 'PT', 'RU']
RAW_DATASET_PATH = '../data/raw/target_4_December_release'
PREPROCESSED_DATASET_PATH = '../data/preprocessed/preprocessed_target_4_December_release'
LABELS_PATH = [os.path.join(RAW_DATASET_PATH, lang, 'subtask-2-annotations.txt') for lang in TARGET_LANG]
INPUTS_PATH = [os.path.join(PREPROCESSED_DATASET_PATH, lang) for lang in TARGET_LANG]

In [3]:
def extract_datasets():
    if not os.path.exists(RAW_DATASET_PATH):
        with zipfile.ZipFile(RAW_DATASET_PATH + '.zip', 'r') as zip_ref:
            zip_ref.extractall(RAW_DATASET_PATH, pwd=b'narratives5202trainTHREE')
    
    if not os.path.exists(PREPROCESSED_DATASET_PATH):
        with zipfile.ZipFile(PREPROCESSED_DATASET_PATH + '.zip', 'r') as zip_ref:
            zip_ref.extractall(PREPROCESSED_DATASET_PATH)

def load_and_map_labels(label_file_paths: list[str]):
    """Load and map narrative labels from files."""
    all_labels = []
    all_narratives_set = set()
    all_subnarratives_set = set()
    
    for label_file_path in label_file_paths:
        labels_df = pd.read_csv(
            label_file_path, 
            sep="\t", 
            header=None, 
            names=["article_id", "narratives", "subnarratives"]
        )
        
        for _, row in labels_df.iterrows():
            # Extract narratives and subnarratives
            narratives = row["narratives"].split(";") if pd.notna(row["narratives"]) else []
            subnarratives = row["subnarratives"].split(";") if pd.notna(row["subnarratives"]) else []
            
            # Update sets of unique labels
            all_narratives_set.update(narratives)
            all_subnarratives_set.update(subnarratives)
            
            all_labels.append({
                "article_id": row["article_id"],
                "narratives": narratives,
                "subnarratives": subnarratives
            })
    
    # Convert sets to sorted lists for consistent ordering
    all_narratives = sorted(list(all_narratives_set - {''} if '' in all_narratives_set else all_narratives_set))
    all_subnarratives = sorted(list(all_subnarratives_set - {''} if '' in all_subnarratives_set else all_subnarratives_set))
    
    return pd.DataFrame(all_labels), all_narratives, all_subnarratives

def parse_conllu_file(file_path):
    """Parse a CoNLL-U format file and return concatenated tokens."""
    with open(file_path, "r", encoding="utf-8") as f:
        data = f.read()
    token_lists = parse(data)
    all_tokens = [token["form"] for token_list in token_lists for token in token_list]
    return " ".join(all_tokens)

def map_input_to_label_with_lang(articles_paths: list[str], article_ids: list[str], labels: pd.DataFrame):
    """Map input articles to their corresponding labels and add language information."""
    labels = labels.set_index("article_id")
    
    articles_data = []
    for articles_path in articles_paths:
        # Extract language from path
        lang = articles_path.split('/')[-1]  # Gets the language code (EN, PT, RU)
        
        for article_id in article_ids:
            file_path = os.path.join(articles_path, f"{article_id.replace('.txt', '.conllu')}")
            if os.path.exists(file_path) and article_id in labels.index:
                article_text = parse_conllu_file(file_path)
                article_labels = labels.loc[article_id]
                articles_data.append({
                    "article_id": article_id,
                    "text": article_text,
                    "narratives": article_labels["narratives"],
                    "subnarratives": article_labels["subnarratives"],
                    "language": lang
                })
    return pd.DataFrame(articles_data)

class NarrativeDataset(Dataset):
    def __init__(self, articles, tokenizer, max_len, task_type='narrative'):
        self.articles = articles
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.task_type = task_type

    def __len__(self):
        return len(self.articles)

    def __getitem__(self, idx):
        article = self.articles.iloc[idx]
        inputs = self.tokenizer(
            article["text"],
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        if self.task_type == 'narrative':
            labels = article["narrative_labels"]
        else:
            labels = article["subnarrative_labels"]
            
        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "labels": torch.tensor(labels, dtype=torch.float32)
        }

def get_predictions(model, data_loader, device, threshold=0.3):
    """Generate predictions from the model."""
    model.eval()
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for batch in data_loader:
            inputs = {key: val.to(device) for key, val in batch.items() if key != "labels"}
            labels = batch["labels"].to(device)
            
            outputs = model(**inputs)
            logits = outputs.logits
            probs = torch.sigmoid(logits)
            preds = (probs > threshold).int()
            
            all_predictions.append(preds.cpu())
            all_labels.append(labels.cpu())
    
    return torch.cat(all_predictions, dim=0).numpy(), torch.cat(all_labels, dim=0).numpy()

def evaluate_model(y_pred, y_true, class_labels, print_report=False):
    """Evaluate model performance using multiple metrics."""
    hamming = hamming_loss(y_true, y_pred)
    
    # Handle zero division in F1 calculation
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    micro_f1 = f1_score(y_true, y_pred, average='micro', zero_division=0)
    
    subset_accuracy = (y_true == y_pred).all(axis=1).mean()
    
    if print_report:
        # Only use actual classes that appear in the data
        active_classes = np.where(y_true.sum(axis=0) > 0)[0]
        active_labels = [class_labels[i] for i in active_classes]
        
        # Filter predictions and true values to only include active classes
        y_true_filtered = y_true[:, active_classes]
        y_pred_filtered = y_pred[:, active_classes]
        
        report = classification_report(
            y_true_filtered, 
            y_pred_filtered,
            target_names=active_labels,
            digits=2,
            zero_division=0
        )
        print("\nClassification Report (Active Classes Only):\n")
        print(report)
    
    return {
        "Hamming Loss": hamming,
        "Macro F1": macro_f1,
        "Micro F1": micro_f1,
        "Subset Accuracy": subset_accuracy
    }

def create_weighted_sampler(labels):
    """Create a weighted sampler to handle class imbalance."""
    label_counts = np.sum(labels, axis=0)
    weights = 1.0 / label_counts
    weights = np.nan_to_num(weights, nan=1.0, posinf=1.0)
    sample_weights = np.sum(labels * weights, axis=1)
    return WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )

def prepare_data(df, all_narratives, all_subnarratives):
    """Prepare separate datasets for narratives and subnarratives."""
    narrative_labels = df["narratives"].apply(
        lambda x: [1 if n in x else 0 for n in all_narratives]
    ).tolist()
    
    subnarrative_labels = df["subnarratives"].apply(
        lambda x: [1 if sn in x else 0 for sn in all_subnarratives]
    ).tolist()
    
    df["narrative_labels"] = narrative_labels
    df["subnarrative_labels"] = subnarrative_labels
    
    return df

class CustomBCEWithLogitsLoss(nn.Module):
    """Custom loss function with label-dependent weighting."""
    def __init__(self, pos_weight=None, reduction='mean'):
        super().__init__()
        self.pos_weight = pos_weight
        self.reduction = reduction
        
    def forward(self, logits, target):
        # Calculate label frequencies in this batch
        batch_pos_counts = torch.sum(target, dim=0)
        batch_neg_counts = target.size(0) - batch_pos_counts
        
        # Avoid division by zero
        eps = 1e-7
        batch_weights = (batch_neg_counts + eps) / (batch_pos_counts + eps)
        
        if self.pos_weight is not None:
            batch_weights = batch_weights * self.pos_weight
            
        loss_fn = BCEWithLogitsLoss(pos_weight=batch_weights, reduction=self.reduction)
        return loss_fn(logits, target)

def prepare_improved_data(df, all_narratives, all_subnarratives):
    """Prepare data with improved label handling."""
    # Create binary label matrices
    mlb_narrative = MultiLabelBinarizer()
    mlb_subnarrative = MultiLabelBinarizer()
    
    narrative_labels = mlb_narrative.fit_transform(df["narratives"])
    subnarrative_labels = mlb_subnarrative.fit_transform(df["subnarratives"])
    
    # Store the label classes
    narrative_classes = list(mlb_narrative.classes_)
    subnarrative_classes = list(mlb_subnarrative.classes_)
    
    print(f"Number of narrative classes: {len(narrative_classes)}")
    print(f"Number of subnarrative classes: {len(subnarrative_classes)}")
    
    # Calculate class weights
    narrative_weights = compute_class_weights(narrative_labels)
    subnarrative_weights = compute_class_weights(subnarrative_labels)
    
    # Add labels to dataframe
    df["narrative_labels"] = list(narrative_labels)
    df["subnarrative_labels"] = list(subnarrative_labels)
    
    return (df, narrative_weights, subnarrative_weights, 
            narrative_classes, subnarrative_classes)

def compute_class_weights(labels):
    """Compute balanced class weights."""
    pos_counts = np.sum(labels, axis=0)
    neg_counts = len(labels) - pos_counts
    
    # Balanced weight calculation
    weights = neg_counts / (pos_counts + 1e-7)
    weights = np.clip(weights, 0.1, 10.0)  # Clip weights to prevent extreme values
    
    return torch.FloatTensor(weights)

def create_stratified_splits_with_target_lang(df, test_size=0.1, val_size=0.1, 
                                            train_target_ratio=0.6, random_state=42, 
                                            target_lang='EN'):
    """
    Create train/val/test splits with controlled amount of target language in training.
    
    Args:
        df: DataFrame containing all data
        test_size: Proportion of target language data for testing
        val_size: Proportion of target language data for validation
        train_target_ratio: Proportion of target language data to include in training
        random_state: Random seed for reproducibility
        target_lang: Target language code (default: 'EN')
    """
    # Separate target language data and other languages
    target_lang_df = df[df['language'] == target_lang]
    other_langs_df = df[df['language'] != target_lang]
    
    # Calculate sizes for target language splits
    total_target = len(target_lang_df)
    test_samples = int(total_target * test_size)
    val_samples = int(total_target * val_size)
    train_samples = int(total_target * train_target_ratio)
    
    # Split target language data
    # First split out the test set
    remaining_target_df, test_df = train_test_split(
        target_lang_df,
        test_size=test_samples,
        random_state=random_state
    )
    
    # Then split the remaining data into train and validation
    train_target_df, val_df = train_test_split(
        remaining_target_df,
        test_size=val_samples,
        random_state=random_state
    )
    
    # If we want more training samples, take them from what's left
    if len(train_target_df) > train_samples:
        train_target_df = train_target_df.sample(n=train_samples, random_state=random_state)
    
    # Combine target language training data with other languages
    train_df = pd.concat([other_langs_df, train_target_df])
    
    # Shuffle the training data
    train_df = train_df.sample(frac=1, random_state=random_state).reset_index(drop=True)
    
    # Print detailed statistics
    print("\nData split sizes:")
    print(f"\nTraining set:")
    print(f"Total samples: {len(train_df)}")
    print("Language distribution:")
    print(train_df['language'].value_counts())
    print(f"\nValidation set ({target_lang} only):")
    print(f"Total samples: {len(val_df)}")
    print(f"\nTest set ({target_lang} only):")
    print(f"Total samples: {len(test_df)}")
    
    return train_df, val_df, test_df

def analyze_rare_labels(df):
    """Analyze and print information about rare labels in the dataset"""
    narrative_counts = {}
    subnarrative_counts = {}
    
    # Count occurrences of each label
    for narratives in df["narratives"]:
        for n in narratives:
            narrative_counts[n] = narrative_counts.get(n, 0) + 1
    
    for subnarratives in df["subnarratives"]:
        for sn in subnarratives:
            subnarrative_counts[sn] = subnarrative_counts.get(sn, 0) + 1
    
    # Find rare labels
    rare_narratives = {k: v for k, v in narrative_counts.items() if v <= 2}
    rare_subnarratives = {k: v for k, v in subnarrative_counts.items() if v <= 2}
    
    print("\nRare Label Analysis:")
    print(f"Total unique narratives: {len(narrative_counts)}")
    print(f"Rare narratives (<=2 occurrences): {len(rare_narratives)}")
    print("Rare narrative counts:", rare_narratives)
    
    print(f"\nTotal unique subnarratives: {len(subnarrative_counts)}")
    print(f"Rare subnarratives (<=2 occurrences): {len(rare_subnarratives)}")
    print("Rare subnarrative counts:", rare_subnarratives)
    
    return rare_narratives, rare_subnarratives

def optimize_threshold(model, val_loader, device, thresholds):
    """Find the optimal threshold for classification."""
    model.eval()
    best_f1 = 0
    best_threshold = 0.5
    
    all_logits = []
    all_labels = []
    
    with torch.no_grad():
        for batch in val_loader:
            inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
            labels = batch['labels'].to(device)
            outputs = model(**inputs)
            all_logits.append(outputs.logits)
            all_labels.append(labels)
    
    logits = torch.cat(all_logits, dim=0)
    labels = torch.cat(all_labels, dim=0)
    probs = torch.sigmoid(logits)
    
    for threshold in thresholds:
        preds = (probs > threshold).float()
        f1 = f1_score(labels.cpu().numpy(), preds.cpu().numpy(), average='macro')
        
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold
    
    return best_threshold

def prepare_data_for_split(df, narrative_classes, subnarrative_classes):
    """
    Prepare validation or test split using the classes from training data.
    
    Args:
        df: DataFrame containing the split data
        narrative_classes: List of narrative classes from training data
        subnarrative_classes: List of subnarrative classes from training data
        
    Returns:
        DataFrame with encoded labels matching training data structure
    """
    # Encode narrative labels using training classes
    narrative_labels = df["narratives"].apply(
        lambda x: [1 if n in x else 0 for n in narrative_classes]
    ).tolist()
    
    # Encode subnarrative labels using training classes
    subnarrative_labels = df["subnarratives"].apply(
        lambda x: [1 if sn in x else 0 for sn in subnarrative_classes]
    ).tolist()
    
    # Add encoded labels to dataframe
    df = df.copy()
    df["narrative_labels"] = narrative_labels
    df["subnarrative_labels"] = subnarrative_labels
      
    return df


def compute_global_score(narrative_preds, narrative_true, subnarrative_preds, subnarrative_true, 
                        narrative_weight=0.4, subnarrative_weight=0.6):
    """
    Compute a weighted global score combining narrative and subnarrative predictions.
    
    Args:
        narrative_preds: Predictions for narratives
        narrative_true: True labels for narratives
        subnarrative_preds: Predictions for subnarratives
        subnarrative_true: True labels for subnarratives
        narrative_weight: Weight for narrative metrics (default: 0.4)
        subnarrative_weight: Weight for subnarrative metrics (default: 0.6)
        
    Returns:
        dict: Combined metrics and individual metrics for both levels
    """
    # Compute metrics for narratives
    narrative_metrics = {
        "hamming": hamming_loss(narrative_true, narrative_preds),
        "macro_f1": f1_score(narrative_true, narrative_preds, average='macro', zero_division=0),
        "micro_f1": f1_score(narrative_true, narrative_preds, average='micro', zero_division=0),
        "subset_acc": (narrative_true == narrative_preds).all(axis=1).mean()
    }
    
    # Compute metrics for subnarratives
    subnarrative_metrics = {
        "hamming": hamming_loss(subnarrative_true, subnarrative_preds),
        "macro_f1": f1_score(subnarrative_true, subnarrative_preds, average='macro', zero_division=0),
        "micro_f1": f1_score(subnarrative_true, subnarrative_preds, average='micro', zero_division=0),
        "subset_acc": (subnarrative_true == subnarrative_preds).all(axis=1).mean()
    }
    
    # Compute global metrics
    global_metrics = {
        "global_hamming": (
            narrative_weight * narrative_metrics["hamming"] +
            subnarrative_weight * subnarrative_metrics["hamming"]
        ),
        "global_macro_f1": (
            narrative_weight * narrative_metrics["macro_f1"] +
            subnarrative_weight * subnarrative_metrics["macro_f1"]
        ),
        "global_micro_f1": (
            narrative_weight * narrative_metrics["micro_f1"] +
            subnarrative_weight * subnarrative_metrics["micro_f1"]
        ),
        "global_subset_acc": (
            narrative_weight * narrative_metrics["subset_acc"] +
            subnarrative_weight * subnarrative_metrics["subset_acc"]
        )
    }
    
    # Compute hierarchical accuracy (both levels must be correct)
    hierarchical_accuracy = (
        (narrative_true == narrative_preds).all(axis=1) &
        (subnarrative_true == subnarrative_preds).all(axis=1)
    ).mean()
    
    # Combine all metrics
    all_metrics = {
        "narrative_metrics": narrative_metrics,
        "subnarrative_metrics": subnarrative_metrics,
        "global_metrics": global_metrics,
        "hierarchical_accuracy": hierarchical_accuracy
    }
    
    return all_metrics

def evaluate_global_performance(narrative_model, subnarrative_model, val_narrative_loader, val_subnarrative_loader, device, 
                                narrative_threshold=0.3, subnarrative_threshold=0.3):
    """
    Evaluate both models together on the validation set.
    """
    # Get predictions for both levels
    narrative_preds, narrative_true = get_predictions(
        narrative_model, val_narrative_loader, device, narrative_threshold)
    subnarrative_preds, subnarrative_true = get_predictions(
        subnarrative_model, val_subnarrative_loader, device, subnarrative_threshold)
    
    # Compute global scores
    global_scores = compute_global_score(
        narrative_preds, narrative_true,
        subnarrative_preds, subnarrative_true
    )
    
    # Print detailed results
    print("\nGlobal Performance Metrics:")
    print("\nNarrative Level Metrics:")
    for metric, value in global_scores["narrative_metrics"].items():
        print(f"{metric}: {value:.4f}")
    
    print("\nSubnarrative Level Metrics:")
    for metric, value in global_scores["subnarrative_metrics"].items():
        print(f"{metric}: {value:.4f}")
    
    print("\nGlobal Combined Metrics:")
    for metric, value in global_scores["global_metrics"].items():
        print(f"{metric}: {value:.4f}")
    
    print(f"\nHierarchical Accuracy: {global_scores['hierarchical_accuracy']:.4f}")
    
    return global_scores

def train_model_improved(model, train_loader, val_loader, device, task_type, 
                        narrative_classes, subnarrative_classes, class_weights, num_epochs=5):
    """Improved training function with better handling of multilabel classification."""
    optimizer = AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)
    criterion = CustomBCEWithLogitsLoss(pos_weight=class_weights.to(device))
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', 
                                                         factor=0.5, patience=2)
    
    best_f1 = 0
    patience = 5
    patience_counter = 0
    best_threshold = 0.5
    
    # Variables for threshold adjustment
    thresholds = np.arange(0.1, 0.9, 0.1)
    
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        progress_bar = tqdm.tqdm(train_loader, desc="Training")
        
        for batch in progress_bar:
            optimizer.zero_grad()
            
            inputs = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**inputs)
            loss = criterion(outputs.logits, inputs['labels'])
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            
            epoch_loss += loss.item()
            progress_bar.set_postfix({"Loss": f"{loss.item():.4f}"})
        
        # Find best threshold on validation set
        best_threshold = optimize_threshold(model, val_loader, device, thresholds)
        
        y_pred, y_true = get_predictions(model, val_loader, device, threshold=best_threshold)
        class_labels = narrative_classes if task_type == 'narrative' else subnarrative_classes
        val_metrics = evaluate_model(y_pred, y_true, class_labels, print_report=True)
        
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print(f"Best threshold: {best_threshold:.2f}")
        for metric, value in val_metrics.items():
            print(f"{metric}: {value:.4f}")
        
        # Update learning rate based on F1 score
        scheduler.step(val_metrics['Macro F1'])
        
        current_f1 = val_metrics['Macro F1']
        if current_f1 > best_f1:
            best_f1 = current_f1
            patience_counter = 0
            torch.save({
                'model_state_dict': model.state_dict(),
                'best_threshold': best_threshold,
                'best_f1': best_f1
            }, f'best_{task_type}_model.pt')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered after epoch {epoch + 1}")
                break
    
    return model, best_threshold

In [None]:

# Extract datasets
extract_datasets()

# Load labels and get unique narratives and subnarratives
labels_df, all_narratives, all_subnarratives = load_and_map_labels(LABELS_PATH)

print(f"Found {len(all_narratives)} unique narratives and {len(all_subnarratives)} unique subnarratives")

# Map inputs to labels
article_ids = labels_df["article_id"]
df = map_input_to_label_with_lang(INPUTS_PATH, article_ids, labels_df)

# Remove excess "Other" labels
other_df = df[
    df["narratives"].apply(lambda x: any("Other" in item for item in x)) & 
    df["subnarratives"].apply(lambda x: any("Other" in item for item in x))
].sample(frac=0.7, random_state=42)
df = df.drop(other_df.index)

# Analyze rare labels before splitting
print("Analyzing label distribution before splitting...")
rare_narratives, rare_subnarratives = analyze_rare_labels(df)

# Create splits with modified strategy
train_df, val_df, test_df = create_stratified_splits_with_target_lang(
    df,
    test_size=0.1,
    val_size=0.1,
    train_target_ratio=0.6,  
    random_state=42,
    target_lang='EN'
)

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')

# Prepare training data and get class information
train_data, narrative_weights, subnarrative_weights, narrative_classes, subnarrative_classes = prepare_improved_data(
    train_df, all_narratives, all_subnarratives)

# Prepare validation and test data using training classes
val_data = prepare_data_for_split(val_df, narrative_classes, subnarrative_classes)
test_data = prepare_data_for_split(test_df, narrative_classes, subnarrative_classes)

# Create datasets
train_narrative_dataset = NarrativeDataset(train_data, tokenizer, max_len=512, task_type='narrative')
val_narrative_dataset = NarrativeDataset(val_data, tokenizer, max_len=512, task_type='narrative')
test_narrative_dataset = NarrativeDataset(test_data, tokenizer, max_len=512, task_type='narrative')

train_subnarrative_dataset = NarrativeDataset(train_data, tokenizer, max_len=512, task_type='subnarrative')
val_subnarrative_dataset = NarrativeDataset(val_data, tokenizer, max_len=512, task_type='subnarrative')
test_subnarrative_dataset = NarrativeDataset(test_data, tokenizer, max_len=512, task_type='subnarrative')

# Create weighted samplers for training
narrative_sampler = create_weighted_sampler(train_data["narrative_labels"].tolist())
subnarrative_sampler = create_weighted_sampler(train_data["subnarrative_labels"].tolist())

# Create data loaders
train_narrative_loader = DataLoader(
    train_narrative_dataset, 
    batch_size=16,
    sampler=narrative_sampler,
    pin_memory=True
)

val_narrative_loader = DataLoader(
    val_narrative_dataset,
    batch_size=16,
    shuffle=False,
    pin_memory=True
)

test_narrative_loader = DataLoader(
    test_narrative_dataset,
    batch_size=16,
    shuffle=False,
    pin_memory=True
)

train_subnarrative_loader = DataLoader(
    train_subnarrative_dataset,
    batch_size=16,
    sampler=subnarrative_sampler,
    pin_memory=True
)

val_subnarrative_loader = DataLoader(
    val_subnarrative_dataset,
    batch_size=16,
    shuffle=False,
    pin_memory=True
)

test_subnarrative_loader = DataLoader(
    test_subnarrative_dataset,
    batch_size=16,
    shuffle=False,
    pin_memory=True
)

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize models with correct number of labels
# Initialize models with correct number of labels
narrative_model = BertForSequenceClassification.from_pretrained(
    "bert-base-multilingual-cased",
    num_labels=len(narrative_classes),
    problem_type="multi_label_classification"
).to(device)

subnarrative_model = BertForSequenceClassification.from_pretrained(
    "bert-base-multilingual-cased",
    num_labels=len(subnarrative_classes),
    problem_type="multi_label_classification"
).to(device)

# Train models with correct class labels
print("\nTraining Narrative Model...")
narrative_model, narrative_threshold = train_model_improved(
    narrative_model,
    train_narrative_loader,
    val_narrative_loader,
    device,
    'narrative',
    narrative_classes,
    subnarrative_classes,
    narrative_weights,
    num_epochs=5
)

print("\nTraining Subnarrative Model...")
subnarrative_model, subnarrative_threshold = train_model_improved(
    subnarrative_model,
    train_subnarrative_loader,
    val_subnarrative_loader,
    device,
    'subnarrative',
    narrative_classes,
    subnarrative_classes,
    subnarrative_weights,
    num_epochs=5
)


In [None]:
# Evaluate global performance on validation set
print("\nEvaluating global performance on validation set...")
val_global_scores = evaluate_global_performance(
    narrative_model, subnarrative_model,
    val_narrative_loader, val_subnarrative_loader,
    device
)

# Evaluate global performance on test set
print("\nEvaluating global performance on test set...")
test_global_scores = evaluate_global_performance(
    narrative_model, subnarrative_model,
    test_narrative_loader, test_subnarrative_loader,
    device,
    narrative_threshold=narrative_threshold,
    subnarrative_threshold=subnarrative_threshold
)


# Save final results
results = {
    'validation_scores': val_global_scores,
    'test_scores': test_global_scores,
    'narrative_threshold': narrative_threshold,
    'subnarrative_threshold': subnarrative_threshold,
    'model_parameters': {
        'narrative_labels': len(narrative_classes),
        'subnarrative_labels': len(subnarrative_classes)
    }
}

# Print final summary
print("\nFinal Performance Summary:")
print("\nValidation Set Performance:")
print(f"Global Macro F1: {val_global_scores['global_metrics']['global_macro_f1']:.4f}")
print(f"Hierarchical Accuracy: {val_global_scores['hierarchical_accuracy']:.4f}")

print("\nTest Set Performance:")
print(f"Global Macro F1: {test_global_scores['global_metrics']['global_macro_f1']:.4f}")
print(f"Hierarchical Accuracy: {test_global_scores['hierarchical_accuracy']:.4f}")