In [12]:
from transformers import BertModel, RobertaModel
import torch.nn as nn
import torch

class SRLEmbeddings(nn.Module):
    def __init__(self, model_name_or_path: str, model_type: str = "bert-base-uncased"):
        super(SRLEmbeddings, self).__init__()

        if model_type == "bert-base-uncased":
            self.model = BertModel.from_pretrained(model_name_or_path)
        elif model_type == "roberta-base":
            self.model = RobertaModel.from_pretrained(model_name_or_path)
        else:
            raise ValueError("Unsupported model_type. Choose either 'bert-base-uncased' or 'roberta-base'.")

        # Move model to CUDA if available
        if torch.cuda.is_available():
            self.model.cuda()

        self.embedding_dim = self.model.config.hidden_size

    def get_sentence_embedding(self, ids: torch.Tensor, attention_masks: torch.Tensor):
        # Assume ids and attention_masks shapes are [batch_size, num_sentences, max_sentence_length]
        batch_size, num_sentences, max_sentence_length = ids.size()

        # Flatten ids and attention_masks to 2D tensors
        ids_flat = ids.view(-1, max_sentence_length)
        attention_masks_flat = attention_masks.view(-1, max_sentence_length)

        with torch.no_grad():
            # Obtain the embeddings from the BERT model
            embeddings = self.model(input_ids=ids_flat, attention_mask=attention_masks_flat)[0]

        # Reshape back to original batch and sentence dimensions
        embeddings_reshaped = embeddings.view(batch_size, num_sentences, max_sentence_length, -1)
        
        # Calculate mean embeddings across the token dimension while ignoring padded tokens
        attention_masks_expanded = attention_masks_flat.unsqueeze(-1).expand(embeddings.size())
        embeddings_masked = embeddings * attention_masks_expanded
        sum_embeddings = torch.sum(embeddings_masked, dim=1)
        token_counts = attention_masks_flat.sum(dim=1, keepdim=True).clamp(min=1)
        embeddings_mean = sum_embeddings / token_counts
        embeddings_mean_reshaped = embeddings_mean.view(batch_size, num_sentences, -1)

        return embeddings_reshaped, embeddings_mean_reshaped

    def get_arg_embedding(self, arg_ids: torch.Tensor, sentence_ids: torch.Tensor, sentence_embeddings: torch.Tensor):
        batch_size, num_sentences, max_sentence_length = sentence_ids.shape
        _, _, num_args, max_arg_length = arg_ids.shape

        arg_embeddings = torch.zeros(batch_size, num_sentences, num_args, self.embedding_dim, device=sentence_embeddings.device)

        print(sentence_ids[0,0,2])
        print(sentence_embeddings[0,0,2])

        for batch_idx in range(batch_size):
            for sent_idx in range(num_sentences):
                for arg_idx in range(num_args):
                    for token_idx in range(max_arg_length):
                        arg_token_id = arg_ids[batch_idx, sent_idx, arg_idx, token_idx].item()
                        if arg_token_id == 0:  # Skip padding tokens
                            continue
                        match_indices = (sentence_ids[batch_idx, sent_idx] == arg_token_id).nonzero(as_tuple=False)
                        if match_indices.nelement() == 0:
                            continue
                        flat_indices = match_indices[:, 0]
                        selected_embeddings = sentence_embeddings[batch_idx, sent_idx, flat_indices]
                        avg_embedding = selected_embeddings.mean(dim=0)
                        arg_embeddings[batch_idx, sent_idx, arg_idx] = avg_embedding

        return arg_embeddings

    def forward(self, sentence_ids: torch.Tensor, sentence_attention_masks: torch.Tensor, predicate_ids: torch.Tensor, arg0_ids: torch.Tensor, arg1_ids: torch.Tensor):
        with torch.no_grad():
            sentence_embeddings, sentence_embeddings_avg = self.get_sentence_embedding(sentence_ids, sentence_attention_masks)

            predicate_embeddings = self.get_arg_embedding(predicate_ids, sentence_ids, sentence_embeddings)
            arg0_embeddings = self.get_arg_embedding(arg0_ids, sentence_ids, sentence_embeddings)
            arg1_embeddings = self.get_arg_embedding(arg1_ids, sentence_ids, sentence_embeddings)

        return sentence_embeddings_avg, predicate_embeddings, arg0_embeddings, arg1_embeddings


In [13]:
# Mock data for testing
batch_size = 2
num_sentences = 3
max_sentence_length = 8
embedding_dim = 768

# Mock input tensors
sentence_ids = torch.tensor([
    [[23, 323, 433, 213, 534, 0, 0, 0], [45, 67, 89, 0, 0, 0, 0, 0], [100, 200, 300, 400, 500, 600, 700, 800]],
    [[101, 102, 103, 0, 0, 0, 0, 0], [201, 202, 203, 204, 205, 0, 0, 0], [301, 302, 303, 304, 0, 0, 0, 0]]
])

sentence_attention_masks = torch.tensor([
    [[1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1]],
    [[1, 1, 1, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 0, 0, 0, 0]]
])

predicate_ids = torch.tensor([
    [[[433, 0, 0, 0]], [[89, 0, 0, 0]], [[500, 0, 0, 0]]],
    [[[103, 0, 0, 0]], [[202, 0, 0, 0]], [[303, 0, 0, 0]]]
])

arg0_ids = torch.tensor([
    [[[23, 0, 0, 0]], [[45, 0, 0, 0]], [[100, 0, 0, 0]]],
    [[[101, 0, 0, 0]], [[201, 0, 0, 0]], [[301, 0, 0, 0]]]
])

arg1_ids = torch.tensor([
    [[[323, 0, 0, 0]], [[67, 0, 0, 0]], [[200, 0, 0, 0]]],
    [[[102, 0, 0, 0]], [[202, 0, 0, 0]], [[302, 0, 0, 0]]]
])

# Instantiate and test the model
model = SRLEmbeddings(model_name_or_path="roberta-base", model_type="roberta-base")
sentence_embeddings_avg, predicate_embeddings, arg0_embeddings, arg1_embeddings = model(
    sentence_ids, sentence_attention_masks, predicate_ids, arg0_ids, arg1_ids
)

print("Sentence Embeddings Avg:", sentence_embeddings_avg)
print("Predicate Embeddings:", predicate_embeddings)
print("Arg0 Embeddings:", arg0_embeddings)
print("Arg1 Embeddings:", arg1_embeddings)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tensor(433)
tensor([-1.7004e-01,  1.1647e-02,  8.8073e-02,  9.3565e-02,  3.1244e-01,
        -2.7817e-01,  2.1436e-02,  1.8705e-01, -2.2464e-02, -1.5706e-01,
        -2.4647e-01, -2.5412e-02, -5.6753e-02, -3.5687e-01,  1.6351e-01,
         6.3674e-02,  3.0988e-01, -3.0804e-02,  2.3896e-01,  1.6658e-01,
         5.2035e-03,  1.3341e-01, -3.3445e-02,  2.1375e-01, -8.7511e-02,
        -1.1443e-01, -1.3420e-01, -8.0512e-03, -2.7295e-02,  2.8130e-02,
         9.9028e-03, -2.3141e-01,  6.5424e-02,  4.8875e-02, -1.3874e-01,
        -8.1838e-02,  1.8053e-01,  1.3529e-01,  1.9469e-01,  8.3232e-02,
        -4.3019e-01,  4.1315e-01, -1.7605e-01, -5.3113e-02,  5.7006e-03,
        -1.8289e-01, -1.8628e-02, -3.3845e-02, -3.1026e-02, -1.0559e-01,
         8.8791e-03,  9.8897e-02,  3.7124e-02, -1.6649e-01,  5.0435e-02,
        -1.3841e-03,  1.0986e-01, -2.8244e-02, -1.0000e-01, -2.7286e-03,
        -7.2472e-02,  9.1051e-01,  8.2347e-02, -2.7762e-02, -1.1047e-01,
         1.1745e-02, -5.1817e-02,  3.78

In [17]:
predicate_embeddings[0, 0]

tensor([[-1.7004e-01,  1.1647e-02,  8.8073e-02,  9.3565e-02,  3.1244e-01,
         -2.7817e-01,  2.1436e-02,  1.8705e-01, -2.2464e-02, -1.5706e-01,
         -2.4647e-01, -2.5412e-02, -5.6753e-02, -3.5687e-01,  1.6351e-01,
          6.3674e-02,  3.0988e-01, -3.0804e-02,  2.3896e-01,  1.6658e-01,
          5.2035e-03,  1.3341e-01, -3.3445e-02,  2.1375e-01, -8.7511e-02,
         -1.1443e-01, -1.3420e-01, -8.0512e-03, -2.7295e-02,  2.8130e-02,
          9.9028e-03, -2.3141e-01,  6.5424e-02,  4.8875e-02, -1.3874e-01,
         -8.1838e-02,  1.8053e-01,  1.3529e-01,  1.9469e-01,  8.3232e-02,
         -4.3019e-01,  4.1315e-01, -1.7605e-01, -5.3113e-02,  5.7006e-03,
         -1.8289e-01, -1.8628e-02, -3.3845e-02, -3.1026e-02, -1.0559e-01,
          8.8791e-03,  9.8897e-02,  3.7124e-02, -1.6649e-01,  5.0435e-02,
         -1.3841e-03,  1.0986e-01, -2.8244e-02, -1.0000e-01, -2.7286e-03,
         -7.2472e-02,  9.1051e-01,  8.2347e-02, -2.7762e-02, -1.1047e-01,
          1.1745e-02, -5.1817e-02,  3.