In [1]:
import torch
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torch.utils.data._utils.collate import default_convert
import numpy as np
import pandas as pd
from itertools import chain
from tqdm.notebook import tqdm

from collections import namedtuple

from typing import Dict, Set, Callable, List, Union, Iterable

In [34]:
!wget 'https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip' -nc -O fb15k237.zip
!unzip -n fb15k237.zip

--2021-09-17 14:30:06--  https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip
Распознаётся download.microsoft.com (download.microsoft.com)… 104.73.92.109
Подключение к download.microsoft.com (download.microsoft.com)|104.73.92.109|:443... соединение установлено.
HTTP-запрос отправлен. Ожидание ответа… 200 OK
Длина: 146221215 (139M) [application/octet-stream]
Сохранение в: «fb15k237.zip»


2021-09-17 14:30:19 (10,6 MB/s) - «fb15k237.zip» сохранён [146221215/146221215]

Archive:  fb15k237.zip
  inflating: Release/MSR-LA_Data_Full Rights_FB15K-237 Knowledge Base Completion Dataset (2650).docx  
  inflating: Release/README.txt      
  inflating: Release/test.txt        
  inflating: Release/text_cvsc.txt   
  inflating: Release/text_emnlp.txt  
  inflating: Release/train.txt       
  inflating: Release/valid.txt       


In [2]:
Entity = str
Relashionship = str
Item = Union[Entity, Relashionship]

In [3]:
Link = namedtuple('Link', ['head', 'rel', 'tail']) # (Entity, Relashionship, Entity)
Dissimilarity = Callable[[Tensor, Tensor], torch.Tensor] # shape=[k], [k] -> []

In [4]:
def normalize(t: Tensor) -> None:
    t /= torch.norm(t)

In [5]:
def L2_dissimilarity(a: Tensor, b: Tensor) -> Tensor: # shape=[k], [k] -> []
    return torch.norm(a - b)

In [6]:
class FB15KDataset(Dataset):
    
    data: pd.DataFrame
    entities: List[Entity]
    relationships: List[Relashionship]
    
    def __init__(self, file: str):
        self.data = pd.read_csv(file, sep='\t', names=['Head', 'Rel', 'Tail'])
        self.entities = list((pd.concat([self.data['Head'], self.data['Tail']])).unique())
        self.relationships = list(self.data['Rel'].unique())
        
    def __len__(self) -> int:
        return len(self.data)
    
    def __getitem__(self, idx: int) -> Link:
        if idx >= len(self):
            raise StopIteration
        ret = Link(*self.data.loc[idx])
        return ret

In [7]:
class TransE:
    
    data: FB15KDataset
    dissimilarity: Dissimilarity
    
    ent_embeddings: Dict[Entity, Tensor] # shape=[k]
    rel_embeddings: Dict[Relashionship, Tensor] # shape=[k]
        
    # For unknown entities/rels
    zero_embedding: Tensor # shape=[k]
    
    @staticmethod
    def _initial_embedding(k: int) -> None:
        krt = np.sqrt(k)
        ret = torch.rand(k) * (12/krt) - (6/krt) # uniform(-6/sqrt(k) : 6/sqrt(k))
        normalize(ret)
        ret.requires_grad_(True)
        return ret
    
    def _init_state(self, k: int) -> None:
        self.ent_embeddings = { e: self._initial_embedding(k) for e in self.data.entities }
        self.rel_embeddings = { l: self._initial_embedding(k) for l in self.data.relationships }
        self.zero_embedding = torch.zeros(k)
    
    def __init__(self, k: int, data: FB15KDataset, dissimilarity: Dissimilarity=L2_dissimilarity):
        self.data = data
        self.dissimilarity = dissimilarity
        self._init_state(k)
        
    def corrupt(self, x: Link) -> Link:
        """Corrupt the link by replacing either head or tail with a random entity"""
        # Uses np.random instead of torch, because torch doesn't have equivalent of .choice
        (head, rel, tail) = x
        new_e = np.random.choice(self.data.entities)
        if np.random.rand() < 0.5:
            return new_e, rel, tail
        else:
            return head, rel, new_e
        
    def link_dissimilarity(self, x: Link) -> Tensor: # shape=[]
        """Compute dissimilarity of the link under current embedding"""
        head, rel, tail = x
        return self.dissimilarity(
            self.ent_embeddings.get(head, self.zero_embedding) +
            self.rel_embeddings.get(rel, self.zero_embedding),
            self.ent_embeddings.get(tail, self.zero_embedding)
        )
    
    def _element_loss(self, x: Link, corrupt_x: Link, margin: float) -> Tensor: # shape=[]
        return torch.clamp(
            margin + 
            self.link_dissimilarity(x) - 
            self.link_dissimilarity(corrupt_x),
            min=0
        )
    
    def _batch_loss(self, batch: List[Link], margin: float) -> Tensor: # shape=[]
        with torch.no_grad():
            for e in self.ent_embeddings.values():
                normalize(e)
        losses = [
            self._element_loss(trp, self.corrupt(trp), margin)
            for trp in batch
        ]
        return torch.sum(torch.stack(losses))
    
    def fit(self, epoch: int, batch_size: int, margin: float, lr: float) -> None:
        optim = torch.optim.SGD(
            chain(
                self.ent_embeddings.values(),
                self.rel_embeddings.values()
            ),
            lr
        )
        dl = DataLoader(
            self.data,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=default_convert # default_collate transposes the result
        )
        for i in range(epoch):
            for batch in tqdm(dl, total=len(self.data)//batch_size + 1, desc=f"Epoch {i+1}/{epoch}"):
                optim.zero_grad()
                self._batch_loss(batch, margin).backward()
                optim.step()
    
    def predict_link(self, target: Iterable[Item], construct: Callable[[Item], Link]) -> List[Item]:
        """
        Rank elements of `target` (usually self.data.entities or self.data.relationships)
        construct should take element of `target` and return a link corresponding to it
        Constructed links are ranked by dissimilarity, and their elements are returned
        """
        ret = []
        for t in target:
            diss = self.link_dissimilarity(construct(t)).item()
            ret.append((t, diss))
        return list(map(lambda x: x[0], sorted(ret, key=lambda x: x[1])))
    
    def predict_rel(self, head: Entity, tail: Entity) -> List[Relashionship]:
        return self.predict_link(self.data.relationships, lambda rel: Link(head, rel, tail))

    def predict_head(self, rel: Relashionship, tail: Entity) -> List[Entity]:
        return self.predict_link(self.data.entities, lambda head: Link(head, rel, tail))

    def predict_tail(self, head: Entity, rel: Relashionship) -> List[Entity]:
        return self.predict_link(self.data.entities, lambda tail: Link(head, rel, tail))
    
    @staticmethod
    def _rank(l: List[Item], i: Item) -> int:
        try:
            return l.index(i) + 1
        except ValueError:
            return len(l) + 1
    
    def rank_rel(self, x: Link) -> int:
        (head, rel, tail) = x
        return self._rank(self.predict_rel(head, tail), rel)

    def rank_head(self, x: Link) -> int:
        (head, rel, tail) = x
        return self._rank(self.predict_head(rel, tail), head)

    def rank_tail(self, x: Link) -> int:
        (head, rel, tail) = x
        return self._rank(self.predict_tail(head, rel), tail)

In [35]:
fb15 = FB15KDataset('Release/train.txt')

In [11]:
model = TransE(20, fb15)

In [12]:
model.fit(100, 512, 2, 0.01)

HBox(children=(FloatProgress(value=0.0, description='Epoch 1/100', max=532.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Epoch 2/100', max=532.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Epoch 3/100', max=532.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Epoch 4/100', max=532.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Epoch 5/100', max=532.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Epoch 6/100', max=532.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Epoch 7/100', max=532.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Epoch 8/100', max=532.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Epoch 9/100', max=532.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Epoch 10/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 11/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 12/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 13/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 14/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 15/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 16/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 17/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 18/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 19/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 20/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 21/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 22/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 23/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 24/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 25/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 26/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 27/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 28/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 29/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 30/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 31/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 32/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 33/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 34/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 35/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 36/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 37/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 38/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 39/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 40/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 41/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 42/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 43/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 44/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 45/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 46/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 47/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 48/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 49/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 50/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 51/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 52/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 53/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 54/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 55/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 56/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 57/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 58/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 59/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 60/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 61/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 62/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 63/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 64/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 65/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 66/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 67/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 68/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 69/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 70/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 71/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 72/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 73/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 74/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 75/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 76/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 77/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 78/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 79/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 80/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 81/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 82/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 83/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 84/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 85/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 86/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 87/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 88/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 89/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 90/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 91/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 92/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 93/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 94/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 95/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 96/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 97/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 98/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 99/100', max=532.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Epoch 100/100', max=532.0, style=ProgressStyle(descriptio…




In [13]:
model.rank_rel(fb15[0])

1

In [36]:
fb15_valid = FB15KDataset('Release/valid.txt')

In [15]:
head_mean_rank = np.mean([model.rank_head(link) for link in tqdm(fb15_valid)])

HBox(children=(FloatProgress(value=0.0, max=17535.0), HTML(value='')))




In [16]:
tail_mean_rank = np.mean([model.rank_tail(link) for link in tqdm(fb15_valid)])

HBox(children=(FloatProgress(value=0.0, max=17535.0), HTML(value='')))




In [167]:
tail_mean_rank

1008.6183062446536

In [17]:
print(f"Mean rank: {np.mean((head_mean_rank, tail_mean_rank))}")

Mean rank: 679.5085258055318


In [159]:
for i in fb15_valid:
    if i is None:
        print(i)

In [144]:
fb15_valid.data.tail()

Unnamed: 0,Head,Rel,Tail
17530,/m/02x4x18,/award/award_category/nominees./award/award_no...,/m/0dgst_d
17531,/m/0bw20,/film/film/other_crew./film/film_crew_gig/film...,/m/09vw2b7
17532,/m/01j4ls,/common/topic/webpage./common/webpage/category,/m/08mbj5d
17533,/m/0cmdwwg,/film/film/release_date_s./film/film_regional_...,/m/06t2t
17534,/m/0gs6vr,/film/actor/film./film/performance/film,/m/0gj96ln
