## Classifier - Try 4

Classify articles frames using aggregated SRL and sentence embeddings

In [1]:
import os

try:
  import google.colab

  from google.colab import drive
  drive.mount('/content/drive')
  IN_COLAB = True
except:
  IN_COLAB = False

if IN_COLAB:
  os.chdir('drive/MyDrive/Git/MasterThesis/data')
else:
  os.chdir('../../data/')

labels_path = "data/en/train-labels-subtask-2.txt"
articles_path = "data/en/train-articles-subtask-2/"

In [2]:
import pandas as pd

# Read the dev-labels-subtask-2.txt file
labels_df = pd.read_csv(labels_path, sep="\t")

# Rename the columns for easier processing
labels_df.columns = ["article_id", "frames"]


labels_df.head()

Unnamed: 0,article_id,frames
0,832959523,"Morality,Security_and_defense,Policy_prescript..."
1,833039623,"Political,Crime_and_punishment,External_regula..."
2,833032367,"Political,Crime_and_punishment,Fairness_and_eq..."
3,814777937,"Political,Morality,Fairness_and_equality,Exter..."
4,821744708,"Policy_prescription_and_evaluation,Political,L..."


In [3]:
# A function to read the article text given its ID
def get_article_content(article_id):
    try:
        with open(f"{articles_path}/article{article_id}.txt", "r") as f:
            return f.read()
    except FileNotFoundError:
        return None

df = labels_df

# Apply the function to get the article content
df["content"] = df["article_id"].apply(get_article_content)

# Drop rows where content could not be found
df.dropna(subset=["content"], inplace=True)

df.head()


Unnamed: 0,article_id,frames,content
0,832959523,"Morality,Security_and_defense,Policy_prescript...",How Theresa May Botched\n\nThose were the time...
1,833039623,"Political,Crime_and_punishment,External_regula...",Robert Mueller III Rests His Case—Dems NEVER W...
2,833032367,"Political,Crime_and_punishment,Fairness_and_eq...",Robert Mueller Not Recommending Any More Indic...
3,814777937,"Political,Morality,Fairness_and_equality,Exter...",The Far Right Is Trying to Co-opt the Yellow V...
4,821744708,"Policy_prescription_and_evaluation,Political,L...",‘Special place in hell’ for those who promoted...


In [4]:
# Split the frames column into a list of frames
df["frames_list"] = df["frames"].str.split(",")

# create for each frame a new column with the frame as name and 1 if the frame is present in the article and 0 if not
for frame in df["frames_list"].explode().unique():
    df[frame] = df["frames_list"].apply(lambda x: 1 if frame in x else 0)

df.head()

Unnamed: 0,article_id,frames,content,frames_list,Morality,Security_and_defense,Policy_prescription_and_evaluation,Legality_Constitutionality_and_jurisprudence,Economic,Political,Crime_and_punishment,External_regulation_and_reputation,Public_opinion,Fairness_and_equality,Capacity_and_resources,Quality_of_life,Cultural_identity,Health_and_safety
0,832959523,"Morality,Security_and_defense,Policy_prescript...",How Theresa May Botched\n\nThose were the time...,"[Morality, Security_and_defense, Policy_prescr...",1,1,1,1,1,0,0,0,0,0,0,0,0,0
1,833039623,"Political,Crime_and_punishment,External_regula...",Robert Mueller III Rests His Case—Dems NEVER W...,"[Political, Crime_and_punishment, External_reg...",0,0,1,1,0,1,1,1,1,0,0,0,0,0
2,833032367,"Political,Crime_and_punishment,Fairness_and_eq...",Robert Mueller Not Recommending Any More Indic...,"[Political, Crime_and_punishment, Fairness_and...",0,0,0,1,0,1,1,1,0,1,0,0,0,0
3,814777937,"Political,Morality,Fairness_and_equality,Exter...",The Far Right Is Trying to Co-opt the Yellow V...,"[Political, Morality, Fairness_and_equality, E...",1,1,0,0,1,1,0,1,1,1,0,0,0,0
4,821744708,"Policy_prescription_and_evaluation,Political,L...",‘Special place in hell’ for those who promoted...,"[Policy_prescription_and_evaluation, Political...",0,0,1,1,0,1,0,1,0,0,0,0,0,0


In [5]:
X = df["content"]
y = df.drop(columns=["article_id", "frames", "frames_list", "content"])

In [6]:
X.head()

0    How Theresa May Botched\n\nThose were the time...
1    Robert Mueller III Rests His Case—Dems NEVER W...
2    Robert Mueller Not Recommending Any More Indic...
3    The Far Right Is Trying to Co-opt the Yellow V...
4    ‘Special place in hell’ for those who promoted...
Name: content, dtype: object

In [7]:
y.head()

Unnamed: 0,Morality,Security_and_defense,Policy_prescription_and_evaluation,Legality_Constitutionality_and_jurisprudence,Economic,Political,Crime_and_punishment,External_regulation_and_reputation,Public_opinion,Fairness_and_equality,Capacity_and_resources,Quality_of_life,Cultural_identity,Health_and_safety
0,1,1,1,1,1,0,0,0,0,0,0,0,0,0
1,0,0,1,1,0,1,1,1,1,0,0,0,0,0
2,0,0,0,1,0,1,1,1,0,1,0,0,0,0
3,1,1,0,0,1,1,0,1,1,1,0,0,0,0
4,0,0,1,1,0,1,0,1,0,0,0,0,0,0


In [8]:
len(X), len(y)

(432, 432)

In [31]:
y.to_csv("../notebooks/classifier/y.csv")

### Create Dataset

In [9]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='transformers')

### Extract SRL Embeddings from articles

In [10]:
!pip install pycuda
!pip install allennlp allennlp-models

Collecting pycuda
  Downloading pycuda-2022.2.2.tar.gz (1.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m51.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting appdirs>=1.4.0
  Downloading appdirs-1.4.4-py2.py3-none-any.whl (9.6 kB)
Collecting pytools>=2011.2
  Downloading pytools-2023.1.1-py2.py3-none-any.whl (70 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m70.6/70.6 kB[0m [31m26.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting mako
  Downloading Mako-1.2.4-py3-none-any.whl (78 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.7/78.7 kB[0m [31m29.7 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: pycuda
  Building wheel for pycuda (pyproject.toml) ... [?25ldone
[?25h  Created wheel for pycuda: filename=pycud

In [11]:
import pycuda
from pycuda import compiler
import pycuda.driver as drv

drv.init()
print("%d device(s) found." % drv.Device.count())
           
for ordinal in range(drv.Device.count()):
    dev = drv.Device(ordinal)
    print (ordinal, dev.name())

1 device(s) found.
0 NVIDIA RTX A4000


In [12]:
from allennlp.predictors.predictor import Predictor
from allennlp_models.structured_prediction.models import srl_bert
from nltk.tokenize import sent_tokenize
import pandas as pd

In [13]:
def batched_extract_srl_components(sentences, predictor):
    # Prepare the batched input for the predictor
    batched_input = [{'sentence': sentence} for sentence in sentences]
    batched_srl = predictor.predict_batch_json(batched_input)
    
    # Extract SRL components from the batched predictions
    results = []
    for srl in batched_srl:
        best_extracted_data = None
        second_best_extracted_data = None
        for verb_entry in srl['verbs']:
            tags = verb_entry['tags']
            arg0_indices = [i for i, tag in enumerate(tags) if tag in ['B-ARG0', 'I-ARG0']]
            arg1_indices = [i for i, tag in enumerate(tags) if tag in ['B-ARG1', 'I-ARG1']]

            if arg0_indices and arg1_indices:
                best_extracted_data = {
                    'predicate': verb_entry['verb'],
                    'ARG0': ' '.join([srl['words'][i] for i in arg0_indices]),
                    'ARG1': ' '.join([srl['words'][i] for i in arg1_indices])
                }
                break
            elif (arg0_indices or arg1_indices) and not second_best_extracted_data:
                second_best_extracted_data = {
                    'predicate': verb_entry['verb'],
                    'ARG0': ' '.join([srl['words'][i] for i in arg0_indices]) if arg0_indices else '',
                    'ARG1': ' '.join([srl['words'][i] for i in arg1_indices]) if arg1_indices else ''
                }

        if best_extracted_data:
            results.append(best_extracted_data)
        elif second_best_extracted_data:
            results.append(second_best_extracted_data)
            
    return results

def optimized_extract_srl(X, predictor, batch_size=32):
    total_articles = len(X)
    processed_articles = 0

    all_results = []

    for article in X:
        sentences = sent_tokenize(article)
        article_srls = []

        for i in range(0, len(sentences), batch_size):
            batched_sentences = sentences[i:i+batch_size]
            article_srls.extend(batched_extract_srl_components(batched_sentences, predictor))

        all_results.append(article_srls)
        processed_articles += 1
        print(f"Processed article {processed_articles}/{total_articles}")

    return pd.Series(all_results)

In [14]:
import pickle

def get_X_srl(X, recalculate=False, pickle_path="../notebooks/classifier/X_srl.pkl"):
    """
    Returns the X_srl either by loading from a pickled file or recalculating.
    """
    if recalculate or not os.path.exists(pickle_path):
        print("Recalculate SRL")
        # Load predictor
        predictor = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/structured-prediction-srl-bert.2020.12.15.tar.gz", cuda_device=0)
        X_srl = optimized_extract_srl(X, predictor)
        with open(pickle_path, 'wb') as f:
            pickle.dump(X_srl, f)
    else:
        print("Load SRL from Pickle")
        with open(pickle_path, 'rb') as f:
            X_srl = pickle.load(f)
    return X_srl

# GPU

In [15]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

def free_gpu():
    print(torch.cuda.mem_get_info())
    print(torch.cuda.memory_summary())
    
free_gpu()

Using device: cuda
(16721772544, 16891248640)
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|---------------------------------------------------------------------------|
| Active memory         |       0 B  |       0 B  |       0 B  |       0 B  |
|       from large pool |       0 B  |       0 B  |       0 B  |       0 B  |
|       from small pool |       0 B  |       0 B  |       0 B  |       0 B  |
|-----------------

In [16]:
import torch
import gc

def list_gpu_tensors():
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj):
                if obj.is_cuda:
                    obj = obj.cpu()
                    obj = obj.to("cpu")
                    print(type(obj), obj.size())
        except:
            pass

        
list_gpu_tensors()



# Dataset

In [17]:
from torch.utils.data import Dataset
from transformers import BertTokenizer
import pandas as pd
import nltk

class ArticleDataset(Dataset):
    def __init__(self, X, X_srl, tokenizer, labels=None, max_sentences_per_article=32, max_sentence_length=32, max_arg_length=16):
        self.X = X
        self.X_srl = X_srl
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_sentences_per_article = max_sentences_per_article
        self.max_sentence_length = max_sentence_length
        self.max_arg_length = max_arg_length
        nltk.download('punkt')  # Download the Punkt tokenizer model for sentence splitting
        
    def __len__(self):
        return len(self.X)
    
    def _truncate_or_pad(self, lst, target_length, pad_value=0):
        """
        Truncate or pad the input list to match the target length.
        """
        if len(lst) > target_length:
            return lst[:target_length]
        else:
            return lst + [pad_value] * (target_length - len(lst))
    
    def __getitem__(self, idx):
        article = self.X.iloc[idx]
        srl = self.X_srl.iloc[idx]

        # Split the article into sentences
        sentences = nltk.sent_tokenize(article)
        sentences = sentences[:self.max_sentences_per_article]  # Limit the number of sentences

        # Tokenize and pad/truncate the sentences
        sentence_ids = [self.tokenizer.encode(sentence, add_special_tokens=True, max_length=self.max_sentence_length, truncation=True, padding='max_length') for sentence in sentences]
        while len(sentence_ids) < self.max_sentences_per_article:
            sentence_ids.append([0] * self.max_sentence_length)

        # Tokenize and pad/truncate the SRL items
        predicate_ids = [self.tokenizer.encode(predicate, add_special_tokens=True, max_length=self.max_arg_length, truncation=True, padding='max_length') for predicate in [item['predicate'] for item in srl]]
        arg0_ids = [self.tokenizer.encode(arg0, add_special_tokens=True, max_length=self.max_arg_length, truncation=True, padding='max_length') for arg0 in [item.get('arg0', '') for item in srl]]
        arg1_ids = [self.tokenizer.encode(arg1, add_special_tokens=True, max_length=self.max_arg_length, truncation=True, padding='max_length') for arg1 in [item.get('arg1', '') for item in srl]]
        
        predicate_ids = predicate_ids[:self.max_sentences_per_article]
        arg0_ids = arg0_ids[:self.max_sentences_per_article]
        arg1_ids = arg1_ids[:self.max_sentences_per_article]  
        
        while len(predicate_ids) < self.max_sentences_per_article:
            predicate_ids.append([0] * self.max_arg_length)
        while len(arg0_ids) < self.max_sentences_per_article:
            arg0_ids.append([0] * self.max_arg_length)
        while len(arg1_ids) < self.max_sentences_per_article:
            arg1_ids.append([0] * self.max_arg_length)

        data = {
            'sentence_ids': torch.tensor(sentence_ids, dtype=torch.long),
            'predicate_ids': torch.tensor(predicate_ids, dtype=torch.long),
            'arg0_ids': torch.tensor(arg0_ids, dtype=torch.long),
            'arg1_ids': torch.tensor(arg1_ids, dtype=torch.long)
        }
        
        if self.labels is not None:
            data['labels'] = self.labels.iloc[idx]
        
        return data


In [18]:
from torch.utils.data import DataLoader, random_split

# Initialize the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def custom_collate_fn(batch):
    # Extract individual lists from the batch
    sentence_ids = [item['sentence_ids'] for item in batch]
    predicate_ids = [item['predicate_ids'] for item in batch]
    arg0_ids = [item['arg0_ids'] for item in batch]
    arg1_ids = [item['arg1_ids'] for item in batch]
    
    # Pad each list
    sentence_ids = torch.nn.utils.rnn.pad_sequence(sentence_ids, batch_first=True, padding_value=0)
    predicate_ids = torch.nn.utils.rnn.pad_sequence(predicate_ids, batch_first=True, padding_value=0)
    arg0_ids = torch.nn.utils.rnn.pad_sequence(arg0_ids, batch_first=True, padding_value=0)
    arg1_ids = torch.nn.utils.rnn.pad_sequence(arg1_ids, batch_first=True, padding_value=0)

    # Conditionally extract and add labels
    output_dict = {
        'sentence_ids': sentence_ids,
        'predicate_ids': predicate_ids,
        'arg0_ids': arg0_ids,
        'arg1_ids': arg1_ids
    }
    
    if 'labels' in batch[0]:
        labels = [item['labels'] for item in batch]
        output_dict['labels'] = torch.Tensor(labels)

    return output_dict


def get_datasets_dataloaders(X, y, tokenizer, recalculate_srl=False, pickle_path="../notebooks/classifier/X_srl.pkl", batch_size=4):
    # Get X_srl
    X_srl = get_X_srl(X, recalculate=recalculate_srl, pickle_path=pickle_path)
    
    # Create the dataset
    dataset = ArticleDataset(X, X_srl, tokenizer, y)
    
    # Split into train and test sets
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    
    # Create dataloaders
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)
    
    print("CREATION DONE")
    return train_dataset, test_dataset, train_dataloader, test_dataloader

train_dataset, test_dataset, train_dataloader, test_dataloader = get_datasets_dataloaders(X, y, tokenizer)

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Load SRL from Pickle


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


CREATION DONE


In [19]:
def get_article_dataloader(article, tokenizer, batch_size=4):
    X = pd.Series([article])
    y = None  # No labels for this single article
    
    predictor = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/structured-prediction-srl-bert.2020.12.15.tar.gz", cuda_device=0)
    # Directly use the optimized_extract_srl function since we don't need to cache for single articles
    X_srl = optimized_extract_srl(X, predictor)
    
    # Create the dataset
    dataset = ArticleDataset(X, X_srl, tokenizer, y)
    
    # Create dataloader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)
    
    return dataloader

In [20]:
get_article_dataloader("This is a test sentence", tokenizer)

Output()

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Spacy models 'en_core_web_sm' not found.  Downloading and installing.


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Collecting en-core-web-sm==3.3.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.3.0/en_core_web_sm-3.3.0-py3-none-any.whl (12.8 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 12.8/12.8 MB 92.0 MB/s eta 0:00:00
Installing collected packages: en-core-web-sm
Successfully installed en-core-web-sm-3.3.0




[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
Processed article 1/1


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


<torch.utils.data.dataloader.DataLoader at 0x7ffa693c2e20>

In [21]:
free_gpu()

(14647689216, 16891248640)
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |       0 B  |     838 MB |     917 MB |     917 MB |
|       from large pool |       0 B  |     836 MB |     836 MB |     836 MB |
|       from small pool |       0 B  |       2 MB |      80 MB |      80 MB |
|---------------------------------------------------------------------------|
| Active memory         |       0 B  |     838 MB |     917 MB |     917 MB |
|       from large pool |       0 B  |     836 MB |     836 MB |     836 MB |
|       from small pool |       0 B  |       2 MB |      80 MB |      80 MB |
|------------------------------------

# Model

In [22]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer
import torch.nn.functional as F

In [23]:
class EmbeddingLayer(nn.Module):
    def __init__(self, bert_model_name):
        super(EmbeddingLayer, self).__init__()
        self.bert_model = BertModel.from_pretrained(bert_model_name)

    def forward(self, x, max_items):
        batch_size = x.shape[0]
        
        # List to collect embeddings for each batch
        batch_embeddings = []

        # Loop through each item in the batch
        for i in range(batch_size):
            sequence_embeddings = []

            # Loop through each sequence in the item and obtain embeddings
            for j in range(max_items):
                sequence = x[i][j].unsqueeze(0)  # Adding an extra dimension for BERT
                embeddings = self.bert_model(sequence).last_hidden_state.squeeze(0)  # Removing the extra dimension after obtaining embeddings
                sequence_embeddings.append(embeddings)
            
            # Stack embeddings for each sequence in the item
            item_embeddings = torch.stack(sequence_embeddings)
            batch_embeddings.append(item_embeddings)
        
        # Stack embeddings for each item in the batch
        reshaped_embeddings = torch.stack(batch_embeddings)
        
        return reshaped_embeddings


In [24]:
class SentenceAttentionLayer(nn.Module):
    def __init__(self, embedding_dim, heads=8):
        super(SentenceAttentionLayer, self).__init__()
        
        self.embedding_dim = embedding_dim
        self.heads = heads
        self.head_dim = embedding_dim // heads

        # Q, K, V weight matrices
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)

        self.fc_out = nn.Linear(heads * self.head_dim, embedding_dim)
        
    def forward(self, values, keys=None, queries=None):
        if keys is None:
            keys = values
        if queries is None:
            queries = values

        N = queries.shape[0]  # batch size
        value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]

        # Split the embedding_dim into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Scaled dot-product attention
        attention = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) 
        attention = attention / (self.embedding_dim ** (1/2))
        attention = F.softmax(attention, dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads*self.head_dim)
        
        # Combine heads
        out = self.fc_out(out)
        
        # Aggregate to single vector per instance
        out = out.sum(dim=1)
        return out

In [25]:
class AggregationLayer(nn.Module):
    def __init__(self, max_sentences_per_article=32, max_sentence_length=32, max_srl_items=None, 
                 max_arg_length=16, bert_model_name="bert-base-uncased"):
        super(AggregationLayer, self).__init__()
        self.embedding_layer = EmbeddingLayer(bert_model_name)
        
        embedding_dim = 768  # For bert-base-uncased
        self.attention = SentenceAttentionLayer(embedding_dim=embedding_dim)
        
        # Store the values as attributes
        self.max_sentences_per_article = max_sentences_per_article
        self.max_sentence_length = max_sentence_length
        self.max_srl_items = max_srl_items if max_srl_items is not None else max_sentences_per_article
        self.max_arg_length = max_arg_length

    def aggregate_word_embeddings(self, embeddings):
        """
        Apply the attention mechanism to the word embeddings to get the sentence embeddings.
        """
        N, num_sentences, num_words, embedding_dim = embeddings.shape
        embeddings_reshaped = embeddings.reshape(N*num_sentences, num_words, embedding_dim)
        
        sentence_embeddings = self.attention(embeddings_reshaped)
    
        # Reshape back to [N, num_sentences, embedding_dim]
        sentence_embeddings = sentence_embeddings.reshape(N, num_sentences, embedding_dim)

        # Aggregate across the num_sentences dimension to get shape [N, 768]
        aggregated_embeddings = sentence_embeddings.sum(dim=1) # TODO: Try mean
        
        return aggregated_embeddings
    
    def forward(self, sentence_ids, predicate_ids, arg0_ids, arg1_ids):
        # Get embeddings
        sentence_embeddings = self.embedding_layer(sentence_ids, self.max_sentences_per_article)
        predicate_embeddings = self.embedding_layer(predicate_ids, self.max_srl_items)
        arg0_embeddings = self.embedding_layer(arg0_ids, self.max_srl_items)
        arg1_embeddings = self.embedding_layer(arg1_ids, self.max_srl_items)
        
        # Aggregate word embeddings into sentence embeddings
        sentence_embeddings = self.aggregate_word_embeddings(sentence_embeddings)
        predicate_embeddings = self.aggregate_word_embeddings(predicate_embeddings)
        arg0_embeddings = self.aggregate_word_embeddings(arg0_embeddings)
        arg1_embeddings = self.aggregate_word_embeddings(arg1_embeddings)
        
        # Get aggregated embeddings (if you want a second level of attention on the sentence embeddings)
        #sentence_aggregated = self.attention(sentence_embeddings)
        #predicate_aggregated = self.attention(predicate_embeddings)
        #arg0_aggregated = self.attention(arg0_embeddings)
        #arg1_aggregated = self.attention(arg1_embeddings)
        
        return sentence_embeddings, predicate_embeddings, arg0_embeddings, arg1_embeddings #sentence_aggregated, predicate_aggregated, arg0_aggregated, arg1_aggregated

# FRISS

In [32]:
import torch.distributions as dist

class FRISSEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(FRISSEncoder, self).__init__()
        # Shared feed-forward layer
        self.fc_shared = nn.Linear(input_dim, hidden_dim)
        # View-specific feed-forward layers
        self.fc_latent_p = nn.Linear(hidden_dim, latent_dim)
        self.fc_latent_a0 = nn.Linear(hidden_dim, latent_dim)
        self.fc_latent_a1 = nn.Linear(hidden_dim, latent_dim)

    def gumbel_softmax(self, logits, temperature=1.0):
        gumbel_noise = dist.Gumbel(0, 1).sample(logits.size()).to(logits.device)
        y = logits + gumbel_noise
        return F.softmax(y / temperature, dim=-1)

    def forward(self, x, view):
        x = F.relu(self.fc_shared(x))
        if view == "p":
            latent = self.fc_latent_p(x)
        elif view == "a0":
            latent = self.fc_latent_a0(x)
        elif view == "a1":
            latent = self.fc_latent_a1(x)
        else:
            raise ValueError("Invalid view provided!")
        
        # Apply Gumbel-Softmax on the latent representation
        latent = self.gumbel_softmax(latent)
        return latent

class FRISSDecoder(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super(FRISSDecoder, self).__init__()
        # Using the dictionary terms (descriptors) to reconstruct
        self.dictionary = nn.Parameter(torch.randn(latent_dim, output_dim))

    def forward(self, latent):
        # Here we're mimicking the linear combination using matrix multiplication.
        # The Gumbel-Softmax sampled latent weights the dictionary terms for reconstruction.
        reconstructed = torch.matmul(latent, self.dictionary)
        return reconstructed

class AutoEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, output_dim):
        super(AutoEncoder, self).__init__()
        self.encoder = FRISSEncoder(input_dim, hidden_dim, latent_dim)
        self.decoder = FRISSDecoder(latent_dim, output_dim)

    def forward(self, x, view):
        latent = self.encoder(x, view)
        reconstructed = self.decoder(latent)
        return reconstructed

# Define the model
input_dim = 1536  # Size of the concatenated embeddings
hidden_dim = 768  # An arbitrary size for the hidden layer
latent_dim = 15   # Number of dictionary terms or frames
output_dim = 768  # Size of the original SRL span embedding (e.g., for the predicate, ARG0, or ARG1)

AutoEncoder(input_dim, hidden_dim, latent_dim, output_dim)

AutoEncoder(
  (encoder): FRISSEncoder(
    (fc_shared): Linear(in_features=1536, out_features=768, bias=True)
    (fc_latent_p): Linear(in_features=768, out_features=15, bias=True)
    (fc_latent_a0): Linear(in_features=768, out_features=15, bias=True)
    (fc_latent_a1): Linear(in_features=768, out_features=15, bias=True)
  )
  (decoder): FRISSDecoder()
)

In [27]:
class FRISS(nn.Module):
    def __init__(self, input_dim=1536, hidden_dim=768, latent_dim=14, output_dim=768, num_classes=14, embedding_dim=768):
        super(FRISS, self).__init__()
        
        self.aggregation_layer = AggregationLayer()
        
        # The unsupervised AutoEncoder part
        self.autoencoder = AutoEncoder(input_dim, hidden_dim, latent_dim, output_dim)
        
        # Span-based Classifier
        self.span_classifier = nn.Linear(latent_dim, num_classes)
        
        # Sentence-based Classifier
        self.sentence_encoder = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, num_classes)
        )
        
    def forward(self, sentence_ids, predicate_ids, arg0_ids, arg1_ids):
        sentence_aggregated, predicate_aggregated, arg0_aggregated, arg1_aggregated = self.aggregation_layer(sentence_ids, predicate_ids, arg0_ids, arg1_ids)
        
        # Concatenate aggregated sentence embeddings with the SRL embeddings
        x_p = torch.cat((sentence_aggregated, predicate_aggregated), 1)
        x_a0 = torch.cat((sentence_aggregated, arg0_aggregated), 1)
        x_a1 = torch.cat((sentence_aggregated, arg1_aggregated), 1)
        
        # Pass through AutoEncoder
        latent_p = self.autoencoder.encoder(x_p, view="p")
        latent_a0 = self.autoencoder.encoder(x_a0, view="a0")
        latent_a1 = self.autoencoder.encoder(x_a1, view="a1")
        
        # Span-based Classification
        aggregated_latent = torch.mean(torch.stack([latent_p, latent_a0, latent_a1]), 0)
        y_hat_span = self.span_classifier(aggregated_latent)
        
        # Sentence-based Classification
        y_hat_sentence = self.sentence_encoder(sentence_aggregated)
        
        return y_hat_span, y_hat_sentence


# Train Model

In [28]:
if 'model' in locals():
    model.to("cpu")
    model.cpu()
    print()
    free_gpu()


(15343484928, 17063280640)
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |       0 B  |     838 MB |     917 MB |     917 MB |
|       from large pool |       0 B  |     836 MB |     836 MB |     836 MB |
|       from small pool |       0 B  |       2 MB |      80 MB |      80 MB |
|---------------------------------------------------------------------------|
| Active memory         |       0 B  |     838 MB |     917 MB |     917 MB |
|       from large pool |       0 B  |     836 MB |     836 MB |     836 MB |
|       from small pool |       0 B  |       2 MB |      80 MB |      80 MB |
|-----------------------------------

The F1-Score (micro-averaged) and Average Precision Score are chosen as primary metrics for evaluating the multi-label classification task due to the following reasons:

1. **F1-Score (Micro)**:
    - The micro-averaged F1-score computes global counts of true positives, false negatives, and false positives. 
    - It provides a balance between precision (the number of correct positive results divided by the number of all positive results) and recall (the number of correct positive results divided by the number of positive results that should have been returned).
    - Given the imbalance in the label distribution observed in the dataset, the micro-averaged F1-score is robust against this imbalance, making it a suitable metric for optimization.

2. **Average Precision Score**:
    - This metric summarizes the precision-recall curve, giving a single value that represents the average of precision values at different recall levels.
    - It's especially valuable when class imbalances exist, as it gives more weight to the positive class (the rarer class in an imbalanced dataset).

Using these metrics will ensure that the model is optimized for a balanced performance across all labels, even if some labels are rarer than others.

In [29]:
from sklearn.metrics import f1_score, average_precision_score
from tqdm.notebook import tqdm
import numpy as np
import os
import torch
import json

def train(model, train_dataloader, test_dataloader, optimizer, alpha=0.5, num_epochs=10, device='cuda', save_path='../notebooks/classifier/'):
    loss_function = torch.nn.BCEWithLogitsLoss()
    
    # Metrics dictionary
    metrics = {
        'f1': [],
        'avg_precision': []
    }
    
    for epoch in tqdm(range(num_epochs), desc="Epochs"):
        model.train()
        total_loss = 0
        batch_progress = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc="Batches", leave=False)

        for batch_idx, batch in batch_progress:
            # Zero the gradients
            optimizer.zero_grad()

            # Move data and labels to device
            sentence_ids = batch['sentence_ids'].to(device)
            predicate_ids = batch['predicate_ids'].to(device)
            arg0_ids = batch['arg0_ids'].to(device)
            arg1_ids = batch['arg1_ids'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass
            logits_span, logits_sentence = model(sentence_ids, predicate_ids, arg0_ids, arg1_ids)

            # Compute the supervised loss
            loss_span = loss_function(logits_span, labels.float())
            loss_sentence = loss_function(logits_sentence, labels.float())
            supervised_loss = loss_span + loss_sentence

            # Combine the supervised and unsupervised losses
            total_loss = alpha * supervised_loss

            # Backward pass and optimization
            total_loss.backward()
            optimizer.step()

            # Clear GPU cache
            torch.cuda.empty_cache()

            batch_progress.set_description(f"Epoch {epoch+1} Loss: {total_loss.item():.4f}")

        print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {total_loss/len(train_dataloader)}")

        # Validation
        model.eval()
        all_preds_span = []
        all_labels = []

        with torch.no_grad():
            for batch in test_dataloader:
                # Move data and labels to device
                sentence_ids = batch['sentence_ids'].to(device)
                predicate_ids = batch['predicate_ids'].to(device)
                arg0_ids = batch['arg0_ids'].to(device)
                arg1_ids = batch['arg1_ids'].to(device)
                labels = batch['labels'].to(device)

                # Forward pass
                logits_span, _ = model(sentence_ids, predicate_ids, arg0_ids, arg1_ids)
                preds_span = (torch.sigmoid(logits_span) > 0.5).float()

                all_preds_span.append(preds_span.cpu().numpy())
                all_labels.append(labels.cpu().numpy())

                torch.cuda.empty_cache()

        all_preds_span = np.vstack(all_preds_span)
        all_labels = np.vstack(all_labels)

        # Compute metrics
        f1 = f1_score(all_labels, all_preds_span, average='micro')
        avg_precision = average_precision_score(all_labels, all_preds_span)

        metrics['f1'].append(f1)
        metrics['avg_precision'].append(avg_precision)

        print(f"Validation Metrics - F1 Score (Micro): {f1:.4f}, Average Precision: {avg_precision:.4f}")

    # Save the model
    model_save_path = os.path.join(save_path, 'model1.pth')
    torch.save(model.state_dict(), model_save_path)
    print(f"Model saved to {model_save_path}")

    # Save metrics to JSON file
    with open('../notebooks/classifier/metrics.json', 'w') as f:
        json.dump(metrics, f)

    return metrics

In [None]:
import torch.optim as optim
import json

# Initialize the FRISS model
model = FRISS()

# Assuming you're using a GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Loss function (criterion)
# You can choose a suitable loss function. For multi-class classification, CrossEntropyLoss is commonly used.
loss_function = nn.BCEWithLogitsLoss()

# Optimizer
# Using the Adam optimizer as it is widely used and works well for many tasks.
# You can adjust the learning rate and other parameters based on your needs.
optimizer = optim.Adam(model.parameters(), lr=0.001)

# now train

# Define your parameters
alpha_value = 0.5
num_epochs_value = 15
device_value = 'cuda'

# Call the train function
metrics = train(model, train_dataloader, test_dataloader, optimizer, alpha=alpha_value, num_epochs=num_epochs_value, device=device_value)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Epochs:   0%|          | 0/15 [00:00<?, ?it/s]

Batches:   0%|          | 0/87 [00:00<?, ?it/s]

Epoch 1/15, Training Loss: 0.012187257409095764
Validation Metrics - Accuracy: 0.0000, F1 Score: 0.4192, Precision: 0.4023, Recall: 0.4375, Average Precision: 0.2627


Batches:   0%|          | 0/87 [00:00<?, ?it/s]

Epoch 2/15, Training Loss: 0.00617617554962635
Validation Metrics - Accuracy: 0.0000, F1 Score: 0.4509, Precision: 0.5019, Recall: 0.4094, Average Precision: 0.2627


Batches:   0%|          | 0/87 [00:00<?, ?it/s]

Epoch 3/15, Training Loss: 0.008196800947189331
Validation Metrics - Accuracy: 0.0000, F1 Score: 0.3725, Precision: 0.5287, Recall: 0.2875, Average Precision: 0.2627


Batches:   0%|          | 0/87 [00:00<?, ?it/s]

Epoch 4/15, Training Loss: 0.005116731394082308
Validation Metrics - Accuracy: 0.0000, F1 Score: 0.3725, Precision: 0.5287, Recall: 0.2875, Average Precision: 0.2627


Batches:   0%|          | 0/87 [00:00<?, ?it/s]

Epoch 5/15, Training Loss: 0.006629019044339657
Validation Metrics - Accuracy: 0.0000, F1 Score: 0.3725, Precision: 0.5287, Recall: 0.2875, Average Precision: 0.2627


Batches:   0%|          | 0/87 [00:00<?, ?it/s]

Epoch 6/15, Training Loss: 0.004905941896140575
Validation Metrics - Accuracy: 0.0000, F1 Score: 0.3725, Precision: 0.5287, Recall: 0.2875, Average Precision: 0.2627


Batches:   0%|          | 0/87 [00:00<?, ?it/s]

Epoch 7/15, Training Loss: 0.005613383837044239
Validation Metrics - Accuracy: 0.0000, F1 Score: 0.3725, Precision: 0.5287, Recall: 0.2875, Average Precision: 0.2627


Batches:   0%|          | 0/87 [00:00<?, ?it/s]

Epoch 8/15, Training Loss: 0.005487698595970869
Validation Metrics - Accuracy: 0.0000, F1 Score: 0.3725, Precision: 0.5287, Recall: 0.2875, Average Precision: 0.2627


Batches:   0%|          | 0/87 [00:00<?, ?it/s]

Epoch 9/15, Training Loss: 0.007822412066161633
Validation Metrics - Accuracy: 0.0000, F1 Score: 0.3725, Precision: 0.5287, Recall: 0.2875, Average Precision: 0.2627


Batches:   0%|          | 0/87 [00:00<?, ?it/s]

In [None]:
def load_model_from_path(model_class, path, device='cuda'):
    """
    Loads the weights into an instance of the model class from the given path.
    
    Args:
    - model_class (torch.nn.Module): The class of the model (uninitialized).
    - path (str): Path to the saved weights.
    - device (str): Device to load the model on ('cpu' or 'cuda').
    
    Returns:
    - model (torch.nn.Module): Model with weights loaded.
    """
    model = model_class().to(device)
    model.load_state_dict(torch.load(path, map_location=device))
    model.eval()
    return model


In [None]:
model = load_model_from_path(FRISS, '../notebooks/classifier/new_model.pth')

In [None]:
def predict(model, dataloader, y_columns, device='cuda'):
    """
    Make predictions with the given model and dataloader.
    
    Args:
    - model (torch.nn.Module): The model to make predictions with.
    - dataloader (DataLoader): DataLoader for the dataset to predict on.
    - y_columns (pandas.Index): Column names from the y dataframe which correspond to labels.
    - device (str): Device to make predictions on ('cpu' or 'cuda').
    
    Returns:
    - predicted_labels (list of lists): List containing the predicted labels for each instance.
    """
    model.eval()
    all_preds_span = []
    
    with torch.no_grad():
        for batch in dataloader:
            # Move data to device
            sentence_ids = batch['sentence_ids'].to(device)
            predicate_ids = batch['predicate_ids'].to(device)
            arg0_ids = batch['arg0_ids'].to(device)
            arg1_ids = batch['arg1_ids'].to(device)

            # Forward pass
            logits_span, _ = model(sentence_ids, predicate_ids, arg0_ids, arg1_ids)
            preds_span = (torch.sigmoid(logits_span) > 0.5).float()

            all_preds_span.append(preds_span.cpu().numpy())
                
            torch.cuda.empty_cache()

    predictions = np.vstack(all_preds_span)
    
    # Convert boolean predictions to labels
    predicted_labels = []
    for pred in predictions:
        labels = list(y_columns[pred.astype(bool)])
        predicted_labels.append(labels)
    
    return predicted_labels


In [None]:
import numpy as np

# article813452859
article = """EU Profits From Trading With UK While London Loses Money – Political Campaigner

With the Parliamentary vote on British Prime Minister Theresa May’s Brexit plan set to be held next month; President of the European Commission Jean Claude Juncker has criticised the UK’s preparations for their departure from the EU.
But is there any chance that May's deal will make it through parliament and if it fails, how could this ongoing political deadlock finally come to an end?
Sputnik spoke with political campaigner Michael Swadling for more…
Sputnik: Does Theresa May have any chance of getting her deal through Parliament on the 14th January?
Michael Swadling: I guess her only chance is if Labour decides that they want to dishonour democracy and effectively keep us in the EU.
© AP Photo / Pablo Martinez Monsivais UK 'In Need of Leadership', May's Brexit Deal Unwelcome to Trump - US Ambassador
There is a chance; as unfortunately there are many MPs who don't respect the vote and may just turn on it, but short of that I don't see any way the Conservatives would vote for it, and the majority is slender as it is, as the DUP is bitterly against it, and I can't see the Lib Dems voting for it, so it will only be if there are enough, what I can describe as remoaner MPs, that the deal won't be dead in the water.
Sputnik: What could be a solution to the political chaos if the Prime Minister's deal is not approved?
Michael Swadling: The EU withdrawal act is in place; we'll leave and revert to WTO terms and that works, that's fine.
I often use the example of an iPhone to people; that's a piece of technology which is manufactured in China, uses American technology and these are two countries we deal with on WTO terms, this isn't a fantasy, stuck in a port somewhere, there isn't a massive tariff, this is the world that really exists today.
When we exit the EU on WTO terms; that will be fine for whatever trading we do with the EU, just as well as it does for our trade in China.READ MORE: UK Finance Chief Bashed for Failing to Unlock Money for No-Deal Brexit — Reports
Sputnik: Do you think that the EU needs the UK more than the UK needs the EU?
Michael Swadling: The EU makes a profit on its trade with the UK; the UK makes a loss on its trade with the EU.
They have a financial incentive to ensure that good trading relations continue far more than we do.
© REUTERS / Toby Melville UK Trade Minister Says '50-50' Chance Brexit Will Not Happen – Reports
The lifeblood and cash flow that keeps manufacturing in Europe going, comes from the city of London.
If someone in a city in Germany wants to do a deal with someone in Japan; the financial services of that are probably going through the city of London, they're not going through Frankfurt and Paris.
Views and opinions, expressed in the article are those of Michael Swadling and do not necessarily reflect those of Sputnik

"""

test_article = get_article_dataloader(article, tokenizer)
predict(model, test_article, y.columns)

In [None]:
import torch
torch.cuda.empty_cache()
free_gpu()

In [None]:
!nvidia-smi