# Semeval 2025 Task 10
### Subtask 2: Narrative Classification

Given a news article and a [two-level taxonomy of narrative labels](https://propaganda.math.unipd.it/semeval2025task10/NARRATIVE-TAXONOMIES.pdf) (where each narrative is subdivided into subnarratives) from a particular domain, assign to the article all the appropriate subnarrative labels. This is a multi-label multi-class document classification task.

## 1. Multi-head per narrative model

### 1.1 Loading pre-saved variables

We start by loading our pre-saved variables:

In [1]:
import pickle
import os
import pandas as pd

base_save_folder_dir = '../saved/'
dataset_folder = os.path.join(base_save_folder_dir, 'Dataset')
embeddings_folder = os.path.join(base_save_folder_dir, 'Embeddings/all_embeddings.npy')

with open(os.path.join(dataset_folder, 'dataset_cleaned.pkl'), 'rb') as f:
    dataset = pickle.load(f)

In [2]:
dataset.head()

Unnamed: 0,language,article_id,content,narratives,subnarratives,narratives_encoded,subnarratives_encoded
0,RU,RU-URW-1161.txt,<PARA>в ближайшие два месяца сша будут стремит...,[Blaming the war on others rather than the inv...,"[The West are the aggressors, Other, The West ...","[0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,RU,RU-URW-1175.txt,<PARA>в ес испугались последствий популярности...,"[Discrediting the West, Diplomacy, Discreditin...","[The West is weak, Other, The EU is divided]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,RU,RU-URW-1149.txt,<PARA>возможность признания аллы пугачевой ино...,[Distrust towards Media],[Western media is an instrument of propaganda],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,RU,RU-URW-1015.txt,<PARA>азаров рассказал о смене риторики киева ...,"[Discrediting Ukraine, Discrediting Ukraine]","[Ukraine is a puppet of the West, Discrediting...","[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,RU,RU-URW-1001.txt,<PARA>в россиянах проснулась массовая любовь к...,[Praise of Russia],[Russia is a guarantor of peace and prosperity],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [3]:
misc_folder = os.path.join(base_save_folder_dir, 'Misc')

with open(os.path.join(misc_folder, 'narrative_to_subnarratives.pkl'), 'rb') as f:
    narrative_to_subnarratives = pickle.load(f)

We'll also need the actual hierarchy of narratives to subnarratives for our new model.  

* Each narrative is also mapped to `Other`—this happens because if no subnarrative matches, we assign it to `Other`.

In [4]:
narrative_to_subnarratives

{'Discrediting Ukraine': ['Other',
  'Ukraine is associated with nazism',
  'Situation in Ukraine is hopeless',
  'Ukraine is a hub for criminal activities',
  'Discrediting Ukrainian government and officials and policies',
  'Rewriting Ukraine’s history',
  'Discrediting Ukrainian nation and society',
  'Ukraine is a puppet of the West',
  'Discrediting Ukrainian military'],
 'Discrediting the West, Diplomacy': ['Other',
  'West is tired of Ukraine',
  'The West is weak',
  'The West does not care about Ukraine, only about its interests',
  'The EU is divided',
  'Diplomacy does/will not work',
  'The West is overreacting'],
 'Praise of Russia': ['Other',
  'Praise of Russian President Vladimir Putin',
  'Praise of Russian military might',
  'Russia has international support from a number of countries and people',
  'Russian invasion has strong national support',
  'Russia is a guarantor of peace and prosperity'],
 'Russia is the Victim': ['Russia actions in Ukraine are only self-defe

In [5]:
label_encoder_folder = os.path.join(base_save_folder_dir, 'LabelEncoders')

with open(os.path.join(label_encoder_folder, 'mlb_narratives.pkl'), 'rb') as f:
    mlb_narratives = pickle.load(f)

with open(os.path.join(label_encoder_folder, 'mlb_subnarratives.pkl'), 'rb') as f:
    mlb_subnarratives = pickle.load(f)

Finally, we get our embeddings:

In [6]:
import numpy as np

embeddings_folder = os.path.join(base_save_folder_dir, 'Embeddings/all_embeddings.npy')

def load_embeddings(filename):
    return np.load(filename)

embeddings_kalm = load_embeddings(embeddings_folder)

We also need to make sure that the embeddings stay aligned with the dataset split.  
* We pass `all_embeddings` so that `dataset_train` and `train_embeddings` match up exactly, keeping everything consistent.

We use stratified splitting here to ensure the label distribution stays the same in both the training and validation sets.  
* This somewhat maintain the class proportions, even for the rare cases, making sure both sets are roughly representative of the original dataset.

In [7]:
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
import numpy as np
import pandas as pd

def stratified_train_val_split_with_embeddings(data, embeddings, labels_column, train_size=0.8, splits=5, shuffle=True, min_instances=2):
    if shuffle:
        shuffled_indices = np.arange(len(data))
        np.random.shuffle(shuffled_indices)
        data = data.iloc[shuffled_indices].reset_index(drop=True)
        embeddings = embeddings[shuffled_indices]

    labels = np.array(data[labels_column].tolist())
    rare_indices = []
    common_indices = []

    class_counts = labels.sum(axis=0)
    rare_classes = np.where(class_counts <= min_instances)[0]

    for idx, label_row in enumerate(labels):
        if any(label_row[rare_classes]):
            rare_indices.append(idx)
        else:
            common_indices.append(idx)

    rare_data = data.iloc[rare_indices]
    rare_labels = labels[rare_indices]
    rare_embeddings = embeddings[rare_indices]

    train_rare = rare_data.iloc[:len(rare_data) // 2].reset_index(drop=True)
    val_rare = rare_data.iloc[len(rare_data) // 2:].reset_index(drop=True)

    train_rare_embeddings = rare_embeddings[:len(rare_data) // 2]
    val_rare_embeddings = rare_embeddings[len(rare_data) // 2:]

    common_data = data.iloc[common_indices].reset_index(drop=True)
    common_labels = labels[common_indices]
    common_embeddings = embeddings[common_indices]

    mskf = MultilabelStratifiedKFold(n_splits=splits)
    for train_idx, val_idx in mskf.split(np.zeros(len(common_labels)), common_labels):
        train_common = common_data.iloc[train_idx]
        val_common = common_data.iloc[val_idx]
        train_common_embeddings = common_embeddings[train_idx]
        val_common_embeddings = common_embeddings[val_idx]
        break

    train_data = pd.concat([train_rare, train_common]).reset_index(drop=True)
    val_data = pd.concat([val_rare, val_common]).reset_index(drop=True)

    train_embeddings = np.concatenate([train_rare_embeddings, train_common_embeddings], axis=0)
    val_embeddings = np.concatenate([val_rare_embeddings, val_common_embeddings], axis=0)

    return (train_data, train_embeddings), (val_data, val_embeddings)

(dataset_train, train_embeddings), (dataset_val, val_embeddings) = stratified_train_val_split_with_embeddings(
    dataset,
    embeddings_kalm,
    labels_column="subnarratives_encoded",
    min_instances=2
)

In [8]:
dataset_train.shape

(1354, 7)

In [9]:
train_embeddings.shape

(1354, 896)

### 1.2 Remapping our subnarrative indices

We know that our articls have many narratives, and each one maps to several subnarratives, creating a hierarchy.  
The problem is, our `subnarratives_encoded` currently looks like a flat list of zeros:

```
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
```

But we need it to reflect the hierarchy properly:

So, we break it down into a list of lists—each inner list represents the true labels for a specific hierarchy:

```
[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0] , [0, 0, 0, ...
 ^ hierarchy 0       ^ hierarchy 1          ^ hierarchy 2 ...
```

This will help us significantly later when we need to know for a specific article, the true subnarrative labels for a specific hierarchy.

In [10]:
dataset['subnarratives_encoded']

0       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
1       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
2       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
3       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
4       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
                              ...                        
1694    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
1695    [0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, ...
1696    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ...
1697    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
1698    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
Name: subnarratives_encoded, Length: 1699, dtype: object

We’re using the label encoders to get the indices of narratives and subnarratives, which we’ll use later.  
* For each narrative in `narrative_to_subnarratives`, we find the index of the narrative and its corresponding subnarratives using the encoders.

In [11]:
narrative_to_sub_map = {}
narrative_classes = list(mlb_narratives.classes_)
subnarrative_classes = list(mlb_subnarratives.classes_)

for narrative, subnarratives in narrative_to_subnarratives.items():
    narrative_idx = narrative_classes.index(narrative)
    subnarrative_indices = [subnarrative_classes.index(sub) for sub in subnarratives]
    narrative_to_sub_map[narrative_idx] = subnarrative_indices

print(narrative_to_sub_map)

{8: [33, 66, 50, 64, 20, 39, 22, 65, 21], 9: [33, 71, 60, 56, 53, 19, 58], 17: [33, 34, 35, 41, 46, 42], 19: [40, 33, 63, 59], 10: [33, 69, 72], 1: [33, 43, 31, 3, 62], 16: [55, 33, 57, 32], 2: [54, 67, 33], 20: [44, 33, 45, 68], 13: [2, 33, 6], 14: [47, 33, 61], 0: [33, 1, 24, 73, 23], 15: [33], 7: [33, 14, 16, 15, 17], 5: [9, 33, 8, 0], 11: [28, 33, 7, 51, 27, 70, 29, 49, 4], 6: [12, 33, 10, 11], 18: [18, 33, 30, 26, 48], 3: [33, 5, 52], 4: [37, 36, 33, 38], 12: [25, 13, 33]}


Now, we remap the `subnarratives_encoded` list to reflect the correct hierarchy for each article.  
* For each narrative, we grab its corresponding subnarrative indices from `narrative_to_sub_map` and assign the sublabels to the appropriate hierarchy column.  

This will give us a new set of columns where each one contains the true subnarrative labels for that narrative hierarchy.

In [12]:
hierarchy_new_column_name = "narrative_hierarchy"

In [13]:
def remap_subnarratives(row, narrative_to_sub_map):
    """Takes in a row and encodes the current subnarrative list to the associated hierarchy based on the narr-subnar map"""
    for narr_idx, sub_indices in narrative_to_sub_map.items():
        sub_labels = [row['subnarratives_encoded'][sub_idx] for sub_idx in sub_indices]
        col_name = f"{hierarchy_new_column_name}_{narr_idx}"
        row[col_name] = sub_labels
    return row

dataset_train_cpy = dataset_train.apply(remap_subnarratives, axis=1, args=(narrative_to_sub_map,)).copy()

We do the same for validation dataset:

In [14]:
dataset_val_cpy = dataset_val.apply(remap_subnarratives, axis=1, args=(narrative_to_sub_map,)).copy()

In [15]:
dataset_val_cpy.head()

Unnamed: 0,language,article_id,content,narratives,subnarratives,narratives_encoded,subnarratives_encoded,narrative_hierarchy_8,narrative_hierarchy_9,narrative_hierarchy_17,...,narrative_hierarchy_0,narrative_hierarchy_15,narrative_hierarchy_7,narrative_hierarchy_5,narrative_hierarchy_11,narrative_hierarchy_6,narrative_hierarchy_18,narrative_hierarchy_3,narrative_hierarchy_4,narrative_hierarchy_12
0,PT,PT_CC_416.txt,<PARA>da patranha do aquecimento global</PARA>...,"[Questioning the measurements and science, Hid...","[Scientific community is unreliable, Climate a...","[0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, ...","[0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, ...","[1, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 0, 0, 0, 0, 0, 0]","[1, 0, 0, 0, 0, 0]",...,"[1, 0, 0, 0, 0]",[1],"[1, 0, 0, 0, 0]","[0, 1, 1, 0]","[0, 1, 0, 0, 0, 0, 1, 1, 0]","[0, 1, 0, 0]","[0, 1, 0, 0, 1]","[1, 0, 1]","[0, 0, 1, 0]","[0, 0, 1]"
1,EN,EN_CC_200022.txt,<PARA>Denmark to Punish Farmers for cow ‘emiss...,"[Criticism of institutions and authorities, Cr...","[Criticism of national governments, Other, Met...","[0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[1, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 0, 0, 0, 0, 0, 0]","[1, 0, 0, 0, 0, 0]",...,"[1, 0, 0, 0, 0]",[1],"[1, 1, 0, 1, 0]","[0, 1, 0, 0]","[0, 1, 0, 0, 0, 0, 0, 0, 0]","[0, 1, 0, 0]","[0, 1, 1, 1, 0]","[1, 0, 0]","[0, 0, 1, 0]","[0, 0, 1]"
2,BG,A9_BG_4016.txt,<PARA>британски допломат потресен от политикат...,"[Discrediting the West, Diplomacy, Blaming the...","[Other, The West are the aggressors]","[0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 0, 0, 0, 0, 0, 0]","[1, 0, 0, 0, 0, 0]",...,"[1, 0, 0, 0, 0]",[1],"[1, 0, 0, 0, 0]","[0, 1, 0, 0]","[0, 1, 0, 0, 0, 0, 0, 0, 0]","[0, 1, 0, 0]","[0, 1, 0, 0, 0]","[1, 0, 0]","[0, 0, 1, 0]","[0, 0, 1]"
3,HI,HI_112.txt,<PARA>टैंकों और बख्तरबंद वाहनों के साथ रूसी इल...,"[Russia is the Victim, Blaming the war on othe...",[Russia actions in Ukraine are only self-defen...,"[0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 0, 0, 0, 1, 0, 0, 0, 0]","[1, 0, 1, 0, 0, 0, 0]","[1, 0, 0, 0, 0, 0]",...,"[1, 0, 0, 0, 0]",[1],"[1, 0, 0, 0, 0]","[0, 1, 0, 0]","[0, 1, 0, 0, 0, 0, 0, 0, 0]","[0, 1, 0, 0]","[0, 1, 0, 0, 0]","[1, 0, 0]","[0, 0, 1, 0]","[0, 0, 1]"
4,PT,PT_393.txt,<PARA>estudante que atirou tinta a montenegro ...,"[Criticism of institutions and authorities, Am...",[Criticism of political organizations and figu...,"[1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 0, 0, 0, 0, 0, 0, 0, 0]","[1, 0, 0, 0, 0, 0, 0]","[1, 0, 0, 0, 0, 0]",...,"[1, 0, 0, 0, 0]",[1],"[1, 0, 1, 1, 0]","[0, 1, 0, 0]","[0, 1, 0, 0, 0, 0, 0, 0, 0]","[0, 1, 0, 0]","[0, 1, 0, 0, 0]","[1, 0, 0]","[0, 0, 1, 0]","[0, 0, 1]"


A sample result looks like this:

In [16]:
for narr_idx, sub_indices in narrative_to_sub_map.items():
    dataset_hierarchy_col_name = f"{hierarchy_new_column_name}_{narr_idx}"
    res = dataset_train_cpy[dataset_hierarchy_col_name]
    print(f"Sample of {dataset_hierarchy_col_name}:")
    print(res.head()) 
    print("\n")

Sample of narrative_hierarchy_8:
0    [0, 0, 0, 0, 0, 0, 0, 0, 0]
1    [1, 0, 0, 0, 0, 0, 0, 0, 0]
2    [1, 0, 0, 0, 0, 0, 0, 0, 0]
3    [0, 0, 0, 0, 0, 0, 0, 0, 0]
4    [1, 0, 0, 0, 0, 0, 0, 0, 0]
Name: narrative_hierarchy_8, dtype: object


Sample of narrative_hierarchy_9:
0    [0, 0, 0, 0, 0, 0, 0]
1    [1, 0, 0, 0, 0, 0, 0]
2    [1, 0, 0, 0, 0, 0, 0]
3    [0, 0, 0, 0, 0, 0, 0]
4    [1, 0, 0, 0, 0, 0, 0]
Name: narrative_hierarchy_9, dtype: object


Sample of narrative_hierarchy_17:
0    [0, 0, 0, 0, 0, 0]
1    [1, 0, 0, 0, 0, 0]
2    [1, 0, 0, 0, 0, 0]
3    [0, 0, 0, 0, 0, 0]
4    [1, 0, 0, 0, 0, 0]
Name: narrative_hierarchy_17, dtype: object


Sample of narrative_hierarchy_19:
0    [0, 0, 0, 0]
1    [0, 1, 0, 0]
2    [0, 1, 0, 0]
3    [0, 0, 0, 0]
4    [0, 1, 0, 0]
Name: narrative_hierarchy_19, dtype: object


Sample of narrative_hierarchy_10:
0    [0, 0, 0]
1    [1, 0, 0]
2    [1, 0, 0]
3    [0, 0, 0]
4    [1, 0, 0]
Name: narrative_hierarchy_10, dtype: object


Sample of narrative

In [17]:
narrative_order = sorted(narrative_to_sub_map.keys())
narrative_order

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]

Now we want to make sure that the true subnarratives for hierarchy 0 are in position 0 of the aggregated list, hierarchy 1 in position 1, and so on.  
This ensures the subnarratives are ordered correctly in the final, aggregated list.

In [18]:
def aggregate_subnarratives(row, narrative_order, narrative_to_sub_map):
    """Takes in a row, and aggregates all hierarchy columns to 1 list.
    The encoded list will be a list of lists, starting from the first hierarchy"""
    aggregated = []
    for narr_idx in narrative_order:
        column_name = f"narrative_hierarchy_{narr_idx}"
        sub_labels = row[column_name]
        aggregated.append(sub_labels)
    return aggregated

dataset_train['aggregated_subnarratives'] = dataset_train_cpy.apply(
    aggregate_subnarratives,
    axis=1,
    args=(narrative_order, narrative_to_sub_map)
)

dataset_val['aggregated_subnarratives'] = dataset_val_cpy.apply(
    aggregate_subnarratives,
    axis=1,
    args=(narrative_order, narrative_to_sub_map)
)

In [19]:
dataset_train['aggregated_subnarratives']

0       [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0], ...
1       [[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [0, 0, 1], ...
2       [[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [0, 0, 1], ...
3       [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0], ...
4       [[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [0, 0, 1], ...
                              ...                        
1349    [[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [0, 0, 1], ...
1350    [[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [0, 0, 1], ...
1351    [[0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0], ...
1352    [[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [0, 0, 1], ...
1353    [[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0], ...
Name: aggregated_subnarratives, Length: 1354, dtype: object

In [20]:
y_train_sub_heads = dataset_train['aggregated_subnarratives'].to_numpy()
y_val_sub_heads = dataset_val['aggregated_subnarratives'].to_numpy()

In [21]:
import torch

train_embeddings_tensor = torch.tensor(train_embeddings, dtype=torch.float32)
val_embeddings_tensor = torch.tensor(val_embeddings, dtype=torch.float32)

In [22]:
input_size = train_embeddings_tensor.shape[1]
print(input_size)

896


Now we have a model with a shared layer that captures the general features of the article.  
* The model was finalised after a lot of experimentaions the BatchNorm + ReLU combo significantly improves performance by stabilizing training and speeding up convergence.
* Also, it seems like the model overfits very quickly when becoming overly complex.
  
We make predictions for the top-level narratives, followed by separate subnarrative predictions for each narrative.  

In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiTaskClassifierMultiHead(nn.Module):
    def __init__(
        self,
        input_size,
        hidden_size=1024,
        num_narratives=len(mlb_narratives.classes_),
        narrative_to_sub_map=narrative_to_sub_map,
        dropout_rate=0.4
    ):
        super().__init__()
        self.shared_layer = nn.Sequential(
            nn.Linear(input_size, hidden_size * 2),
            nn.BatchNorm1d(hidden_size * 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )

        self.narrative_head = nn.Sequential(
            nn.Linear(hidden_size * 2, num_narratives),
            nn.Sigmoid()
        )

        self.subnarrative_heads = nn.ModuleDict()
        for narr_idx, sub_indices in narrative_to_sub_map.items():
            num_subs_for_this_narr = len(sub_indices)
            self.subnarrative_heads[str(narr_idx)] = nn.Sequential(
                nn.Linear(hidden_size * 2, num_subs_for_this_narr),
                nn.Sigmoid()
            )

    def forward(self, x):
        shared_out = self.shared_layer(x)
        narr_probs = self.narrative_head(shared_out)

        sub_probs_dict = {}
        for narr_idx, head in self.subnarrative_heads.items():
            sub_probs_dict[narr_idx] = head(shared_out)

        return narr_probs, sub_probs_dict

In [24]:
network_params = {
    'lr': 0.001,
    'hidden_size': 512,
    'dropout': 0.4
}

In [25]:
model_multi_head = MultiTaskClassifierMultiHead(
    input_size=input_size,
    hidden_size=network_params['hidden_size'],
)

In [26]:
print(model_multi_head)

MultiTaskClassifierMultiHead(
  (shared_layer): Sequential(
    (0): Linear(in_features=896, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.4, inplace=False)
  )
  (narrative_head): Sequential(
    (0): Linear(in_features=1024, out_features=21, bias=True)
    (1): Sigmoid()
  )
  (subnarrative_heads): ModuleDict(
    (8): Sequential(
      (0): Linear(in_features=1024, out_features=9, bias=True)
      (1): Sigmoid()
    )
    (9): Sequential(
      (0): Linear(in_features=1024, out_features=7, bias=True)
      (1): Sigmoid()
    )
    (17): Sequential(
      (0): Linear(in_features=1024, out_features=6, bias=True)
      (1): Sigmoid()
    )
    (19): Sequential(
      (0): Linear(in_features=1024, out_features=4, bias=True)
      (1): Sigmoid()
    )
    (10): Sequential(
      (0): Linear(in_features=1024, out_features=3, bias=True)
      (1): Sigmoid()
    )
    (1): Sequent

In [27]:
y_train_nar = dataset_train['narratives_encoded'].tolist()
y_val_nar = dataset_val['narratives_encoded'].tolist()

y_train_sub_nar = dataset_train['subnarratives_encoded'].tolist()
y_val_sub_nar = dataset_val['subnarratives_encoded'].tolist()

We move everything to a tensor:

In [28]:
y_train_nar = torch.tensor(y_train_nar, dtype=torch.float32)
y_train_sub_nar = torch.tensor(y_train_sub_nar, dtype=torch.float32)

y_val_nar = torch.tensor(y_val_nar, dtype=torch.float32)
y_val_sub_nar = torch.tensor(y_val_sub_nar, dtype=torch.float32)

In [29]:
train_embeddings_tensor = torch.tensor(train_embeddings, dtype=torch.float32)
val_embeddings_tensor = torch.tensor(val_embeddings, dtype=torch.float32)

We calculate class weights to handle label imbalance in the training data. 
* This way, rare labels are given higher importance to ensure the model learns them effectively.
* The custom ```WeightedBCELoss``` applies these weights during training to balance the impact of common and rare labels, preventing the model from focusing only on frequent ones.

In [30]:
import torch
import torch.nn as nn

def compute_class_weights(y_train):
    total_samples = y_train.shape[0]
    class_weights = []
    for label in range(y_train.shape[1]):
        pos_count = y_train[:, label].sum().item()
        neg_count = total_samples - pos_count
        pos_weight = total_samples / (2 * pos_count) if pos_count > 0 else 0
        neg_weight = total_samples / (2 * neg_count) if neg_count > 0 else 0
        class_weights.append((pos_weight, neg_weight))
    return class_weights

class WeightedBCELoss(nn.Module):
    def __init__(self, class_weights):
        super().__init__()
        self.class_weights = class_weights

    def forward(self, probs, targets):
        bce_loss = 0
        epsilon = 1e-7
        for i, (pos_weight, neg_weight) in enumerate(self.class_weights):
            prob = probs[:, i]
            bce = -pos_weight * targets[:, i] * torch.log(prob + epsilon) - \
                  neg_weight * (1 - targets[:, i]) * torch.log(1 - prob + epsilon)
            bce_loss += bce.mean()
        return bce_loss / len(self.class_weights)

class_weights_sub_nar = compute_class_weights(y_val_sub_nar)
class_weights_nar = compute_class_weights(y_val_nar)
narrative_criterion = WeightedBCELoss(class_weights_nar)

We create a separate loss function for each hierarchy of subnarratives to handle their specific class imbalance.  

In [31]:
sub_criterion_dict = {}

for narr_idx, sub_indices in narrative_to_sub_map.items():
    local_weights = [ class_weights_sub_nar[sub_i] for sub_i in sub_indices ]

    sub_criterion = WeightedBCELoss(local_weights)
    sub_criterion_dict[str(narr_idx)] = sub_criterion

We define a ```ConditionalLoss``` class to handle the multi-task loss calculation.

* In the forward method, we first calculate the loss for the top-level narratives using the narrative criterion.
We then loop through each subnarrative head and compute the loss for each one, based on its specific subnarrative labels


* We introduce a conditioning term that penalizes inconsistencies between narrative and subnarrative predictions.
Finally, we combine the narrative loss, sub-loss, and conditioning loss to get the total loss, which is returned.

In [32]:
class MultiHeadLoss(nn.Module):
    def __init__(self, narrative_criterion, sub_criterion_dict, 
                 condition_weight=0.6, sub_weight=0.5):
        super().__init__()
        self.narrative_criterion = narrative_criterion
        self.sub_criterion_dict = sub_criterion_dict
        self.condition_weight = condition_weight
        self.sub_weight = sub_weight
        
    def forward(self, narr_probs, sub_probs_dict, y_narr, y_sub_heads):
        narr_loss = self.narrative_criterion(narr_probs, y_narr)
        sub_loss = 0.0
        condition_loss = 0.0
        
        for narr_idx_str, sub_probs in sub_probs_dict.items():
            narr_idx = int(narr_idx_str)
            y_sub = [row[narr_idx] for row in y_sub_heads]
            y_sub_tensor = torch.tensor(y_sub, dtype=torch.float32, device=sub_probs.device)
            
            sub_loss_func = self.sub_criterion_dict[narr_idx_str]
            sub_loss += sub_loss_func(sub_probs, y_sub_tensor)

            narr_pred = narr_probs[:, narr_idx].unsqueeze(1)
            condition_term = torch.mean(
                # Penalize high probs of sub, based on first level narr predictinos
                torch.abs(sub_probs * (1 - narr_pred)) + 
                # If a narrative is true, then the subnarrative predictions should match their actual true values.
                narr_pred * torch.abs(sub_probs - y_sub_tensor.unsqueeze(1))
            )
            condition_loss += condition_term
            
        sub_loss = sub_loss / len(sub_probs_dict)
        condition_loss = condition_loss / len(sub_probs_dict)
        
        total_loss = (1 - self.sub_weight) * narr_loss + \
                    self.sub_weight * sub_loss + \
                    self.condition_weight * condition_loss
        
        return total_loss

In [33]:
multi_head_loss_fn = MultiHeadLoss(narrative_criterion, sub_criterion_dict)

In [34]:
def check_weight_norms(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            norm = param.norm(2).item()
            print(f"Layer: {name} | Weight Norm: {norm:.4f}")

We define the function for training our model:

In [35]:
def train_with_multihead(
    model,
    optimizer,
    loss_fn=multi_head_loss_fn,
    train_embeddings=train_embeddings_tensor,
    y_train_nar=y_train_nar,
    y_train_sub_heads=y_train_sub_heads,
    val_embeddings=val_embeddings_tensor,
    y_val_nar=y_val_nar,
    y_val_sub_heads=y_val_sub_heads,
    patience=5,
    num_epochs=100,
    scheduler=None,
    min_delta=0.001
):
    best_val_loss = float('inf')
    best_model = None
    patience_counter = 0
    for epoch in range(num_epochs):
        model.train()
        train_narr_probs, train_sub_probs_dict = model(train_embeddings)
        train_loss = loss_fn(train_narr_probs, train_sub_probs_dict, y_train_nar, y_train_sub_heads)
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        
        # Validation phase
        model.eval()
        with torch.no_grad():
            val_narr_probs, val_sub_probs_dict = model(val_embeddings)
            val_loss = loss_fn(val_narr_probs, val_sub_probs_dict, y_val_nar, y_val_sub_heads)
        
        print(f"Epoch {epoch+1}/{num_epochs}, "
              f"Training Loss: {train_loss.item():.4f}, "
              f"Validation Loss: {val_loss.item():.4f}")
        
        if scheduler:
            scheduler.step(val_loss)
            current_lr = scheduler.optimizer.param_groups[0]['lr']
            print(f"Current Learning Rate: {current_lr:.6f}")
        
        if val_loss.item() < best_val_loss - min_delta:
            best_val_loss = val_loss.item()
            patience_counter = 0
            best_model = model.state_dict().copy()
        else:
            patience_counter += 1
            print(f"Validation loss did not significantly improve for {patience_counter} epoch(s).")
        
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break
    
    if best_model:
        model.load_state_dict(best_model)
    return model

In [36]:
optimizer_multi_head = torch.optim.AdamW(model_multi_head.parameters(), lr=0.001)

We will also initialize a scheduler to adjust the learning rate dynamically during training based on how the model is performing

In [37]:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_multi_head, mode='min', factor=0.5, patience=3)

In [38]:
train_with_multihead(
    model=model_multi_head,
    optimizer=optimizer_multi_head,
    scheduler=scheduler,
    patience=10
)

Epoch 1/100, Training Loss: 1.0249, Validation Loss: 0.9849
Current Learning Rate: 0.001000
Epoch 2/100, Training Loss: 0.8745, Validation Loss: 0.9799
Current Learning Rate: 0.001000
Epoch 3/100, Training Loss: 0.7871, Validation Loss: 0.9748
Current Learning Rate: 0.001000
Epoch 4/100, Training Loss: 0.7228, Validation Loss: 0.9698
Current Learning Rate: 0.001000
Epoch 5/100, Training Loss: 0.6734, Validation Loss: 0.9648
Current Learning Rate: 0.001000
Epoch 6/100, Training Loss: 0.6298, Validation Loss: 0.9595
Current Learning Rate: 0.001000
Epoch 7/100, Training Loss: 0.5975, Validation Loss: 0.9537
Current Learning Rate: 0.001000
Epoch 8/100, Training Loss: 0.5706, Validation Loss: 0.9474
Current Learning Rate: 0.001000
Epoch 9/100, Training Loss: 0.5436, Validation Loss: 0.9405
Current Learning Rate: 0.001000
Epoch 10/100, Training Loss: 0.5230, Validation Loss: 0.9331
Current Learning Rate: 0.001000
Epoch 11/100, Training Loss: 0.5039, Validation Loss: 0.9251
Current Learning R

MultiTaskClassifierMultiHead(
  (shared_layer): Sequential(
    (0): Linear(in_features=896, out_features=1024, bias=True)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.4, inplace=False)
  )
  (narrative_head): Sequential(
    (0): Linear(in_features=1024, out_features=21, bias=True)
    (1): Sigmoid()
  )
  (subnarrative_heads): ModuleDict(
    (8): Sequential(
      (0): Linear(in_features=1024, out_features=9, bias=True)
      (1): Sigmoid()
    )
    (9): Sequential(
      (0): Linear(in_features=1024, out_features=7, bias=True)
      (1): Sigmoid()
    )
    (17): Sequential(
      (0): Linear(in_features=1024, out_features=6, bias=True)
      (1): Sigmoid()
    )
    (19): Sequential(
      (0): Linear(in_features=1024, out_features=4, bias=True)
      (1): Sigmoid()
    )
    (10): Sequential(
      (0): Linear(in_features=1024, out_features=3, bias=True)
      (1): Sigmoid()
    )
    (1): Sequent

In [39]:
from sklearn.metrics import classification_report, f1_score
from dataclasses import dataclass

@dataclass
class EvaluationResults:
    best_threshold: float
    best_f1: float
    narrative_report: str
    subnarrative_report: str
    narrative_predictions: np.ndarray
    subnarrative_predictions: np.ndarray

class MultiHeadEvaluator:    
    def __init__(
        self,
        narrative_to_sub_map=narrative_to_sub_map,
        narrative_order=narrative_order,
        num_subnarratives=len(mlb_subnarratives.classes_),
        narrative_classes=mlb_narratives.classes_,
        subnarrative_classes=mlb_subnarratives.classes_,
        device='cpu'
    ):
        self.narrative_to_sub_map = narrative_to_sub_map
        self.narrative_order = narrative_order
        self.num_subnarratives = num_subnarratives
        self.narrative_classes = narrative_classes
        self.subnarrative_classes = subnarrative_classes
        self.device = device

    def _flatten_subnarratives(self, y_sub_hierarchical):
        """Reconstruct flattened subnarrative array from hierarchical structure."""
        num_samples = len(y_sub_hierarchical)
        sub_global_array = np.zeros((num_samples, self.num_subnarratives), dtype=int)

        for i in range(num_samples):
            for j, narr_idx in enumerate(self.narrative_order):
                sub_label_vec = y_sub_hierarchical[i][j]
                narr_idx = int(narr_idx)
                sub_indices = self.narrative_to_sub_map[narr_idx]
                for local_sub_i, global_sub_i in enumerate(sub_indices):
                    sub_global_array[i, global_sub_i] = sub_label_vec[local_sub_i]

        return sub_global_array

    def visualize_results(self, results):
        print(f"\nBest Threshold: {results.best_threshold:.2f}")
        print(f"Best Average F1: {results.best_f1:.4f}")
        print("\nNarrative Classification Report:")
        print(results.narrative_report)
        print("\nSubnarrative Classification Report:")
        print(results.subnarrative_report)

    def evaluate(
        self,
        model,
        embeddings=val_embeddings_tensor,
        y_nar_true=y_val_nar,
        y_sub_hierarchical=y_val_sub_heads,
        thresholds=None,
    ):
        if thresholds is None:
            thresholds = np.arange(0.1, 1.0, 0.1)

        y_nar_true_np = y_nar_true.cpu().numpy()
        
        embeddings = embeddings.to(self.device)
        with torch.no_grad():
            narr_probs, sub_probs_dict = model(embeddings)
            narr_probs = narr_probs.cpu().numpy()
            sub_probs_dict = {k: v.cpu().numpy() for k, v in sub_probs_dict.items()}

        best_threshold = 0
        best_f1 = -1
        best_nar_report = ""
        best_sub_report = ""
        best_nar_preds = None
        best_sub_preds = None
        
        for threshold in thresholds:
            narr_preds = (narr_probs >= threshold).astype(int)
            samples = len(narr_probs)
            sub_preds_flattened = np.zeros((samples, self.num_subnarratives), dtype=int)

            for narr_idx, sub_indices in self.narrative_to_sub_map.items():
                sub_probs_for_narr = sub_probs_dict[str(narr_idx)]
                predicted_narr_mask = narr_preds[:, narr_idx] == 1
                sub_preds_for_narr = (sub_probs_for_narr >= threshold).astype(int)

                for sample_idx in range(samples):
                    if predicted_narr_mask[sample_idx]:
                        for local_sub_i, global_sub_i in enumerate(sub_indices):
                            sub_preds_flattened[sample_idx, global_sub_i] = sub_preds_for_narr[sample_idx, local_sub_i]

            f1_nar = f1_score(y_nar_true_np, narr_preds, average="macro", zero_division=0)
            y_sub_true_np = self._flatten_subnarratives(y_sub_hierarchical)
            f1_sub = f1_score(y_sub_true_np, sub_preds_flattened, average="macro", zero_division=0)
            avg_f1 = (f1_nar + f1_sub) / 2.0

            if avg_f1 > best_f1:
                best_f1 = avg_f1
                best_threshold = threshold
                best_nar_preds = narr_preds
                best_sub_preds = sub_preds_flattened
                best_nar_report = classification_report(
                    y_nar_true_np, narr_preds,
                    target_names=self.narrative_classes,
                    zero_division=0
                )
                best_sub_report = classification_report(
                    y_sub_true_np, sub_preds_flattened,
                    target_names=self.subnarrative_classes,
                    zero_division=0
                )

        results = EvaluationResults(
            best_threshold=best_threshold,
            best_f1=best_f1,
            narrative_report=best_nar_report,
            subnarrative_report=best_sub_report,
            narrative_predictions=best_nar_preds,
            subnarrative_predictions=best_sub_preds
        )
        
        self.visualize_results(results)
        
        return results

In [40]:
evaluator = MultiHeadEvaluator()

In [41]:
_ = evaluator.evaluate(
    model=model_multi_head,
)


Best Threshold: 0.50
Best Average F1: 0.3699

Narrative Classification Report:
                                                   precision    recall  f1-score   support

                         Amplifying Climate Fears       0.70      0.96      0.81        46
                     Amplifying war-related fears       0.69      0.72      0.71        47
Blaming the war on others rather than the invader       0.31      0.62      0.41        32
                     Climate change is beneficial       0.50      0.25      0.33         4
             Controversy about green technologies       0.36      1.00      0.53         4
                    Criticism of climate movement       0.41      0.82      0.55        11
                    Criticism of climate policies       0.36      0.57      0.44        21
        Criticism of institutions and authorities       0.48      0.86      0.62        28
                             Discrediting Ukraine       0.67      0.80      0.73        79
         

### Providing the already predicted narrative

Let h(x) be the shared layer output for the embedding x:

        shared_out = self.shared_layer(x)

We compute the probability P(narr_i | x) for each narrative:

        narr_probs = self.narrative_head(shared_out)

Previously, we used the following formula for the subnarrative probability P(subnarr_j | x):

        P(subnarr_j | x) = σ(h(x))

* Where sigma is the sigmoid activation function. This means the subnarrative prediction was based only on the shared layer output.

However, the new idea is to consider:

        P(subnarr_j | x) = σ(concat(h(x), P(narr_i | x)))


Where narr_i is the narrative associated with subnarrative subnarr_j in the hierarchy.

* If the probability of the narrative is high, the subnarrative head will be more likely to predict the relevant subnarratives.
* If the probability is low, the model will ignore the corresponding subnarratives.
* At the same time, the shared output of the shared layer will help determine which subnarrative is most appropriate for the given document (and we can potentially use other techniques like attention to further improve the model).

In [42]:
import torch
import torch.nn as nn

class MultiTaskClassifierMultiHeadConcat(nn.Module):
    def __init__(
        self,
        input_size,
        hidden_size,
        num_narratives=len(mlb_narratives.classes_),
        narrative_to_sub_map=narrative_to_sub_map,
        dropout_rate=network_params['dropout']
    ):
        super().__init__()
        
        # Shared layer
        self.shared_layer = nn.Sequential(
            nn.Linear(input_size, hidden_size * 2),
            nn.BatchNorm1d(hidden_size * 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )

        self.narrative_head = nn.Sequential(
            nn.Linear(hidden_size * 2, num_narratives),
            nn.Sigmoid()
        )

        self.subnarrative_heads = nn.ModuleDict()
        for narr_idx, sub_indices in narrative_to_sub_map.items():
            num_subs_for_this_narr = len(sub_indices)
            self.subnarrative_heads[str(narr_idx)] = nn.Sequential(
                nn.Linear(hidden_size * 2 + 1, num_subs_for_this_narr),
                nn.Sigmoid()
            )

    def forward(self, x):
        shared_out = self.shared_layer(x)

        narr_probs = self.narrative_head(shared_out)

        sub_probs_dict = {}
        for narr_idx, head in self.subnarrative_heads.items():
            conditioned_input = torch.cat((shared_out, narr_probs[:, int(narr_idx)].unsqueeze(1)), dim=1)
            sub_probs_dict[narr_idx] = head(conditioned_input)

        return narr_probs, sub_probs_dict

In [43]:
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau

def initialize_and_train_model(
    model,
    num_epochs=100,
    lr=0.001,
    patience=10,
    use_scheduler=True,
    scheduler_patience=3,
    loss_fn=multi_head_loss_fn,
    num_subnarratives=len(mlb_subnarratives.classes_),
    device='cpu'
):
    optimizer = AdamW(model.parameters(), lr=lr)

    scheduler = None
    if use_scheduler:
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=scheduler_patience)

    trained_model = train_with_multihead(
                                    model=model,
                                    optimizer=optimizer,
                                    scheduler=scheduler,
                                    loss_fn=loss_fn,
                                    patience=patience
                                )
    return trained_model

In [44]:
model_multi_head_concat= MultiTaskClassifierMultiHeadConcat(
    input_size=input_size,
    hidden_size=network_params['hidden_size'],
)

In [45]:
trained_model_concat = initialize_and_train_model(
    model_multi_head_concat,
)

Epoch 1/100, Training Loss: 1.0219, Validation Loss: 0.9845
Current Learning Rate: 0.001000
Epoch 2/100, Training Loss: 0.8686, Validation Loss: 0.9790
Current Learning Rate: 0.001000
Epoch 3/100, Training Loss: 0.7851, Validation Loss: 0.9738
Current Learning Rate: 0.001000
Epoch 4/100, Training Loss: 0.7216, Validation Loss: 0.9688
Current Learning Rate: 0.001000
Epoch 5/100, Training Loss: 0.6703, Validation Loss: 0.9640
Current Learning Rate: 0.001000
Epoch 6/100, Training Loss: 0.6304, Validation Loss: 0.9591
Current Learning Rate: 0.001000
Epoch 7/100, Training Loss: 0.5972, Validation Loss: 0.9539
Current Learning Rate: 0.001000
Epoch 8/100, Training Loss: 0.5681, Validation Loss: 0.9479
Current Learning Rate: 0.001000
Epoch 9/100, Training Loss: 0.5427, Validation Loss: 0.9408
Current Learning Rate: 0.001000
Epoch 10/100, Training Loss: 0.5227, Validation Loss: 0.9327
Current Learning Rate: 0.001000
Epoch 11/100, Training Loss: 0.5037, Validation Loss: 0.9243
Current Learning R

In [46]:
_ = evaluator.evaluate(
    model=trained_model_concat,
)


Best Threshold: 0.40
Best Average F1: 0.3741

Narrative Classification Report:
                                                   precision    recall  f1-score   support

                         Amplifying Climate Fears       0.69      1.00      0.81        46
                     Amplifying war-related fears       0.59      0.83      0.69        47
Blaming the war on others rather than the invader       0.25      0.78      0.38        32
                     Climate change is beneficial       0.50      0.50      0.50         4
             Controversy about green technologies       0.20      1.00      0.33         4
                    Criticism of climate movement       0.39      0.82      0.53        11
                    Criticism of climate policies       0.36      0.76      0.49        21
        Criticism of institutions and authorities       0.40      0.86      0.55        28
                             Discrediting Ukraine       0.59      0.92      0.72        79
         

### Using multiplication instead of concantenation

Instead of using concat, we can try an element-wise multiplication. This sounds more logical as multiplication can act as a "gate":

* If the narrative probability is close to 0, the corresponding subnarrative head’s input will be scaled down, effectively disabling that subnarrative head.
* If the narrative probability is close to 1, the shared layer output passes through somewhat unaffected.

In [47]:
import torch
import torch.nn as nn

class MultiTaskClassifierMultiHeadMult(nn.Module):
    def __init__(
        self,
        input_size,
        hidden_size,
        num_narratives=len(mlb_narratives.classes_),
        narrative_to_sub_map=narrative_to_sub_map,
        dropout_rate=network_params['dropout'],
        bias=0.1
    ):
        super().__init__()
        
        self.shared_layer = nn.Sequential(
            nn.Linear(input_size, hidden_size * 2),
            nn.BatchNorm1d(hidden_size * 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate)
        )

        self.narrative_head = nn.Sequential(
            nn.Linear(hidden_size * 2, num_narratives),
            nn.Sigmoid()
        )

        self.subnarrative_heads = nn.ModuleDict()
        for narr_idx, sub_indices in narrative_to_sub_map.items():
            num_subs_for_this_narr = len(sub_indices)
            self.subnarrative_heads[str(narr_idx)] = nn.Sequential(
                nn.Linear(hidden_size * 2, num_subs_for_this_narr),
                nn.Sigmoid()
            )

        self.bias = bias

    def forward(self, x):
        # Shared layer output
        shared_out = self.shared_layer(x)

        # Narrative probabilities
        narr_probs = self.narrative_head(shared_out)

        sub_probs_dict = {}
        for narr_idx, head in self.subnarrative_heads.items():
            narr_pred = narr_probs[:, int(narr_idx)].unsqueeze(1)

            conditioned_input = shared_out * (narr_pred + self.bias)

            sub_probs_dict[narr_idx] = head(conditioned_input)

        return narr_probs, sub_probs_dict

In [48]:
model_multi_head_mult = MultiTaskClassifierMultiHeadMult(
    input_size=input_size,
    hidden_size=512,
)

In [49]:
trained_model_mult = initialize_and_train_model(model_multi_head_mult, loss_fn=multi_head_loss_fn)

Epoch 1/100, Training Loss: 1.0104, Validation Loss: 0.9860
Current Learning Rate: 0.001000
Epoch 2/100, Training Loss: 0.8908, Validation Loss: 0.9823
Current Learning Rate: 0.001000
Epoch 3/100, Training Loss: 0.8295, Validation Loss: 0.9787
Current Learning Rate: 0.001000
Epoch 4/100, Training Loss: 0.7878, Validation Loss: 0.9751
Current Learning Rate: 0.001000
Epoch 5/100, Training Loss: 0.7531, Validation Loss: 0.9715
Current Learning Rate: 0.001000
Epoch 6/100, Training Loss: 0.7282, Validation Loss: 0.9677
Current Learning Rate: 0.001000
Epoch 7/100, Training Loss: 0.7062, Validation Loss: 0.9635
Current Learning Rate: 0.001000
Epoch 8/100, Training Loss: 0.6851, Validation Loss: 0.9591
Current Learning Rate: 0.001000
Epoch 9/100, Training Loss: 0.6666, Validation Loss: 0.9542
Current Learning Rate: 0.001000
Epoch 10/100, Training Loss: 0.6513, Validation Loss: 0.9491
Current Learning Rate: 0.001000
Epoch 11/100, Training Loss: 0.6383, Validation Loss: 0.9439
Current Learning R

The results are a bit suprising, but they make sense because concatenation gives the subnarrative heads more "flexibility", while multiplication is more restrictive acting as a hard gate.

* If our narrative predictions are not confident or most importantly not correct, the subnarrative head will receive very weak input because of the multiplication.

In [50]:
_ = evaluator.evaluate(
    model=trained_model_mult,
)


Best Threshold: 0.50
Best Average F1: 0.3465

Narrative Classification Report:
                                                   precision    recall  f1-score   support

                         Amplifying Climate Fears       0.64      0.98      0.78        46
                     Amplifying war-related fears       0.51      0.87      0.65        47
Blaming the war on others rather than the invader       0.23      0.72      0.35        32
                     Climate change is beneficial       0.43      0.75      0.55         4
             Controversy about green technologies       0.17      1.00      0.30         4
                    Criticism of climate movement       0.34      0.91      0.50        11
                    Criticism of climate policies       0.33      0.81      0.47        21
        Criticism of institutions and authorities       0.36      0.86      0.51        28
                             Discrediting Ukraine       0.58      0.91      0.71        79
         

Because subnarrative heads rely heavily on the narrative probabilities, we will reduce the `sub_weight` increasing the the narrative weight:

In [51]:
multi_head_loss_fn = MultiHeadLoss(
    narrative_criterion, 
    sub_criterion_dict,
    condition_weight=0.5,
    sub_weight=0.3
)

In [52]:
model_multi_head_mult = MultiTaskClassifierMultiHeadMult(
    input_size=input_size,
    hidden_size=512,
)

In [53]:
trained_model_mult = initialize_and_train_model(model_multi_head_mult,
                                                loss_fn=multi_head_loss_fn,
                                               )

Epoch 1/100, Training Loss: 0.9761, Validation Loss: 0.9359
Current Learning Rate: 0.001000
Epoch 2/100, Training Loss: 0.8304, Validation Loss: 0.9310
Current Learning Rate: 0.001000
Epoch 3/100, Training Loss: 0.7645, Validation Loss: 0.9266
Current Learning Rate: 0.001000
Epoch 4/100, Training Loss: 0.7228, Validation Loss: 0.9229
Current Learning Rate: 0.001000
Epoch 5/100, Training Loss: 0.6819, Validation Loss: 0.9192
Current Learning Rate: 0.001000
Epoch 6/100, Training Loss: 0.6584, Validation Loss: 0.9155
Current Learning Rate: 0.001000
Epoch 7/100, Training Loss: 0.6368, Validation Loss: 0.9115
Current Learning Rate: 0.001000
Epoch 8/100, Training Loss: 0.6182, Validation Loss: 0.9071
Current Learning Rate: 0.001000
Epoch 9/100, Training Loss: 0.6035, Validation Loss: 0.9023
Current Learning Rate: 0.001000
Epoch 10/100, Training Loss: 0.5870, Validation Loss: 0.8968
Current Learning Rate: 0.001000
Epoch 11/100, Training Loss: 0.5735, Validation Loss: 0.8908
Current Learning R

In [54]:
_ = evaluator.evaluate(
    model=trained_model_mult,
)


Best Threshold: 0.40
Best Average F1: 0.3496

Narrative Classification Report:
                                                   precision    recall  f1-score   support

                         Amplifying Climate Fears       0.62      1.00      0.77        46
                     Amplifying war-related fears       0.54      0.85      0.66        47
Blaming the war on others rather than the invader       0.24      0.69      0.36        32
                     Climate change is beneficial       0.20      0.50      0.29         4
             Controversy about green technologies       0.19      1.00      0.32         4
                    Criticism of climate movement       0.31      1.00      0.48        11
                    Criticism of climate policies       0.31      0.71      0.43        21
        Criticism of institutions and authorities       0.35      0.89      0.51        28
                             Discrediting Ukraine       0.57      0.92      0.71        79
         