# Semeval 2025 Task 10
### Subtask 2: Narrative Classification

Given a news article and a [two-level taxonomy of narrative labels](https://propaganda.math.unipd.it/semeval2025task10/NARRATIVE-TAXONOMIES.pdf) (where each narrative is subdivided into subnarratives) from a particular domain, assign to the article all the appropriate subnarrative labels. This is a multi-label multi-class document classification task.

## 1. Multi-head per narrative model

### 1.1 Loading pre-saved variables

We start by loading our pre-saved variables:

In [1]:
import pickle
import os
import pandas as pd

root_dir = "../../"
base_save_folder_dir = '../saved/'
dataset_folder = os.path.join(base_save_folder_dir, 'Dataset')

with open(os.path.join(dataset_folder, 'dataset_cleaned.pkl'), 'rb') as f:
    dataset_train = pickle.load(f)

In [2]:
dataset_train.head()

Unnamed: 0,language,article_id,content,narratives,subnarratives,narratives_encoded,subnarratives_encoded
0,RU,RU-URW-1161.txt,<PARA>в ближайшие два месяца сша будут стремит...,[URW: Blaming the war on others rather than th...,"[The West are the aggressors, Other, The West ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,RU,RU-URW-1175.txt,<PARA>в ес испугались последствий популярности...,"[URW: Discrediting the West, Diplomacy, URW: D...","[The West is weak, Other, The EU is divided]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,RU,RU-URW-1149.txt,<PARA>возможность признания аллы пугачевой ино...,[URW: Distrust towards Media],[Western media is an instrument of propaganda],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,RU,RU-URW-1015.txt,<PARA>азаров рассказал о смене риторики киева ...,"[URW: Discrediting Ukraine, URW: Discrediting ...","[Ukraine is a puppet of the West, Discrediting...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,RU,RU-URW-1001.txt,<PARA>в россиянах проснулась массовая любовь к...,[URW: Praise of Russia],[Russia is a guarantor of peace and prosperity],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [3]:
misc_folder = os.path.join(base_save_folder_dir, 'Misc')

with open(os.path.join(misc_folder, 'narrative_to_subnarratives.pkl'), 'rb') as f:
    narrative_to_subnarratives = pickle.load(f)

In [4]:
dataset_train.shape

(1699, 7)

We'll also need the actual hierarchy of narratives to subnarratives for our new model.  

* Each narrative is also mapped to `Other`—this happens because if no subnarrative matches, we assign it to `Other`.

In [5]:
narrative_to_subnarratives

{'URW: Discrediting Ukraine': ['Ukraine is a puppet of the West',
  'Ukraine is associated with nazism',
  'Discrediting Ukrainian government and officials and policies',
  'Rewriting Ukraine’s history',
  'Discrediting Ukrainian military',
  'Other',
  'Ukraine is a hub for criminal activities',
  'Situation in Ukraine is hopeless',
  'Discrediting Ukrainian nation and society'],
 'URW: Discrediting the West, Diplomacy': ['The EU is divided',
  'The West is overreacting',
  'The West is weak',
  'Diplomacy does/will not work',
  'Other',
  'The West does not care about Ukraine, only about its interests',
  'West is tired of Ukraine'],
 'URW: Praise of Russia': ['Praise of Russian President Vladimir Putin',
  'Russian invasion has strong national support',
  'Praise of Russian military might',
  'Russia is a guarantor of peace and prosperity',
  'Other',
  'Russia has international support from a number of countries and people'],
 'URW: Russia is the Victim': ['Other',
  'UA is anti-RU

In [6]:
label_encoder_folder = os.path.join(base_save_folder_dir, 'LabelEncoders')

with open(os.path.join(label_encoder_folder, 'mlb_narratives.pkl'), 'rb') as f:
    mlb_narratives = pickle.load(f)

with open(os.path.join(label_encoder_folder, 'mlb_subnarratives.pkl'), 'rb') as f:
    mlb_subnarratives = pickle.load(f)

Finally, we get our embeddings:

In [7]:
import numpy as np

embeddings_folder = os.path.join(base_save_folder_dir, 'Embeddings/all_embeddings.npy')

def load_embeddings(filename):
    return np.load(filename)

train_embeddings = load_embeddings(embeddings_folder)

In [8]:
dataset_train.head()

Unnamed: 0,language,article_id,content,narratives,subnarratives,narratives_encoded,subnarratives_encoded
0,RU,RU-URW-1161.txt,<PARA>в ближайшие два месяца сша будут стремит...,[URW: Blaming the war on others rather than th...,"[The West are the aggressors, Other, The West ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,RU,RU-URW-1175.txt,<PARA>в ес испугались последствий популярности...,"[URW: Discrediting the West, Diplomacy, URW: D...","[The West is weak, Other, The EU is divided]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,RU,RU-URW-1149.txt,<PARA>возможность признания аллы пугачевой ино...,[URW: Distrust towards Media],[Western media is an instrument of propaganda],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,RU,RU-URW-1015.txt,<PARA>азаров рассказал о смене риторики киева ...,"[URW: Discrediting Ukraine, URW: Discrediting ...","[Ukraine is a puppet of the West, Discrediting...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,RU,RU-URW-1001.txt,<PARA>в россиянах проснулась массовая любовь к...,[URW: Praise of Russia],[Russia is a guarantor of peace and prosperity],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


We also need to make sure that the embeddings stay aligned with the dataset split.  
* We pass `all_embeddings` so that `dataset_train` and `train_embeddings` match up exactly, keeping everything consistent.

We use stratified splitting here to ensure the label distribution stays the same in both the training and validation sets.  
* This somewhat maintain the class proportions, even for the rare cases, making sure both sets are roughly representative of the original dataset.

You're training on the full multilingual dataset and using the English-only validation set provided by the challenge. This could feel wrong because typically we want our validation set to have the same distribution as our training se

The challenge's ultimate goal is English performance
The validation set they provided is specifically for development/tuning
Using this English validation set helps you track what really matters - how well your model performs on English content

In [9]:
train_embeddings.shape

(1699, 896)

In [10]:
dataset_train.shape

(1699, 7)

In [11]:
with open(os.path.join(dataset_folder, 'dataset_val_cleaned.pkl'), 'rb') as f:
    dataset_val = pickle.load(f)

In [12]:
dataset_val.shape

(41, 7)

In [13]:
dataset_val.head()

Unnamed: 0,language,article_id,content,narratives,subnarratives,narratives_encoded,subnarratives_encoded
0,EN,EN_UA_DEV_100029.txt,<PARA>general Milley: russian military stocks ...,[URW: Speculating war outcomes],[Russian army is collapsing],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,EN,EN_UA_DEV_100002.txt,"<PARA>Ukrainian nationalism, ukrainian patriot...","[URW: Discrediting Ukraine, URW: Speculating w...","[Discrediting Ukrainian nation and society, Uk...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,EN,EN_UA_DEV_100003.txt,"<PARA>medvedev: Russia seeks more in Ukraine, ...",[Other],[Other],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,EN,EN_UA_DEV_100013.txt,<PARA>former commander-in-chief of Ukrainian A...,"[URW: Discrediting Ukraine, URW: Blaming the w...",[Discrediting Ukrainian government and officia...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,EN,EN_UA_DEV_100012.txt,<PARA>Ukraine's minerals: what the west is fig...,"[URW: Discrediting the West, Diplomacy, URW: D...","[The West does not care about Ukraine, only ab...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [14]:
embeddings_folder = os.path.join(base_save_folder_dir, 'Embeddings/dev_embeddings_english.npy')

val_embeddings = load_embeddings(embeddings_folder)

In [15]:
val_embeddings.shape

(41, 896)

In [16]:
def custom_shuffling(data, embeddings):
    shuffled_indices = np.arange(len(data))
    np.random.shuffle(shuffled_indices)
    data = data.iloc[shuffled_indices].reset_index(drop=True)
    embeddings = embeddings[shuffled_indices]

    return data, embeddings

In [17]:
dataset_train, train_embeddings = custom_shuffling(dataset_train, train_embeddings)

In [18]:
dataset_val, val_embeddings = custom_shuffling(dataset_val, val_embeddings)

### 1.2 Remapping our subnarrative indices

We know that our articls have many narratives, and each one maps to several subnarratives, creating a hierarchy.  
The problem is, our `subnarratives_encoded` currently looks like a flat list of zeros:

```
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
```

But we need it to reflect the hierarchy properly:

So, we break it down into a list of lists—each inner list represents the true labels for a specific hierarchy:

```
[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0] , [0, 0, 0, ...
 ^ hierarchy 0       ^ hierarchy 1          ^ hierarchy 2 ...
```

This will help us significantly later when we need to know for a specific article, the true subnarrative labels for a specific hierarchy.

We’re using the label encoders to get the indices of narratives and subnarratives, which we’ll use later.  
* For each narrative in `narrative_to_subnarratives`, we find the index of the narrative and its corresponding subnarratives using the encoders.

In [19]:
misc_folder = os.path.join(base_save_folder_dir, 'Misc')

In [20]:
narrative_to_sub_map = {}
narrative_classes = list(mlb_narratives.classes_)
subnarrative_classes = list(mlb_subnarratives.classes_)

for narrative, subnarratives in narrative_to_subnarratives.items():
    narrative_idx = narrative_classes.index(narrative)
    subnarrative_indices = [subnarrative_classes.index(sub) for sub in subnarratives]
    narrative_to_sub_map[narrative_idx] = subnarrative_indices

print(narrative_to_sub_map)

{13: [65, 66, 20, 39, 21, 33, 64, 50, 22], 14: [53, 58, 60, 19, 33, 56, 71], 19: [34, 46, 35, 42, 33, 41], 20: [33, 63, 40, 59], 15: [33, 69, 72], 11: [31, 33, 43, 62, 3], 18: [33, 32, 55, 57], 12: [54, 33, 67], 21: [33, 45, 68, 44], 16: [33], 17: [33, 47, 61], 0: [23, 1, 73, 33, 24], 10: [33], 5: [17, 15, 16, 33, 14], 3: [33, 9, 0, 8], 6: [4, 49, 28, 51, 27, 33, 70, 7, 29], 4: [33, 11, 10, 12], 9: [18, 33, 30, 26, 48], 8: [33, 6, 2], 1: [33, 52, 5], 2: [33, 38, 36, 37], 7: [33, 25, 13]}


Now, we remap the `subnarratives_encoded` list to reflect the correct hierarchy for each article.  
* For each narrative, we grab its corresponding subnarrative indices from `narrative_to_sub_map` and assign the sublabels to the appropriate hierarchy column.  

This will give us a new set of columns where each one contains the true subnarrative labels for that narrative hierarchy.

In [21]:
hierarchy_new_column_name = "narrative_hierarchy"

In [22]:
def remap_subnarratives(row, narrative_to_sub_map):
    """Takes in a row and encodes the current subnarrative list to the associated hierarchy based on the narr-subnar map"""
    for narr_idx, sub_indices in narrative_to_sub_map.items():
        sub_labels = [row['subnarratives_encoded'][sub_idx] for sub_idx in sub_indices]
        col_name = f"{hierarchy_new_column_name}_{narr_idx}"
        row[col_name] = sub_labels
    return row

dataset_train_cpy = dataset_train.apply(remap_subnarratives, axis=1, args=(narrative_to_sub_map,)).copy()

We do the same for validation dataset:

In [23]:
dataset_val_cpy = dataset_val.apply(remap_subnarratives, axis=1, args=(narrative_to_sub_map,)).copy()

In [24]:
dataset_val_cpy.head()

Unnamed: 0,language,article_id,content,narratives,subnarratives,narratives_encoded,subnarratives_encoded,narrative_hierarchy_13,narrative_hierarchy_14,narrative_hierarchy_19,...,narrative_hierarchy_10,narrative_hierarchy_5,narrative_hierarchy_3,narrative_hierarchy_6,narrative_hierarchy_4,narrative_hierarchy_9,narrative_hierarchy_8,narrative_hierarchy_1,narrative_hierarchy_2,narrative_hierarchy_7
0,EN,EN_CC_200040.txt,<PARA>climate protesters out of control as the...,"[CC: Criticism of climate movement, CC: Critic...","[Other, Climate movement is alarmist, Criticis...","[0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 1, 0, 0, 0]","[0, 0, 0, 0, 1, 0, 0]","[0, 0, 0, 0, 1, 0]",...,[1],"[0, 1, 0, 1, 0]","[1, 0, 0, 1]","[0, 0, 0, 0, 1, 1, 0, 0, 0]","[1, 0, 0, 0]","[0, 1, 0, 0, 0]","[1, 0, 0]","[1, 0, 0]","[1, 0, 0, 0]","[1, 0, 0]"
1,EN,EN_CC_200054.txt,<PARA>EU Central Banker Pushes Bitcoin Ban Und...,[CC: Hidden plots by secret schemes of powerfu...,[Climate agenda has hidden motives],"[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0]","[0, 0, 0, 0, 0, 0, 0]","[0, 0, 0, 0, 0, 0]",...,[0],"[0, 0, 0, 0, 0]","[0, 0, 0, 0]","[0, 0, 0, 0, 0, 0, 0, 0, 0]","[0, 0, 0, 0]","[0, 0, 0, 0, 0]","[0, 1, 0]","[0, 0, 0]","[0, 0, 0, 0]","[0, 0, 0]"
2,EN,EN_UA_DEV_213.txt,<PARA>US and NATO escalation of conflict with ...,[URW: Blaming the war on others rather than th...,"[The West are the aggressors, Western media is...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 0, 0, 0, 0, 1, 0, 0, 0]","[0, 0, 0, 1, 1, 0, 0]","[0, 0, 0, 0, 1, 0]",...,[1],"[0, 0, 0, 1, 0]","[1, 0, 0, 0]","[0, 0, 0, 0, 0, 1, 0, 0, 0]","[1, 0, 0, 0]","[0, 1, 0, 0, 0]","[1, 0, 0]","[1, 0, 0]","[1, 0, 0, 0]","[1, 0, 0]"
3,EN,EN_UA_DEV_100036.txt,<PARA>opinion: the unseen scars of Ukraine's m...,[Other],[Other],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 1, 0, 0, 0]","[0, 0, 0, 0, 1, 0, 0]","[0, 0, 0, 0, 1, 0]",...,[1],"[0, 0, 0, 1, 0]","[1, 0, 0, 0]","[0, 0, 0, 0, 0, 1, 0, 0, 0]","[1, 0, 0, 0]","[0, 1, 0, 0, 0]","[1, 0, 0]","[1, 0, 0]","[1, 0, 0, 0]","[1, 0, 0]"
4,EN,EN_UA_DEV_100004.txt,<PARA>RUSSIANS to cut coke plant from avdeevka...,[Other],[Other],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 1, 0, 0, 0]","[0, 0, 0, 0, 1, 0, 0]","[0, 0, 0, 0, 1, 0]",...,[1],"[0, 0, 0, 1, 0]","[1, 0, 0, 0]","[0, 0, 0, 0, 0, 1, 0, 0, 0]","[1, 0, 0, 0]","[0, 1, 0, 0, 0]","[1, 0, 0]","[1, 0, 0]","[1, 0, 0, 0]","[1, 0, 0]"


A sample result looks like this:

In [25]:
for narr_idx, sub_indices in narrative_to_sub_map.items():
    dataset_hierarchy_col_name = f"{hierarchy_new_column_name}_{narr_idx}"
    res = dataset_train_cpy[dataset_hierarchy_col_name]
    print(f"Sample of {dataset_hierarchy_col_name}:")
    print(res.head()) 
    print("\n")

Sample of narrative_hierarchy_13:
0    [0, 0, 0, 0, 0, 0, 0, 0, 0]
1    [0, 0, 0, 0, 0, 1, 0, 0, 0]
2    [0, 0, 0, 0, 1, 0, 0, 0, 0]
3    [0, 0, 1, 0, 0, 0, 0, 0, 0]
4    [0, 0, 0, 0, 0, 1, 0, 0, 0]
Name: narrative_hierarchy_13, dtype: object


Sample of narrative_hierarchy_14:
0    [0, 0, 0, 0, 0, 0, 0]
1    [0, 0, 0, 0, 1, 0, 0]
2    [0, 0, 0, 0, 0, 0, 0]
3    [0, 0, 0, 0, 0, 0, 0]
4    [0, 0, 0, 0, 1, 0, 0]
Name: narrative_hierarchy_14, dtype: object


Sample of narrative_hierarchy_19:
0    [0, 0, 0, 0, 0, 0]
1    [0, 0, 0, 0, 1, 0]
2    [0, 0, 0, 0, 0, 0]
3    [0, 0, 0, 0, 0, 0]
4    [0, 0, 0, 0, 1, 0]
Name: narrative_hierarchy_19, dtype: object


Sample of narrative_hierarchy_20:
0    [0, 0, 0, 0]
1    [1, 0, 0, 0]
2    [0, 0, 0, 0]
3    [0, 0, 0, 0]
4    [1, 0, 0, 0]
Name: narrative_hierarchy_20, dtype: object


Sample of narrative_hierarchy_15:
0    [0, 0, 1]
1    [1, 0, 0]
2    [0, 0, 0]
3    [0, 0, 0]
4    [1, 0, 0]
Name: narrative_hierarchy_15, dtype: object


Sample of narra

In [26]:
narrative_order = sorted(narrative_to_sub_map.keys())
narrative_order

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]

Now we want to make sure that the true subnarratives for hierarchy 0 are in position 0 of the aggregated list, hierarchy 1 in position 1, and so on.  
This ensures the subnarratives are ordered correctly in the final, aggregated list.

In [27]:
def aggregate_subnarratives(row, narrative_order, narrative_to_sub_map):
    """Takes in a row, and aggregates all hierarchy columns to 1 list.
    The encoded list will be a list of lists, starting from the first hierarchy"""
    aggregated = []
    for narr_idx in narrative_order:
        column_name = f"narrative_hierarchy_{narr_idx}"
        sub_labels = row[column_name]
        aggregated.append(sub_labels)
    return aggregated

dataset_train['aggregated_subnarratives'] = dataset_train_cpy.apply(
    aggregate_subnarratives,
    axis=1,
    args=(narrative_order, narrative_to_sub_map)
)

dataset_val['aggregated_subnarratives'] = dataset_val_cpy.apply(
    aggregate_subnarratives,
    axis=1,
    args=(narrative_order, narrative_to_sub_map)
)

In [28]:
dataset_train['aggregated_subnarratives']

0       [[0, 0, 0, 0, 0], [0, 0, 0], [0, 0, 0, 0], [0,...
1       [[0, 0, 0, 1, 0], [1, 0, 0], [1, 0, 0, 0], [1,...
2       [[0, 0, 0, 0, 0], [0, 0, 0], [0, 0, 0, 0], [0,...
3       [[0, 0, 0, 0, 0], [0, 0, 0], [0, 0, 0, 0], [0,...
4       [[0, 0, 0, 1, 0], [1, 0, 0], [1, 0, 0, 0], [1,...
                              ...                        
1694    [[0, 0, 0, 0, 0], [0, 0, 0], [0, 1, 1, 1], [0,...
1695    [[0, 1, 0, 0, 0], [0, 0, 0], [0, 0, 0, 0], [0,...
1696    [[0, 0, 0, 1, 0], [1, 0, 0], [1, 0, 0, 0], [1,...
1697    [[0, 0, 0, 0, 0], [0, 0, 0], [0, 0, 0, 0], [0,...
1698    [[0, 0, 0, 0, 0], [0, 0, 0], [0, 0, 0, 0], [0,...
Name: aggregated_subnarratives, Length: 1699, dtype: object

In [29]:
y_train_sub_heads = dataset_train['aggregated_subnarratives'].to_numpy()
y_val_sub_heads = dataset_val['aggregated_subnarratives'].to_numpy()

In [30]:
import torch

train_embeddings_tensor = torch.tensor(train_embeddings, dtype=torch.float32)
val_embeddings_tensor = torch.tensor(val_embeddings, dtype=torch.float32)

In [31]:
input_size = train_embeddings_tensor.shape[1]
print(input_size)

896


Now we have a model with a shared layer that captures the general features of the article.  
* The BatchNorm + ReLU combo seems to significantly improve performance by stabilizing training and speeding up convergence.
* Also, it seems like the model overfits very quickly when becoming overly complex.
  
We make predictions for the top-level narratives, followed by separate subnarrative predictions for each narrative.  

In [32]:
dataset_train['language'].unique()

array(['EN', 'HI', 'BG', 'PT', 'RU'], dtype=object)

In [33]:
dataset_train.head()

Unnamed: 0,language,article_id,content,narratives,subnarratives,narratives_encoded,subnarratives_encoded,aggregated_subnarratives
0,EN,EN_UA_300049.txt,<PARA>ukraine war: massacre 28 killed at marke...,[URW: Blaming the war on others rather than th...,"[Ukraine is the aggressor, Western media is an...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[[0, 0, 0, 0, 0], [0, 0, 0], [0, 0, 0, 0], [0,..."
1,HI,HI_322.txt,<PARA>डेमोक्रेट पार्टी का रुख़</PARA>\n\n<PARA...,"[URW: Overpraising the West, URW: Overpraising...",[The West belongs in the right side of history...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[[0, 0, 0, 1, 0], [1, 0, 0], [1, 0, 0, 0], [1,..."
2,BG,A9_BG_5143.txt,<PARA>la vanguardia: украйна се отказа от наде...,[URW: Discrediting Ukraine],[Discrediting Ukrainian military],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[[0, 0, 0, 0, 0], [0, 0, 0], [0, 0, 0, 0], [0,..."
3,BG,A9_BG_4971.txt,<PARA>байдън със страховита прогноза: тези дър...,"[URW: Discrediting Ukraine, URW: Amplifying wa...",[Discrediting Ukrainian government and officia...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[[0, 0, 0, 0, 0], [0, 0, 0], [0, 0, 0, 0], [0,..."
4,PT,PT_277.txt,<PARA>as infeções mortais podem voltar se o pe...,[CC: Amplifying Climate Fears],[Other],"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[[0, 0, 0, 1, 0], [1, 0, 0], [1, 0, 0, 0], [1,..."


In [34]:
import torch.nn as nn
import torch.nn.functional as F

class MultiTaskClassifierMultiHead(nn.Module):
    def __init__(
        self,
        input_size,
        hidden_size=1024,
        num_narratives=len(mlb_narratives.classes_),
        narrative_to_sub_map=narrative_to_sub_map,
        dropout_rate=0.4
    ):
        super().__init__()
        self.shared_layer = nn.Sequential(
            nn.Linear(input_size, hidden_size * 2),
            nn.BatchNorm1d(hidden_size * 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )

        self.narrative_head = nn.Sequential(
            nn.Linear(hidden_size * 2, num_narratives),
            nn.Sigmoid()
        )

        self.subnarrative_heads = nn.ModuleDict()
        for narr_idx, sub_indices in narrative_to_sub_map.items():
            num_subs_for_this_narr = len(sub_indices)
            self.subnarrative_heads[str(narr_idx)] = nn.Sequential(
                nn.Linear(hidden_size * 2, num_subs_for_this_narr),
                nn.Sigmoid()
            )

    def forward(self, x):
        shared_out = self.shared_layer(x)
        narr_probs = self.narrative_head(shared_out)

        sub_probs_dict = {}
        for narr_idx, head in self.subnarrative_heads.items():
            sub_probs_dict[narr_idx] = head(shared_out)

        return narr_probs, sub_probs_dict

In [35]:
network_params = {
    'lr': 0.001,
    'hidden_size': 1024,
    'dropout': 0.4
}

In [36]:
model_multi_head = MultiTaskClassifierMultiHead(
    input_size=input_size,
    hidden_size=network_params['hidden_size'],
)

In [37]:
print(model_multi_head)

MultiTaskClassifierMultiHead(
  (shared_layer): Sequential(
    (0): Linear(in_features=896, out_features=2048, bias=True)
    (1): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.4, inplace=False)
  )
  (narrative_head): Sequential(
    (0): Linear(in_features=2048, out_features=22, bias=True)
    (1): Sigmoid()
  )
  (subnarrative_heads): ModuleDict(
    (13): Sequential(
      (0): Linear(in_features=2048, out_features=9, bias=True)
      (1): Sigmoid()
    )
    (14): Sequential(
      (0): Linear(in_features=2048, out_features=7, bias=True)
      (1): Sigmoid()
    )
    (19): Sequential(
      (0): Linear(in_features=2048, out_features=6, bias=True)
      (1): Sigmoid()
    )
    (20): Sequential(
      (0): Linear(in_features=2048, out_features=4, bias=True)
      (1): Sigmoid()
    )
    (15): Sequential(
      (0): Linear(in_features=2048, out_features=3, bias=True)
      (1): Sigmoid()
    )
    (11): Sequ

In [38]:
y_train_nar = dataset_train['narratives_encoded'].tolist()
y_val_nar = dataset_val['narratives_encoded'].tolist()

y_train_sub_nar = dataset_train['subnarratives_encoded'].tolist()
y_val_sub_nar = dataset_val['subnarratives_encoded'].tolist()

In [39]:
dataset_val.head()

Unnamed: 0,language,article_id,content,narratives,subnarratives,narratives_encoded,subnarratives_encoded,aggregated_subnarratives
0,EN,EN_CC_200040.txt,<PARA>climate protesters out of control as the...,"[CC: Criticism of climate movement, CC: Critic...","[Other, Climate movement is alarmist, Criticis...","[0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ...","[[0, 0, 0, 1, 0], [1, 0, 0], [1, 0, 0, 0], [1,..."
1,EN,EN_CC_200054.txt,<PARA>EU Central Banker Pushes Bitcoin Ban Und...,[CC: Hidden plots by secret schemes of powerfu...,[Climate agenda has hidden motives],"[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, ...","[[0, 0, 0, 0, 0], [0, 0, 0], [0, 0, 0, 0], [0,..."
2,EN,EN_UA_DEV_213.txt,<PARA>US and NATO escalation of conflict with ...,[URW: Blaming the war on others rather than th...,"[The West are the aggressors, Western media is...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[[0, 0, 0, 1, 0], [1, 0, 0], [1, 0, 0, 0], [1,..."
3,EN,EN_UA_DEV_100036.txt,<PARA>opinion: the unseen scars of Ukraine's m...,[Other],[Other],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[[0, 0, 0, 1, 0], [1, 0, 0], [1, 0, 0, 0], [1,..."
4,EN,EN_UA_DEV_100004.txt,<PARA>RUSSIANS to cut coke plant from avdeevka...,[Other],[Other],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[[0, 0, 0, 1, 0], [1, 0, 0], [1, 0, 0, 0], [1,..."


We move everything to a tensor:

In [40]:
y_train_nar = torch.tensor(y_train_nar, dtype=torch.float32)
y_train_sub_nar = torch.tensor(y_train_sub_nar, dtype=torch.float32)

y_val_nar = torch.tensor(y_val_nar, dtype=torch.float32)
y_val_sub_nar = torch.tensor(y_val_sub_nar, dtype=torch.float32)

In [41]:
train_embeddings_tensor = torch.tensor(train_embeddings, dtype=torch.float32)
val_embeddings_tensor = torch.tensor(val_embeddings, dtype=torch.float32)

We calculate class weights to handle label imbalance in the training data. 
* This way, rare labels are given higher importance to ensure the model learns them effectively.
* The custom ```WeightedBCELoss``` applies these weights during training to balance the impact of common and rare labels, preventing the model from focusing only on frequent ones.

In [42]:

def compute_class_weights(y_train):
    total_samples = y_train.shape[0]
    class_weights = []
    for label in range(y_train.shape[1]):
        pos_count = y_train[:, label].sum().item()
        neg_count = total_samples - pos_count
        pos_weight = total_samples / (2 * pos_count) if pos_count > 0 else 0
        neg_weight = total_samples / (2 * neg_count) if neg_count > 0 else 0
        class_weights.append((pos_weight, neg_weight))
    return class_weights

class WeightedBCELoss(nn.Module):
    def __init__(self, class_weights):
        super().__init__()
        self.class_weights = class_weights

    def forward(self, probs, targets):
        bce_loss = 0
        epsilon = 1e-7
        for i, (pos_weight, neg_weight) in enumerate(self.class_weights):
            prob = probs[:, i]
            bce = -pos_weight * targets[:, i] * torch.log(prob + epsilon) - \
                  neg_weight * (1 - targets[:, i]) * torch.log(1 - prob + epsilon)
            bce_loss += bce.mean()
        return bce_loss / len(self.class_weights)

class_weights_sub_nar = compute_class_weights(y_val_sub_nar)
class_weights_nar = compute_class_weights(y_val_nar)
narrative_criterion = WeightedBCELoss(class_weights_nar)

We create a separate loss function for each hierarchy of subnarratives to handle their specific class imbalance.  

In [43]:
sub_criterion_dict = {}

for narr_idx, sub_indices in narrative_to_sub_map.items():
    local_weights = [ class_weights_sub_nar[sub_i] for sub_i in sub_indices ]

    sub_criterion = WeightedBCELoss(local_weights)
    sub_criterion_dict[str(narr_idx)] = sub_criterion

We define a loss class to handle the multi-task loss calculation.

* In the forward method, we first calculate the loss for the top-level narratives using the narrative criterion.
We then loop through each subnarrative head and compute the loss for each one similarly.


* We introduce a conditioning term that penalizes inconsistencies between narrative and subnarrative predictions. The conditioning term is there to match our hierarchical problem.

In [44]:
class MultiHeadLoss(nn.Module):
    def __init__(self, narrative_criterion, sub_criterion_dict, 
                 condition_weight=0.8, sub_weight=0.5):
        super().__init__()
        self.narrative_criterion = narrative_criterion
        self.sub_criterion_dict = sub_criterion_dict
        self.condition_weight = condition_weight
        self.sub_weight = sub_weight
        
    def forward(self, narr_probs, sub_probs_dict, y_narr, y_sub_heads):
        narr_loss = self.narrative_criterion(narr_probs, y_narr)
        sub_loss = 0.0
        condition_loss = 0.0
        
        for narr_idx_str, sub_probs in sub_probs_dict.items():
            narr_idx = int(narr_idx_str)
            y_sub = [row[narr_idx] for row in y_sub_heads]
            y_sub_tensor = torch.tensor(y_sub, dtype=torch.float32, device=sub_probs.device)
            
            sub_loss_func = self.sub_criterion_dict[narr_idx_str]
            sub_loss += sub_loss_func(sub_probs, y_sub_tensor)

            narr_pred = narr_probs[:, narr_idx].unsqueeze(1)
            condition_term = torch.mean(
                # Penalize high probs of sub, based on first level narr predictinos
                torch.abs(sub_probs * (1 - narr_pred)) + 
                # If a narrative is true, then the subnarrative predictions should match their actual true values.
                narr_pred * torch.abs(sub_probs - y_sub_tensor.unsqueeze(1))
            )
            condition_loss += condition_term
            
        sub_loss = sub_loss / len(sub_probs_dict)
        condition_loss = condition_loss / len(sub_probs_dict)
        
        total_loss = (1 - self.sub_weight) * narr_loss + \
                    self.sub_weight * sub_loss + \
                    self.condition_weight * condition_loss
        
        return total_loss

In [45]:
multi_head_loss_fn = MultiHeadLoss(narrative_criterion, sub_criterion_dict)

In [46]:
def check_weight_norms(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            norm = param.norm(2).item()
            print(f"Layer: {name} | Weight Norm: {norm:.4f}")

We define the function for training our model:

In [47]:
def train_with_multihead(
    model,
    optimizer,
    loss_fn=multi_head_loss_fn,
    train_embeddings=train_embeddings_tensor,
    y_train_nar=y_train_nar,
    y_train_sub_heads=y_train_sub_heads,
    val_embeddings=val_embeddings_tensor,
    y_val_nar=y_val_nar,
    y_val_sub_heads=y_val_sub_heads,
    patience=5,
    num_epochs=100,
    scheduler=None,
    min_delta=0.001
):
    best_val_loss = float('inf')
    best_model = None
    patience_counter = 0
    for epoch in range(num_epochs):
        model.train()
        train_narr_probs, train_sub_probs_dict = model(train_embeddings)
        train_loss = loss_fn(train_narr_probs, train_sub_probs_dict, y_train_nar, y_train_sub_heads)
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        # Validation phase
        model.eval()
        with torch.no_grad():
            val_narr_probs, val_sub_probs_dict = model(val_embeddings)
            val_loss = loss_fn(val_narr_probs, val_sub_probs_dict, y_val_nar, y_val_sub_heads)

        print(f"Epoch {epoch+1}/{num_epochs}, "
              f"Training Loss: {train_loss.item():.4f}, "
              f"Validation Loss: {val_loss.item():.4f}")

        if scheduler:
            scheduler.step(val_loss)
            current_lr = scheduler.optimizer.param_groups[0]['lr']
            print(f"Current Learning Rate: {current_lr:.6f}")

        if val_loss.item() < best_val_loss - min_delta:
            best_val_loss = val_loss.item()
            patience_counter = 0
            best_model = model.state_dict().copy()
        else:
            patience_counter += 1
            print(f"Validation loss did not significantly improve for {patience_counter} epoch(s).")

        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

    if best_model:
        model.load_state_dict(best_model)
    return model

In [48]:
optimizer_multi_head = torch.optim.AdamW(model_multi_head.parameters(), lr=network_params['lr'])

We will also initialize a scheduler to adjust the learning rate dynamically during training based on how the model is performing.

In [49]:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_multi_head, mode='min', factor=0.5, patience=3)

In [50]:
trained_multi_head_simple = train_with_multihead(
    model=model_multi_head,
    optimizer=optimizer_multi_head,
    patience=10
)

Epoch 1/100, Training Loss: 1.0386, Validation Loss: 1.0168
Epoch 2/100, Training Loss: 0.7622, Validation Loss: 1.0059
Epoch 3/100, Training Loss: 0.6248, Validation Loss: 0.9972
Epoch 4/100, Training Loss: 0.5507, Validation Loss: 0.9896
Epoch 5/100, Training Loss: 0.5098, Validation Loss: 0.9820
Epoch 6/100, Training Loss: 0.4804, Validation Loss: 0.9735
Epoch 7/100, Training Loss: 0.4560, Validation Loss: 0.9638
Epoch 8/100, Training Loss: 0.4377, Validation Loss: 0.9530
Epoch 9/100, Training Loss: 0.4261, Validation Loss: 0.9414
Epoch 10/100, Training Loss: 0.4132, Validation Loss: 0.9292
Epoch 11/100, Training Loss: 0.4057, Validation Loss: 0.9167
Epoch 12/100, Training Loss: 0.3958, Validation Loss: 0.9041
Epoch 13/100, Training Loss: 0.3870, Validation Loss: 0.8916
Epoch 14/100, Training Loss: 0.3786, Validation Loss: 0.8793
Epoch 15/100, Training Loss: 0.3723, Validation Loss: 0.8671
Epoch 16/100, Training Loss: 0.3639, Validation Loss: 0.8549
Epoch 17/100, Training Loss: 0.35

In [51]:
coarse_classes = sorted(narrative_to_subnarratives.keys())

We find the coarse, or narrative classes, to be used in the evaluation function.

In [52]:
coarse_classes

['CC: Amplifying Climate Fears',
 'CC: Climate change is beneficial',
 'CC: Controversy about green technologies',
 'CC: Criticism of climate movement',
 'CC: Criticism of climate policies',
 'CC: Criticism of institutions and authorities',
 'CC: Downplaying climate change',
 'CC: Green policies are geopolitical instruments',
 'CC: Hidden plots by secret schemes of powerful groups',
 'CC: Questioning the measurements and science',
 'Other',
 'URW: Amplifying war-related fears',
 'URW: Blaming the war on others rather than the invader',
 'URW: Discrediting Ukraine',
 'URW: Discrediting the West, Diplomacy',
 'URW: Distrust towards Media',
 'URW: Hidden plots by secret schemes of powerful groups',
 'URW: Negative Consequences for the West',
 'URW: Overpraising the West',
 'URW: Praise of Russia',
 'URW: Russia is the Victim',
 'URW: Speculating war outcomes']

We do the same for the fine, this includes every single pair of a `narrative: subnarrative` from all hierarchies.
- In case `Other` is labeled for an article as a narrative, we leave the fine as just `Other` (that is what we should do)
- Also, for instances where the subnarrative is truly `Other` we match it with it's narrative parent.

In [53]:
fine_label_set = set()

for narrative, subnarratives in narrative_to_subnarratives.items():
    if narrative == "Other":
        fine_label_set.add("Other")
    else:
        for sub in subnarratives:
            if sub == "Other":
                fine_label_set.add(f"{narrative}: Other")
            else:
                fine_label_set.add(f"{narrative}: {sub}")

fine_classes = sorted(fine_label_set)

In [54]:
fine_classes[:15]

['CC: Amplifying Climate Fears: Amplifying existing fears of global warming',
 'CC: Amplifying Climate Fears: Doomsday scenarios for humans',
 'CC: Amplifying Climate Fears: Earth will be uninhabitable soon',
 'CC: Amplifying Climate Fears: Other',
 'CC: Amplifying Climate Fears: Whatever we do it is already too late',
 'CC: Climate change is beneficial: CO2 is beneficial',
 'CC: Climate change is beneficial: Other',
 'CC: Climate change is beneficial: Temperature increase is beneficial',
 'CC: Controversy about green technologies: Other',
 'CC: Controversy about green technologies: Renewable energy is costly',
 'CC: Controversy about green technologies: Renewable energy is dangerous',
 'CC: Controversy about green technologies: Renewable energy is unreliable',
 'CC: Criticism of climate movement: Ad hominem attacks on key activists',
 'CC: Criticism of climate movement: Climate movement is alarmist',
 'CC: Criticism of climate movement: Climate movement is corrupt']

For the evaluator, we start by iterating over a range of thresholds, for both narratives and subnarratives.

* First, we get the predictions of our model for the narratives and subnarratives.
  - For each sample, we make a prediction, with the current thresholds.
  - The predictions are evaluated, with the exact scorer that is used by the challenge.
  - The metrics that we are aiming for, based on the evaluation rules of the challenge that claims:
 ```
    The official evaluation measure will be averaged (over test documents) samples F1 computed for entire narrative_x:subnarrative_x labels. That is, we will first compute an F1 score per test document by comparing the predicted to the gold narrative_x:subnarrative_x labels of the document, and we will then average over the test documents. Both the narrative_x and the subnarrative_x part of each predicted narrative_x:subnarrative_x label will have to be correct for the predicted label to be considered correct.
```

In [55]:
import os
from typing import Dict, List
from sklearn import metrics

class MultiHeadEvaluator:
    def __init__(
        self,
        classes_coarse=coarse_classes,
        classes_fine=fine_classes,
        narrative_to_sub_map=narrative_to_sub_map,
        narrative_order=narrative_order,
        narrative_classes=mlb_narratives.classes_,
        subnarrative_classes=mlb_subnarratives.classes_,
        device='cpu',
        output_dir='../../submissions',
    ):
        self.narrative_to_sub_map = narrative_to_sub_map
        self.narrative_order = narrative_order
        self.narrative_classes = list(narrative_classes)
        self.subnarrative_classes = list(subnarrative_classes)
        
        self.classes_coarse = classes_coarse
        self.classes_fine = classes_fine

        self.device = device
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
    
    def evaluate(
        self,
        model,
        embeddings=val_embeddings_tensor,
        dataset=dataset_val,
        thresholds=None,
        save=False,
        std_weight=0.6,
        lower_thres=0.1,
        upper_thres=0.55
    ):
        if thresholds is None:
            thresholds = np.arange(lower_thres, upper_thres, 0.05)    
        embeddings = embeddings.to(self.device)
    
        best_results = {
            'best_coarse_f1': -1,
            'best_coarse_std': float('inf'),
            'best_fine_f1': -1,
            'best_fine_std': float('inf'),
            'narr_threshold': 0,
            'sub_threshold': 0,
            'predictions': None,
            'best_combined_score': -float('inf')
        }
    
        with torch.no_grad():
            narr_probs, sub_probs_dict = model(embeddings)
            narr_probs = narr_probs.cpu().numpy()
            sub_probs_dict = {k: v.cpu().numpy() for k, v in sub_probs_dict.items()}
    
        for narr_threshold in thresholds:
            for sub_threshold in thresholds:
                predictions = []
                for sample_idx, row in dataset.iterrows():
                    pred = self._make_prediction(
                        row['article_id'],
                        sample_idx,
                        narr_probs,
                        sub_probs_dict,
                        narr_threshold,
                        sub_threshold
                    )
                    predictions.append(pred)
                
                f1_coarse_mean, coarse_std, f1_fine_mean, fine_std = self._compute_metrics_coarse_fine(
                    predictions,
                    dataset
                )
                
                combined_score = f1_fine_mean - (std_weight * coarse_std)
                
                if combined_score > best_results['best_combined_score']:
                    best_results.update({
                        'best_coarse_f1': f1_coarse_mean,
                        'best_coarse_std': coarse_std,
                        'best_fine_f1': f1_fine_mean,
                        'best_fine_std': fine_std,
                        'narr_threshold': narr_threshold,
                        'sub_threshold': sub_threshold,
                        'predictions': predictions,
                        'best_combined_score': combined_score
                    })
    
        print("\nBest thresholds found:")
        print(f"Narrative threshold: {best_results['narr_threshold']:.2f}")
        print(f"Subnarrative threshold: {best_results['sub_threshold']:.2f}")
        print('\n')
        print(f"Coarse-F1: {best_results['best_coarse_f1']:.3f}")
        print(f"F1 st. dev. coarse: {best_results['best_coarse_std']:.3f}")
        print(f"Fine-F1: {best_results['best_fine_f1']:.3f} ")
        print(f"F1 st. dev. fine: {best_results['best_fine_std']:.3f}")

        if save:
            self._save_predictions(best_results, os.path.join(self.output_dir, 'submission.txt'))
        
        return best_results

    def _make_prediction(self, article_id, sample_idx, narr_probs, sub_probs_dict, narr_threshold, sub_threshold):
        other_idx = self.narrative_classes.index("Other")
        # Find all narratives >= narr_threshold 
        # (except 'Other', this is going to be used as a fallback)
        active_narratives = [
            (n_idx, prob)
            for n_idx, prob in enumerate(narr_probs[sample_idx])
            if n_idx != other_idx and prob >= narr_threshold
        ]
        
        if not active_narratives:
            return {
                'article_id': article_id,
                'narratives': ["Other"],
                'pairs': ["Other"]
            }
        
        narratives = []
        pairs = []
        seen_pairs = set()
        
        active_narratives.sort(key=lambda x: x[1], reverse=True)
        for narr_idx, narr_prob in active_narratives:
            narr_name = self.narrative_classes[narr_idx]
            
            sub_probs = sub_probs_dict[str(narr_idx)][sample_idx]
            active_subnarratives = [
                (local_idx, s_prob)
                for local_idx, s_prob in enumerate(sub_probs)
                if s_prob >= sub_threshold
            ]
            active_subnarratives.sort(key=lambda x: x[1], reverse=True)
            # Fallback, when no subnarrative active, we put other.
            if not active_subnarratives:
                pairs.append(f"{narr_name}: Other")
            else:
                for local_idx, _ in active_subnarratives:
                    global_sub_idx = self.narrative_to_sub_map[narr_idx][local_idx]
                    sub_name = self.subnarrative_classes[global_sub_idx]
                    pair = f"{narr_name}: {sub_name}"
                    if pair not in seen_pairs:
                        pairs.append(pair)
                        seen_pairs.add(pair)

            narratives.append(narr_name)
        
        return {
            'article_id': article_id,
            'narratives': narratives,
            'pairs': pairs
        }

    def _compute_metrics_coarse_fine(self, predictions, dataset):
        """
        Mimics the official scorere used by the challenge.
        """
        gold_coarse_all = []
        gold_fine_all = []
        pred_coarse_all = []
        pred_fine_all = []

        for pred, (_, row) in zip(predictions, dataset.iterrows()):
            gold_coarse = row['narratives']
            gold_subnarratives = row['subnarratives']
            
            pred_coarse = pred['narratives']
            pred_fine = []
            for p in pred['pairs']:
                # If previously we predicted a "Other",
                # we output "Other Other" as the narrative and subnarrative.
                if p == "Other":
                    pred_fine.append("Other")
                else:
                    # Takes the whole nar : sub pair.
                    pred_fine.append(p)

            gold_fine = []
            for gold_nar, gold_sub in zip(gold_coarse, gold_subnarratives):
                # We do the same for truths.
                if gold_nar == "Other":
                    gold_fine.append("Other")
                else:
                    gold_fine.append(f"{gold_nar}: {gold_sub}")
            
            gold_coarse_all.append(gold_coarse)
            gold_fine_all.append(gold_fine)
            pred_coarse_all.append(pred_coarse)
            pred_fine_all.append(pred_fine)

        f1_coarse_mean, coarse_std = self._evaluate_multi_label(gold_coarse_all, pred_coarse_all, self.classes_coarse)
        f1_fine_mean, fine_std = self._evaluate_multi_label(gold_fine_all, pred_fine_all, self.classes_fine)

        return f1_coarse_mean, coarse_std, f1_fine_mean, fine_std

    def _evaluate_multi_label(self, gold, predicted, class_list):
        """
        Mimics the official f1-score calculation used by the challenge.
        """
        f1_scores = []
        for g_labels, p_labels in zip(gold, predicted):
            g_onehot = np.zeros(len(class_list), dtype=int)
            for lab in g_labels:
                if lab in class_list:
                    g_onehot[class_list.index(lab)] = 1
                    
            p_onehot = np.zeros(len(class_list), dtype=int)
            for lab in p_labels:
                if lab in class_list:
                    p_onehot[class_list.index(lab)] = 1

            f1_doc = metrics.f1_score(g_onehot, p_onehot, zero_division=0)
            f1_scores.append(f1_doc)
        
        return float(np.mean(f1_scores)), float(np.std(f1_scores))

    def _save_predictions(self, best_results, filepath):
        predictions = best_results['predictions']
        if os.path.exists(filepath):
            os.remove(filepath)
        
        with open(filepath, 'w', encoding='utf-8') as f:
            for pred in predictions:
                line = (f"{pred['article_id']}\t"
                        f"{';'.join(pred['narratives'])}\t"
                        f"{';'.join(pred['pairs'])}\n")
                f.write(line)

In [56]:
evaluator = MultiHeadEvaluator()

Our model does a decent job correctly predicting the fine-grained roles, about 46% of the time.
* The somewhat high standard deviation suggests some inconsistent performance across articles.
    - We aim for a balanced prediction, between the F1 and the std score.

Our model also does a decent job when it comes to predicting exact pairs of `narrative: subnarratives`.

In [57]:
results = evaluator.evaluate(
    model=trained_multi_head_simple,
)


Best thresholds found:
Narrative threshold: 0.40
Subnarrative threshold: 0.25


Coarse-F1: 0.433
F1 st. dev. coarse: 0.374
Fine-F1: 0.308 
F1 st. dev. fine: 0.313


### Providing the already predicted narrative

The results we got from the base, multi-head model, are encouraging. But, despite of the loss, our model by itself isn't providing any extra info when it comes to the phase of predicting subnarratives.

Let h(x) be the shared layer output for the embedding x:

        shared_out = self.shared_layer(x)

We compute the probability P(narr_i | x) for each narrative:

        narr_probs = self.narrative_head(shared_out)

Previously, we used the following formula for the subnarrative probability P(subnarr_j | x):

        P(subnarr_j | x) = σ(h(x))

It makes sense to try:

        P(subnarr_j | x) = σ(concat(h(x), P(narr_i | x)))


Where narr_i is the narrative associated with subnarrative subnarr_j in the hierarchy.

Essentially:

* If the probability of the narrative is high, the subnarrative head will be more likely to predict the relevant subnarratives.
* If the probability is low, the model will ignore the corresponding subnarratives.
* At the same time, the shared output of the shared layer will help determine which subnarrative is most appropriate for the given document (and we can potentially use other techniques like attention to further improve the model).

In [58]:

class MultiTaskClassifierMultiHeadConcat(nn.Module):
    def __init__(
        self,
        input_size,
        hidden_size,
        num_narratives=len(mlb_narratives.classes_),
        narrative_to_sub_map=narrative_to_sub_map,
        dropout_rate=network_params['dropout']
    ):
        super().__init__()
        
        self.shared_layer = nn.Sequential(
            nn.Linear(input_size, hidden_size * 2),
            nn.BatchNorm1d(hidden_size * 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )

        self.narrative_head = nn.Sequential(
            nn.Linear(hidden_size * 2, num_narratives),
            nn.Sigmoid()
        )

        self.subnarrative_heads = nn.ModuleDict()
        for narr_idx, sub_indices in narrative_to_sub_map.items():
            num_subs_for_this_narr = len(sub_indices)
            self.subnarrative_heads[str(narr_idx)] = nn.Sequential(
                nn.Linear(hidden_size * 2 + 1, num_subs_for_this_narr),
                nn.Sigmoid()
            )

    def forward(self, x):
        shared_out = self.shared_layer(x)

        narr_probs = self.narrative_head(shared_out)

        sub_probs_dict = {}
        for narr_idx, head in self.subnarrative_heads.items():
            conditioned_input = torch.cat((shared_out, narr_probs[:, int(narr_idx)].unsqueeze(1)), dim=1)
            sub_probs_dict[narr_idx] = head(conditioned_input)

        return narr_probs, sub_probs_dict

We make a function to initialize and train the model.

In [59]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau

def initialize_and_train_model(
    model,
    num_epochs=100,
    lr=0.001,
    patience=10,
    use_scheduler=True,
    scheduler_patience=3,
    loss_fn=multi_head_loss_fn,
    num_subnarratives=len(mlb_subnarratives.classes_),
    device='cpu'
):
    optimizer = AdamW(model.parameters(), lr=lr)

    scheduler = None
    if use_scheduler:
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=scheduler_patience)

    trained_model = train_with_multihead(
                                    model=model,
                                    optimizer=optimizer,
                                    scheduler=scheduler,
                                    loss_fn=loss_fn,
                                    patience=patience
                                )
    return trained_model

In [60]:
model_multi_head_concat= MultiTaskClassifierMultiHeadConcat(
    input_size=input_size,
    hidden_size=2048,
)

In [61]:
print(model_multi_head_concat)

MultiTaskClassifierMultiHeadConcat(
  (shared_layer): Sequential(
    (0): Linear(in_features=896, out_features=4096, bias=True)
    (1): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.4, inplace=False)
  )
  (narrative_head): Sequential(
    (0): Linear(in_features=4096, out_features=22, bias=True)
    (1): Sigmoid()
  )
  (subnarrative_heads): ModuleDict(
    (13): Sequential(
      (0): Linear(in_features=4097, out_features=9, bias=True)
      (1): Sigmoid()
    )
    (14): Sequential(
      (0): Linear(in_features=4097, out_features=7, bias=True)
      (1): Sigmoid()
    )
    (19): Sequential(
      (0): Linear(in_features=4097, out_features=6, bias=True)
      (1): Sigmoid()
    )
    (20): Sequential(
      (0): Linear(in_features=4097, out_features=4, bias=True)
      (1): Sigmoid()
    )
    (15): Sequential(
      (0): Linear(in_features=4097, out_features=3, bias=True)
      (1): Sigmoid()
    )
    (11)

We train:

In [62]:
trained_model_concat = initialize_and_train_model(
    model_multi_head_concat,
    patience=5,
)

Epoch 1/100, Training Loss: 1.0259, Validation Loss: 1.0033
Current Learning Rate: 0.001000
Epoch 2/100, Training Loss: 0.6668, Validation Loss: 0.9884
Current Learning Rate: 0.001000
Epoch 3/100, Training Loss: 0.5528, Validation Loss: 0.9783
Current Learning Rate: 0.001000
Epoch 4/100, Training Loss: 0.4988, Validation Loss: 0.9696
Current Learning Rate: 0.001000
Epoch 5/100, Training Loss: 0.4657, Validation Loss: 0.9603
Current Learning Rate: 0.001000
Epoch 6/100, Training Loss: 0.4424, Validation Loss: 0.9496
Current Learning Rate: 0.001000
Epoch 7/100, Training Loss: 0.4241, Validation Loss: 0.9374
Current Learning Rate: 0.001000
Epoch 8/100, Training Loss: 0.4076, Validation Loss: 0.9237
Current Learning Rate: 0.001000
Epoch 9/100, Training Loss: 0.3957, Validation Loss: 0.9091
Current Learning Rate: 0.001000
Epoch 10/100, Training Loss: 0.3847, Validation Loss: 0.8944
Current Learning Rate: 0.001000
Epoch 11/100, Training Loss: 0.3740, Validation Loss: 0.8799
Current Learning R

In [63]:
_ = evaluator.evaluate(
    model=trained_model_concat,
    save=True
)


Best thresholds found:
Narrative threshold: 0.50
Subnarrative threshold: 0.25


Coarse-F1: 0.493
F1 st. dev. coarse: 0.379
Fine-F1: 0.347 
F1 st. dev. fine: 0.325


### Using multiplication instead of concantenation

Instead of using concat, we can try an element-wise multiplication. This sounds more logical as multiplication can act as a "gate":

* If the narrative probability is close to 0, the corresponding subnarrative head’s input will be scaled down, effectively disabling that subnarrative head.
* If the narrative probability is close to 1, the shared layer output passes through somewhat unaffected.

In [64]:

class MultiTaskClassifierMultiHeadMult(nn.Module):
    def __init__(
        self,
        input_size,
        hidden_size,
        num_narratives=len(mlb_narratives.classes_),
        narrative_to_sub_map=narrative_to_sub_map,
        dropout_rate=network_params['dropout'],
        bias=0.1
    ):
        super().__init__()
        
        self.shared_layer = nn.Sequential(
            nn.Linear(input_size, hidden_size * 2),
            nn.BatchNorm1d(hidden_size * 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )

        self.narrative_head = nn.Sequential(
            nn.Linear(hidden_size * 2, num_narratives),
            nn.Sigmoid()
        )

        self.subnarrative_heads = nn.ModuleDict()
        for narr_idx, sub_indices in narrative_to_sub_map.items():
            num_subs_for_this_narr = len(sub_indices)
            self.subnarrative_heads[str(narr_idx)] = nn.Sequential(
                nn.Linear(hidden_size * 2, num_subs_for_this_narr),
                nn.Sigmoid()
            )

        self.bias = bias

    def forward(self, x):
        shared_out = self.shared_layer(x)


        narr_probs = self.narrative_head(shared_out)

        sub_probs_dict = {}
        for narr_idx, head in self.subnarrative_heads.items():
            narr_pred = narr_probs[:, int(narr_idx)].unsqueeze(1)

            conditioned_input = shared_out * (narr_pred + self.bias)

            sub_probs_dict[narr_idx] = head(conditioned_input)

        return narr_probs, sub_probs_dict

In [65]:
model_multi_head_mult = MultiTaskClassifierMultiHeadMult(
    input_size=input_size,
    hidden_size=network_params['hidden_size'],
)

In [66]:
trained_model_mult = initialize_and_train_model(model_multi_head_mult)

Epoch 1/100, Training Loss: 1.0236, Validation Loss: 1.0195
Current Learning Rate: 0.001000
Epoch 2/100, Training Loss: 0.8291, Validation Loss: 1.0108
Current Learning Rate: 0.001000
Epoch 3/100, Training Loss: 0.7524, Validation Loss: 1.0033
Current Learning Rate: 0.001000
Epoch 4/100, Training Loss: 0.7044, Validation Loss: 0.9970
Current Learning Rate: 0.001000
Epoch 5/100, Training Loss: 0.6702, Validation Loss: 0.9915
Current Learning Rate: 0.001000
Epoch 6/100, Training Loss: 0.6403, Validation Loss: 0.9865
Current Learning Rate: 0.001000
Epoch 7/100, Training Loss: 0.6134, Validation Loss: 0.9817
Current Learning Rate: 0.001000
Epoch 8/100, Training Loss: 0.5908, Validation Loss: 0.9768
Current Learning Rate: 0.001000
Epoch 9/100, Training Loss: 0.5704, Validation Loss: 0.9715
Current Learning Rate: 0.001000
Epoch 10/100, Training Loss: 0.5521, Validation Loss: 0.9659
Current Learning Rate: 0.001000
Epoch 11/100, Training Loss: 0.5376, Validation Loss: 0.9599
Current Learning R

The results are a bit suprising, but they make sense because concatenation gives the subnarrative heads more "flexibility" while multiplication is more restrictive acting as a hard gate.

* If our narrative predictions are not confident or even and most importantly, not correct, the subnarrative head will receive very weak input because of the multiplication.

In [67]:
_ = evaluator.evaluate(
    model=trained_model_mult,
)


Best thresholds found:
Narrative threshold: 0.50
Subnarrative threshold: 0.10


Coarse-F1: 0.480
F1 st. dev. coarse: 0.352
Fine-F1: 0.300 
F1 st. dev. fine: 0.315


Because subnarrative heads rely really on the narrative probabilities, we will try reducing the sub_weight proportionally increasing the the narrative weight and see what happens:

In [68]:
multi_head_loss_fn = MultiHeadLoss(
    narrative_criterion,
    sub_criterion_dict,
    condition_weight=0.8,
    sub_weight=0.3
)

In [69]:
model_multi_head_mult = MultiTaskClassifierMultiHeadMult(
    input_size=input_size,
    hidden_size=2048,
)

In [70]:
trained_model_mult = initialize_and_train_model(model_multi_head_mult,
                                                loss_fn=multi_head_loss_fn,
                                                patience=10
                                               )

Epoch 1/100, Training Loss: 1.0348, Validation Loss: 1.0224
Current Learning Rate: 0.001000
Epoch 2/100, Training Loss: 0.7825, Validation Loss: 1.0075
Current Learning Rate: 0.001000
Epoch 3/100, Training Loss: 0.7030, Validation Loss: 0.9970
Current Learning Rate: 0.001000
Epoch 4/100, Training Loss: 0.6580, Validation Loss: 0.9886
Current Learning Rate: 0.001000
Epoch 5/100, Training Loss: 0.6115, Validation Loss: 0.9817
Current Learning Rate: 0.001000
Epoch 6/100, Training Loss: 0.5754, Validation Loss: 0.9755
Current Learning Rate: 0.001000
Epoch 7/100, Training Loss: 0.5481, Validation Loss: 0.9695
Current Learning Rate: 0.001000
Epoch 8/100, Training Loss: 0.5288, Validation Loss: 0.9633
Current Learning Rate: 0.001000
Epoch 9/100, Training Loss: 0.5116, Validation Loss: 0.9568
Current Learning Rate: 0.001000
Epoch 10/100, Training Loss: 0.4929, Validation Loss: 0.9498
Current Learning Rate: 0.001000
Epoch 11/100, Training Loss: 0.4784, Validation Loss: 0.9421
Current Learning R

The results are better, but still not good enough.

In [71]:
_ = evaluator.evaluate(
    model=trained_model_mult,
)


Best thresholds found:
Narrative threshold: 0.45
Subnarrative threshold: 0.15


Coarse-F1: 0.467
F1 st. dev. coarse: 0.353
Fine-F1: 0.300 
F1 st. dev. fine: 0.312


## Ensemble predictions

Another approach we can try is to take different checkpoints of the training phase during the neural network training. Different checkpoints might be better at detecting different types of narratives.

* Early stages of our model may be better at capturing some narratives and subnarratives, while later stages might need more training.
. With this sapproach we can use multiple "good" snapshots of our model.
* An obvious improvement we can do the previous ensemble model, is to consider the loss of each checkpoint as a factor to what each one should say.
  - This means that model states that didn't do very good, won't get too much "say" in the final result, in comparison to better models.

In [72]:
checkpoint_dir='checkpoints'

A checkpoint is a snapshot of our model during training.
* We will save the epoch, model state and validation loss during that epoch.

In [73]:
from dataclasses import dataclass

@dataclass
class Checkpoint:
    epoch: int              
    model_state: Dict
    val_loss: float

def save_checkpoint(checkpoint):
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    file_name = f"checkpoint_epoch_{checkpoint.epoch}.pt"
    
    checkpoint_path = os.path.join(checkpoint_dir, file_name)
    
    checkpoint_dict = {
        'epoch': checkpoint.epoch,
        'model_state_dict': checkpoint.model_state,
        'val_loss': checkpoint.val_loss
    }
    with open(checkpoint_path, 'wb') as f:
        pickle.dump(checkpoint_dict, f)
    
    return checkpoint_path

There are different strategies out there on how to select checkpoints.
* We will start with a very simple one, and that is:
   - We select the best checkpoint based on the lowest val loss.
   - For the rest checkpoints, select them evenly across the training phase.


In [74]:
def select_checkpoints(checkpoints, k=5, strategy='linear'):
    if not checkpoints:
        print('\n[WARNING] Found empty checkpoints')
        return []
    if k == 1 or len(checkpoints) == 1:
        return min(checkpoints, key=lambda x: x[0])
        
    best_checkpoint = min(checkpoints, key=lambda x: x[0])
    
    sorted_by_epoch = sorted(checkpoints, key=lambda x: x[2])
    total_epochs = len(sorted_by_epoch)
    
    if strategy == 'linear':
        indices = np.linspace(0, total_epochs-1, k-1).astype(int)
    elif strategy == 'log':
        indices = np.logspace(0, np.log10(total_epochs-1), k-1).astype(int)
    else:
        print('Unsupported strategy.')
    
    time_diverse = [sorted_by_epoch[i] for i in indices]
    
    all_checkpoints = [best_checkpoint]
    for checkpoint in time_diverse:
        if checkpoint not in all_checkpoints:
            all_checkpoints.append(checkpoint)
    
    return all_checkpoints[:k]

Our model train function is modified so that:
* It saves a checkpoint of the model at each epoch.
    - This will store the captured epoch, and the validation loss during that epoch.
    - At the end of the training phase, we select some of the checkpoints based on a strategy.

In [75]:
import glob 

def train_best_checkp(
    model,
    optimizer,
    loss_fn=multi_head_loss_fn,
    train_embeddings=train_embeddings_tensor,
    y_train_nar=y_train_nar,
    y_train_sub_heads=y_train_sub_heads,
    val_embeddings=val_embeddings_tensor,
    y_val_nar=y_val_nar,
    y_val_sub_heads=y_val_sub_heads,
    patience=10,
    num_epochs=100,
    scheduler=None,
    min_delta=0.001,
    clear_prev_checkp=True,
    top_k=5,
    strategy='linear'
):
    best_val_loss = float('inf')
    patience_counter = 0
    all_checkpoints = []
    print('Deleting previous checkpoints..')
    files = glob.glob(os.path.join(checkpoint_dir, '*'))
    for f in files:
        try:
            if os.path.isfile(f):
                os.remove(f)
        except Exception as e:
            print(f"\n[WARNING] Couldn't delete {f}: {e}")
    
    for epoch in range(num_epochs):
        model.train()
        train_narr_probs, train_sub_probs_dict = model(train_embeddings)
        train_loss = loss_fn(train_narr_probs, train_sub_probs_dict, y_train_nar, y_train_sub_heads)
        
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        
        model.eval()
        with torch.no_grad():
            val_narr_probs, val_sub_probs_dict = model(val_embeddings)
            val_loss = loss_fn(val_narr_probs, val_sub_probs_dict, y_val_nar, y_val_sub_heads)
        
        print(f"Epoch {epoch+1}/{num_epochs}, "
              f"Training Loss: {train_loss.item():.4f}, "
              f"Validation Loss: {val_loss.item():.4f}")
        
        if scheduler:
            scheduler.step(val_loss)
            current_lr = scheduler.optimizer.param_groups[0]['lr']
            print(f"Current Learning Rate: {current_lr:.6f}")

        checkpoint = Checkpoint(
            epoch=epoch,
            model_state=model.state_dict(),
            val_loss=val_loss.item()
        )
            
        checkpoint_path = save_checkpoint(checkpoint)
        all_checkpoints.append((val_loss.item(), checkpoint_path, epoch))
        
        if val_loss.item() < best_val_loss - min_delta:
            best_val_loss = val_loss.item()
            patience_counter = 0
        else:
            patience_counter += 1
            print(f"Validation loss did not significantly improve for {patience_counter} epoch(s).")
        
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break
            
    selected_checkpoints = select_checkpoints(all_checkpoints, k=top_k, strategy=strategy)
    selected_paths = [cp[1] for cp in selected_checkpoints]
    selected_losses = [cp[0] for cp in selected_checkpoints]
    
    for _, path, _ in all_checkpoints:
        if path not in selected_paths:
            try:
                os.remove(path)
            except:
                print('\n[WARNING] Could not remove checkpoint path: ', path)
                pass
    
    print("\nSelected Checkpoints:")
    print("Val Loss | Epoch | Checkpoint")
    print("-" * 50)
    for val_loss, path, epoch in selected_checkpoints:
        print(f"{val_loss:.4f} | {epoch:5d} | {os.path.basename(path)}")
        
    return model, selected_paths, selected_losses

We then define the actual ensemble model.
* For the final prediction, each model will have a vote in proportion to it's loss in that checkpoint
  - By taking the average, predictions of all votes, we remove any noise by using better checkpoints.
  - The final model is also more robust since it is not relying on a single model.

In [76]:
class WeightedCheckpointEnsemble:
    def __init__(self, checkpoint_paths, val_losses, model_class=MultiTaskClassifierMultiHeadConcat):
        self.models = []
        self.weights = []
        
        losses = torch.tensor(val_losses)
        weights = torch.softmax(-losses / losses.mean(), dim=0)
        
        print('Checkpoint Weights:')
        for path, weight in zip(checkpoint_paths, weights):
            print(f"{os.path.basename(path)} weight: {weight:.3f}")
            
        checkpoint = self._load_checkpoint(checkpoint_paths[0])
        state_dict = checkpoint['model_state_dict']
        
        for checkpoint_path, weight in zip(checkpoint_paths, weights):            
            model = model_class(
                input_size=input_size,
                hidden_size=2048,
                dropout_rate=0.4 
            )
            
            checkpoint = self._load_checkpoint(checkpoint_path)
            model.load_state_dict(checkpoint['model_state_dict'])
            model.eval()
            
            self.models.append(model)
            self.weights.append(weight)
    
    def predict(self, x):
        narrative_probs_sum = None
        subnarrative_probs_dict_sum = {}
        
        with torch.no_grad():
            for model, weight in zip(self.models, self.weights):
                narr_probs, sub_probs_dict = model(x)
                
                weighted_narr_probs = narr_probs * weight
                
                if narrative_probs_sum is None:
                    narrative_probs_sum = weighted_narr_probs
                else:
                    narrative_probs_sum += weighted_narr_probs
                
                for narr_idx, sub_probs in sub_probs_dict.items():
                    weighted_sub_probs = sub_probs * weight
                    if narr_idx not in subnarrative_probs_dict_sum:
                        subnarrative_probs_dict_sum[narr_idx] = weighted_sub_probs
                    else:
                        subnarrative_probs_dict_sum[narr_idx] += weighted_sub_probs
        
        return narrative_probs_sum, subnarrative_probs_dict_sum
    
    def _load_checkpoint(self, checkpoint_path):
        with open(checkpoint_path, 'rb') as f:
            return pickle.load(f)

In [77]:
def evaluate_ensemble(
    base_evaluator,
    ensemble_model,
    embeddings=val_embeddings_tensor,
    y_nar_true=y_val_nar,
    y_sub_hierarchical=y_val_sub_heads,
    thresholds=None,
    save=False,
):
    def ensemble_predict(embedding):
        return ensemble_model.predict(embedding)
    
    return base_evaluator.evaluate(
        model=ensemble_predict,
        embeddings=embeddings,
        save=save
    )

In [78]:
model_concat_ens = MultiTaskClassifierMultiHeadConcat(
    input_size=input_size,
    hidden_size=2048,
    dropout_rate=0.4
)
optimizer= AdamW(model_concat_ens.parameters(), lr=0.001)

In [79]:
trained_concat, checkpoint_paths, val_losses = train_best_checkp(
    model=model_concat_ens,
    optimizer=optimizer,
    patience=5,
    top_k=5,
)

Deleting previous checkpoints..
Epoch 1/100, Training Loss: 1.0535, Validation Loss: 1.0145
Epoch 2/100, Training Loss: 0.6633, Validation Loss: 0.9960
Epoch 3/100, Training Loss: 0.5377, Validation Loss: 0.9844
Epoch 4/100, Training Loss: 0.4841, Validation Loss: 0.9751
Epoch 5/100, Training Loss: 0.4553, Validation Loss: 0.9655
Epoch 6/100, Training Loss: 0.4299, Validation Loss: 0.9543
Epoch 7/100, Training Loss: 0.4092, Validation Loss: 0.9413
Epoch 8/100, Training Loss: 0.3926, Validation Loss: 0.9276
Epoch 9/100, Training Loss: 0.3788, Validation Loss: 0.9140
Epoch 10/100, Training Loss: 0.3667, Validation Loss: 0.9008
Epoch 11/100, Training Loss: 0.3536, Validation Loss: 0.8878
Epoch 12/100, Training Loss: 0.3451, Validation Loss: 0.8749
Epoch 13/100, Training Loss: 0.3377, Validation Loss: 0.8619
Epoch 14/100, Training Loss: 0.3296, Validation Loss: 0.8489
Epoch 15/100, Training Loss: 0.3217, Validation Loss: 0.8352
Epoch 16/100, Training Loss: 0.3132, Validation Loss: 0.8204
E

In [80]:
ensemble_model = WeightedCheckpointEnsemble(checkpoint_paths, val_losses)

Checkpoint Weights:
checkpoint_epoch_37.pt weight: 0.233
checkpoint_epoch_0.pt weight: 0.141
checkpoint_epoch_13.pt weight: 0.175
checkpoint_epoch_27.pt weight: 0.220
checkpoint_epoch_41.pt weight: 0.230


  return torch.load(io.BytesIO(b))


In [81]:
_ = evaluate_ensemble(
    evaluator,
    ensemble_model,
)


Best thresholds found:
Narrative threshold: 0.50
Subnarrative threshold: 0.30


Coarse-F1: 0.438
F1 st. dev. coarse: 0.385
Fine-F1: 0.292 
F1 st. dev. fine: 0.310
