# Create data for Baseline Run

## Download the data 
-> Code from https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/dbp15k.html#DBP15K## 

In [None]:
import os
import os.path as osp
import shutil
from typing import Callable, Dict, List, Optional, Tuple

import torch
from torch import Tensor

from torch_geometric.data import (
    Data,
    InMemoryDataset,
    download_url,
    extract_zip,
)
from torch_geometric.io import read_txt_array
from torch_geometric.utils import sort_edge_index


[docs]class DBP15K(InMemoryDataset):
    r"""The DBP15K dataset from the
    `"Cross-lingual Entity Alignment via Joint Attribute-Preserving Embedding"
    <https://arxiv.org/abs/1708.05045>`_ paper, where Chinese, Japanese and
    French versions of DBpedia were linked to its English version.
    Node features are given by pre-trained and aligned monolingual word
    embeddings from the `"Cross-lingual Knowledge Graph Alignment via Graph
    Matching Neural Network" <https://arxiv.org/abs/1905.11605>`_ paper.

    Args:
        root (string): Root directory where the dataset should be saved.
        pair (string): The pair of languages (:obj:`"en_zh"`, :obj:`"en_fr"`,
            :obj:`"en_ja"`, :obj:`"zh_en"`, :obj:`"fr_en"`, :obj:`"ja_en"`).
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
    """
    url = 'https://docs.google.com/uc?export=download&id={}&confirm=t'
    file_id = '1ggYlYf2_kTyi7oF9g07oTNn3VDhjl7so'

    def __init__(self, root: str, pair: str,
                 transform: Optional[Callable] = None,
                 pre_transform: Optional[Callable] = None):
        assert pair in ['en_zh', 'en_fr', 'en_ja', 'zh_en', 'fr_en', 'ja_en']
        self.pair = pair
        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self) -> List[str]:
        return ['en_zh', 'en_fr', 'en_ja', 'zh_en', 'fr_en', 'ja_en']

    @property
    def processed_file_names(self) -> str:
        return f'{self.pair}.pt'

    def download(self):
        path = download_url(self.url.format(self.file_id), self.root)
        extract_zip(path, self.root)
        os.unlink(path)
        shutil.rmtree(self.raw_dir)
        os.rename(osp.join(self.root, 'DBP15K'), self.raw_dir)

    def process(self):
        embs = {}
        with open(osp.join(self.raw_dir, 'sub.glove.300d'), 'r') as f:
            for i, line in enumerate(f):
                info = line.strip().split(' ')
                if len(info) > 300:
                    embs[info[0]] = torch.tensor([float(x) for x in info[1:]])
                else:
                    embs['**UNK**'] = torch.tensor([float(x) for x in info])

        g1_path = osp.join(self.raw_dir, self.pair, 'triples_1')
        x1_path = osp.join(self.raw_dir, self.pair, 'id_features_1')
        g2_path = osp.join(self.raw_dir, self.pair, 'triples_2')
        x2_path = osp.join(self.raw_dir, self.pair, 'id_features_2')

        x1, edge_index1, rel1, assoc1 = self.process_graph(
            g1_path, x1_path, embs)
        x2, edge_index2, rel2, assoc2 = self.process_graph(
            g2_path, x2_path, embs)

        train_path = osp.join(self.raw_dir, self.pair, 'train.examples.20')
        train_y = self.process_y(train_path, assoc1, assoc2)

        test_path = osp.join(self.raw_dir, self.pair, 'test.examples.1000')
        test_y = self.process_y(test_path, assoc1, assoc2)

        data = Data(x1=x1, edge_index1=edge_index1, rel1=rel1, x2=x2,
                    edge_index2=edge_index2, rel2=rel2, train_y=train_y,
                    test_y=test_y)
        torch.save(self.collate([data]), self.processed_paths[0])

    def process_graph(
        self,
        triple_path: str,
        feature_path: str,
        embeddings: Dict[str, Tensor],
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:

        g1 = read_txt_array(triple_path, sep='\t', dtype=torch.long)
        subj, rel, obj = g1.t()

        x_dict = {}
        with open(feature_path, 'r') as f:
            for line in f:
                info = line.strip().split('\t')
                info = info if len(info) == 2 else info + ['**UNK**']
                seq = info[1].lower().split()
                hs = [embeddings.get(w, embeddings['**UNK**']) for w in seq]
                x_dict[int(info[0])] = torch.stack(hs, dim=0)

        idx = torch.tensor(list(x_dict.keys()))
        assoc = torch.full((idx.max().item() + 1, ), -1, dtype=torch.long)
        assoc[idx] = torch.arange(idx.size(0))

        subj, obj = assoc[subj], assoc[obj]
        edge_index = torch.stack([subj, obj], dim=0)
        edge_index, rel = sort_edge_index(edge_index, rel)

        xs = [None for _ in range(idx.size(0))]
        for i in x_dict.keys():
            xs[assoc[i]] = x_dict[i]
        x = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True)

        return x, edge_index, rel, assoc

    def process_y(self, path: str, assoc1: Tensor, assoc2: Tensor) -> Tensor:
        row, col, mask = read_txt_array(path, sep='\t', dtype=torch.long).t()
        mask = mask.to(torch.bool)
        return torch.stack([assoc1[row[mask]], assoc2[col[mask]]], dim=0)

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}({self.pair})'


## Process the data

In [None]:
import torch
import os

In [None]:
from torch_geometric.data import Data

## ResMap

In [None]:
id_fts1 = open("db/raw/en_fr/id_features_1")
id_fts2 = open("db/raw/en_fr/id_features_2")

In [None]:
res_pos_map_1 = dict()
res_pos_map_2 = dict()

In [None]:
for line in id_fts1:
    res_pos_map_1.update({int(line.split('\t')[0]):line.split('\t')[1][:-1]})
for line in id_fts2:
    res_pos_map_2.update({int(line.split('\t')[0]):line.split('\t')[1][:-1]})

In [None]:
uebersetzung_2_umgedreht = {}
i = 0
for line in id_fts2:
    uebersetzung_2_umgedreht.update({int(line.split('\t')[0]): i})
    i = i +1 

In [None]:
uebersetzung_1_umgedreht = {}
i = 0
for line in id_fts1:
    uebersetzung_1_umgedreht.update({int(line.split('\t')[0]): i})
    i = i +1 

In [None]:
new_2_res_map = dict(zip(uebersetzung_2_umgedreht.values(), res_pos_map_2.values()))

In [None]:
new_1_res_map = dict(zip(uebersetzung_1_umgedreht.values(), res_pos_map_1.values()))

## Edges

In [None]:
edges_2 = open("db/raw/en_fr/triples_2")
edges_1 = open("db/raw/en_fr/triples_1")

In [None]:
edges_2_list = []
edges_1_list = []

In [None]:
for line in edges_2:
    edges_2_list.append([int(line.split('\t')[0]), int(line.split('\t')[2][:-1])])

In [None]:
for line in edges_1:
    edges_1_list.append([int(line.split('\t')[0]), int(line.split('\t')[2][:-1])])

In [None]:
new_edges_2_list = []
for l in edges_2_list:
    new_edges_2_list.append([uebersetzung_2_umgedreht[l[0]], uebersetzung_2_umgedreht[l[1]]])

In [None]:
new_edges_1_list = []
for l in edges_1_list:
    new_edges_1_list.append([uebersetzung_1_umgedreht[l[0]], uebersetzung_1_umgedreht[l[1]]])

## Alignment

In [None]:
train_ref = open("db/raw/en_fr/train.ref")

In [None]:
training_alignment = []
for line in train_ref:
     training_alignment.append([uebersetzung_1_umgedreht[int(line.split('\t')[0])],
                                uebersetzung_2_umgedreht[int(line.split('\t')[1][:-1])]])
        
        

In [None]:
test = open("db/raw/en_fr/test.ref")
testing_alignment = []
for line in test:
     testing_alignment.append([uebersetzung_1_umgedreht[int(line.split('\t')[0])],
                                uebersetzung_2_umgedreht[int(line.split('\t')[1][:-1])]])

In [None]:
testing_alignment

[[4500, 4500],
 [4501, 4501],
 [4502, 4502],
 [4503, 4503],
 [4504, 4504],
 [4505, 4505],
 [4506, 4506],
 [4507, 4507],
 [4508, 4508],
 [4509, 4509],
 [4510, 4510],
 [4511, 4511],
 [4512, 4512],
 [4513, 4513],
 [4514, 4514],
 [4515, 4515],
 [4516, 4516],
 [4517, 4517],
 [4518, 4518],
 [4519, 4519],
 [4520, 4520],
 [4521, 4521],
 [4522, 4522],
 [4523, 4523],
 [4524, 4524],
 [4525, 4525],
 [4526, 4526],
 [4527, 4527],
 [4528, 4528],
 [4529, 4529],
 [4530, 4530],
 [4531, 4531],
 [4532, 4532],
 [4533, 4533],
 [4534, 4534],
 [4535, 4535],
 [4536, 4536],
 [4537, 4537],
 [4538, 4538],
 [4539, 4539],
 [4540, 4540],
 [4541, 4541],
 [4542, 4542],
 [4543, 4543],
 [4544, 4544],
 [4545, 4545],
 [4546, 4546],
 [4547, 4547],
 [4548, 4548],
 [4549, 4549],
 [4550, 4550],
 [4551, 4551],
 [4552, 4552],
 [4553, 4553],
 [4554, 4554],
 [4555, 4555],
 [4556, 4556],
 [4557, 4557],
 [4558, 4558],
 [4559, 4559],
 [4560, 4560],
 [4561, 4561],
 [4562, 4562],
 [4563, 4563],
 [4564, 4564],
 [4565, 4565],
 [4566, 45

In [None]:
from sentence_transformers import SentenceTransformer

In [None]:
model = SentenceTransformer('distiluse-base-multilingual-cased-v1')

Downloading:   0%|          | 0.00/690 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/114 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.58M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.38k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/556 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/539M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/452 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/996k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/341 [00:00<?, ?B/s]

In [None]:
attributes = new_1_res_map.values()
embeddings = model.encode([i for i in list(attributes)])
fertige_embeddings = torch.tensor(embeddings)

In [None]:
attributes = new_2_res_map.values()
embeddings = model.encode([i for i in list(attributes)])
fertige_embeddings2 = torch.tensor(embeddings)

In [None]:
testing_alignment_left = torch.tensor([item[0] for item in testing_alignment])

In [None]:
testing_alignment_right = torch.tensor([item[1] for item in testing_alignment])

In [None]:
training_alignment_left = torch.tensor([item[0] for item in training_alignment])
training_alignment_right = torch.tensor([item[1] for item in training_alignment])

In [None]:
new_edges_1_tens= torch.tensor(new_edges_1_list, dtype=torch.long)
new_edges_1_cont  = new_edges_1_tens.t().contiguous()

In [None]:
new_edges_2_tens= torch.tensor(new_edges_2_list, dtype=torch.long)
new_edges_2_cont  = new_edges_2_tens.t().contiguous()

In [None]:
dbp15k = Data(x_one = fertige_embeddings, edge_index_one = new_edges_1_cont, 
     x_two = fertige_embeddings2, edge_index_two = new_edges_2_cont, 
     train_set_left =training_alignment_left,
     train_set_right = training_alignment_right,
     test_set_left =testing_alignment_left,
     test_set_right = testing_alignment_right)

In [None]:
torch.save(dbp15k, "../2_preprocessed/dbp15k/dbp15k_estis.pt")