## Classifier - Try 5

Classify articles frames using aggregated SRL and sentence embeddings

In [1]:
import os

try:
  import google.colab

  from google.colab import drive
  drive.mount('/content/drive')
  IN_COLAB = True
except:
  IN_COLAB = False

if IN_COLAB:
  os.chdir('drive/MyDrive/Git/MasterThesis/data')
else:
  os.chdir('../../data/')

labels_path = "data/en/train-labels-subtask-2.txt"
articles_path = "data/en/train-articles-subtask-2/"

In [2]:
import pandas as pd

# Read the dev-labels-subtask-2.txt file
labels_df = pd.read_csv(labels_path, sep="\t")

# Rename the columns for easier processing
labels_df.columns = ["article_id", "frames"]


labels_df.head()

Unnamed: 0,article_id,frames
0,832959523,"Morality,Security_and_defense,Policy_prescript..."
1,833039623,"Political,Crime_and_punishment,External_regula..."
2,833032367,"Political,Crime_and_punishment,Fairness_and_eq..."
3,814777937,"Political,Morality,Fairness_and_equality,Exter..."
4,821744708,"Policy_prescription_and_evaluation,Political,L..."


In [3]:
# A function to read the article text given its ID
def get_article_content(article_id):
    try:
        with open(f"{articles_path}/article{article_id}.txt", "r") as f:
            return f.read()
    except FileNotFoundError:
        return None

df = labels_df

# Apply the function to get the article content
df["content"] = df["article_id"].apply(get_article_content)

# Drop rows where content could not be found
df.dropna(subset=["content"], inplace=True)

df.head()


Unnamed: 0,article_id,frames,content
0,832959523,"Morality,Security_and_defense,Policy_prescript...",How Theresa May Botched\n\nThose were the time...
1,833039623,"Political,Crime_and_punishment,External_regula...",Robert Mueller III Rests His Case—Dems NEVER W...
2,833032367,"Political,Crime_and_punishment,Fairness_and_eq...",Robert Mueller Not Recommending Any More Indic...
3,814777937,"Political,Morality,Fairness_and_equality,Exter...",The Far Right Is Trying to Co-opt the Yellow V...
4,821744708,"Policy_prescription_and_evaluation,Political,L...",‘Special place in hell’ for those who promoted...


In [4]:
# Split the frames column into a list of frames
df["frames_list"] = df["frames"].str.split(",")

# create for each frame a new column with the frame as name and 1 if the frame is present in the article and 0 if not
for frame in df["frames_list"].explode().unique():
    df[frame] = df["frames_list"].apply(lambda x: 1 if frame in x else 0)

df.head()

Unnamed: 0,article_id,frames,content,frames_list,Morality,Security_and_defense,Policy_prescription_and_evaluation,Legality_Constitutionality_and_jurisprudence,Economic,Political,Crime_and_punishment,External_regulation_and_reputation,Public_opinion,Fairness_and_equality,Capacity_and_resources,Quality_of_life,Cultural_identity,Health_and_safety
0,832959523,"Morality,Security_and_defense,Policy_prescript...",How Theresa May Botched\n\nThose were the time...,"[Morality, Security_and_defense, Policy_prescr...",1,1,1,1,1,0,0,0,0,0,0,0,0,0
1,833039623,"Political,Crime_and_punishment,External_regula...",Robert Mueller III Rests His Case—Dems NEVER W...,"[Political, Crime_and_punishment, External_reg...",0,0,1,1,0,1,1,1,1,0,0,0,0,0
2,833032367,"Political,Crime_and_punishment,Fairness_and_eq...",Robert Mueller Not Recommending Any More Indic...,"[Political, Crime_and_punishment, Fairness_and...",0,0,0,1,0,1,1,1,0,1,0,0,0,0
3,814777937,"Political,Morality,Fairness_and_equality,Exter...",The Far Right Is Trying to Co-opt the Yellow V...,"[Political, Morality, Fairness_and_equality, E...",1,1,0,0,1,1,0,1,1,1,0,0,0,0
4,821744708,"Policy_prescription_and_evaluation,Political,L...",‘Special place in hell’ for those who promoted...,"[Policy_prescription_and_evaluation, Political...",0,0,1,1,0,1,0,1,0,0,0,0,0,0


In [5]:
X = df["content"]
y = df.drop(columns=["article_id", "frames", "frames_list", "content"])

In [6]:
X.head()

0    How Theresa May Botched\n\nThose were the time...
1    Robert Mueller III Rests His Case—Dems NEVER W...
2    Robert Mueller Not Recommending Any More Indic...
3    The Far Right Is Trying to Co-opt the Yellow V...
4    ‘Special place in hell’ for those who promoted...
Name: content, dtype: object

In [7]:
y.head()

Unnamed: 0,Morality,Security_and_defense,Policy_prescription_and_evaluation,Legality_Constitutionality_and_jurisprudence,Economic,Political,Crime_and_punishment,External_regulation_and_reputation,Public_opinion,Fairness_and_equality,Capacity_and_resources,Quality_of_life,Cultural_identity,Health_and_safety
0,1,1,1,1,1,0,0,0,0,0,0,0,0,0
1,0,0,1,1,0,1,1,1,1,0,0,0,0,0
2,0,0,0,1,0,1,1,1,0,1,0,0,0,0
3,1,1,0,0,1,1,0,1,1,1,0,0,0,0
4,0,0,1,1,0,1,0,1,0,0,0,0,0,0


In [8]:
len(X), len(y)

(432, 432)

In [9]:
# y.to_csv("../notebooks/classifier/y.csv")

### Create Dataset

In [10]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='transformers')

### Extract SRL Embeddings from articles

In [11]:
!pip install pycuda
!pip install allennlp allennlp-models

Collecting pycuda
  Downloading pycuda-2022.2.2.tar.gz (1.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m25.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting pytools>=2011.2
  Downloading pytools-2023.1.1-py2.py3-none-any.whl (70 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m70.6/70.6 kB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting appdirs>=1.4.0
  Downloading appdirs-1.4.4-py2.py3-none-any.whl (9.6 kB)
Collecting mako
  Downloading Mako-1.2.4-py3-none-any.whl (78 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.7/78.7 kB[0m [31m19.4 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: pycuda
  Building wheel for pycuda (pyproject.toml) ... [?25ldone
[?25h  Created wheel for pycu

In [12]:
import pycuda
from pycuda import compiler
import pycuda.driver as drv

drv.init()
print("%d device(s) found." % drv.Device.count())
           
for ordinal in range(drv.Device.count()):
    dev = drv.Device(ordinal)
    print (ordinal, dev.name())

1 device(s) found.
0 Quadro P5000


In [13]:
from allennlp.predictors.predictor import Predictor
from allennlp_models.structured_prediction.models import srl_bert
from nltk.tokenize import sent_tokenize
import pandas as pd

In [14]:
def batched_extract_srl_components(sentences, predictor):
    # Prepare the batched input for the predictor
    batched_input = [{'sentence': sentence} for sentence in sentences]
    batched_srl = predictor.predict_batch_json(batched_input)
    
    # Extract SRL components from the batched predictions
    results = []
    for srl in batched_srl:
        best_extracted_data = None
        second_best_extracted_data = None
        for verb_entry in srl['verbs']:
            tags = verb_entry['tags']
            arg0_indices = [i for i, tag in enumerate(tags) if tag in ['B-ARG0', 'I-ARG0']]
            arg1_indices = [i for i, tag in enumerate(tags) if tag in ['B-ARG1', 'I-ARG1']]

            if arg0_indices and arg1_indices:
                best_extracted_data = {
                    'predicate': verb_entry['verb'],
                    'ARG0': ' '.join([srl['words'][i] for i in arg0_indices]),
                    'ARG1': ' '.join([srl['words'][i] for i in arg1_indices])
                }
                break
            elif (arg0_indices or arg1_indices) and not second_best_extracted_data:
                second_best_extracted_data = {
                    'predicate': verb_entry['verb'],
                    'ARG0': ' '.join([srl['words'][i] for i in arg0_indices]) if arg0_indices else '',
                    'ARG1': ' '.join([srl['words'][i] for i in arg1_indices]) if arg1_indices else ''
                }

        if best_extracted_data:
            results.append(best_extracted_data)
        elif second_best_extracted_data:
            results.append(second_best_extracted_data)
            
    return results

def optimized_extract_srl(X, predictor, batch_size=32):
    total_articles = len(X)
    processed_articles = 0

    all_results = []

    for article in X:
        sentences = sent_tokenize(article)
        article_srls = []

        for i in range(0, len(sentences), batch_size):
            batched_sentences = sentences[i:i+batch_size]
            article_srls.extend(batched_extract_srl_components(batched_sentences, predictor))

        all_results.append(article_srls)
        processed_articles += 1
        print(f"Processed article {processed_articles}/{total_articles}")

    return pd.Series(all_results)

In [15]:
import pickle

def get_X_srl(X, recalculate=False, pickle_path="../notebooks/classifier/X_srl.pkl"):
    """
    Returns the X_srl either by loading from a pickled file or recalculating.
    """
    if recalculate or not os.path.exists(pickle_path):
        print("Recalculate SRL")
        # Load predictor
        predictor = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/structured-prediction-srl-bert.2020.12.15.tar.gz", cuda_device=0)
        X_srl = optimized_extract_srl(X, predictor)
        with open(pickle_path, 'wb') as f:
            pickle.dump(X_srl, f)
    else:
        print("Load SRL from Pickle")
        with open(pickle_path, 'rb') as f:
            X_srl = pickle.load(f)
    return X_srl

# GPU

In [16]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

def free_gpu():
    print(torch.cuda.mem_get_info())
    print(torch.cuda.memory_summary())

Using device: cuda


In [17]:
import torch
import gc

def list_gpu_tensors():
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj):
                if obj.is_cuda:
                    obj = obj.cpu()
                    obj = obj.to("cpu")
                    print(type(obj), obj.size())
        except:
            pass

        
list_gpu_tensors()



# Dataset

In [18]:
from torch.utils.data import Dataset
from transformers import BertTokenizer
import pandas as pd
import nltk

class ArticleDataset(Dataset):
    def __init__(self, X, X_srl, tokenizer, labels=None, max_sentences_per_article=32, max_sentence_length=32, max_arg_length=16):
        self.X = X
        self.X_srl = X_srl
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_sentences_per_article = max_sentences_per_article
        self.max_sentence_length = max_sentence_length
        self.max_arg_length = max_arg_length
        nltk.download('punkt')  # Download the Punkt tokenizer model for sentence splitting
        
    def __len__(self):
        return len(self.X)
    
    def _truncate_or_pad(self, lst, target_length, pad_value=0):
        """
        Truncate or pad the input list to match the target length.
        """
        if len(lst) > target_length:
            return lst[:target_length]
        else:
            return lst + [pad_value] * (target_length - len(lst))
    
    def __getitem__(self, idx):
        article = self.X.iloc[idx]
        srl = self.X_srl.iloc[idx]

        # Split the article into sentences
        sentences = nltk.sent_tokenize(article)
        sentences = sentences[:self.max_sentences_per_article]  # Limit the number of sentences

        # Tokenize and pad/truncate the sentences
        sentence_ids = [self.tokenizer.encode(sentence, add_special_tokens=True, max_length=self.max_sentence_length, truncation=True, padding='max_length') for sentence in sentences]
        while len(sentence_ids) < self.max_sentences_per_article:
            sentence_ids.append([0] * self.max_sentence_length)

        # Tokenize and pad/truncate the SRL items
        predicate_ids = [self.tokenizer.encode(predicate, add_special_tokens=True, max_length=self.max_arg_length, truncation=True, padding='max_length') for predicate in [item['predicate'] for item in srl]]
        arg0_ids = [self.tokenizer.encode(arg0, add_special_tokens=True, max_length=self.max_arg_length, truncation=True, padding='max_length') for arg0 in [item.get('arg0', '') for item in srl]]
        arg1_ids = [self.tokenizer.encode(arg1, add_special_tokens=True, max_length=self.max_arg_length, truncation=True, padding='max_length') for arg1 in [item.get('arg1', '') for item in srl]]
        
        predicate_ids = predicate_ids[:self.max_sentences_per_article]
        arg0_ids = arg0_ids[:self.max_sentences_per_article]
        arg1_ids = arg1_ids[:self.max_sentences_per_article]  
        
        while len(predicate_ids) < self.max_sentences_per_article:
            predicate_ids.append([0] * self.max_arg_length)
        while len(arg0_ids) < self.max_sentences_per_article:
            arg0_ids.append([0] * self.max_arg_length)
        while len(arg1_ids) < self.max_sentences_per_article:
            arg1_ids.append([0] * self.max_arg_length)

        data = {
            'sentence_ids': torch.tensor(sentence_ids, dtype=torch.long),
            'predicate_ids': torch.tensor(predicate_ids, dtype=torch.long),
            'arg0_ids': torch.tensor(arg0_ids, dtype=torch.long),
            'arg1_ids': torch.tensor(arg1_ids, dtype=torch.long)
        }
        
        if self.labels is not None:
            data['labels'] = self.labels.iloc[idx]
        
        return data


In [20]:
from torch.utils.data import DataLoader, random_split

# Initialize the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def custom_collate_fn(batch):
    # Extract individual lists from the batch
    sentence_ids = [item['sentence_ids'] for item in batch]
    predicate_ids = [item['predicate_ids'] for item in batch]
    arg0_ids = [item['arg0_ids'] for item in batch]
    arg1_ids = [item['arg1_ids'] for item in batch]
    
    # Pad each list
    sentence_ids = torch.nn.utils.rnn.pad_sequence(sentence_ids, batch_first=True, padding_value=0)
    predicate_ids = torch.nn.utils.rnn.pad_sequence(predicate_ids, batch_first=True, padding_value=0)
    arg0_ids = torch.nn.utils.rnn.pad_sequence(arg0_ids, batch_first=True, padding_value=0)
    arg1_ids = torch.nn.utils.rnn.pad_sequence(arg1_ids, batch_first=True, padding_value=0)

    # Conditionally extract and add labels
    output_dict = {
        'sentence_ids': sentence_ids,
        'predicate_ids': predicate_ids,
        'arg0_ids': arg0_ids,
        'arg1_ids': arg1_ids
    }
    
    if 'labels' in batch[0]:
        labels = [item['labels'] for item in batch]
        output_dict['labels'] = torch.Tensor(labels)

    return output_dict


def get_datasets_dataloaders(X, y, tokenizer, recalculate_srl=False, pickle_path="../notebooks/classifier/X_srl.pkl", batch_size=4):
    # Get X_srl
    X_srl = get_X_srl(X, recalculate=recalculate_srl, pickle_path=pickle_path)
    
    # Create the dataset
    dataset = ArticleDataset(X, X_srl, tokenizer, y)
    
    # Split into train and test sets
    train_size = int(0.80 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
    
    # Create dataloaders
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)
    
    print("CREATION DONE")
    return train_dataset, test_dataset, train_dataloader, test_dataloader

train_dataset, test_dataset, train_dataloader, test_dataloader = get_datasets_dataloaders(X, y, tokenizer, batch_size=4)

Recalculate SRL


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Processed article 1/432
Processed article 2/432
Processed article 3/432
Processed article 4/432
Processed article 5/432
Processed article 6/432
Processed article 7/432
Processed article 8/432


KeyboardInterrupt: 

In [54]:
def get_article_dataloader(article, tokenizer, batch_size=4):
    X = pd.Series([article])
    y = None  # No labels for this single article
    
    predictor = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/structured-prediction-srl-bert.2020.12.15.tar.gz", cuda_device=0)
    # Directly use the optimized_extract_srl function since we don't need to cache for single articles
    X_srl = optimized_extract_srl(X, predictor)
    
    # Create the dataset
    dataset = ArticleDataset(X, X_srl, tokenizer, y)
    
    # Create dataloader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)
    
    return dataloader

In [55]:
get_article_dataloader("This is a test sentence", tokenizer)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Processed article 1/1


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


<torch.utils.data.dataloader.DataLoader at 0x7fb6f94dfd00>

# Model

In [56]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel

### Aggregate

In [57]:
# Embedding Layer
class EmbeddingLayer(nn.Module):
    def __init__(self, bert_model_name):
        super(EmbeddingLayer, self).__init__()
        self.bert_model = BertModel.from_pretrained(bert_model_name)

    def forward(self, x):
        batch_size, num_sentences, sentence_length = x.shape
        batch_embeddings = []
        for i in range(batch_size):
            sentence_embeddings = []
            for j in range(num_sentences):
                sentence = x[i][j].unsqueeze(0)  # Get the j-th sentence for the i-th batch item
                # If using actual BERT, you'd get the embeddings as:
                outputs = self.bert_model(sentence)
                embeddings = outputs.last_hidden_state.squeeze(0)
                
                sentence_embeddings.append(embeddings)
            batch_embeddings.append(torch.stack(sentence_embeddings))
        reshaped_embeddings = torch.stack(batch_embeddings)
        return reshaped_embeddings
    
# Create mock data
batch_size = 2
max_items = 8
mock_data = torch.randint(0, 10000, (batch_size, 8, max_items))

# Pass mock data through the embedding layer
layer = EmbeddingLayer("bert-base-uncased")
embeddings_output = layer(mock_data)

embeddings_output.shape

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([2, 8, 8, 768])

In [58]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class WordAttentionLayer(nn.Module):
    def __init__(self, embedding_dim, heads=8):
        super(WordAttentionLayer, self).__init__()
        self.embedding_dim = embedding_dim
        self.heads = heads
        self.head_dim = embedding_dim // heads
        
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embedding_dim)

    def forward(self, values, keys=None, queries=None):
        if keys is None:
            keys = values
        if queries is None:
            queries = values
            
        N, _, word_len, _ = values.shape
        key_len, query_len = keys.shape[2], queries.shape[2]
        
        # Reshape for multi-head attention
        values = values.reshape(N, -1, word_len, self.heads, self.head_dim)
        keys = keys.reshape(N, -1, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, -1, query_len, self.heads, self.head_dim)
        
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)
        
        attention = torch.einsum("nsqhd,nskhd->nshqk", [queries, keys])
        attention = attention / (self.embedding_dim ** (1/2))
        attention = torch.nn.functional.softmax(attention, dim=-1)
        
        out = torch.einsum("nshql,nslhd->nsqhd", [attention, values]).reshape(N, -1, query_len, self.heads*self.head_dim)
        out = self.fc_out(out)
        out = out.mean(dim=2)  # Aggregate across words
        
        return out

batch_size = 2    
    
mock_data = torch.randn(batch_size, 32, 32, 768)
    
word_attention_layer = WordAttentionLayer(embedding_dim=768)

sentence_embeddings = word_attention_layer(mock_data)

sentence_embeddings.shape

torch.Size([2, 32, 768])

In [59]:
class AggregationLayer(nn.Module):
    def __init__(self, max_sentences_per_article=32, max_sentence_length=32, max_srl_items=None, 
                 max_arg_length=16, bert_model_name="bert-base-uncased"):
        super(AggregationLayer, self).__init__()
        self.embedding_layer = EmbeddingLayer(bert_model_name)
        embedding_dim = 768
        self.word_attention = WordAttentionLayer(embedding_dim=embedding_dim)
        self.max_sentences_per_article = max_sentences_per_article
        self.max_sentence_length = max_sentence_length
        self.max_srl_items = max_srl_items if max_srl_items is not None else max_sentences_per_article
        self.max_arg_length = max_arg_length

    def aggregate_word_embeddings(self, embeddings):
        # Using word attention to get sentence embeddings
        return self.word_attention(embeddings)
    
    def forward(self, sentence_ids, predicate_ids, arg0_ids, arg1_ids):
        sentence_embeddings = self.embedding_layer(sentence_ids)
        predicate_embeddings = self.embedding_layer(predicate_ids)
        arg0_embeddings = self.embedding_layer(arg0_ids)
        arg1_embeddings = self.embedding_layer(arg1_ids)
        
        sentence_embeddings = self.aggregate_word_embeddings(sentence_embeddings)
        predicate_embeddings = self.aggregate_word_embeddings(predicate_embeddings)
        arg0_embeddings = self.aggregate_word_embeddings(arg0_embeddings)
        arg1_embeddings = self.aggregate_word_embeddings(arg1_embeddings)
        
        return sentence_embeddings, predicate_embeddings, arg0_embeddings, arg1_embeddings
    
# Generate dummy data for the AggregationLayer
batch_size = 2
num_sentences = 12
sentence_length = 8
predicate_length = 8
arg0_length = 8
arg1_length = 8

# Dummy data for sentences, predicates, arg0, and arg1
sentence_ids = torch.randint(0, 10000, (batch_size, num_sentences, sentence_length))
predicate_ids = torch.randint(0, 10000, (batch_size, num_sentences, predicate_length))
arg0_ids = torch.randint(0, 10000, (batch_size, num_sentences, arg0_length))
arg1_ids = torch.randint(0, 10000, (batch_size, num_sentences, arg1_length))

# Rerun the AggregationLayer with the adjusted logic
aggregation_layer = AggregationLayer()

# Pass the dummy data through the AggregationLayer
sentence_emb_output, predicate_emb_output, arg0_emb_output, arg1_emb_output = aggregation_layer(sentence_ids, predicate_ids, arg0_ids, arg1_ids)

sentence_emb_output.shape, predicate_emb_output.shape, arg0_emb_output.shape, arg1_emb_output.shape

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


(torch.Size([2, 12, 768]),
 torch.Size([2, 12, 768]),
 torch.Size([2, 12, 768]),
 torch.Size([2, 12, 768]))

# Unsupervised Model

In [60]:
class Autoencoder(nn.Module):
    def __init__(self, D_w, D_h, K, identifier):
        super(Autoencoder, self).__init__()
        self.D_w = D_w
        self.D_h = D_h
        self.K = K
        
        self.feed_forward1 = nn.Linear(D_w * 2, D_h)
        self.feed_forward2 = nn.Linear(D_h, K)
        
        if identifier == 'p':
            self.F = nn.Parameter(torch.randn(K, self.D_w))
        elif identifier == 'a0':
            self.F = nn.Parameter(torch.randn(K, self.D_w))
        elif identifier == 'a1':
            self.F = nn.Parameter(torch.randn(K, self.D_w))
        else:
            raise ValueError(f"Invalid identifier: {identifier}")

    def forward(self, v, v_sentence):
        # Concatenate the embeddings
        concatenated_embedding = torch.cat((v, v_sentence), dim=-1)
        
        # Compute hidden representation
        h = torch.nn.functional.relu(self.feed_forward1(concatenated_embedding))
        #print("h:", h.shape)
        # Compute logits
        logits = self.feed_forward2(h)
        #print("logits:", logits.shape)
        # Compute descriptor weights using softmax
        d = torch.nn.functional.softmax(logits, dim=-1)
        #print("d:", d.shape)
        # Compute the reconstructed embeddings
        vhat = torch.mm(d, self.F)
        #print("vhat:", vhat.shape)
        return vhat, d, self.F

batch_size = 2
embedding_dim = 768
    
# Generating mock embeddings for article, predicate, ARG0, ARG1
article_embedding = torch.randn(batch_size, embedding_dim)
v_p = torch.randn(batch_size, embedding_dim)
v_a0 = torch.randn(batch_size, embedding_dim)
v_a1 = torch.randn(batch_size, embedding_dim)

# Testing Autoencoder
autoencoder = Autoencoder(embedding_dim, 768, 15, "p")

# Testing Autoencoder again
vhat_p, dz, F = autoencoder(v_p, article_embedding)
vhat_p.shape, dz.shape, F.shape


(torch.Size([2, 768]), torch.Size([2, 15]), torch.Size([15, 768]))

In [61]:
class FRISSLoss(nn.Module):
    def __init__(self, lambda_orthogonality, M, t):
        super(FRISSLoss, self).__init__()
        
        self.lambda_orthogonality = lambda_orthogonality
        self.M = M
        self.t = t

    def contrastive_loss(self, vhat, v, negatives):
        loss = 0
        for i in range(vhat.size(0)):  # loop over batch dimension
            for n in negatives[i]:  # loop over negative samples for the current batch entry
                loss += torch.max(torch.tensor(0.0), 1 + (vhat[i] - v[i]).pow(2).sum() - (vhat[i] - n).pow(2).sum())
        return loss / (vhat.size(0) * negatives.size(1))


    def focal_triplet_loss(self, vhat, v, d, F):
        loss = 0
        for b in range(vhat.size(0)):  # loop over batch dimension
            # Get the indices of the t smallest descriptor weights
            indices = d[b].argsort()[:int(self.t)]
            for i in indices:
                m_i = self.M * (1 - d[b, i]).pow(2)
                v_i = torch.mm(d[b].unsqueeze(0), F)[0]  # Reconstructing using descriptor matrix
                loss += torch.max(torch.tensor(0.0, device=v.device), m_i + (vhat[b] - v[b]).pow(2).sum() - (vhat[b] - v_i).pow(2).sum())
        return loss / (vhat.size(0) * self.t)

    def orthogonality_term(self, F):
        return (torch.mm(F, F.t()) - torch.eye(F.size(0)).to(F.device)).pow(2).sum()

    def forward(self, v_p, vhat_p, d_p, F_p, v_a0, vhat_a0, d_a0, F_a0, v_a1, vhat_a1, d_a1, F_a1, negatives):
        # Calculate losses for predicate
        Ju_p = self.contrastive_loss(vhat_p, v_p, negatives)
        Jt_p = self.focal_triplet_loss(vhat_p, v_p, d_p, F_p)
        
        # Calculate losses for ARG0
        Ju_a0 = self.contrastive_loss(vhat_a0, v_a0, negatives)
        Jt_a0 = self.focal_triplet_loss(vhat_a0, v_a0, d_a0, F_a0)
        
        # Calculate losses for ARG1
        Ju_a1 = self.contrastive_loss(vhat_a1, v_a1, negatives)
        Jt_a1 = self.focal_triplet_loss(vhat_a1, v_a1, d_a1, F_a1)
        
        orthogonality = self.lambda_orthogonality * (self.orthogonality_term(F_p) + self.orthogonality_term(F_a0) + self.orthogonality_term(F_a1))
        
        # Aggregate the losses
        loss = Ju_p + Jt_p + Ju_a0 + Jt_a0 + Ju_a1 + Jt_a1 + orthogonality
        
        return loss

# Mock Data Preparation
batch_size = 2
embedding_dim = 768
K = 15  # Number of frames/descriptors

# Generating mock embeddings for article, predicate, ARG0, ARG1 and their reconstructions
article_embedding = torch.randn(batch_size, embedding_dim)
v_p = torch.randn(batch_size, embedding_dim)
vhat_p = torch.randn(batch_size, embedding_dim)
v_a0 = torch.randn(batch_size, embedding_dim)
vhat_a0 = torch.randn(batch_size, embedding_dim)
v_a1 = torch.randn(batch_size, embedding_dim)
vhat_a1 = torch.randn(batch_size, embedding_dim)

# Generating mock descriptor weights and descriptor matrices for predicate, ARG0, ARG1
d_p = torch.randn(batch_size, K)
d_a0 = torch.randn(batch_size, K)
d_a1 = torch.randn(batch_size, K)
F_p = torch.randn(K, embedding_dim)
F_a0 = torch.randn(K, embedding_dim)
F_a1 = torch.randn(K, embedding_dim)

# Generating some negative samples (let's assume 5 negative samples per batch entry)
num_negatives = 5
negatives = torch.randn(batch_size, num_negatives, embedding_dim)

# Initialize loss function
lambda_orthogonality = 0.1
M = 1.0
t = 3  # Number of descriptors with smallest weights for negative samples

loss_fn = FRISSLoss(lambda_orthogonality, M, t)

# Calculate loss
loss = loss_fn(v_p, vhat_p, d_p, F_p, v_a0, vhat_a0, d_a0, F_a0, v_a1, vhat_a1, d_a1, F_a1, negatives)
print("FRiSSLoss output:", loss.item())


FRiSSLoss output: 2683819.5


In [62]:
class FRISSUnsupervised(nn.Module):
    def __init__(self, D_w, D_h, K, lambda_orthogonality, M, t):
        super(FRISSUnsupervised, self).__init__()
        
        self.loss_fn = FRISSLoss(lambda_orthogonality, M, t)

        # Separate autoencoders for "p", "a0", and "a1"
        self.autoencoder_p = Autoencoder(D_w, D_h, K, "p")
        self.autoencoder_a0 = Autoencoder(D_w, D_h, K, "a0")
        self.autoencoder_a1 = Autoencoder(D_w, D_h, K, "a1")

    def forward(self, v_p, v_a0, v_a1, v_article, negatives):
        # Get reconstructed embeddings and descriptor weights for each view using autoencoders
        vhat_p, d_p, F_p = self.autoencoder_p(v_p, v_article)
        vhat_a0, d_a0, F_a0 = self.autoencoder_a0(v_a0, v_article)
        vhat_a1, d_a1, F_a1 = self.autoencoder_a1(v_a1, v_article)
        
        
        # Compute unsupervised loss
        loss = self.loss_fn(
            v_p, vhat_p, d_p, self.autoencoder_p.F, 
            v_a0, vhat_a0, d_a0, self.autoencoder_a0.F, 
            v_a1, vhat_a1, d_a1, self.autoencoder_a1.F, 
            negatives
        )
        
        return loss

    
# Mock Data Preparation
D_h = 768
batch_size = 2
embedding_dim = 768
K = 15  # Number of frames/descriptors

# Generating mock embeddings for article, predicate, ARG0, ARG1, and their corresponding sentence embeddings
article_embedding = torch.randn(batch_size, embedding_dim)
v_p = torch.randn(batch_size, embedding_dim)
v_a0 = torch.randn(batch_size, embedding_dim)
v_a1 = torch.randn(batch_size, embedding_dim)


# Generating some negative samples (let's assume 5 negative samples per batch entry)
num_negatives = 5
negatives = torch.randn(batch_size, num_negatives, embedding_dim)

# Testing FRISSUnsupervised
unsupervised_module = FRISSUnsupervised(embedding_dim, D_h, K, lambda_orthogonality, M, t)
loss = unsupervised_module(v_p, v_a0, v_a1, article_embedding, negatives)
print("Unsupervised module loss:", loss.item())

Unsupervised module loss: 2757215.5


# Supervised

In [63]:
class FRISSSupervised(nn.Module):
    def __init__(self, D_w, K, num_frames):
        super(FRISSSupervised, self).__init__()
        
        # Semantic role classifier for predicate, ARG0, and ARG1
        self.sem_role_classifier = nn.Linear(D_w, num_frames)
        
        # Sentence classifier
        self.sentence_classifier = nn.Linear(D_w, num_frames)

    def forward(self, vhat_p, vhat_a0, vhat_a1, article_embedding):
        # Predictions from the aggregated semantic role embeddings
        pred_p = self.sem_role_classifier(vhat_p)
        pred_a0 = self.sem_role_classifier(vhat_a0)
        pred_a1 = self.sem_role_classifier(vhat_a1)
        
        # Average the predictions
        avg_pred = (pred_p + pred_a0 + pred_a1) / 3.0
        
        # Predictions from the aggregated sentence embeddings
        sent_pred = self.sentence_classifier(article_embedding)
        
        # Combine the predictions
        combined_pred = (avg_pred + sent_pred) / 2.0
        
        return combined_pred

# Mock Data Preparation

batch_size = 2
embedding_dim = 768
K = 15  # Number of frames/descriptors
num_frames = 15  # Assuming the number of frames is equal to K for simplicity

# Generating mock reconstructed embeddings for predicate, ARG0, ARG1
vhat_p = torch.randn(batch_size, embedding_dim)
vhat_a0 = torch.randn(batch_size, embedding_dim)
vhat_a1 = torch.randn(batch_size, embedding_dim)

# Generating mock embeddings for the article
article_embedding = torch.randn(batch_size, embedding_dim)

# Initialize and test the supervised module
supervised_module = FRISSSupervised(embedding_dim, K, num_frames)

# Forward pass the mock data
combined_pred = supervised_module(vhat_p, vhat_a0, vhat_a1, article_embedding)
print("Combined predictions shape:", combined_pred.shape)


Combined predictions shape: torch.Size([2, 15])


# Combined Model

In [64]:
import torch.nn as nn

import torch.nn as nn

class FRISS(nn.Module):
    def __init__(self, embedding_dim, D_h, K, lambda_orthogonality, M, t, num_frames, max_sentences_per_article=32, max_sentence_length=32, bert_model_name="bert-base-uncased"):
        super(FRISS, self).__init__()
        
        # Aggregation layer
        self.aggregation = AggregationLayer(max_sentences_per_article, max_sentence_length, bert_model_name=bert_model_name)
        
        # Autoencoder for each type (p, a0, a1)
        self.autoencoder_p = Autoencoder(embedding_dim, D_h, K, "p")
        self.autoencoder_a0 = Autoencoder(embedding_dim, D_h, K, "a0")
        self.autoencoder_a1 = Autoencoder(embedding_dim, D_h, K, "a1")
        
        # Unsupervised training module
        self.unsupervised = FRISSUnsupervised(embedding_dim, D_h, K, lambda_orthogonality, M, t)
        
        # Supervised training module
        self.supervised = FRISSSupervised(embedding_dim, K, num_frames)
        
    def forward(self, sentence_ids, predicate_ids, arg0_ids, arg1_ids, negatives=None):
        # Convert input IDs to embeddings
        v_sentence, v_p, v_a0, v_a1 = self.aggregation(sentence_ids, predicate_ids, arg0_ids, arg1_ids)
        
        # Handle multiple spans by averaging predictions
        unsupervised_losses = []
        combined_preds = []
        
        # Process each span
        for span_idx in range(v_p.size(1)):
            # Extract embeddings for the current span
            s_sentence_span = v_sentence[:, span_idx, :]
            v_p_span = v_p[:, span_idx, :]
            v_a0_span = v_a0[:, span_idx, :]
            v_a1_span = v_a1[:, span_idx, :]
            
            # Autoencoder embeddings for the current span
            vhat_p, d_p, F_p = self.autoencoder_p(v_p_span, s_sentence_span)
            vhat_a0, d_a0, F_a0 = self.autoencoder_a0(v_a0_span, s_sentence_span)
            vhat_a1, d_a1, F_a1 = self.autoencoder_a1(v_a1_span, s_sentence_span)
            
            # Unsupervised loss for the current span
            unsupervised_loss = self.unsupervised(v_p_span, v_a0_span, v_a1_span, s_sentence_span, negatives)
            unsupervised_losses.append(unsupervised_loss)
            
            # Supervised prediction for the current span
            combined_pred = self.supervised(vhat_p, vhat_a0, vhat_a1, s_sentence_span)
            combined_preds.append(combined_pred)
        
        # Average predictions and losses across all spans
        unsupervised_loss = torch.mean(torch.stack(unsupervised_losses))
        combined_pred = torch.mean(torch.stack(combined_preds), dim=0)
        
        return unsupervised_loss, combined_pred


# Set the necessary parameters
batch_size = 2
embedding_dim = 768
K = 15  # Number of frames/descriptors
num_frames = 15  # Assuming the number of frames is equal to K for simplicity
D_h = 512  # Dimension of the hidden representation
lambda_orthogonality = 0.1
M = 0.5
t = 0.1

# Define some mock token IDs data parameters
max_sentences_per_article = 5
max_sentence_length = 10

# Generating mock token IDs for predicate, ARG0, ARG1, and their corresponding sentences
# We assume a vocab size of 30522 (standard BERT vocab size) for simplicity.
vocab_size = 1000

sentence_ids = torch.randint(0, vocab_size, (batch_size, max_sentences_per_article, max_sentence_length))
predicate_ids = torch.randint(0, vocab_size, (batch_size, max_sentences_per_article, max_sentence_length))
arg0_ids = torch.randint(0, vocab_size, (batch_size, max_sentences_per_article, max_sentence_length))
arg1_ids = torch.randint(0, vocab_size, (batch_size, max_sentences_per_article, max_sentence_length))

# Generating some negative samples (let's assume 5 negative samples per batch entry)
num_negatives = 5
negatives = torch.randn(batch_size, num_negatives, embedding_dim)

# Initialize the FRISS model
friss_model = FRISS(embedding_dim, D_h, K, lambda_orthogonality, M, t, num_frames, max_sentences_per_article, max_sentence_length)

# Forward pass the mock data
unsupervised_loss, combined_pred = friss_model(sentence_ids, predicate_ids, arg0_ids, arg1_ids, negatives)
unsupervised_loss.item(), combined_pred.shape

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


(2736702.75, torch.Size([2, 15]))

# Train Model

The F1-Score (micro-averaged) and Average Precision Score are chosen as primary metrics for evaluating the multi-label classification task due to the following reasons:

1. **F1-Score (Micro)**:
    - The micro-averaged F1-score computes global counts of true positives, false negatives, and false positives. 
    - It provides a balance between precision (the number of correct positive results divided by the number of all positive results) and recall (the number of correct positive results divided by the number of positive results that should have been returned).
    - Given the imbalance in the label distribution observed in the dataset, the micro-averaged F1-score is robust against this imbalance, making it a suitable metric for optimization.

2. **Average Precision Score**:
    - This metric summarizes the precision-recall curve, giving a single value that represents the average of precision values at different recall levels.
    - It's especially valuable when class imbalances exist, as it gives more weight to the positive class (the rarer class in an imbalanced dataset).

Using these metrics will ensure that the model is optimized for a balanced performance across all labels, even if some labels are rarer than others.

In [65]:
from tqdm.notebook import tqdm

import numpy as np
from sklearn.metrics import f1_score, average_precision_score

def train(model, train_dataloader, test_dataloader, optimizer, alpha=0.5, num_epochs=10, device='cuda', save_path='../notebooks/classifier/'):
    loss_function = torch.nn.BCEWithLogitsLoss()
    
    metrics = {
        'f1': [],
        'avg_precision': []
    }
    
    def negative_sampling(embeddings, num_negatives=5):
        """
        Performs negative sampling for contrastive loss.

        Args:
        - embeddings (torch.Tensor): Tensor of shape [batch_size, num_sentences, embedding_dim]
        - num_negatives (int): Number of negative samples required

        Returns:
        - negatives (torch.Tensor): Tensor of shape [batch_size, num_negatives, embedding_dim]
        """
        batch_size, num_sentences, embedding_dim = embeddings.size()
        negatives = []

        for i in range(batch_size):
            # Randomly sample negative indices, ensuring they are different from the current index
            negative_indices = torch.randint(0, num_sentences, (num_negatives,))
            negative_samples = embeddings[i, negative_indices]
            negatives.append(negative_samples)

        return torch.stack(negatives)
    
    for epoch in tqdm(range(num_epochs), desc="Epochs"):
        model.train()
        total_loss = 0
        supervised_total_loss = 0
        unsupervised_total_loss = 0
        batch_progress = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc="Batches", leave=False)

        for batch_idx, batch in batch_progress:
            optimizer.zero_grad()

            sentence_ids = batch['sentence_ids'].to(device)
            predicate_ids = batch['predicate_ids'].to(device)
            arg0_ids = batch['arg0_ids'].to(device)
            arg1_ids = batch['arg1_ids'].to(device)
            labels = batch['labels'].to(device)
            
            v_sentence, v_p, v_a0, v_a1 = model.aggregation(sentence_ids, predicate_ids, arg0_ids, arg1_ids)
            
            negatives_p = negative_sampling(v_p)
            negatives_a0 = negative_sampling(v_a0)
            negatives_a1 = negative_sampling(v_a1)

            negatives = torch.cat((negatives_p, negatives_a0, negatives_a1), dim=1)

            unsupervised_loss, combined_pred = model(sentence_ids, predicate_ids, arg0_ids, arg1_ids, negatives)
            supervised_loss = loss_function(combined_pred, labels.float())
            combined_loss = alpha * supervised_loss + (1-alpha) * unsupervised_loss
            combined_loss.backward()
            optimizer.step()

            total_loss += combined_loss.item()
            supervised_total_loss += supervised_loss.item()
            unsupervised_total_loss += unsupervised_loss.item()

            batch_progress.set_description(f"Epoch {epoch+1} Combined Loss: {combined_loss.item():.4f}")
            
             # Log metrics to CSV
            with open(save_path + 'training_metrics.csv', 'a') as f:
                writer = csv.writer(f)
                writer.writerow([batch_idx, epoch+1, total_loss/len(train_dataloader), supervised_total_loss/len(train_dataloader), unsupervised_total_loss/len(train_dataloader)])

            # Explicitly delete tensors to free up memory
            del sentence_ids, predicate_ids, arg0_ids, arg1_ids, labels, v_sentence, v_p, v_a0, v_a1, negatives, unsupervised_loss, combined_pred
            torch.cuda.empty_cache()

        print(f"Epoch {epoch+1}/{num_epochs}, Combined Loss: {total_loss/len(train_dataloader)}, Supervised Loss: {supervised_total_loss/len(train_dataloader)}, Unsupervised Loss: {unsupervised_total_loss/len(train_dataloader)}")
        
        model.eval()
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch in test_dataloader:
                sentence_ids = batch['sentence_ids'].to(device)
                predicate_ids = batch['predicate_ids'].to(device)
                arg0_ids = batch['arg0_ids'].to(device)
                arg1_ids = batch['arg1_ids'].to(device)
                labels = batch['labels'].to(device)
                
                v_sentence, v_p, v_a0, v_a1 = model.aggregation(sentence_ids, predicate_ids, arg0_ids, arg1_ids)
                
                negatives_p = negative_sampling(v_p)
                negatives_a0 = negative_sampling(v_a0)
                negatives_a1 = negative_sampling(v_a1)

                negatives = torch.cat((negatives_p, negatives_a0, negatives_a1), dim=1)

                _, logits = model(sentence_ids, predicate_ids, arg0_ids, arg1_ids, negatives)
                preds = (torch.sigmoid(logits) > 0.5).float()

                all_preds.append(preds.cpu().numpy())
                all_labels.append(labels.cpu().numpy())

                # Explicitly delete tensors to free up memory
                del sentence_ids, predicate_ids, arg0_ids, arg1_ids, labels, v_sentence, v_p, v_a0, v_a1, negatives, logits, preds
                torch.cuda.empty_cache()

        all_preds = np.vstack(all_preds)
        all_labels = np.vstack(all_labels)

        f1 = f1_score(all_labels, all_preds, average='micro')
        avg_precision = average_precision_score(all_labels, all_preds)

        metrics['f1'].append(f1)
        metrics['avg_precision'].append(avg_precision)

        print(f"Validation Metrics - F1 Score (Micro): {f1:.4f}, Average Precision: {avg_precision:.4f}")

    model_save_path = os.path.join(save_path, 'model1.pth')
    torch.save(model.state_dict(), model_save_path)
    print(f"Model saved to {model_save_path}")
    
    with open('../notebooks/classifier/metrics.json', 'w') as f:
        json.dump(metrics, f)

    return metrics


In [66]:
import torch.optim as optim
import json

# Hyperparameters
embedding_dim = 768
D_h = int(embedding_dim / 2)
K = 5
lambda_orthogonality = 0.1
M = 0.1
t = 1.0
num_frames = 14

# Model instantiation
model = FRISS(embedding_dim, D_h, K, lambda_orthogonality, M, t, num_frames)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
loss_function = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Train the model
alpha_value = 0.5
num_epochs_value = 5
device_value = 'cuda'
metrics = train(model, train_dataloader, test_dataloader, optimizer, alpha=alpha_value, num_epochs=num_epochs_value, device=device_value)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Epochs:   0%|          | 0/5 [00:00<?, ?it/s]

Batches:   0%|          | 0/194 [00:00<?, ?it/s]

Epoch 1/5, Combined Loss: 334345.1158988402, Supervised Loss: 0.5196840799960893, Unsupervised Loss: 668689.7150451031
Validation Metrics - F1 Score (Micro): 0.3388, Average Precision: 0.2500




Batches:   0%|          | 0/194 [00:00<?, ?it/s]

Epoch 2/5, Combined Loss: 194698.326433634, Supervised Loss: 0.5081640991967978, Unsupervised Loss: 389396.14497422683
Validation Metrics - F1 Score (Micro): 0.3388, Average Precision: 0.2500




Batches:   0%|          | 0/194 [00:00<?, ?it/s]

Epoch 3/5, Combined Loss: 122032.01409471649, Supervised Loss: 0.511193690840731, Unsupervised Loss: 244063.5166720361
Validation Metrics - F1 Score (Micro): 0.3719, Average Precision: 0.2500




Batches:   0%|          | 0/194 [00:00<?, ?it/s]

Epoch 4/5, Combined Loss: 80651.81962789949, Supervised Loss: 0.5101502588422028, Unsupervised Loss: 161303.12926868556
Validation Metrics - F1 Score (Micro): 0.2222, Average Precision: 0.2500




Batches:   0%|          | 0/194 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [41]:
torch.cuda.empty_cache()

In [43]:
!nvidia-smi

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Sun Oct 22 13:52:11 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.116.04   Driver Version: 525.116.04   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Quadro RTX 5000     Off  | 00000000:00:05.0 Off |                  Off |
| 33%   31C    P8    14W / 230W |   2811MiB / 16384MiB |      0%      Default |
|                               |            

In [None]:
def load_model_from_path(model_class, path, device='cuda'):
    """
    Loads the weights into an instance of the model class from the given path.
    
    Args:
    - model_class (torch.nn.Module): The class of the model (uninitialized).
    - path (str): Path to the saved weights.
    - device (str): Device to load the model on ('cpu' or 'cuda').
    
    Returns:
    - model (torch.nn.Module): Model with weights loaded.
    """
    model = model_class().to(device)
    model.load_state_dict(torch.load(path, map_location=device))
    model.eval()
    return model


In [None]:
model = load_model_from_path(FRISS, '../notebooks/classifier/new_model.pth')

In [None]:
def predict(model, dataloader, y_columns, device='cuda'):
    """
    Make predictions with the given model and dataloader.
    
    Args:
    - model (torch.nn.Module): The model to make predictions with.
    - dataloader (DataLoader): DataLoader for the dataset to predict on.
    - y_columns (pandas.Index): Column names from the y dataframe which correspond to labels.
    - device (str): Device to make predictions on ('cpu' or 'cuda').
    
    Returns:
    - predicted_labels (list of lists): List containing the predicted labels for each instance.
    """
    model.eval()
    all_preds_span = []
    
    with torch.no_grad():
        for batch in dataloader:
            # Move data to device
            sentence_ids = batch['sentence_ids'].to(device)
            predicate_ids = batch['predicate_ids'].to(device)
            arg0_ids = batch['arg0_ids'].to(device)
            arg1_ids = batch['arg1_ids'].to(device)

            # Forward pass
            logits_span, _ = model(sentence_ids, predicate_ids, arg0_ids, arg1_ids)
            preds_span = (torch.sigmoid(logits_span) > 0.5).float()

            all_preds_span.append(preds_span.cpu().numpy())
                
            torch.cuda.empty_cache()

    predictions = np.vstack(all_preds_span)
    
    # Convert boolean predictions to labels
    predicted_labels = []
    for pred in predictions:
        labels = list(y_columns[pred.astype(bool)])
        predicted_labels.append(labels)
    
    return predicted_labels


In [None]:
import numpy as np

# article813452859
article = """EU Profits From Trading With UK While London Loses Money – Political Campaigner

With the Parliamentary vote on British Prime Minister Theresa May’s Brexit plan set to be held next month; President of the European Commission Jean Claude Juncker has criticised the UK’s preparations for their departure from the EU.
But is there any chance that May's deal will make it through parliament and if it fails, how could this ongoing political deadlock finally come to an end?
Sputnik spoke with political campaigner Michael Swadling for more…
Sputnik: Does Theresa May have any chance of getting her deal through Parliament on the 14th January?
Michael Swadling: I guess her only chance is if Labour decides that they want to dishonour democracy and effectively keep us in the EU.
© AP Photo / Pablo Martinez Monsivais UK 'In Need of Leadership', May's Brexit Deal Unwelcome to Trump - US Ambassador
There is a chance; as unfortunately there are many MPs who don't respect the vote and may just turn on it, but short of that I don't see any way the Conservatives would vote for it, and the majority is slender as it is, as the DUP is bitterly against it, and I can't see the Lib Dems voting for it, so it will only be if there are enough, what I can describe as remoaner MPs, that the deal won't be dead in the water.
Sputnik: What could be a solution to the political chaos if the Prime Minister's deal is not approved?
Michael Swadling: The EU withdrawal act is in place; we'll leave and revert to WTO terms and that works, that's fine.
I often use the example of an iPhone to people; that's a piece of technology which is manufactured in China, uses American technology and these are two countries we deal with on WTO terms, this isn't a fantasy, stuck in a port somewhere, there isn't a massive tariff, this is the world that really exists today.
When we exit the EU on WTO terms; that will be fine for whatever trading we do with the EU, just as well as it does for our trade in China.READ MORE: UK Finance Chief Bashed for Failing to Unlock Money for No-Deal Brexit — Reports
Sputnik: Do you think that the EU needs the UK more than the UK needs the EU?
Michael Swadling: The EU makes a profit on its trade with the UK; the UK makes a loss on its trade with the EU.
They have a financial incentive to ensure that good trading relations continue far more than we do.
© REUTERS / Toby Melville UK Trade Minister Says '50-50' Chance Brexit Will Not Happen – Reports
The lifeblood and cash flow that keeps manufacturing in Europe going, comes from the city of London.
If someone in a city in Germany wants to do a deal with someone in Japan; the financial services of that are probably going through the city of London, they're not going through Frankfurt and Paris.
Views and opinions, expressed in the article are those of Michael Swadling and do not necessarily reflect those of Sputnik

"""

test_article = get_article_dataloader(article, tokenizer)
predict(model, test_article, y.columns)

In [None]:
import torch
torch.cuda.empty_cache()
free_gpu()

In [None]:
!nvidia-smi