In [1]:
import torch
import os
import sys

import pytorch_lightning as pl
import torch
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader, random_split

sys.path.append("../py")
sys.path.append("../lib/BioInfer_software_1.0.1_Python3/")

from bioinferdataset import BioInferDataset
from config import *
from INN import INNModelLightning


def update_batch_S(new_batch, batch):
    S = batch[0]["S"]
    for i in range(1, len(batch)):
        s_new = batch[i]["S"]
        s_new[s_new > -1] += batch[i - 1]["T"].shape[0]
        S = torch.cat([S, s_new])
    new_batch["S"] = S
    return new_batch


def collate_list_keys(new_batch, batch, list_keys):
    for key in list_keys:
        new_batch[key] = [sample[key] for sample in batch]
    return new_batch


def collate_cat_keys(new_batch, batch, cat_keys):
    for key in cat_keys:
        new_batch[key] = torch.cat([sample[key] for sample in batch])
    return new_batch


def collate_func(batch):
    cat_keys = ["element_names", "L", "labels", "is_entity", "L"]
    list_keys = ["tokens", "entity_spans"]

    if type(batch) == dict:
        batch = [batch]

    new_batch = {}
    new_batch = collate_list_keys(new_batch, batch, list_keys)
    new_batch = collate_cat_keys(new_batch, batch, cat_keys)
    new_batch = update_batch_S(new_batch, batch)

    T = torch.arange(len(new_batch["element_names"]))
    new_batch["T"] = T

    return new_batch


def set_device():
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        print("Running on the GPU")
        GPUS = 1
    else:
        device = torch.device("cpu")
        print("Running on the CPU")
        GPUS = 0
    return GPUS


def load_dataset():
    dataset = BioInferDataset(XML_PATH)
    if os.path.isfile(PREPPED_DATA_PATH):
        dataset.load_samples_from_pickle(PREPPED_DATA_PATH)
    else:
        dataset.prep_data()
        dataset.samples_to_pickle(PREPPED_DATA_PATH)
    return dataset


def split_data(dataset):
    train_max_range = round(0.8 * len(dataset))
    train_idx = range(0, train_max_range)
    val_idx = range(train_max_range, len(dataset))
    train_set, val_set = random_split(dataset, lengths=[len(train_idx), len(val_idx)])
    return train_set, val_set



GPUS = set_device()
dataset = load_dataset()
train_set, val_set = split_data(dataset)

torch.autograd.set_detect_anomaly(True)

train_data_loader = DataLoader(
    train_set, collate_fn=collate_func, batch_size=BATCH_SIZE
)
val_data_loader = DataLoader(val_set, collate_fn=collate_func, batch_size=1)

model = INNModelLightning(
    vocab_dict=dataset.vocab_dict,
    element_to_idx=dataset.element_to_idx,
    word_embedding_dim=WORD_EMBEDDING_DIM,
    cell_state_clamp_val=CELL_STATE_CLAMP_VAL,
    hidden_state_clamp_val=HIDDEN_STATE_CLAMP_VAL,
)

In [6]:
model

INNModelLightning(
  (cell): DAGLSTMCell(
    (W_ioc_hat): Linear(in_features=512, out_features=1536, bias=False)
    (U_ioc_hat): Linear(in_features=1024, out_features=1536, bias=False)
    (W_fs): Linear(in_features=512, out_features=1024, bias=False)
    (U_fs): Linear(in_features=1024, out_features=1024, bias=False)
  )
  (inn): INNModel(
    (word_embeddings): Embedding(5200, 256)
    (element_embeddings): Embedding(132, 512)
    (attn_scores): Linear(in_features=512, out_features=1, bias=True)
    (blstm): LSTM(256, 256, bidirectional=True)
    (cell): DAGLSTMCell(
      (W_ioc_hat): Linear(in_features=512, out_features=1536, bias=False)
      (U_ioc_hat): Linear(in_features=1024, out_features=1536, bias=False)
      (W_fs): Linear(in_features=512, out_features=1024, bias=False)
      (U_fs): Linear(in_features=1024, out_features=1024, bias=False)
    )
    (output_linear): Sequential(
      (0): Linear(in_features=512, out_features=1024, bias=True)
      (1): Linear(in_features=

In [7]:
batch = next(iter(train_data_loader))

In [8]:
x = model(batch)

In [9]:
from torchviz import make_dot

In [11]:
dot = make_dot(x, params=dict(model.named_parameters()))
dot.render('./dot',format='svg')

'./dot.svg'