In [1]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer

class UnifiedNLIModel(nn.Module):
    def __init__(self, model_name1='bert-base-uncased', model_name2='bert-base-uncased', num_classes=3):
        super(UnifiedNLIModel, self).__init__()

        # Load pre-trained models
        self.model1 = BertModel.from_pretrained(model_name1)
        self.model2 = BertModel.from_pretrained(model_name2)

        # Freeze the models if you don't want to train them further
        for param in self.model1.parameters():
            param.requires_grad = True
        for param in self.model2.parameters():
            param.requires_grad = True

        # Define a fully connected layer to combine the outputs
        combined_hidden_size = self.model1.config.hidden_size + self.model2.config.hidden_size
        self.fc = nn.Linear(combined_hidden_size, num_classes)

    def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2):
        # Get outputs from both models
        outputs1 = self.model1(input_ids1, attention_mask=attention_mask1)
        outputs2 = self.model2(input_ids2, attention_mask=attention_mask2)

        # Get the pooled output (CLS token) from both models
        pooled_output1 = outputs1.pooler_output
        pooled_output2 = outputs2.pooler_output

        # Concatenate the pooled outputs
        combined_output = torch.cat((pooled_output1, pooled_output2), dim=1)

        # Pass the combined output through the fully connected layer
        logits = self.fc(combined_output)

        return logits

# Example usage
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Define sample input sentences
premise = "A man is eating food."
hypothesis = "The man is having a meal."

# Tokenize the input sentences for both models
inputs1 = tokenizer(premise, hypothesis, return_tensors='pt', padding=True, truncation=True)
inputs2 = tokenizer(hypothesis, premise, return_tensors='pt', padding=True, truncation=True)

# Instantiate the model
model = UnifiedNLIModel()

# Pass the inputs through the model
logits = model(inputs1['input_ids'], inputs1['attention_mask'], inputs2['input_ids'], inputs2['attention_mask'])

# Output logits for each class (entailment, contradiction, neutral)
print("Logits:", logits)

# Get the predicted class by applying torch.argmax
predicted_class = torch.argmax(logits, dim=1).item()

# Define the mapping of index to label
label_map = {0: "entailment", 1: "contradiction", 2: "neutral"}

# Get the corresponding label
predicted_label = label_map[predicted_class]

print(f"Predicted class: {predicted_class} : {predicted_label}")

  from .autonotebook import tqdm as notebook_tqdm


Logits: tensor([[-0.0545,  0.3824, -0.2765]], grad_fn=<AddmmBackward0>)
Predicted class: 1 : contradiction
