In [1]:
import matplotlib.pyplot as plt
from rdflib import Graph
from rdflib.term import URIRef, BNode, Literal
from rdflib.namespace import RDF, RDFS, OWL
from om.ont import get_n, tokenize
from termcolor import colored
from ldp import parser
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import math
import random

import itertools

In [2]:
def print_node(n, g, l=0, ml=1):
    print('\t' * l, get_n(n, g))
    if l >= ml:
        return
    for _, p, o in g.triples((n, None, None)):
        print('\t' * (l + 1), colored(get_n(p, g), 'blue'), end=' ')
        print_node(o, g, l+1, ml)

def get_ld(ld, g, res, l=0):

    if type(ld) is tuple:
        for c in ld[1]:
            get_ld(c, g, res, l+1)


    elif type(ld) is URIRef:
        res.append(str(g.value(ld, RDFS.label)))

In [3]:
with open('logdefs_HP.csv', 'r') as f:
    raw_defs = list(map(lambda x: x.split(','), f.readlines()))

ldf = []
for ld in tqdm(raw_defs):

    if len(ld) > 2:
        continue
    line = []
    try:
        ldf.append((URIRef(ld[0]), parser.parse(ld[1])))
    except:
        continue

  0%|          | 0/24429 [00:00<?, ?it/s]

In [4]:


hp = Graph().parse('hp.owl')

x = []
c = []

for e, l in ldf:

    e1l = hp.value(e, RDFS.label)

    ld = []
    get_ld(l, hp, ld)

    x.append(e1l)
    c.append(ld)


In [5]:
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")

class CustomDataset(Dataset):
    def __init__(self, x, c):
        self.transform = None
        self.target_transform = None

        self.x = x
        self.c = c


        ml = -1
        for q in itertools.chain(x, itertools.chain(*c)):
            l = len(tokenizer.tokenize(q))
            if l > ml:
                ml = l
        self.ml = ml
        self.ms = max(map(len, c))

        tk1 = tokenizer(x, return_tensors='pt', padding='max_length', max_length=ml)
        self.id1 = tk1['input_ids']
        self.at1 = tk1['attention_mask']

        tc = []
        tm = {}
        for q in c:

            tk1 = tokenizer(q, return_tensors='pt', padding='max_length', truncation=True, max_length=ml)

            tc.append((tk1['input_ids'], tk1['attention_mask']))
            for w, a in zip(tk1['input_ids'], tk1['attention_mask']):
                tm[w] = a

        self.tc = tc
        self.tm = tm

    def __len__(self):
        return len(self.id1)

    def __getitem__(self, idx):
        id1 = self.id1[idx]
        at1 = self.at1[idx]

        tc1, ta1 = self.tc[idx]
        ns = random.choices(list(self.tm.keys()), k=tc1.shape[0])
        nat = list(map(lambda x: self.tm[x], ns))

        ns = torch.cat(list(map(lambda x: x.unsqueeze(0), ns)), dim=0)
        nat = torch.cat(list(map(lambda x: x.unsqueeze(0), nat)), dim=0)

        sm = [1] * tc1.shape[0]

        if tc1.shape[0] < self.ms:
            dif = self.ms - tc1.shape[0]
            sm.extend([0] * dif)
            pad = torch.zeros((dif, self.ml))
            tc1 = torch.cat([tc1, pad], dim=0)
            ta1 = torch.cat([ta1, pad], dim=0)
            ns = torch.cat([ns, pad], dim=0)
            nat = torch.cat([nat, pad], dim=0)

        sm = torch.Tensor(sm)
        return (id1, at1), (tc1, ta1), (ns, nat), sm



dataset = CustomDataset(x, c)

print(len(dataset))

24425


In [6]:

class Model(nn.Module):

    def __init__(self):
        super(Model, self).__init__()
        self.tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")
        self.biobert = AutoModel.from_pretrained("dmis-lab/biobert-v1.1")
        for param in self.biobert.base_model.parameters():
            param.requires_grad = False

        self.biobert.eval()

        self.dff = nn.Sequential(
            nn.Linear(768 * 2, 768 * 4),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(768 * 4, 768),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def emb(self, x, a):
        out = self.biobert(input_ids=x, attention_mask=a)['last_hidden_state']
        cf = a.sum(dim=1, keepdims=True)
        cf[cf == 0] = 1
        res = out.sum(1) / cf
        return res

    def forward(self, x, xa, y, ya):
        e1 = self.emb(x, xa)
        bs = y.shape[0]
        ss = y.shape[1]

        e2 = self.emb(torch.flatten(y, end_dim=1).long(), torch.flatten(ya, end_dim=1))

        jc = torch.cat([e1.unsqueeze(1).repeat(1, ss, 1), e2.reshape(bs, ss, -1)], dim=2)

        return self.dff(jc)


    def sims(self, anc, ex, bs=10):
        at = self.tokenizer([anc] + ex, return_tensors='pt', padding=True)


        ids = at['input_ids']
        ats = at['attention_mask']

        act = ids[0]
        aca = ats[0]

        ext = ids[1:]
        exa = ats[1:]


        os = []

        with torch.no_grad():

            for e, a in DataLoader(list(zip(ext, exa)), batch_size=bs):
                out = self(act.unsqueeze(0), aca.unsqueeze(0), e.unsqueeze(0), a.unsqueeze(0))
                os.append(out.squeeze(0).t().squeeze(0))

        return torch.cat(os, dim=0)


In [20]:
model = nn.DataParallel(Model())
model.cuda(0)
crit = nn.BCELoss()
optimizer = optim.AdamW(model.parameters(), lr=0.00003)

batch_size = 32
epochs = 30

progress = tqdm(total=epochs * math.ceil(len(dataset) / batch_size))

lh = []
for e in range(epochs):
    el = []
    for (ids, ati), (ps, pa), (ns, na), sm in DataLoader(dataset, batch_size=batch_size, shuffle=True):
        optimizer.zero_grad()

        ep = model(ids.cuda(0), ati.cuda(0), ps.cuda(0), pa.cuda(0))
        en = model(ids.cuda(0), ati.cuda(0), ns.cuda(0), na.cuda(0))


        pv = ep * sm.unsqueeze(-1).cuda(0) + (sm == 0).float().unsqueeze(-1).cuda(0)
        nv = en * sm.unsqueeze(-1).cuda(0)

        ploss = crit(pv, torch.ones(sm.shape).unsqueeze(-1).cuda(0))
        nloss = crit(nv, torch.zeros(sm.shape).unsqueeze(-1).cuda(0))
        loss = ploss + nloss
        loss.backward()
        el.append(loss.item())
        optimizer.step()
        progress.update(1)

    lh.append(sum(el) / len(el))

progress.close()
plt.plot(lh)
plt.show()

torch.save(model.state_dict(), 'complex_biob2.pt')

  0%|          | 0/22920 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [8]:
model = Model()
model.load_state_dict(torch.load('complex_bio.pt'))
model.eval()

Model(
  (biobert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
   

In [12]:
print(x[0], c[0] + c[1])
print(x[1], c[0] + c[1])
print(model.sims(x[0], c[0] + c[1]))
print(model.sims(x[1], c[0] + c[1]))


Muscle fiber hypertrophy ['hypertrophic', 'characteristic of part of', 'cell of skeletal muscle', 'has modifier', 'abnormal', 'increased amount', 'characteristic of', 'amyloid deposition', 'part of', 'nerve', 'has modifier', 'abnormal']
Amyloidosis of peripheral nerves ['hypertrophic', 'characteristic of part of', 'cell of skeletal muscle', 'has modifier', 'abnormal', 'increased amount', 'characteristic of', 'amyloid deposition', 'part of', 'nerve', 'has modifier', 'abnormal']
tensor([0.7600, 0.0228, 0.8211, 0.2107, 0.3271, 0.1553, 0.2601, 0.0024, 0.4524,
        0.2386, 0.2107, 0.3271])
tensor([0.0152, 0.0732, 0.4795, 0.5809, 0.6409, 0.6073, 0.5172, 0.8248, 0.6383,
        0.7233, 0.5809, 0.6409])
