In [19]:
!pip install torch transformers




In [23]:
pip install allennlp

Note: you may need to restart the kernel to use updated packages.


In [21]:
pip install allennlp_models

Note: you may need to restart the kernel to use updated packages.


In [16]:
import torch.nn as nn
from transformers import BertModel, BertTokenizer

In [17]:
class SentenceEncoder(nn.Module):
    def __init__(self, pretrained_model_name='bert-base-uncased'):
        super(SentenceEncoder, self).__init__()
        self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name)
        self.bert = BertModel.from_pretrained(pretrained_model_name)
        
    def forward(self, sentences):
        # Tokenize sentences
        inputs = self.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True, max_length=512)
        # Get BERT embeddings
        outputs = self.bert(**inputs)
        # Use the average of token embeddings as the sentence representation
        sentence_embeddings = outputs.last_hidden_state.mean(dim=1)
        return sentence_embeddings

In [18]:
def get_view_representation(sentence_embeddings, spans):
    """
    sentence_embeddings: Tensor of shape [batch_size, embedding_dim]
    spans: List of tuples indicating the start and end indices of the span in the sentence
    """
    representations = []
    for i, (start, end) in enumerate(spans):
        span_embedding = sentence_embeddings[i, start:end].mean(dim=0)
        sentence_embedding = sentence_embeddings[i].mean()
        view_representation = torch.cat([span_embedding, sentence_embedding], dim=0)
        representations.append(view_representation)
    return torch.stack(representations)


In [19]:
class MultiViewFrameRepresentation(nn.Module):
    def __init__(self, input_dim, hidden_dim, K):
        super(MultiViewFrameRepresentation, self).__init__()
        self.Wh = nn.Linear(input_dim, hidden_dim)
        self.Wz_dict = nn.ModuleDict({
            'p': nn.Linear(hidden_dim, K),
            'a0': nn.Linear(hidden_dim, K),
            'a1': nn.Linear(hidden_dim, K)
        })
        self.Fz_dict = nn.ModuleDict({
            'p': nn.Embedding(K, input_dim // 2),
            'a0': nn.Embedding(K, input_dim // 2),
            'a1': nn.Embedding(K, input_dim // 2)
        })
        
    def forward(self, view_representation, view_type):
        h = torch.relu(self.Wh(view_representation))
        lz = self.Wz_dict[view_type](h)
        return lz
