# Semeval 2025 Task 10
### Subtask 2: Narrative Classification

Given a news article and a [two-level taxonomy of narrative labels](https://propaganda.math.unipd.it/semeval2025task10/NARRATIVE-TAXONOMIES.pdf) (where each narrative is subdivided into subnarratives) from a particular domain, assign to the article all the appropriate subnarrative labels. This is a multi-label multi-class document classification task.

## 1. Multi-head per narrative model

### 1.1 Loading pre-saved variables

In [1]:
import pickle
import os

base_save_folder_dir = '../saved/'
dataset_folder = os.path.join(base_save_folder_dir, 'Dataset')

with open(os.path.join(dataset_folder, 'dataset.pkl'), 'rb') as f:
    dataset = pickle.load(f)

In [2]:
misc_folder = os.path.join(base_save_folder_dir, 'Misc')

with open(os.path.join(misc_folder, 'narrative_to_subnarratives.pkl'), 'rb') as f:
    narrative_to_subnarratives = pickle.load(f)

In [3]:
narrative_to_subnarratives

{'Discrediting Ukraine': ['Discrediting Ukrainian nation and society',
  'Ukraine is associated with nazism',
  'Ukraine is a hub for criminal activities',
  'Other',
  'Situation in Ukraine is hopeless',
  'Discrediting Ukrainian military',
  'Rewriting Ukraine’s history',
  'Ukraine is a puppet of the West',
  'Discrediting Ukrainian government and officials and policies'],
 'Discrediting the West, Diplomacy': ['The West is overreacting',
  'The West does not care about Ukraine, only about its interests',
  'The EU is divided',
  'Other',
  'Diplomacy does/will not work',
  'The West is weak',
  'West is tired of Ukraine'],
 'Praise of Russia': ['Praise of Russian military might',
  'Russia is a guarantor of peace and prosperity',
  'Russian invasion has strong national support',
  'Other',
  'Russia has international support from a number of countries and people',
  'Praise of Russian President Vladimir Putin'],
 'Russia is the Victim': ['Russia actions in Ukraine are only self-defe

In [4]:
label_encoder_folder = os.path.join(base_save_folder_dir, 'LabelEncoders')

with open(os.path.join(label_encoder_folder, 'mlb_narratives.pkl'), 'rb') as f:
    mlb_narratives = pickle.load(f)

with open(os.path.join(label_encoder_folder, 'mlb_subnarratives.pkl'), 'rb') as f:
    mlb_subnarratives = pickle.load(f)

In [5]:
import numpy as np

embeddings_folder = os.path.join(base_save_folder_dir, 'Embeddings/all_embeddings.npy')

def load_embeddings(filename):
    return np.load(filename)

all_embeddings = load_embeddings(embeddings_folder)

In [6]:
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
import numpy as np
import pandas as pd

def stratified_train_val_split_with_embeddings(data, embeddings, labels_column, train_size=0.8, splits=5, shuffle=True, min_instances=2):
    if shuffle:
        shuffled_indices = np.arange(len(data))
        np.random.shuffle(shuffled_indices)
        data = data.iloc[shuffled_indices].reset_index(drop=True)
        embeddings = embeddings[shuffled_indices]

    labels = np.array(data[labels_column].tolist())
    rare_indices = []
    common_indices = []

    class_counts = labels.sum(axis=0)
    rare_classes = np.where(class_counts <= min_instances)[0]

    for idx, label_row in enumerate(labels):
        if any(label_row[rare_classes]):
            rare_indices.append(idx)
        else:
            common_indices.append(idx)

    rare_data = data.iloc[rare_indices]
    rare_labels = labels[rare_indices]
    rare_embeddings = embeddings[rare_indices]

    train_rare = rare_data.iloc[:len(rare_data) // 2].reset_index(drop=True)
    val_rare = rare_data.iloc[len(rare_data) // 2:].reset_index(drop=True)

    train_rare_embeddings = rare_embeddings[:len(rare_data) // 2]
    val_rare_embeddings = rare_embeddings[len(rare_data) // 2:]

    common_data = data.iloc[common_indices].reset_index(drop=True)
    common_labels = labels[common_indices]
    common_embeddings = embeddings[common_indices]

    mskf = MultilabelStratifiedKFold(n_splits=splits)
    for train_idx, val_idx in mskf.split(np.zeros(len(common_labels)), common_labels):
        train_common = common_data.iloc[train_idx]
        val_common = common_data.iloc[val_idx]
        train_common_embeddings = common_embeddings[train_idx]
        val_common_embeddings = common_embeddings[val_idx]
        break

    train_data = pd.concat([train_rare, train_common]).reset_index(drop=True)
    val_data = pd.concat([val_rare, val_common]).reset_index(drop=True)

    train_embeddings = np.concatenate([train_rare_embeddings, train_common_embeddings], axis=0)
    val_embeddings = np.concatenate([val_rare_embeddings, val_common_embeddings], axis=0)

    return (train_data, train_embeddings), (val_data, val_embeddings)

(dataset_train, train_embeddings), (dataset_val, val_embeddings) = stratified_train_val_split_with_embeddings(
    dataset,
    all_embeddings,
    labels_column="subnarratives_encoded",
    min_instances=2
)

### 1.2 Remapping our subnarrative indices

In [7]:
dataset['subnarratives_encoded']

0       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
1       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
2       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
3       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
4       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
                              ...                        
1694    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
1695    [0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, ...
1696    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ...
1697    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
1698    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
Name: subnarratives_encoded, Length: 1699, dtype: object

In [8]:
narrative_to_sub_map = {}
narrative_classes = list(mlb_narratives.classes_)
subnarrative_classes = list(mlb_subnarratives.classes_)

for narrative, subnarratives in narrative_to_subnarratives.items():
    narrative_idx = narrative_classes.index(narrative)
    subnarrative_indices = [subnarrative_classes.index(sub) for sub in subnarratives]
    narrative_to_sub_map[narrative_idx] = subnarrative_indices

print(narrative_to_sub_map)

{8: [22, 66, 64, 33, 50, 21, 39, 65, 20], 9: [58, 56, 53, 33, 19, 60, 71], 17: [35, 42, 46, 33, 41, 34], 19: [40, 33, 63, 59], 10: [72, 33, 69], 1: [43, 3, 62, 33, 31], 16: [32, 33, 57, 55], 2: [67, 33, 54], 20: [33, 44, 45, 68], 13: [2, 33, 6], 14: [47, 33, 61], 0: [73, 24, 23, 1, 33], 15: [33], 7: [14, 33, 17, 16, 15], 5: [8, 9, 33, 0], 11: [29, 51, 7, 33, 28, 27, 4, 70, 49], 6: [33, 11, 12, 10], 18: [48, 18, 33, 26, 30], 3: [33, 52, 5], 4: [33, 36, 37, 38], 12: [25, 13, 33]}


In [9]:
narrative_to_sub_map

{8: [22, 66, 64, 33, 50, 21, 39, 65, 20],
 9: [58, 56, 53, 33, 19, 60, 71],
 17: [35, 42, 46, 33, 41, 34],
 19: [40, 33, 63, 59],
 10: [72, 33, 69],
 1: [43, 3, 62, 33, 31],
 16: [32, 33, 57, 55],
 2: [67, 33, 54],
 20: [33, 44, 45, 68],
 13: [2, 33, 6],
 14: [47, 33, 61],
 0: [73, 24, 23, 1, 33],
 15: [33],
 7: [14, 33, 17, 16, 15],
 5: [8, 9, 33, 0],
 11: [29, 51, 7, 33, 28, 27, 4, 70, 49],
 6: [33, 11, 12, 10],
 18: [48, 18, 33, 26, 30],
 3: [33, 52, 5],
 4: [33, 36, 37, 38],
 12: [25, 13, 33]}

In [10]:
def remap_subnarratives(row, narrative_to_sub_map):
    """Takes in a row and encodes the current subnarrative list to the associated hierarchy based on the narr-subnar map"""
    for narr_idx, sub_indices in narrative_to_sub_map.items():
        sub_labels = [row['subnarratives_encoded'][sub_idx] for sub_idx in sub_indices]
        col_name = f"narrative_hierarchy_{narr_idx}"
        row[col_name] = sub_labels
    return row

dataset_train_cpy = dataset_train.apply(remap_subnarratives, axis=1, args=(narrative_to_sub_map,)).copy()

In [11]:
dataset_val_cpy = dataset_val.apply(remap_subnarratives, axis=1, args=(narrative_to_sub_map,)).copy()

In [12]:
dataset_val_cpy.head()

Unnamed: 0,language,article_id,content,narratives,subnarratives,narratives_encoded,subnarratives_encoded,narrative_hierarchy_8,narrative_hierarchy_9,narrative_hierarchy_17,...,narrative_hierarchy_0,narrative_hierarchy_15,narrative_hierarchy_7,narrative_hierarchy_5,narrative_hierarchy_11,narrative_hierarchy_6,narrative_hierarchy_18,narrative_hierarchy_3,narrative_hierarchy_4,narrative_hierarchy_12
0,EN,EN_CC_200022.txt,<PARA>Denmark to Punish Farmers for cow ‘emiss...,"[Criticism of institutions and authorities, Cr...","[Criticism of national governments, Other, Met...","[0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0, 0, 0, 1, 0, 0, 0, 0, 0]","[0, 0, 0, 1, 0, 0, 0]","[0, 0, 0, 1, 0, 0]",...,"[0, 0, 0, 0, 1]",[1],"[1, 1, 0, 0, 1]","[0, 0, 1, 0]","[0, 0, 0, 1, 0, 0, 0, 0, 0]","[1, 0, 0, 0]","[0, 0, 1, 1, 1]","[1, 0, 0]","[1, 0, 0, 0]","[0, 0, 1]"
1,EN,EN_CC_200221.txt,<PARA>“the hour of decision”</PARA>\n\nshortly...,[Hidden plots by secret schemes of powerful gr...,"[Blaming global elites, Other, Other, CO2 conc...","[0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, ...","[0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 1, 0, 0, 0, 0, 0]","[0, 0, 0, 1, 0, 0, 0]","[0, 0, 0, 1, 0, 0]",...,"[0, 0, 0, 0, 1]",[1],"[0, 1, 0, 1, 0]","[0, 0, 1, 0]","[0, 0, 0, 1, 0, 0, 1, 0, 1]","[1, 0, 0, 0]","[1, 0, 1, 0, 0]","[1, 0, 0]","[1, 0, 0, 0]","[0, 0, 1]"
2,EN,EN_CC_200110.txt,<PARA>if democrats successfully ban gas stoves...,"[Criticism of climate policies, Criticism of i...",[Climate policies have negative impact on the ...,"[0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, ...","[0, 0, 0, 1, 0, 0, 0, 0, 0]","[0, 0, 0, 1, 0, 0, 0]","[0, 0, 0, 1, 0, 0]",...,"[0, 0, 0, 0, 1]",[1],"[0, 1, 0, 1, 1]","[0, 0, 1, 0]","[0, 0, 0, 1, 0, 0, 0, 0, 0]","[1, 0, 1, 1]","[0, 0, 1, 0, 0]","[1, 0, 0]","[1, 1, 0, 0]","[0, 0, 1]"
3,HI,HI_173.txt,<PARA>यूक्रेन के बढ़ते हमलों के बीच रूस ने क्य...,"[Praise of Russia, Russia is the Victim, Blami...",[Russia is a guarantor of peace and prosperity...,"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0]","[0, 0, 0, 0, 0, 0, 0]","[1, 1, 0, 0, 0, 1]",...,"[0, 0, 0, 0, 0]",[0],"[0, 0, 0, 0, 0]","[0, 0, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0]","[0, 0, 0, 0]","[0, 0, 0, 0, 0]","[0, 0, 0]","[0, 0, 0, 0]","[0, 0, 0]"
4,PT,PT_130.txt,<PARA>hungria mantém veto de fundos da UE a Ky...,"[Discrediting Ukraine, Discrediting Ukraine]","[Other, Discrediting Ukrainian government and ...","[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 1, 0, 0, 0, 0, 1]","[0, 0, 0, 1, 0, 0, 0]","[0, 0, 0, 1, 0, 0]",...,"[0, 0, 0, 0, 1]",[1],"[0, 1, 0, 0, 0]","[0, 0, 1, 0]","[0, 0, 0, 1, 0, 0, 0, 0, 0]","[1, 0, 0, 0]","[0, 0, 1, 0, 0]","[1, 0, 0]","[1, 0, 0, 0]","[0, 0, 1]"


In [13]:
for narr_idx, sub_indices in narrative_to_sub_map.items():
    column_name = f"narrative_hierarchy_{narr_idx}"
    res = dataset_train_cpy[column_name]
    print(res)

0       [0, 0, 0, 0, 0, 0, 0, 0, 0]
1       [0, 0, 0, 1, 0, 0, 0, 0, 0]
2       [0, 0, 0, 1, 0, 0, 0, 0, 0]
3       [0, 0, 0, 0, 0, 0, 0, 0, 0]
4       [0, 0, 0, 1, 0, 0, 0, 0, 0]
                   ...             
1359    [0, 0, 0, 0, 1, 0, 0, 0, 0]
1360    [0, 0, 0, 0, 0, 0, 0, 0, 0]
1361    [0, 0, 0, 1, 0, 0, 0, 0, 0]
1362    [0, 0, 0, 0, 0, 0, 0, 0, 0]
1363    [0, 0, 0, 1, 0, 0, 0, 0, 0]
Name: narrative_hierarchy_8, Length: 1364, dtype: object
0       [0, 0, 0, 0, 0, 0, 0]
1       [0, 0, 0, 1, 0, 0, 0]
2       [0, 0, 0, 1, 0, 0, 0]
3       [0, 0, 0, 0, 0, 0, 0]
4       [0, 0, 0, 1, 0, 0, 0]
                ...          
1359    [0, 0, 0, 0, 0, 0, 0]
1360    [0, 0, 0, 0, 0, 0, 0]
1361    [0, 0, 0, 1, 0, 0, 0]
1362    [0, 0, 0, 0, 0, 0, 0]
1363    [0, 0, 0, 1, 0, 0, 0]
Name: narrative_hierarchy_9, Length: 1364, dtype: object
0       [0, 0, 0, 0, 0, 0]
1       [0, 0, 0, 1, 0, 0]
2       [0, 0, 0, 1, 0, 0]
3       [0, 0, 0, 0, 0, 0]
4       [0, 0, 0, 1, 0, 0]
               ...       

In [14]:
# Sort order of narratives to start from hierarchy 0
narrative_order = sorted(narrative_to_sub_map.keys())

In [15]:
narrative_order

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]

In [16]:
def aggregate_subnarratives(row, narrative_order, narrative_to_sub_map):
    """Takes in a row, and aggregates all hierarchy columns to 1 list.
    The encoded list will be a list of lists, starting from the first hierarchy"""
    aggregated = []
    for narr_idx in narrative_order:
        column_name = f"narrative_hierarchy_{narr_idx}"
        sub_labels = row[column_name]
        aggregated.append(sub_labels)
    return aggregated

dataset_train['aggregated_subnarratives'] = dataset_train_cpy.apply(
    aggregate_subnarratives,
    axis=1,
    args=(narrative_order, narrative_to_sub_map)
)

dataset_val['aggregated_subnarratives'] = dataset_val_cpy.apply(
    aggregate_subnarratives,
    axis=1,
    args=(narrative_order, narrative_to_sub_map)
)

In [17]:
dataset_train['aggregated_subnarratives']

0       [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0], ...
1       [[0, 0, 0, 0, 1], [0, 0, 0, 1, 0], [0, 1, 0], ...
2       [[0, 0, 0, 0, 1], [0, 0, 0, 1, 0], [0, 1, 0], ...
3       [[0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0], ...
4       [[0, 0, 0, 0, 1], [0, 0, 0, 1, 0], [0, 1, 0], ...
                              ...                        
1359    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0], ...
1360    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0], ...
1361    [[0, 0, 0, 0, 1], [0, 0, 0, 1, 0], [0, 1, 0], ...
1362    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0], ...
1363    [[0, 0, 0, 0, 1], [0, 0, 0, 1, 0], [0, 1, 0], ...
Name: aggregated_subnarratives, Length: 1364, dtype: object

In [18]:
y_train_sub_heads = dataset_train['aggregated_subnarratives'].to_numpy()
y_val_sub_heads = dataset_val['aggregated_subnarratives'].to_numpy()

In [19]:
import torch

train_embeddings_tensor = torch.tensor(train_embeddings, dtype=torch.float32)
val_embeddings_tensor = torch.tensor(val_embeddings, dtype=torch.float32)

In [20]:
input_size = train_embeddings_tensor.shape[1]
print(input_size)

896


In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiTaskClassifierMultiHead(nn.Module):
    def __init__(
        self,
        input_size,
        hidden_size,
        num_narratives=len(mlb_narratives.classes_),
        narrative_to_sub_map=narrative_to_sub_map,
        dropout_rate=0.3
    ):
        super().__init__()
        # Shared layer
        self.shared_layer = nn.Sequential(
            nn.Linear(input_size, hidden_size * 2),
            nn.BatchNorm1d(hidden_size * 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )

        # Top-level narratives: multi-label => Sigmoid
        self.narrative_head = nn.Sequential(
            nn.Linear(hidden_size * 2, num_narratives),
            nn.Sigmoid()
        )

        # Subnarrative heads: multi-label => Sigmoid
        self.subnarrative_heads = nn.ModuleDict()
        for narr_idx, sub_indices in narrative_to_sub_map.items():
            num_subs_for_this_narr = len(sub_indices)
            self.subnarrative_heads[str(narr_idx)] = nn.Sequential(
                nn.Linear(hidden_size * 2, num_subs_for_this_narr),
                nn.Sigmoid()
            )

    def forward(self, x):
        shared_out = self.shared_layer(x)
        narr_probs = self.narrative_head(shared_out)

        sub_probs_dict = {}
        for narr_idx, head in self.subnarrative_heads.items():
            sub_probs_dict[narr_idx] = head(shared_out)

        return narr_probs, sub_probs_dict

In [22]:
model_multi_head = MultiTaskClassifierMultiHead(
    input_size=input_size,
    hidden_size=512,
)

In [23]:
y_train_nar = dataset_train['narratives_encoded'].tolist()
y_val_nar = dataset_val['narratives_encoded'].tolist()

y_train_sub_nar = dataset_train['subnarratives_encoded'].tolist()
y_val_sub_nar = dataset_val['subnarratives_encoded'].tolist()

In [24]:
y_train_nar = torch.tensor(y_train_nar, dtype=torch.float32)
y_train_sub_nar = torch.tensor(y_train_sub_nar, dtype=torch.float32)

y_val_nar = torch.tensor(y_val_nar, dtype=torch.float32)
y_val_sub_nar = torch.tensor(y_val_sub_nar, dtype=torch.float32)

In [25]:
train_embeddings_tensor = torch.tensor(train_embeddings, dtype=torch.float32)
val_embeddings_tensor = torch.tensor(val_embeddings, dtype=torch.float32)

In [26]:
import torch
import torch.nn as nn

def compute_class_weights(y_train):
    total_samples = y_train.shape[0]
    class_weights = []
    for label in range(y_train.shape[1]):
        pos_count = y_train[:, label].sum().item()
        neg_count = total_samples - pos_count
        pos_weight = total_samples / (2 * pos_count) if pos_count > 0 else 0
        neg_weight = total_samples / (2 * neg_count) if neg_count > 0 else 0
        class_weights.append((pos_weight, neg_weight))
    return class_weights

class WeightedBCELoss(nn.Module):
    def __init__(self, class_weights):
        super().__init__()
        self.class_weights = class_weights

    def forward(self, probs, targets):
        bce_loss = 0
        epsilon = 1e-7
        for i, (pos_weight, neg_weight) in enumerate(self.class_weights):
            prob = probs[:, i]
            bce = -pos_weight * targets[:, i] * torch.log(prob + epsilon) - \
                  neg_weight * (1 - targets[:, i]) * torch.log(1 - prob + epsilon)
            bce_loss += bce.mean()
        return bce_loss / len(self.class_weights)

class_weights_sub_nar = compute_class_weights(y_val_sub_nar)
class_weights_nar = compute_class_weights(y_val_nar)
narrative_criterion = WeightedBCELoss(class_weights_nar)

In [27]:
# For each subnarrative head, add a weighted version of BCE based on the indices
sub_criterion_dict = {}

for narr_idx, sub_indices in narrative_to_sub_map.items():
    local_weights = [ class_weights_sub_nar[sub_i] for sub_i in sub_indices ]

    sub_criterion = WeightedBCELoss(local_weights)
    sub_criterion_dict[str(narr_idx)] = sub_criterion

In [28]:
def multi_head_loss(narr_probs, sub_probs_dict, y_narr, y_sub_heads):
    narr_loss = narrative_criterion(narr_probs, y_narr)

    sub_loss = 0.0
    count_active = 0
    i = 0
    for narr_idx_str, sub_probs in sub_probs_dict.items():
        narr_idx = int(narr_idx_str)
        # Find the true subnarratives for the batch
        y_sub = [row[narr_idx] for row in y_sub_heads]
        y_sub_tensor = torch.tensor(y_sub, dtype=torch.float32)

        sub_loss_func = sub_criterion_dict[narr_idx_str]
        ce_loss = sub_loss_func(sub_probs, y_sub_tensor)

        sub_loss += ce_loss
        count_active += 1
        i += 1

    if count_active > 0:
        sub_loss = sub_loss / count_active
    else:
        sub_loss = 0.0

    total_loss = narr_loss + sub_loss

    return total_loss

In [29]:
def train_with_multihead(
    model,
    optimizer,
    narrative_criterion,
    train_embeddings=train_embeddings_tensor,
    y_train_nar=y_train_nar,
    y_train_sub_heads=y_train_sub_heads,
    val_embeddings=val_embeddings_tensor,
    y_val_nar=y_val_nar,
    y_val_sub_heads=y_val_sub_heads,
    patience=3,
    num_epochs=100,
):
    best_val_loss = float('inf')
    best_model = None
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        train_narr_probs, train_sub_probs_dict = model(train_embeddings)
        train_loss = multi_head_loss(train_narr_probs, train_sub_probs_dict, y_train_nar, y_train_sub_heads)

        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        model.eval()
        with torch.no_grad():
            val_narr_probs, val_sub_probs_dict = model(val_embeddings)
            val_loss = multi_head_loss(val_narr_probs, val_sub_probs_dict, y_val_nar, y_val_sub_heads)

        print(f"Epoch {epoch+1}/{num_epochs}, "
              f"Training Loss: {train_loss.item():.4f} "
              f"Validation Loss: {val_loss.item():.4f} ")

        if val_loss.item() < best_val_loss:
            best_val_loss = val_loss.item()
            patience_counter = 0
            best_model = model.state_dict()
        else:
            patience_counter += 1
            print(f"Validation loss did not improve for {patience_counter} epoch(s).")

        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

    if best_model:
        model.load_state_dict(best_model)

    return model

In [30]:
optimizer_multi_head = torch.optim.AdamW(model_multi_head.parameters(), lr=0.001)

In [31]:
train_with_multihead(
    model=model_multi_head,
    optimizer=optimizer_multi_head,
    narrative_criterion=narrative_criterion,
)

Epoch 1/100, Training Loss: 1.4292 Validation Loss: 1.3683 
Epoch 2/100, Training Loss: 1.1706 Validation Loss: 1.3606 
Epoch 3/100, Training Loss: 1.0477 Validation Loss: 1.3530 
Epoch 4/100, Training Loss: 0.9628 Validation Loss: 1.3454 
Epoch 5/100, Training Loss: 0.8954 Validation Loss: 1.3376 
Epoch 6/100, Training Loss: 0.8417 Validation Loss: 1.3297 
Epoch 7/100, Training Loss: 0.7930 Validation Loss: 1.3215 
Epoch 8/100, Training Loss: 0.7576 Validation Loss: 1.3129 
Epoch 9/100, Training Loss: 0.7215 Validation Loss: 1.3038 
Epoch 10/100, Training Loss: 0.6884 Validation Loss: 1.2941 
Epoch 11/100, Training Loss: 0.6615 Validation Loss: 1.2841 
Epoch 12/100, Training Loss: 0.6369 Validation Loss: 1.2740 
Epoch 13/100, Training Loss: 0.6142 Validation Loss: 1.2637 
Epoch 14/100, Training Loss: 0.5916 Validation Loss: 1.2533 
Epoch 15/100, Training Loss: 0.5720 Validation Loss: 1.2430 
Epoch 16/100, Training Loss: 0.5552 Validation Loss: 1.2325 
Epoch 17/100, Training Loss: 0.53

MultiTaskClassifierMultiHead(
  (shared_layer): Sequential(
    (0): Linear(in_features=896, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.3, inplace=False)
  )
  (narrative_head): Sequential(
    (0): Linear(in_features=1024, out_features=21, bias=True)
    (1): Sigmoid()
  )
  (subnarrative_heads): ModuleDict(
    (8): Sequential(
      (0): Linear(in_features=1024, out_features=9, bias=True)
      (1): Sigmoid()
    )
    (9): Sequential(
      (0): Linear(in_features=1024, out_features=7, bias=True)
      (1): Sigmoid()
    )
    (17): Sequential(
      (0): Linear(in_features=1024, out_features=6, bias=True)
      (1): Sigmoid()
    )
    (19): Sequential(
      (0): Linear(in_features=1024, out_features=4, bias=True)
      (1): Sigmoid()
    )
    (10): Sequential(
      (0): Linear(in_features=1024, out_features=3, bias=True)
      (1): Sigmoid()
    )
    (1): Sequent

In [32]:
import numpy as np
import torch
from sklearn.metrics import classification_report, f1_score

def evaluate_multihead_model(
    model,
    embeddings,
    y_nar_true,
    y_sub_hierarchical,
    num_subnarratives = len(mlb_subnarratives.classes_),
    thresholds = np.arange(0.1, 1.0, 0.1),
    target_names_nar=mlb_narratives.classes_,
    target_names_sub=mlb_subnarratives.classes_,
    device='cpu',
):

    def build_global_sub_array(
        y_sub_hierarchical,
        num_subnarratives=74,
        narrative_to_sub_map=narrative_to_sub_map,
        narrative_order=narrative_order,
    ):
        """Reconstructs the subnarratives to flatten them (again) to a single array for evaluation"""
        num_samples = len(y_sub_hierarchical)
        sub_global_array = np.zeros((num_samples, num_subnarratives), dtype=int)

        for i in range(num_samples):
            for j, narr_idx in enumerate(narrative_order):
                sub_label_vec = y_sub_hierarchical[i][j]
                narr_idx = int(narr_idx)
                sub_indices = narrative_to_sub_map[narr_idx]
                for local_sub_i, global_sub_i in enumerate(sub_indices):
                    sub_global_array[i, global_sub_i] = sub_label_vec[local_sub_i]

        return sub_global_array

    embeddings = embeddings.to(device)
    y_nar_true_np = y_nar_true.cpu().numpy()

    best_threshold = 0
    best_f1 = -1
    best_report_nar = None
    best_report_sub = None
    samples = len(embeddings)

    with torch.no_grad():
        # get the predictions for both
        narr_probs, sub_probs_dict = model(embeddings)

        narr_probs = narr_probs.cpu().numpy()
        for k in sub_probs_dict:
            sub_probs_dict[k] = sub_probs_dict[k].cpu().numpy()

    for threshold in thresholds:
        narr_preds = (narr_probs >= threshold).astype(int)

        # Need to reconstruct the subnarratives to flatten them (again) to a single array for evaluation
        sub_preds_global = np.zeros((samples, num_subnarratives), dtype=int)

        for narr_idx, sub_indices in narrative_to_sub_map.items():
            # Get the predictions for this narrative hierarchy
            sub_probs_for_narr = sub_probs_dict[str(narr_idx)]
            # If top-level narrative is 1, then threshold subnarratives; otherwise 0.
            # Finds for each sample, go to the narr_idx position (the hierarchy we are at)
            predicted_narr_mask = narr_preds[:, narr_idx] == 1  # shape (num_samples,)

            # For all samples, threshold sub_probs_for_narr:
            sub_preds_for_narr = (sub_probs_for_narr >= threshold).astype(int)

            # But only keep sub_preds_for_narr if predicted_narr_mask is True:
            # If predicted_narr_mask is False for a sample, subnarratives go to 0.
            for sample_idx in range(samples):
                if predicted_narr_mask[sample_idx] == 1:
                    # Construct the flattened pred array
                    for local_sub_i, global_sub_i in enumerate(sub_indices):
                        sub_preds_global[sample_idx, global_sub_i] = sub_preds_for_narr[sample_idx, local_sub_i]
                else:
                    continue

        f1_nar = f1_score(y_nar_true_np, narr_preds, average="macro", zero_division=0)

        # Also flatten the true y_sub to a single array in the same way as we did with the predictions
        y_sub_true_np = build_global_sub_array(y_sub_hierarchical, num_subnarratives=num_subnarratives)

        f1_sub = f1_score(y_sub_true_np, sub_preds_global, average="macro", zero_division=0)

        avg_f1 = (f1_nar + f1_sub) / 2.0

        if avg_f1 > best_f1:
            best_f1 = avg_f1
            best_threshold = threshold

            report_nar = classification_report(
                y_nar_true_np,
                narr_preds,
                target_names=target_names_nar,
                zero_division=0
            )
            report_sub = classification_report(
                y_sub_true_np,
                sub_preds_global,
                target_names=target_names_sub,
                zero_division=0
            )
            best_report_nar = report_nar
            best_report_sub = report_sub

    print(f"Best threshold = {best_threshold:.2f}, best (avg) F1 = {best_f1:.4f}")
    print("Best Narratives classification report:")
    print(best_report_nar)
    print("Best Subnarratives classification report:")
    print(best_report_sub)

In [33]:
evaluate_multihead_model(
    model=model_multi_head,
    embeddings=val_embeddings_tensor,
    y_nar_true=y_val_nar,
    y_sub_hierarchical=y_val_sub_heads,
)

Best threshold = 0.50, best (avg) F1 = 0.3827
Best Narratives classification report:
                                                   precision    recall  f1-score   support

                         Amplifying Climate Fears       0.81      1.00      0.90        48
                     Amplifying war-related fears       0.69      0.69      0.69        49
Blaming the war on others rather than the invader       0.27      0.38      0.32        39
                     Climate change is beneficial       0.00      0.00      0.00         2
             Controversy about green technologies       0.40      1.00      0.57         4
                    Criticism of climate movement       0.26      0.70      0.38        10
                    Criticism of climate policies       0.50      0.58      0.54        19
        Criticism of institutions and authorities       0.60      0.77      0.68        35
                             Discrediting Ukraine       0.70      0.79      0.74        85
    