In [13]:
import os
import sys

from rdflib import Graph, BNode, URIRef
from om.ont import get_n, tokenize
import itertools
import time
from tqdm.auto import tqdm
import re
import torch
from pymagnitude import Magnitude

In [4]:


def batched(iterable, n):
    if n < 1:
        raise ValueError('n must be at least one')
    it = iter(iterable)
    while batch := tuple(itertools.islice(it, n)):
        yield batch

In [5]:
base_path = '/projets/melodi/gsantoss/data/oaei/tracks/populated/data_100'
out_base = '/projets/melodi/gsantoss/canarde/canard_emb'

In [6]:
models = [('/projets/melodi/gsantoss/canarde/ncembs/fasttext.magnitude', 'fasttext'),  ('/projets/melodi/gsantoss/canarde/ncembs/glove.magnitude', 'glove'), ('/projets/melodi/gsantoss/canarde/ncembs/word2vec.magnitude', 'word2vec')]

In [20]:

for md, mn in tqdm(models):

    vectors = Magnitude(md)
    
    for p, d, fs in os.walk(base_path):
        for f in tqdm(fs):
            if f.endswith('.ttl'):
                ont_name = f.split('_')[0]
                g = Graph().parse(os.path.join(p, f))
                
                subs = set(g.subjects())
                props = set(g.predicates())
                objs = set(g.objects())
                
                ks = []
                sents = []
                
                for s in subs.union(props, objs):
                
                    if type(s) == BNode:
                        continue
                        
                    
                    if s.startswith('http://'):
                        txt = ' '.join(tokenize(s.split('#')[-1]))
                    else:
                        txt = s
                        
                    ks.append(re.sub(r'\n+', ' ', s))
                    sents.append(txt)
                    
                
                vcs = [list(map(str.lower, x.split())) for x in sents]
                
                embs = []
    
                for tks in tqdm(vcs):
                    t = torch.from_numpy(vectors.query(tks))
                    embs.extend(torch.mean(t, dim=0, keepdim=True))
                
                embs = torch.stack(embs)
                embl = embs.tolist()
                eln = []
                for l in embl:
                    eln.append(' '.join([str(v) for v in l]))
                    
                with open(os.path.join(out_base, f'{ont_name}-{mn}'), 'w') as f:
                    f.write(f'{len(embs)}\n')
                    
                    f.writelines([f'{k}\n' for k in ks])
                    f.writelines([f'{l}\n' for l in eln])