# Semeval 2025 Task 10
### Subtask 2: Narrative Classification

Given a news article and a [two-level taxonomy of narrative labels](https://propaganda.math.unipd.it/semeval2025task10/NARRATIVE-TAXONOMIES.pdf) (where each narrative is subdivided into subnarratives) from a particular domain, assign to the article all the appropriate subnarrative labels. This is a multi-label multi-class document classification task.

In [1]:
# Instructor compatible with these versions only
!pip install -q huggingface_hub==0.25.0
!pip install -q langchain==0.1.2 sentence_transformers==2.2.2

## 1. Importing libraries

We will start by importing the libraries needed.

In [2]:
import pandas as pd
import numpy as np

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import optimizers
from tensorflow.keras.callbacks import ModelCheckpoint

from matplotlib import pyplot as plt
import seaborn as sns
import os

In [3]:
!pip install InstructorEmbedding



## 2. Reading our data

In [4]:
raw_annotation_data = []

with open("data/semeval_data/labels/EN/subtask-2-annotations.txt", 'r') as file:
    for line in file:
        parts = line.strip().split('\t')
        article_id = parts[0]
        narrative_to_subnarratives = parts[2].split(';') # second part is the mapping from narrative to subnarrative
        narratives = []
        subnarratives = []

        # look to that narrative to subnarrative mapping
        for nar_to_sub in narrative_to_subnarratives:
          subnarrative_list = nar_to_sub.split(' ')
          if subnarrative_list[0] == 'Other':
            narratives.append('Other')
            subnarratives.append('Other')
            continue

          nar_to_sub = ' '.join(subnarrative_list[1:])
          nar, sub = nar_to_sub.split(':')
          narratives.append(nar.strip())
          subnarratives.append(sub.strip())

        raw_annotation_data.append({
            'article_id': article_id,
            'narratives': narratives,
            'subnarratives': subnarratives
        })

annotations_df = pd.DataFrame(raw_annotation_data)

In [5]:
annotations_df.sample(20)

Unnamed: 0,article_id,narratives,subnarratives
47,EN_UA_024847.txt,[Blaming the war on others rather than the inv...,"[The West are the aggressors, Other, Other, Uk..."
191,EN_CC_100002.txt,"[Criticism of institutions and authorities, Hi...","[Criticism of national governments, Blaming gl..."
128,EN_CC_100147.txt,[Other],[Other]
152,EN_UA_022319.txt,[Other],[Other]
177,EN_UA_103168.txt,[Amplifying war-related fears],[Other]
144,EN_UA_100868.txt,[Hidden plots by secret schemes of powerful gr...,[Other]
77,EN_UA_016466.txt,"[Amplifying war-related fears, Amplifying war-...","[By continuing the war we risk WWIII, Russia w..."
51,EN_CC_100065.txt,[Other],[Other]
48,EN_UA_100587.txt,[Other],[Other]
148,EN_UA_100864.txt,"[Amplifying war-related fears, Discrediting th...","[Russia will also attack other countries, Othe..."


In [6]:
annotations_df.shape

(200, 3)

In [7]:
annotations_df.iloc[2]

article_id                                        EN_UA_021270.txt
narratives       [Speculating war outcomes, Discrediting Ukrain...
subnarratives    [Other, Situation in Ukraine is hopeless, West...
Name: 2, dtype: object

In [8]:
annotations_df.iloc[2].narratives

['Speculating war outcomes',
 'Discrediting Ukraine',
 'Discrediting the West, Diplomacy',
 'Praise of Russia',
 'Discrediting the West, Diplomacy']

In [9]:
annotations_df.iloc[2].subnarratives

['Other',
 'Situation in Ukraine is hopeless',
 'West is tired of Ukraine',
 'Praise of Russian military might',
 'The West does not care about Ukraine, only about its interests']

In [10]:
def read_file_content(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        return file.read()


In [11]:
folder_path = "data/semeval_data/raw-documents/EN"
documents_df = pd.DataFrame(columns=['article_id', 'content'])
for filename in os.listdir(folder_path):
    if filename.endswith('.txt') and filename.startswith('EN'):
        article_id = (filename)
        file_path = os.path.join(folder_path, filename)
        content = read_file_content(file_path)

        new_row = pd.DataFrame({'article_id': [article_id], 'content': [content]})

        documents_df = pd.concat([documents_df, new_row], ignore_index=True)

documents_df.head()

Unnamed: 0,article_id,content
0,EN_UA_104876.txt,Putin honours army unit blamed for Bucha massa...
1,EN_UA_023211.txt,Europe Putin thanks US journalist Tucker Carls...
2,EN_UA_011260.txt,Russia has a clear plan to resolve the conflic...
3,EN_UA_101067.txt,"First war of TikTok era sees tragedy, humor an..."
4,EN_UA_102963.txt,Ukraine's President Zelenskyy to address Mexic...


In [12]:
documents_df.shape

(200, 2)

In [13]:
dataset = pd.merge(documents_df, annotations_df, on='article_id')
dataset.head()

Unnamed: 0,article_id,content,narratives,subnarratives
0,EN_UA_104876.txt,Putin honours army unit blamed for Bucha massa...,[Other],[Other]
1,EN_UA_023211.txt,Europe Putin thanks US journalist Tucker Carls...,[Other],[Other]
2,EN_UA_011260.txt,Russia has a clear plan to resolve the conflic...,"[Russia is the Victim, Discrediting Ukraine, D...","[UA is anti-RU extremists, Ukraine is a hub fo..."
3,EN_UA_101067.txt,"First war of TikTok era sees tragedy, humor an...",[Other],[Other]
4,EN_UA_102963.txt,Ukraine's President Zelenskyy to address Mexic...,[Other],[Other]


In [14]:
def extract_article_id(filename):
    number_part = filename.split('_')[-1].split('.')[0]
    return number_part

print(extract_article_id('EN_UA_103861.txt'))

103861


In [15]:
dataset['article_id'] = dataset['article_id'].apply(extract_article_id)
dataset.head()

Unnamed: 0,article_id,content,narratives,subnarratives
0,104876,Putin honours army unit blamed for Bucha massa...,[Other],[Other]
1,23211,Europe Putin thanks US journalist Tucker Carls...,[Other],[Other]
2,11260,Russia has a clear plan to resolve the conflic...,"[Russia is the Victim, Discrediting Ukraine, D...","[UA is anti-RU extremists, Ukraine is a hub fo..."
3,101067,"First war of TikTok era sees tragedy, humor an...",[Other],[Other]
4,102963,Ukraine's President Zelenskyy to address Mexic...,[Other],[Other]


In [16]:
dataset['narratives']

0                                                [Other]
1                                                [Other]
2      [Russia is the Victim, Discrediting Ukraine, D...
3                                                [Other]
4                                                [Other]
                             ...                        
195    [Criticism of institutions and authorities, Cr...
196                   [Discrediting the West, Diplomacy]
197                                              [Other]
198                       [Amplifying war-related fears]
199                           [Speculating war outcomes]
Name: narratives, Length: 200, dtype: object

In [17]:
dataset.shape

(200, 4)

In [18]:
unique_narratives = dataset['narratives'].explode().unique()
unique_narratives

array(['Other', 'Russia is the Victim', 'Discrediting Ukraine',
       'Blaming the war on others rather than the invader',
       'Discrediting the West, Diplomacy',
       'Criticism of institutions and authorities',
       'Criticism of climate policies', 'Criticism of climate movement',
       'Hidden plots by secret schemes of powerful groups',
       'Controversy about green technologies',
       'Amplifying war-related fears', 'Downplaying climate change',
       'Speculating war outcomes', 'Overpraising the West',
       'Distrust towards Media',
       'Questioning the measurements and science', 'Praise of Russia',
       'Negative Consequences for the West',
       'Climate change is beneficial',
       'Green policies are geopolitical instruments'], dtype=object)

In [19]:
dataset['narratives'].explode().value_counts()

narratives
Other                                                97
Discrediting the West, Diplomacy                     50
Amplifying war-related fears                         43
Discrediting Ukraine                                 29
Criticism of institutions and authorities            24
Blaming the war on others rather than the invader    21
Criticism of climate movement                        18
Russia is the Victim                                 17
Speculating war outcomes                             17
Criticism of climate policies                        16
Hidden plots by secret schemes of powerful groups    16
Praise of Russia                                     12
Overpraising the West                                10
Distrust towards Media                               10
Controversy about green technologies                  9
Questioning the measurements and science              8
Negative Consequences for the West                    7
Downplaying climate change           

In [20]:
unique_subnarratives = dataset['subnarratives'].explode().unique()
unique_subnarratives

array(['Other', 'UA is anti-RU extremists',
       'Ukraine is a hub for criminal activities',
       'Ukraine is associated with nazism', 'Ukraine is the aggressor',
       'The West are the aggressors', 'Ukraine is a puppet of the West',
       'Diplomacy does/will not work',
       'Criticism of political organizations and figures',
       'Climate movement is corrupt',
       'Ad hominem attacks on key activists', 'Blaming global elites',
       'Climate policies are ineffective',
       'By continuing the war we risk WWIII',
       'There is a real possibility that nuclear weapons will be employed',
       'Ice is not melting', 'Climate cycles are natural',
       'Criticism of international entities',
       'Russia will also attack other countries',
       'Criticism of national governments',
       'Discrediting Ukrainian military', 'Ukrainian army is collapsing',
       'Discrediting Ukrainian government and officials and policies',
       'West is tired of Ukraine',
       'C

In [21]:
dataset['subnarratives'].explode().value_counts()

subnarratives
Other                                                                     151
The West are the aggressors                                                18
There is a real possibility that nuclear weapons will be employed          16
Criticism of national governments                                          12
Russia will also attack other countries                                    11
Western media is an instrument of propaganda                                9
The West does not care about Ukraine, only about its interests              8
Ad hominem attacks on key activists                                         8
Ukraine is a puppet of the West                                             8
Criticism of political organizations and figures                            7
Diplomacy does/will not work                                                7
The West belongs in the right side of history                               7
The West is weak                                  

In [22]:
from sklearn.preprocessing import MultiLabelBinarizer

mlb_narratives = MultiLabelBinarizer()
mlb_subnarratives = MultiLabelBinarizer()

In [23]:
dataset['narratives']

0                                                [Other]
1                                                [Other]
2      [Russia is the Victim, Discrediting Ukraine, D...
3                                                [Other]
4                                                [Other]
                             ...                        
195    [Criticism of institutions and authorities, Cr...
196                   [Discrediting the West, Diplomacy]
197                                              [Other]
198                       [Amplifying war-related fears]
199                           [Speculating war outcomes]
Name: narratives, Length: 200, dtype: object

In [24]:
narratives_binary = mlb_narratives.fit_transform(dataset['narratives'])
subnarratives_binary = mlb_subnarratives.fit_transform(dataset['subnarratives'])

dataset['narratives_encoded'] = narratives_binary.tolist()
dataset['subnarratives_encoded'] = subnarratives_binary.tolist()

In [25]:
dataset.head()

Unnamed: 0,article_id,content,narratives,subnarratives,narratives_encoded,subnarratives_encoded
0,104876,Putin honours army unit blamed for Bucha massa...,[Other],[Other],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,23211,Europe Putin thanks US journalist Tucker Carls...,[Other],[Other],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,11260,Russia has a clear plan to resolve the conflic...,"[Russia is the Victim, Discrediting Ukraine, D...","[UA is anti-RU extremists, Ukraine is a hub fo...","[0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,101067,"First war of TikTok era sees tragedy, humor an...",[Other],[Other],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,102963,Ukraine's President Zelenskyy to address Mexic...,[Other],[Other],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [26]:
len(unique_narratives)

20

In [27]:
len(unique_subnarratives)

58

In [28]:
import spacy
import re
import emoji

nlp = spacy.load("en_core_web_sm")

def clean_article(article_text):
    # Remove URLs
    article_text = re.sub(r'http\S+|www\S+|https\S+|[a-zA-Z0-9.-]+\.com', '', article_text, flags=re.MULTILINE)

    doc = nlp(article_text)
    cleaned_tokens = []

    for token in doc:
        if (token.is_space or '@' in token.text or emoji.is_emoji(token.text)):
            continue

        cleaned_tokens.append(token.text + token.whitespace_)

    cleaned_article = "".join(cleaned_tokens).strip()

    return cleaned_article

In [29]:
dataset['cleaned_content'] = dataset['content'].apply(clean_article)

In [30]:
import warnings
from sklearn.metrics import classification_report, confusion_matrix

def get_classification_report(y_true, y_pred):
  # We will ignore the warnings we get
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        report = classification_report(y_true, y_pred, output_dict=True)
    report_df = pd.DataFrame(report).transpose()
    return report_df

def plot_confusion_matrix(y_true, y_pred):
    conf_matrix = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(8, 6))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()
    return conf_matrix

In [31]:
embeddings_dir = 'data/embeddings/narrative_classification/'
embedding_file_name = 'all_embeddings.npy'
embeddings_full_path = embeddings_dir + embedding_file_name

In [32]:
import os

def are_embeddings_saved(filepath):
    if os.path.exists(filepath):
        return True
    return False

In [33]:
from InstructorEmbedding import INSTRUCTOR
instructor_model = INSTRUCTOR('hkunlp/instructor-large')

  from tqdm.autonotebook import trange


load INSTRUCTOR_Transformer
max_seq_length  512


  model.load_state_dict(torch.load(os.path.join(input_path, 'pytorch_model.bin'), map_location=torch.device('cpu')))


In [34]:
def precompute_embeddings(dataset, model, file_path):
    embeddings = []
    for index, row in dataset.iterrows():
        context = row['cleaned_content']
        instruction = f"""
                    Given the following article, classify the narrative based on its content. The narrative classification
                    should capture the main theme of the article, which could relate to topics such as the Ukraine-Russia war, 
                    climate change, or other global issues. The model should focus on understanding the key events, actors, and 
                    emotions expressed in the article.
                    """
        embedding = instructor_model.encode([[instruction, context]])
        embeddings.append(embedding[0])

    embeddings_array = np.array(embeddings)
    np.save(file_path, embeddings_array)
    print(f"Embeddings saved to {file_path}")

def load_embeddings(filename):
    return np.load(filename)

if not are_embeddings_saved(embeddings_full_path): precompute_embeddings(dataset, instructor_model, embeddings_full_path)

In [35]:
all_embeddings = load_embeddings(embeddings_full_path)

In [36]:
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
import numpy as np

def iterative_split_data_with_embeddings(data, embeddings, labels_column, 
                                         train_size=0.8, val_size_ratio=0.5, 
                                         splits=5, shuffle=True):
    if shuffle:
        shuffled_indices = np.arange(len(data))
        np.random.shuffle(shuffled_indices)
        data = data.iloc[shuffled_indices].reset_index(drop=True)
        embeddings = embeddings[shuffled_indices]

    labels = np.array(data[labels_column].tolist())

    mskf = MultilabelStratifiedKFold(n_splits=splits)
    for train_idx, temp_idx in mskf.split(np.zeros(len(labels)), labels):
        train_data = data.iloc[train_idx]
        temp_data = data.iloc[temp_idx]
        train_embeddings = embeddings[train_idx]
        temp_embeddings = embeddings[temp_idx]
        break

    temp_labels = np.array(temp_data[labels_column].tolist())
    val_size = int(len(temp_data) * val_size_ratio)
    mskf_temp = MultilabelStratifiedKFold(n_splits=splits)

    for val_idx, test_idx in mskf_temp.split(np.zeros(len(temp_labels)), temp_labels):
        val_data = temp_data.iloc[val_idx]
        test_data = temp_data.iloc[test_idx]
        val_embeddings = temp_embeddings[val_idx]
        test_embeddings = temp_embeddings[test_idx]
        break

    return (train_data, train_embeddings), (val_data, val_embeddings), (test_data, test_embeddings)


(dataset_train, train_embeddings), \
(dataset_val, val_embeddings), \
(dataset_test, test_embeddings) = iterative_split_data_with_embeddings(
    dataset, 
    all_embeddings, 
    splits=4,
    labels_column="narratives_encoded", 
)

In [37]:
dataset_train.shape

(151, 7)

In [38]:
dataset_val.shape

(37, 7)

In [39]:
y_train_nar = dataset_train['narratives_encoded'].tolist()
y_val_nar = dataset_val['narratives_encoded'].tolist()
y_test_nar = dataset_test['narratives_encoded'].tolist()

In [40]:
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, multilabel_confusion_matrix
from sklearn.multiclass import OneVsRestClassifier

ovr_logistic = OneVsRestClassifier(LogisticRegression(max_iter=1000, class_weight='balanced'))

In [41]:
ovr_logistic.fit(train_embeddings, y_train_nar)



In [42]:
from sklearn.model_selection import StratifiedKFold, cross_val_score

def evaluate_model(model, x, y_true, show_confusion=True):
  y_pred = model.predict(x)
  classification_report_df = get_classification_report(y_true, y_pred)
  print(classification_report_df)
  print('\n')
  if show_confusion: conf = plot_confusion_matrix(y_true, y_pred)

def get_cross_val_score(model, x, y, scoring='f1_macro', splits=3):
  cv = StratifiedKFold(n_splits=splits, shuffle=True)
  cross_val_scores = cross_val_score(model, x, y, cv=cv, scoring=scoring)
  print(f"Cross-validation scores: {cross_val_scores}")
  print(f"Mean CV F1 Score: {cross_val_scores.mean()}")

In [43]:
evaluate_model(ovr_logistic, val_embeddings, y_val_nar, False)

              precision    recall  f1-score  support
0              0.250000  0.600000  0.352941      5.0
1              0.090909  0.250000  0.133333      4.0
2              0.000000  0.000000  0.000000      1.0
3              0.666667  1.000000  0.800000      2.0
4              0.428571  1.000000  0.600000      3.0
5              0.200000  1.000000  0.333333      1.0
6              0.444444  1.000000  0.615385      4.0
7              0.166667  0.666667  0.266667      3.0
8              0.352941  0.857143  0.500000      7.0
9              0.250000  1.000000  0.400000      1.0
10             0.000000  0.000000  0.000000      1.0
11             0.000000  0.000000  0.000000      0.0
12             0.250000  0.666667  0.363636      3.0
13             0.333333  1.000000  0.500000      1.0
14             0.785714  0.611111  0.687500     18.0
15             0.125000  0.500000  0.200000      2.0
16             0.000000  0.000000  0.000000      1.0
17             0.000000  0.000000  0.000000   

In [44]:
from sklearn.neighbors import KNeighborsClassifier

knn = KNeighborsClassifier(n_neighbors=5, metric='cosine')

ovr_knn = OneVsRestClassifier(knn)

ovr_knn.fit(train_embeddings, y_train_nar)



In [45]:
evaluate_model(ovr_knn, val_embeddings, y_val_nar, False)

              precision    recall  f1-score  support
0              0.400000  0.400000  0.400000      5.0
1              0.000000  0.000000  0.000000      4.0
2              0.000000  0.000000  0.000000      1.0
3              0.000000  0.000000  0.000000      2.0
4              0.000000  0.000000  0.000000      3.0
5              0.500000  1.000000  0.666667      1.0
6              1.000000  0.250000  0.400000      4.0
7              0.000000  0.000000  0.000000      3.0
8              0.285714  0.285714  0.285714      7.0
9              0.000000  0.000000  0.000000      1.0
10             0.000000  0.000000  0.000000      1.0
11             0.000000  0.000000  0.000000      0.0
12             0.000000  0.000000  0.000000      3.0
13             0.000000  0.000000  0.000000      1.0
14             0.888889  0.444444  0.592593     18.0
15             0.000000  0.000000  0.000000      2.0
16             0.000000  0.000000  0.000000      1.0
17             0.000000  0.000000  0.000000   

### Fine-tuning BERT to predict narratives

In [46]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")



In [47]:
dataset_train.head()

Unnamed: 0,article_id,content,narratives,subnarratives,narratives_encoded,subnarratives_encoded,cleaned_content
0,25165,Zelensky Says US Politicians ‘Don’t Care About...,"[Discrediting the West, Diplomacy, Discreditin...","[The West does not care about Ukraine, only ab...","[0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",Zelensky Says US Politicians ‘Don’t Care About...
1,10901,G20 communique set to echo Modi's Ukraine line...,[Other],[Other],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",G20 communique set to echo Modi's Ukraine line...
3,100093,Disease X is a ‘blueprint’ for future pandemic...,[Other],[Other],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",Disease X is a ‘blueprint’ for future pandemic...
4,103861,The World Needs Peacemaker Trump Again \n\n by...,[Other],[Other],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",The World Needs Peacemaker Trump Again by Jeff...
5,101954,A 'Watershed Event': Five Takeaways From Israe...,[Other],[Other],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",A 'Watershed Event': Five Takeaways From Israe...


In [48]:
def tokenize_data(texts, max_length=512):
    return tokenizer(texts, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")

train_encodings = tokenize_data(dataset_train['cleaned_content'].tolist())
val_encodings = tokenize_data(dataset_val['cleaned_content'].tolist())

In [49]:
narrative_encodings = dataset_train['narratives_encoded']

print(narrative_encodings)

0      [0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, ...
1      [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...
3      [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...
4      [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...
5      [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...
                             ...                        
193    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...
195    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...
196    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...
198    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...
199    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...
Name: narratives_encoded, Length: 151, dtype: object


In [50]:
import torch
from torch import nn

class NarrativeClassificationBERT(nn.Module):
    def __init__(self, bert_model, narrative_classes):
        super(NarrativeClassificationBERT, self).__init__()
        self.bert = bert_model
        self.dropout = nn.Dropout(0.3)

        hidden_size = self.bert.config.hidden_size
        self.narrative_classifier = nn.Linear(hidden_size, narrative_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)

        narrative_logits = self.narrative_classifier(pooled_output)
        return narrative_logits

In [55]:
import torch
from torch.utils.data import Dataset

class ArticleDataset(Dataset):
    def __init__(self, encodings, narrative_labels):
        self.encodings = encodings
        self.narrative_labels = narrative_labels

    def __getitem__(self, idx):
        print(self.encodings)
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['narrative_labels'] = torch.tensor(self.narrative_labels[idx], dtype=torch.float)
        return item

    def __len__(self):
        return len(self.narrative_labels)

train_dataset = ArticleDataset(
    train_encodings,
    dataset_train['narratives_encoded'].tolist()
)

val_dataset = ArticleDataset(
    val_encodings,
    dataset_val['narratives_encoded'].tolist()
)

In [56]:
len(train_dataset)

151

In [None]:
from torch.nn.functional import binary_cross_entropy_with_logits

def compute_loss(narrative_logits, narrative_labels, weights=None):
    return binary_cross_entropy_with_logits(narrative_logits, narrative_labels, pos_weight=weights)

In [68]:
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import BertModel, BertConfig

model_config = {
    'lr': 4e-5,
    'batch_size': 16,
}

bert_config = BertConfig.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

model = NarrativeClassificationBERT(bert_model, len(mlb_narratives.classes_))
optimizer = AdamW(model.parameters(), lr=model_config['lr'])

train_loader = DataLoader(train_dataset, batch_size=model_config['batch_size'], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=model_config['batch_size'])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

NarrativeClassificationBERT(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-1

In [None]:
def train(num_epochs=15):
    for epoch in range(num_epochs):
        
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
    
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            narrative_labels = batch['narrative_labels'].to(device)
    
            narrative_logits = model(input_ids, attention_mask)
    
            loss = compute_loss(narrative_logits, narrative_labels)
            
            loss.backward()
            optimizer.step()
    
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                narrative_labels = batch['narrative_labels'].to(device)
    
                narrative_logits = model(input_ids, attention_mask)
                loss = compute_loss(narrative_logits, narrative_labels)
                val_loss += loss.item()
    
        avg_val_loss = val_loss / len(val_loader)
        print(f"Epoch {epoch + 1}, Training Loss: {loss.item()}, Validation Loss: {avg_val_loss}")

In [None]:
train()