In [1]:
from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import BertTokenizer, BertModel, PretrainedConfig, PreTrainedModel, Trainer, TrainingArguments, DataCollatorWithPadding, default_data_collator, EvalPrediction
from transformers.modeling_outputs import SequenceClassifierOutput
from datasets import load_dataset

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report


In [2]:
bert = BertModel.from_pretrained('bert-base-uncased')
dataset = load_dataset('grostaco/laptops-trial')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

Found cached dataset laptops-trial (C:/Users/User/.cache/huggingface/datasets/grostaco___laptops-trial/default/0.0.0/ca5733dfc0f9290466b24cc18c4981e2c3f639aa23138bec1753f67f23cb530a)
100%|██████████| 3/3 [00:00<00:00, 82.99it/s]


In [3]:
class PointWiseFFN(nn.Module):
    def __init__(self, d_hidden: int, d_ff: int):
        super().__init__()

        self.dense1 = nn.Linear(d_hidden, d_ff)
        self.dense2 = nn.Linear(d_ff, d_hidden)

    def forward(self, h):
        h = self.dense1(h)
        h = F.relu(h)
        h = self.dense2(h)

        return h


class IntraAttentionStack(nn.Module):
    def __init__(self, d_hidden: int, num_heads: int, layers: int):
        super().__init__()

        self.layers = nn.ModuleList([
            IntraAttentionBlock(d_hidden, d_hidden, num_heads)
            for _ in range(layers)
        ])
    
    def forward(self, emb: torch.Tensor, attn_mask = None):
        for layer in self.layers:
            emb = layer(emb, attn_mask=attn_mask)

        return emb 

class IntraAttentionBlock(nn.Module):
    def __init__(self, d_hidden: int, d_ff: int, num_heads: int):
        super().__init__()
        
        self.attn = nn.MultiheadAttention(d_hidden, num_heads, dropout=.2, batch_first=True)
        self.pffn = PointWiseFFN(d_hidden, d_ff) 
    
    def forward(self, emb: torch.Tensor, attn_mask = None):
        attn_output, _ = self.attn(emb, emb, emb, need_weights=False, key_padding_mask=attn_mask)
        h = self.pffn(attn_output)

        return h
    
class GlobalAttention(nn.Module):
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.in_features = in_features 
        self.out_features = out_features

        self.weight = nn.Parameter(torch.randn(in_features, out_features))
        #self.bias = nn.Parameter(torch.randn(in_features, out_features))
    
    # TODO: implement biases
    def forward(self, h_cw: torch.Tensor, h_ap: torch.Tensor, only_weights = False):

        logits = F.tanh(h_ap @ self.weight @
                        h_cw.swapaxes(1, 2))  # + self.bias
        I_attn = F.softmax(logits, 
                           dim=-1)
        
        if only_weights:
            return I_attn
        return I_attn @ h_cw
    
    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}'.format(
            self.in_features, self.out_features
        )


def wdmc(h_cp: torch.Tensor, a_starts: tuple[int, ...], a_ends: tuple[int, ...], window_size: int):
    """_summary_

    Args:
        h_cp (torch.Tensor): _description_
        a_range (tuple[tuple[int, int], ...]): _description_
        window_size (int): _description_

    Returns:
        _type_: _description_
    """
    tensors = []

    for a_s, a_e in zip(a_starts, a_ends):
        r = []
        n = h_cp.size(1)

        if a_s - window_size/2 > 0:
            d_fwd = torch.arange(a_s - window_size/2, 0, -1)
            d_fwd = (1 - d_fwd/n)

            r.append(d_fwd)

        r.append(torch.ones(min(a_s, window_size//2) +
                            min(n - a_e - 1, window_size//2) + a_e - a_s + 1))

        if a_e + window_size/2 + 1 < n:
            d_bwd = torch.arange(1, n - a_e - window_size/2)
            d_bwd = (1 - d_bwd/n)
            r.append(d_bwd)

        tensors.append(torch.cat(r).view(-1, 1).repeat(1, 1, h_cp.size(-1)))

    return torch.cat(tensors).to(h_cp.device) * h_cp


In [13]:
class MAMNConfig(PretrainedConfig):
    model_type = 'MAMN'

    def __init__(self, num_embeddings: int = 30522, d_hidden: int = 768, num_heads: int = 4,
                 layers: int = 4, window_size: int = 8, **kwargs):
        self.num_embeddings = num_embeddings
        self.d_hidden = d_hidden 
        self.num_heads = num_heads 
        self.layers = layers 
        self.window_size = window_size

        super().__init__(**kwargs)

class MAMN(PreTrainedModel):
    config_class = MAMNConfig 

    def __init__(self, config: MAMNConfig):
        super().__init__(config)
        
        self.embedding = nn.Embedding(config.num_embeddings, config.d_hidden)
        self.intra_attn_layers1 = IntraAttentionStack(
            config.d_hidden, config.num_heads, config.layers)
        self.intra_attn_layers2 = IntraAttentionStack(
            config.d_hidden, config.num_heads, config.layers)

        self.window_size = config.window_size
        self.global_attn1 = GlobalAttention(config.d_hidden, config.d_hidden)
        self.global_attn2 = GlobalAttention(config.d_hidden, config.d_hidden)

        self.dense = nn.Linear(config.d_hidden, config.num_labels)

    def forward(self, input_ids: torch.LongTensor, 
                aspects_input_ids: torch.LongTensor,
                attention_mask: torch.LongTensor,
                aspects_attention_mask: torch.LongTensor,
                start: tuple[int, ...], 
                end: tuple[int, ...],
                labels = None,
                **kwargs):
        context_emb = self.embedding(input_ids)
        aspects_emb = self.embedding(aspects_input_ids)

        h_cp = self.intra_attn_layers1(context_emb, attn_mask=attention_mask.bool() if attention_mask is not None else None)
        h_ap = self.intra_attn_layers2(
            aspects_emb, attn_mask=aspects_attention_mask.bool() if attention_mask is not None else None)

        #print(f'h_cp: {h_cp.shape} {torch.isnan(h_cp).any()}')
        h_cw = wdmc(h_cp, start, end, self.window_size)
        
        #print(f'h_cw: {h_cw.shape} {torch.isnan(h_cw).any()}')
        g = self.global_attn1(h_cw, h_ap) #* aspects_attention_mask

        #print(f'g: {g.shape} {torch.isnan(g).any()}')

        h_cw_avg = torch.mean(h_cw, dim=1) #* attention_mask
        #print(f'h_cw_avg: {h_cw_avg.shape} {torch.isnan(h_cw_avg).any()}')

        attn_weights = self.global_attn2(h_ap, h_cw_avg, only_weights=True)
        # print(
        #     f'attn_weights: {attn_weights.shape} {torch.isnan(attn_weights).any()}')

        O = (attn_weights @ g)
        logits = F.tanh(self.dense(O))[:, 0]

        #print(logits.shape)

        loss = None 
        if labels is not None:
            criterion = nn.CrossEntropyLoss()
            loss = criterion(logits, labels)
            
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits, 
            attentions=attn_weights
        )
    
    def _init_weights(self, module: nn.Module):
        if isinstance(module, nn.Linear | GlobalAttention):
            torch.nn.init.uniform(module.weight)

In [14]:
label2id = dataset['train'].features['labels']._str2int
id2label = {v:k for k, v in label2id.items()}

config = MAMNConfig(label2id=label2id, id2label=id2label)
model = MAMN(config)
model.embedding.weight = bert.embeddings.word_embeddings.weight

In [6]:
def tokenize_aspects(aspects):
    tokenized = tokenizer(aspects, padding='max_length', truncation=True, max_length=16, return_tensors='pt')
    
    return {f'aspects_{k}': v for k, v in tokenized.items()}

def tokenize(contents):
    return tokenizer(contents, padding='max_length', truncation=True, max_length=256, return_tensors='pt')



dataset = dataset.map(tokenize, input_columns='content', batched=True)
dataset = dataset.map(tokenize_aspects, input_columns='aspect', batched=True)
dataset = dataset.filter(lambda end: end < 80, input_columns='end')


Loading cached processed dataset at C:\Users\User\.cache\huggingface\datasets\grostaco___laptops-trial\default\0.0.0\ca5733dfc0f9290466b24cc18c4981e2c3f639aa23138bec1753f67f23cb530a\cache-d8b2b0b0d98ff577.arrow
Loading cached processed dataset at C:\Users\User\.cache\huggingface\datasets\grostaco___laptops-trial\default\0.0.0\ca5733dfc0f9290466b24cc18c4981e2c3f639aa23138bec1753f67f23cb530a\cache-f9857aefe704468a.arrow
Loading cached processed dataset at C:\Users\User\.cache\huggingface\datasets\grostaco___laptops-trial\default\0.0.0\ca5733dfc0f9290466b24cc18c4981e2c3f639aa23138bec1753f67f23cb530a\cache-309cc52e379f16ca.arrow
Loading cached processed dataset at C:\Users\User\.cache\huggingface\datasets\grostaco___laptops-trial\default\0.0.0\ca5733dfc0f9290466b24cc18c4981e2c3f639aa23138bec1753f67f23cb530a\cache-0a8afeac1ec0ff91.arrow
Loading cached processed dataset at C:\Users\User\.cache\huggingface\datasets\grostaco___laptops-trial\default\0.0.0\ca5733dfc0f9290466b24cc18c4981e2c3f639a

In [15]:
samples = 3

contexts = torch.tensor(dataset['train']['input_ids'][:samples], dtype=torch.long) 
aspects = torch.tensor(dataset['train']['aspects_input_ids'][:samples], dtype=torch.long)
attention_mask = torch.tensor(dataset['train']['attention_mask'][:samples], dtype=torch.long)
aspects_attention_mask = torch.tensor(dataset['train']['aspects_attention_mask'][:samples], dtype=torch.long)
starts = dataset['train']['start'][:samples]
ends = dataset['train']['end'][:samples]
model.to('cpu').forward(contexts, aspects,
                        attention_mask, aspects_attention_mask, starts, ends).logits


tensor([[-0.0333, -0.0240, -0.0156],
        [-0.0331, -0.0241, -0.0156],
        [-0.0317, -0.0257, -0.0164]], grad_fn=<SelectBackward0>)

In [25]:
def compute_metrics(p: EvalPrediction):
    y_pred = p.predictions[0].argmax(-1)
    print(y_pred)
    y_true = p.label_ids

    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(
        y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    report = classification_report(
        y_true, y_pred, target_names=label2id.keys())

    print(report)
    return {
        'accuracy': accuracy,
        'precision': precision,
        'f1': f1,
        'recall': recall,
    }

args = TrainingArguments(
    output_dir='mamn',
    evaluation_strategy='epoch',
    save_strategy='epoch',
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    load_best_model_at_end=True,
    metric_for_best_model='f1',
    num_train_epochs=1,
    learning_rate=2e-5,
)

trainer = Trainer(
    model=model,
    args=args,
    tokenizer=tokenizer,
    train_dataset=dataset['train'].select(range(128)),
    eval_dataset=dataset['validation'].select(range(128)),
    compute_metrics=compute_metrics,
)

trainer.train()

  return torch._native_multi_head_attention(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
                                             
100%|██████████| 8/8 [00:03<00:00,  3.51it/s]

[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
              precision    recall  f1-score   support

    positive       0.00      0.00      0.00        58
    negative       0.38      1.00      0.55        48
     neutral       0.00      0.00      0.00        22

    accuracy                           0.38       128
   macro avg       0.12      0.33      0.18       128
weighted avg       0.14      0.38      0.20       128

{'eval_loss': 1.0969144105911255, 'eval_accuracy': 0.375, 'eval_precision': 0.140625, 'eval_f1': 0.20454545454545453, 'eval_recall': 0.375, 'eval_runtime': 0.9281, 'eval_samples_per_second': 137.914, 'eval_steps_per_second': 8.62, 'epoch': 1.0}


100%|██████████| 8/8 [00:13<00:00,  1.68s/it]

{'train_runtime': 13.3253, 'train_samples_per_second': 9.606, 'train_steps_per_second': 0.6, 'train_loss': 1.0981285572052002, 'epoch': 1.0}





TrainOutput(global_step=8, training_loss=1.0981285572052002, metrics={'train_runtime': 13.3253, 'train_samples_per_second': 9.606, 'train_steps_per_second': 0.6, 'train_loss': 1.0981285572052002, 'epoch': 1.0})

In [26]:
trainer.evaluate(dataset['test'].select(range(32)))

  return torch._native_multi_head_attention(
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
              precision    recall  f1-score   support

    positive       0.00      0.00      0.00        10
    negative       0.50      1.00      0.67        16
     neutral       0.00      0.00      0.00         6

    accuracy                           0.50        32
   macro avg       0.17      0.33      0.22        32
weighted avg       0.25      0.50      0.33        32



100%|██████████| 2/2 [00:00<00:00,  4.87it/s]


{'eval_loss': 1.0960354804992676,
 'eval_accuracy': 0.5,
 'eval_precision': 0.25,
 'eval_f1': 0.3333333333333333,
 'eval_recall': 0.5,
 'eval_runtime': 0.4597,
 'eval_samples_per_second': 69.608,
 'eval_steps_per_second': 4.351,
 'epoch': 1.0}