In [16]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
import pickle
import warnings

warnings.filterwarnings("ignore")

In [18]:
class SimilarityDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length):
        self.data = dataframe.dropna() 
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        id = self.data.iloc[idx]['id']
        sentence1 = self.data.iloc[idx]['sentence1']
        sentence2 = self.data.iloc[idx]['sentence2']

        inputs = self.tokenizer.encode_plus(
            sentence1,
            sentence2,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'id': id,
            'input_ids': inputs['input_ids'].squeeze(0),
            'attention_mask': inputs['attention_mask'].squeeze(0),
        }

class BertSimilarityModel(torch.nn.Module):
    def __init__(self, bert_model_name):
        super(BertSimilarityModel, self).__init__()
        self.bert = torch.load(bert_model_name)
        self.linear = torch.nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        similarity_score = self.linear(pooled_output)
        return similarity_score.squeeze(1)

test_data = pd.read_csv('Data/sample_test.csv', sep='\t')

with open('checkpoints/model_1A.pkl', 'rb') as f:
    model = pickle.load(f)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

test_dataset = SimilarityDataset(test_data, tokenizer, max_length=128)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

model.eval()
predictions = []

with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']

        outputs = model(input_ids, attention_mask)
        predictions.append(round(outputs.item(), 3)) 

test_data['score'] = predictions

test_data = test_data[['id', 'score', 'sentence1', 'sentence2']]

test_data.to_csv('Data/sample_demo.csv', sep='\t', index=False)

print(test_data.head())

   id  score                             sentence1  \
0   1  4.923     A man with a hard hat is dancing.   
1   2  4.811      A young child is riding a horse.   
2   3  4.879  A man is feeding a mouse to a snake.   
3   4  2.047        A woman is playing the guitar.   
4   5  2.456         A woman is playing the flute.   

                                  sentence2  
0      A man wearing a hard hat is dancing.  
1                A child is riding a horse.  
2  The man is feeding a mouse to the snake.  
3                  A man is playing guitar.  
4                 A man is playing a flute.  
