In [173]:
from __future__ import annotations

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import BertTokenizer, BertModel
from datasets import load_dataset

In [186]:
class PointWiseFFN(nn.Module):
    def __init__(self, d_hidden: int, d_ff: int):
        super().__init__()

        self.dense1 = nn.Linear(d_hidden, d_ff)
        self.dense2 = nn.Linear(d_ff, d_hidden)

    def forward(self, h):
        h = self.dense1(h)
        h = F.relu(h)
        h = self.dense2(h)

        return h


class IntraAttentionStack(nn.Module):
    def __init__(self, d_hidden: int, num_heads: int, layers: int):
        super().__init__()

        self.layers = nn.ModuleList([
            IntraAttentionBlock(d_hidden, d_hidden, num_heads)
            for _ in range(layers)
        ])
    
    def forward(self, emb: torch.Tensor):
        for layer in self.layers:
            emb = layer(emb)

        return emb 

class IntraAttentionBlock(nn.Module):
    def __init__(self, d_hidden: int, d_ff: int, num_heads: int):
        super().__init__()
        
        self.attn = nn.MultiheadAttention(d_hidden, num_heads, dropout=.2, batch_first=True)
        self.pffn = PointWiseFFN(d_hidden, d_ff) 
    
    def forward(self, emb: torch.Tensor):
        attn_output, _ = self.attn(emb, emb, emb, need_weights=False)
        h = self.pffn(attn_output)

        return h
    
class GlobalAttention(nn.Module):
    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.in_features = in_features 
        self.out_features = out_features

        self.weights = nn.Parameter(torch.randn(in_features, out_features))
        self.bias = nn.Parameter(torch.randn(out_features))
    
    # TODO: implement biases
    def forward(self, h_cw: torch.Tensor, h_ap: torch.Tensor, only_weights = False):
        if h_cw.dim() == 3:
            logits = F.tanh(h_cw @ self.weights @
                            h_ap.swapaxes(1, 2))  # + self.bias
        else:
            logits = F.tanh(h_cw @ self.weights @ h_ap.swapaxes(1, 2))  # + self.bias

        I_attn = F.softmax(logits, 
                           dim=-1)
        
        if only_weights:
            return I_attn.squeeze(1)
        return I_attn @ h_cw
    
    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}'.format(
            self.in_features, self.out_features
        )


def wdmc(h_cp: torch.Tensor, a_range: tuple[tuple[int, int], ...], window_size: int):
    tensors = []

    for a_s, a_e in a_range:
        n = h_cp.size(1)

        d = torch.arange(window_size / 2, max(a_s, n - a_e) - 1) + 1

        d_weighted = 1 - (d - (window_size/2))/n

        r_s = int(a_s - window_size / 2)
        r_e = int(n - a_e - window_size / 2 - 1)

        tensors.append(torch.cat((d_weighted[:r_s].flip(-1), torch.ones(
            window_size + a_e - a_s + 1), d_weighted[-r_e:])).view(-1, 1).repeat(1, 1, h_cp.size(-1)))
    return torch.cat(tensors) * h_cp


In [188]:
class MAMN(nn.Module):
    def __init__(self, num_embeddings: int, d_hidden: int, num_heads: int, layers: int, window_size: int,
                 num_labels: int):
        super().__init__()
        
        self.embedding = nn.Embedding(num_embeddings, d_hidden)
        self.intra_attn_layers = IntraAttentionStack(d_hidden, num_heads, layers)

        self.window_size = window_size 
        self.global_attn = GlobalAttention(d_hidden, d_hidden)

        self.dense = nn.Linear(d_hidden, num_labels)


    def forward(self, context: torch.LongTensor, 
                aspects: torch.LongTensor,
                a_ranges: tuple[tuple[int, int], ...]):
        context_emb = self.embedding(context)
        aspects_emb = self.embedding(aspects)

        h_cp = self.intra_attn_layers(context_emb)
        h_ap = self.intra_attn_layers(aspects_emb)

        h_cw = wdmc(h_cp, a_ranges, self.window_size)
        g = self.global_attn(h_cw, h_ap) 

        h_cw_avg = torch.mean(h_cw, dim=1)

        attn_weights = self.global_attn(h_cw_avg, h_ap, only_weights=True)

        O = (attn_weights @ g).squeeze(1)
        logits = F.tanh(self.dense(O))

        return F.softmax(logits, -1)


In [189]:
x = MAMN(30522, 768, 4, 2, 8, 3)(torch.arange(80).long().unsqueeze(0), torch.arange(80).long().unsqueeze(0), ((20, 20),))
x

torch.Size([1, 80, 768]) torch.Size([768, 768]) torch.Size([1, 80, 768])


tensor([[0.3216, 0.3348, 0.3436]], grad_fn=<SoftmaxBackward0>)

In [153]:
a = torch.randn(1, 80, 768)
b = torch.randn(768, 768)
c = torch.randn(1, 3, 768)

(c @ b @ a.swapaxes(1, 2)).shape

torch.Size([1, 3, 80])

In [23]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
dataset = load_dataset('grostaco/laptops-trial')


Downloading model.safetensors: 100%|██████████| 440M/440M [07:01<00:00, 1.05MB/s] 
Found cached dataset laptops-trial (C:/Users/User/.cache/huggingface/datasets/grostaco___laptops-trial/default/0.0.0/330ba984e4d7218c66e6e89063270f8c480a86a2c58afa0f854519ce925c5330)
100%|██████████| 3/3 [00:00<00:00, 749.30it/s]


In [19]:
def tokenize_aspects(aspects):
    tokenized = tokenizer(aspects)
    
    return {f'aspect_{k}': v for k, v in tokenized.items()}

dataset = dataset.map(tokenizer, input_columns='content', batched=True)
dataset = dataset.map(tokenize_aspects, input_columns='aspect', batched=True)

                                                                 

In [29]:
model.embeddings.word_embeddings

Embedding(30522, 768, padding_idx=0)