In [103]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm

In [19]:
pip install -U sentence-transformers

Collecting sentence-transformers
  Using cached sentence-transformers-2.2.1.tar.gz (84 kB)
Collecting transformers<5.0.0,>=4.6.0
  Using cached transformers-4.20.1-py3-none-any.whl (4.4 MB)
Collecting torchvision
  Downloading torchvision-0.12.0-cp39-cp39-manylinux1_x86_64.whl (21.0 MB)
[K     |████████████████████████████████| 21.0 MB 554 kB/s eta 0:00:01    |███████                         | 4.5 MB 772 kB/s eta 0:00:22     |██████████████████▋             | 12.2 MB 574 kB/s eta 0:00:16     |██████████████████████          | 14.4 MB 587 kB/s eta 0:00:12
Collecting sentencepiece
  Downloading sentencepiece-0.1.96-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[K     |████████████████████████████████| 1.2 MB 351 kB/s eta 0:00:01
[?25hCollecting huggingface-hub>=0.8.1
  Downloading huggingface_hub-0.8.1-py3-none-any.whl (101 kB)
[K     |████████████████████████████████| 101 kB 586 kB/s ta 0:00:01
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers

In [73]:
from sentence_transformers import SentenceTransformer
hugface_mdl = SentenceTransformer('sentence-transformers/all-MiniLM-L12-v1')

In [90]:
mentions_df = pd.read_csv('mentions.csv', dtype=str)
print(f"Len(mentions_df)={len(mentions_df)}")
mentions_df.head(2)



Len(mentions_df)=3915


Unnamed: 0,idx,left_context,link_title,link_text,right_context,url,mention_in_page
0,0,стил е съвременният международно признат,светски,,календар на който се основава,,Григориански календар
1,1,е съвременният международно признат светски,календар,,на който се основава и,,Григориански календар


In [94]:
mentions_df['link_repr'] = mentions_df.apply(lambda x: x['link_title'] if x['link_text']is None else x['link_title'], axis=1)
mentions_df.head(2)

Unnamed: 0,idx,left_context,link_title,link_text,right_context,url,mention_in_page,link_repr
0,0,стил е съвременният международно признат,светски,,календар на който се основава,,Григориански календар,светски
1,1,е съвременният международно признат светски,календар,,на който се основава и,,Григориански календар,календар


In [93]:
entities_df = pd.read_csv('entities.csv')
print(f"Len(entities_df)={len(entities_df)}")
entities_df.head(2)

Len(entities_df)=201


Unnamed: 0,idx,title,text,url
0,0,Григориански календар,'Григорианският календар (понякога наричан и Г...,https://bg.wikipedia.org/wiki/%D0%93%D1%80%D0%...
1,1,GNU General Public License,GNU General Public License (на български преве...,https://bg.wikipedia.org/wiki/GNU_General_Publ...


In [95]:
# Merge mentions and entities (inner merge)
merge_df = mentions_df.merge(entities_df, \
                           left_on='link_title', \
                           right_on='title', \
                           how='inner',
                           suffixes=['_mention', '_entitity'])
print('Eligible mentions: ', len(merge_df))
merge_df.head(2)

Eligible mentions:  220


Unnamed: 0,idx_mention,left_context,link_title,link_text,right_context,url_mention,mention_in_page,link_repr,idx_entitity,title,text,url_entitity
0,6,е въведен в употреба на,4 октомври,,1582 г в съответствие с,,Григориански календар,4 октомври,62,4 октомври,4 октомври е 277-ият ден в годината според гри...,https://bg.wikipedia.org/wiki/4_%D0%BE%D0%BA%D...
1,205,са следните За събитията до,4 октомври,,1582 г включително има само,,Приемане на григорианския календар,4 октомври,62,4 октомври,4 октомври е 277-ият ден в годината според гри...,https://bg.wikipedia.org/wiki/4_%D0%BE%D0%BA%D...


In [98]:
class MentionEntityDataset(Dataset):
    def __init__(self, hugface_mdl, merge_df):
        self.mention_vecs = hugface_mdl.encode(pd.array( \
            merge_df['left_context'] +' ' +\
            merge_df['link_repr'] + ' ' + \
            merge_df['right_context']))
        self.entities_vec = hugface_mdl.encode(pd.array( \
            merge_df['title'] +' ' +\
            merge_df['text']))
        
        assert(len(self.mention_vecs) == len(self.entities_vec))

    def __len__(self):
        return len(self.mention_vecs)

    def __getitem__(self, idx):
        return self.mention_vecs[idx], self.entities_vec[idx]

In [99]:
dataset = MentionEntityDataset(hugface_mdl, merge_df)
len(dataset)

220

In [26]:
class MentionToEntityNet(nn.Module):
    def __init__(self, in_size=300, out_size=300):
        super(MentionToEntityNet, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(in_size, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, out_size),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [27]:
model = MentionToEntityNet()
print(model)

MentionToEntityNet(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=300, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=300, bias=True)
  )
)


In [105]:
# Training cycle
MAX_EPOCHS = 20
BATCH_SIZE = 32
DISPLAY_STEP = 2
LEARNING_RATE = 0.01
def train(model, dataset):
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    loss_fct = torch.nn.CrossEntropyLoss() # Cross Entropy
    model.train()
    for epoch in range(1, MAX_EPOCHS+1):
        losses = []
        dataloader = DataLoader(dataset, batch_size=BATCH_SIZE,
                        shuffle=True, drop_last=False)
        for x, y in tqdm(dataloader):
            optimizer.zero_grad()
            logits = model(x)
            loss = loss_fct(logits, y)  
            loss.backward()
            loss_value = loss.item()
            losses.append(loss_value)
        
        train_loss_value = np.mean(losses)
        
        # Display logs per each DISPLAY_STEP
        if (epoch) % DISPLAY_STEP == 0:
            print("Epoch: {:04d} loss={:.9f} ".format(epoch, train_loss_value))
        
    

train(model, dataset)
print ("Optimization Finished!")

  0%|          | 0/7 [00:00<?, ?it/s]

IndexError: too many indices for tensor of dimension 2