In [1]:
import os
import os.path as path
from rdflib import Graph, URIRef, BNode
from rdflib.namespace import RDF, RDFS, OWL
from om.ont import get_namespace, get_n
from owl_utils import load_entities, load_cqas, load_sg, add_depth, to_pyg
from termcolor import colored
from tqdm.auto import tqdm
from transformers import AutoTokenizer, BertModel
import itertools
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader as TorchDataLoader
from torch_geometric.data import Data

from torch_geometric.loader import DataLoader
from cqa_search import build_raw_data, build_graph_dataset, pad_seq
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree



In [2]:
paths = {
    'edas.owl': '/home/guilherme/Documents/kg/conference/edas.owl',
    'ekaw.owl': '/home/guilherme/Documents/kg/conference/ekaw.owl',
    'confOf.owl': '/home/guilherme/Documents/kg/conference/confOf.owl',
    'conference.owl': '/home/guilherme/Documents/kg/conference/Conference.owl',
    'cmt.owl': '/home/guilherme/Documents/kg/conference/cmt.owl',
}

cqa_path = '/home/guilherme/Documents/complex/CQAs'
entities_path = '/home/guilherme/Documents/complex/entities-cqas'

In [3]:

idata = load_entities(entities_path, paths)
isg = load_sg(entities_path, paths)

cqas = load_cqas(cqa_path)
raw_data = build_raw_data(idata, cqas)


  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

  0%|          | 0/101 [00:00<?, ?it/s]

In [4]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [5]:
dataset = build_graph_dataset(tokenizer, cqas, idata, raw_data['edas'])

In [6]:

class GNN(MessagePassing):
    def __init__(self):
        super().__init__(aggr='add')  # "Add" aggregation (Step 5).

    def forward(self, x, edge_index, edge_features):

        # Step 4-5: Start propagating messages.
        out = self.propagate(edge_index, x=x, ef=edge_features)

        return out

    def message(self, x_j, ef):
        return x_j * ef
    

gnn = GNN()

edge_index = torch.LongTensor([[1, 2],
                           [0, 0]])
edge_features = torch.Tensor([[0.5], [1.0]])
x = torch.Tensor([[1], [2], [3]])


data = Data(x=x, edge_index=edge_index, edge_attr=edge_features)
out = gnn(data.x, data.edge_index, data.edge_attr)
print(out)

tensor([[4.],
        [0.],
        [0.]])


In [39]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')

    def forward(self, x):
        x = self.bert(x)
        return x
    
    def embed_cqa(self, x):
        mask = x > 0
        out = self.bert(input_ids=x, attention_mask=mask)['last_hidden_state']
        om =  mask.unsqueeze(-1).float()
        mo = out * om
        return mo.sum(dim=1) / om.sum(dim=1)
    
    def embed_subg(self, x, edge_index, edge_attr):
        print(x.shape)
        return self.embed_cqa(x)
model = Model()

In [8]:
gnn = GNN()

for batch in DataLoader(dataset, batch_size=2):
    print(batch)
    # out = gnn(batch.x_s, batch.edge_index_s)
    print(out.shape)
    break

GraphDataBatch(rsi=[2], rpi=[2], rni=[2], cqs=[2, 429], cqp=[2, 429], cqn=[2, 429], x_s=[18, 459], x_p=[71, 459], x_n=[59, 459], edge_index_s=[2, 36], edge_index_p=[2, 36], edge_index_n=[2, 382], edge_feat_s=[36, 50], edge_feat_p=[377, 50], edge_feat_n=[382, 50])
torch.Size([3, 1])


In [9]:
def build_raw_ts(op, data):
    ge = Graph().parse(op)

    res = {}
    for k in data:
        tn, ng = data[k]
        res[k] = tn
        for t in ng:
            ge.add(t)

    mc = 0
    mp = 0
    ifd = []
    for s in tqdm(set(ge.subjects())):

        eg = Graph()
        eg.add((s, RDF.type, OWL.Class))
        add_depth(s, eg, ge, 4)
        cm, pm, fm = to_pyg(s, eg)
        mcc = max(map(len, cm))
        mpc = max(map(len, pm))
        if mcc > mc:
            mc = mcc
        if mpc > mp:
            mp = mpc

        ifd.append((s, cm, pm, fm))

    return ifd, mc, mp, res

ifd, mc, mp, res = build_raw_ts('/home/guilherme/Documents/kg/conference/edas.owl', isg['edas'])

  0%|          | 0/524 [00:00<?, ?it/s]

In [41]:
def embed_subg(model, ifd, tokenizer, mc, mp):
    model.eval()
    ts = []
    gd = []
    for s, cm, pm, fm in tqdm(ifd):
        e1id = tokenizer(cm, return_tensors='pt', padding=True)['input_ids']
        pd1 = pad_seq(e1id, mc)
        pd1 = torch.cat([torch.zeros((1, mc)), pd1], dim=0)

        e1pid = tokenizer(pm, return_tensors='pt', padding=True)['input_ids']
        pd3 = pad_seq(e1pid, mp)

        edge1 = torch.LongTensor(fm)

        ts.append(s)
        gd.append(Data(x=pd1.long(), edge_index=edge1, edge_attr=pd3))

    fe = []
    for batch in DataLoader(gd, batch_size=2, shuffle=False):
        with torch.no_grad():
            out = model.embed_subg(batch.x, batch.edge_index, batch.edge_attr)
            fe.append(out)
        fe.append(torch.ones((2, 768)))

    fe = torch.cat(fe, dim=0)

    return ts, fe


ts, fe = embed_subg(model, ifd, tokenizer, mc, mp)

  0%|          | 0/524 [00:00<?, ?it/s]

torch.Size([85, 167])
torch.Size([100, 167])


KeyboardInterrupt: 

In [35]:
def embed_cqas(model, data, tokenizer):
    model.eval()
    cq = []
    cqi = []
    for k in data:
        cq.append(k)
        cqi.append(data[k])
    
    cqid = tokenizer(cqi, return_tensors='pt', padding=True)['input_ids']
    
    cqeb = []
    
    for c in DataLoader(cqid, batch_size=2, shuffle=False):
        with torch.no_grad():
            out = model.embed_cqa(c)            
            cqeb.append(out)
        
    cqeb = torch.cat(cqeb, dim=0)
    return cq, cqeb



cq, cqeb = embed_cqas(model, cqas['edas'], tokenizer)

In [36]:
def eval_metrics(cq, cqeb, fe, ts, res, th=0.8):
    metrics = []

    for c, e in zip(cq, cqeb):
        sim = torch.cosine_similarity(e.unsqueeze(0), fe, dim=1)
        resid = torch.where(sim > th)[0]
        rs = set()
        for r in resid:
            rs.add(ts[r])
        metrics.append((1 if res[c] in rs else 0, len(rs)))

    rc = sum([m[0] for m in metrics]) / len(metrics)
    avgp = sum([1 / m[1] if m[1] > 0 else 0 for m in metrics]) / len(metrics)
    fm = 2 * rc * avgp / (rc + avgp) if rc + avgp > 0 else 0
    return rc, avgp, fm

avgp, rc, fm = eval_metrics(cq, cqeb, fe, ts, res, th=0.8)
print(f'avgp: {avgp:.2f}, rec: {rc:.2f}, afm: {fm:.2f}')

avgp: 0.00, rec: 0.00, afm: 0.00
