In [1]:
import pandas as pd

In [2]:
df = pd.read_csv('siamese_synonyms.csv')
df.drop(['Unnamed: 0'], axis=1, inplace = True)

In [3]:
import fasttext
path = 'ft_native_300_ru_wiki_lenta_lemmatize.bin'
fasttext_model = fasttext.load_model(path)



In [4]:
df['Word_1'] = df['Word_1'].apply(lambda x: fasttext_model[x])
df['Word_2'] = df['Word_2'].apply(lambda x: fasttext_model[x])

In [5]:
import torch
from torch.utils.data import Dataset, DataLoader

class SynonimsDataset(Dataset):
    def __init__(self, data_file, embeddings = None):
        self.data = data_file #FIX
        self.embeddings = embeddings

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        word_1, word_2, is_synonims = self.data.iloc[idx]
        if self.embeddings:
            word_1 = self.embeddings[word_1]
            word_2 = self.embeddings[word_2]
        is_synonims = 1 if is_synonims else -1 
        return word_1, word_2, is_synonims

In [6]:
from sklearn.model_selection import train_test_split
train, test = train_test_split(df)
train, val = train_test_split(train)
BATCH_SIZE = 128

train_dataloader = DataLoader(SynonimsDataset(train), batch_size=BATCH_SIZE, shuffle=False)
val_dataloader = DataLoader(SynonimsDataset(val), batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(SynonimsDataset(test), batch_size=BATCH_SIZE, shuffle=False)

In [7]:
from model import BaseSiamese
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

DEVICE = torch.device('cuda')

model = BaseSiamese(300)
model.to(DEVICE)

loss_fn = torch.nn.CosineEmbeddingLoss()
optimizer = torch.optim.Adam(params = model.parameters(), lr=1e-3, weight_decay=1e-5)

def train(device, model, num_epochs, train_dataloader, val_dataloader, loss_fn=loss_fn, optimizer=optimizer):
    model.train()
    for epoch in range(num_epochs):
        tr_loss = 0
        for batch in tqdm(train_dataloader):
            word_1, word_2, label = tuple(t.to(device) for t in batch)
            word_1_processed = model(word_1)
            word_2_processed = model(word_2)
            loss = loss_fn(word_1_processed, word_2_processed, label)
            
            optimizer.zero_grad()
            loss.backward()
            tr_loss += loss.item()

            optimizer.step()
        
        val_loss, val_acc = evaluate(device, model, val_dataloader)
        print(f"epoch {epoch}, loss: {tr_loss/len(train_dataloader)}") #fix
        print(f"valid loss: {val_loss}, valid accuracy {val_acc}")

def evaluate(device, model, val_dataloader, loss_fn=loss_fn):
    model.eval()
    eval_loss = 0
    cos = torch.nn.CosineSimilarity(dim=1)
    predicted_labels = []
    correct_labels = []
    for batch in tqdm(val_dataloader):
        word_1, word_2, label = tuple(t.to(device) for t in batch)
        word_1_processed = model(word_1)
        word_2_processed = model(word_2)
        loss = loss_fn(word_1_processed, word_2_processed, label)
        eval_loss += loss.item()
        sim = cos(word_1_processed, word_2_processed)
        predicted_labels += list(sim.cpu() > 0.5)
        correct_labels += list(label.cpu() == 1)
    # return predicted_labels, correct_labels
        
    return eval_loss/len(val_dataloader), accuracy_score(predicted_labels, correct_labels) #fix

In [8]:
train(DEVICE, model, 10, train_dataloader, val_dataloader, loss_fn, optimizer)

  0%|          | 0/2438 [00:00<?, ?it/s]

  0%|          | 0/813 [00:00<?, ?it/s]

epoch 0, loss: 0.25185492804688053
valid loss: 0.2390517454965528, valid accuracy 0.821225235622235


  0%|          | 0/2438 [00:00<?, ?it/s]

  0%|          | 0/813 [00:00<?, ?it/s]

epoch 1, loss: 0.23927350608758793
valid loss: 0.23770079161481517, valid accuracy 0.8223600692440854


  0%|          | 0/2438 [00:00<?, ?it/s]

  0%|          | 0/813 [00:00<?, ?it/s]

epoch 2, loss: 0.2384187940071116
valid loss: 0.23723934253570045, valid accuracy 0.8225427966916715


  0%|          | 0/2438 [00:00<?, ?it/s]

  0%|          | 0/813 [00:00<?, ?it/s]

epoch 3, loss: 0.23795537847392964
valid loss: 0.23701629588818052, valid accuracy 0.8226678207347566


  0%|          | 0/2438 [00:00<?, ?it/s]

  0%|          | 0/813 [00:00<?, ?it/s]

epoch 4, loss: 0.2376587052903672
valid loss: 0.23680795837299118, valid accuracy 0.8230909790344297


  0%|          | 0/2438 [00:00<?, ?it/s]

  0%|          | 0/813 [00:00<?, ?it/s]

epoch 5, loss: 0.23742746231104137
valid loss: 0.23666376548133067, valid accuracy 0.8234564339296019


  0%|          | 0/2438 [00:00<?, ?it/s]

  0%|          | 0/813 [00:00<?, ?it/s]

epoch 6, loss: 0.23727606615881336
valid loss: 0.23662392368029259, valid accuracy 0.8231102135025966


  0%|          | 0/2438 [00:00<?, ?it/s]

  0%|          | 0/813 [00:00<?, ?it/s]

epoch 7, loss: 0.23716251541714475
valid loss: 0.23648477759015224, valid accuracy 0.8234756683977688


  0%|          | 0/2438 [00:00<?, ?it/s]

  0%|          | 0/813 [00:00<?, ?it/s]

epoch 8, loss: 0.23707268909047724
valid loss: 0.23641458485472538, valid accuracy 0.8237257164839392


  0%|          | 0/2438 [00:00<?, ?it/s]

  0%|          | 0/813 [00:00<?, ?it/s]

epoch 9, loss: 0.2370064547063876
valid loss: 0.23637510955260396, valid accuracy 0.8235141373341027


In [10]:
_, accuracy = evaluate(DEVICE, model, test_dataloader, loss_fn)
accuracy

  0%|          | 0/1084 [00:00<?, ?it/s]

0.8236295441431044