In [1]:
import fire, sys, tqdm
import os
import yaml
import cv2
import json
import gzip
import numpy as np
import pandas as pd
from math import ceil
from skimage import io as skio
from collections import Counter
from collections import OrderedDict
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,Dataset

import transformers as tf
from transformers import AutoModel, AutoTokenizer
from torchvision import transforms

import utils.load as load
from rgcn import RGCN
from utils.load import tic, toc, here
from model_utils import sum_sparse, adj, pca, enrich, bert_emb, squeeze_emb, score_distmult_bc, compute_ranks_fast, sfcn_emb
from squeezenet import SqueezeNetwork
from SFCNnet import SFCNNetwork
from utils.data_utils import PrepareDataset



In [2]:
def node_classification_task(config_yaml, data, emb, trainable, device):
    
    '''
    with early stopping. save model with lowest validation loss
    '''
    final = config_yaml["data_prep"]["final"]
    split_ratio = config_yaml["link_pred"]["split_ratio"]
    bases = config_yaml["rgcn"]["bases"]
    lr = config_yaml["rgcn"]["lr"]
    l2 = config_yaml["rgcn"]["l2"]
    wd = config_yaml["rgcn"]["wd"]
    epochs = config_yaml["training"]["epochs"]
    printnorms = config_yaml["training"]["printnorms"]
    patience = config_yaml["training"]["patience"]
    best_score = config_yaml["training"]["best_score"]
    delta = config_yaml["training"]["delta"]
    model_path = config_yaml["training"]["model_path"]
    results_path = config_yaml["training"]["results_path"]
    
    epochs_list = []
    train_loss_list = []
    valid_loss_list = []
    training_acc_list = []
    withheld_acc_list = []

    results_dict = {}
    
    rgcn = RGCN(data.triples.long().to(device), n=data.num_entities, r=data.num_relations, insize=emb, hidden=emb, numcls=data.num_classes, device = device, link_prediction = False, bases=bases)

    rgcn.to(device)

    opt = torch.optim.Adam(lr=lr, weight_decay=wd, params=rgcn.parameters())
    
    if final==False:
        patience_left = patience
        best_state = None
        for e in range(epochs):
            tic()
            # both the train & validation are already on graph. But only the loss from train is used to backpropagate and learn the graph

            # https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn/entity_classify.py
            opt.zero_grad()
            # features are all the data
            features = torch.cat([trainable], dim=0) # features: number of entities in dataset x embedding size
            features = features.requires_grad_()
            features.to(device)
            out = rgcn(features)

            idxt, clst = data.training[:, 0], data.training[:, 1] # indices, classes
            idxw, clsw = data.withheld[:, 0], data.withheld[:, 1] # indices, classes

            clst = clst.to(device)
            clsw = clsw.to(device)
            out_train = out[idxt, :]
            loss = F.cross_entropy(out_train, clst, reduction='mean')
            val_loss = F.cross_entropy(out[idxw,:], clsw, reduction='mean')
            if l2 != 0.0:
                loss = loss + l2 * rgcn.penalty()

            # compute performance metrics
            with torch.no_grad():
                # clst & clsw are class labels
                # check if the probabilities given out given instance 
                training_acc = (out[idxt, :].argmax(dim=1) == clst).sum().item() / idxt.size(0)
                withheld_acc = (out[idxw, :].argmax(dim=1) == clsw).sum().item() / idxw.size(0)

            loss.backward()
            opt.step()
            epochs_list.append(e)
            train_loss_list.append(loss.cpu().detach().numpy())
            valid_loss_list.append(val_loss.cpu().detach().numpy())
            training_acc_list.append(training_acc)
            withheld_acc_list.append(withheld_acc)

            if e> 20:
                if best_score < 0:
                    best_score = val_loss
                    val_acc = withheld_acc
                    model_dict = OrderedDict()
                    for k, v in rgcn.state_dict().items():
                        if k not in ['hor_graph','ver_graph']:
                            model_dict[k] = v 

                    best_state = model_dict
                if val_loss >= best_score - delta: # if validation loss is greater than best score
                    patience_left -= 1
                else:
                    best_score = val_loss
                    val_acc = withheld_acc
                    model_dict = OrderedDict()
                    for k, v in rgcn.state_dict().items():
                        if k not in ['hor_graph','ver_graph']: # cannot save sparse tensors
                            model_dict[k] = v 

                    best_state = model_dict
                    patience_left = patience
                if patience_left <= 0:
                    torch.save(best_state, f'{model_path}/nc_{val_acc:.2}.pt')
                    print("Early stopping after no improvement for {} epoch".format(patience))
                    break

            print(f'epoch {e:02}: loss {loss:.2}, train acc {training_acc:.2}, \t val_loss {val_loss:.2}, withheld acc {withheld_acc:.2} \t ({toc():.5}s)')

            results_dict = {'epochs':epochs_list, 'train_loss': train_loss_list, 'valid_loss': valid_loss_list, 'training_acc':training_acc_list, 'validation_acc': withheld_acc_list}
            df = pd.DataFrame(data=results_dict)
            df.to_csv(os.path.join(results_path, "nc_train_val_logs.csv"), index=False)
            
    elif final==True:
        features = torch.cat([trainable], dim=0)
        rgcn = RGCN(data.triples.long().to(device), n=data.num_entities, r=data.num_relations, insize=emb, hidden=emb, numcls=data.num_classes, device = device, link_prediction = False, bases=bases)
        rgcn.to(device)
        rgcn.load_state_dict(torch.load(f'{model_path}/nc_0.66.pt'),strict=False)
        
        test_results_dict = {}
        
        rgcn.eval()
        features.to(device)
        with torch.no_grad():
            out = rgcn(features)

        idxw, clsw = data.withheld[:, 0], data.withheld[:, 1] # indices, classes
        clsw = clsw.to(device)
        out_test = out[idxw, :]
        test_acc = (out[idxw, :].argmax(dim=1) == clsw).sum().item() / idxw.size(0)
        # scores on set
        test_results_dict = {'type':'test', 'acc':test_acc}
        with open(os.path.join(results_path, "nc_test_logs.csv"), 'w') as file:
             file.write(json.dumps(test_results_dict)) # use `json.loads` to do the reverse
        

In [3]:
def link_prediction_task(config_yaml, data, emb, trainable, device):
    
    '''
    evaluate on training & validation, then immediate evaluate on test.
    
    with early stopping
    '''
    
    results_path = config_yaml["training"]["results_path"]
    split_ratio = config_yaml["link_pred"]["split_ratio"]
    bases = config_yaml["rgcn"]["bases"]
    lr = config_yaml["rgcn"]["lr"]
    l2 = config_yaml["rgcn"]["l2"]
    wd = config_yaml["rgcn"]["wd"]
    epochs = config_yaml["training"]["epochs"]
    final = config_yaml["data_prep"]["final"]
    patience = config_yaml["training"]["patience"]
    best_score = config_yaml["training"]["best_score"]
    delta = config_yaml["training"]["delta"]
    model_path = config_yaml["training"]["model_path"]
    
    results_df = pd.DataFrame()
    results_dict = {}
    
    train_val_triples, test_triples = train_test_split(data.triples.long(), test_size = split_ratio, random_state=0)
    train_triples, val_triples = train_test_split(train_val_triples, test_size = split_ratio, random_state=0)
    rgcn = RGCN(data.triples.long().to(device), n=data.num_entities, r=data.num_relations, insize=emb, hidden=emb, numcls=data.num_classes, device = device, link_prediction=True, bases=bases)
    rgcn.to(device)
    
    opt = torch.optim.Adam(lr=lr, weight_decay=wd, params=rgcn.parameters())
    criterion = nn.BCEWithLogitsLoss()
    
    epochs_list = []
    train_loss_list = []
    valid_MRR_list = []
    hits1_list = []
    hits3_list = []
    hits10_list = []
    hits100_list = []

    
    patience_left = patience
    best_state = None
    for e in range(epochs):
        tic()
        opt.zero_grad()

        rgcn.train()
        features = torch.cat([trainable], dim=0)
        features.requires_grad_()
        node_embeddings = rgcn(features) # number of entities x embedding size => NEED TO CHECK IF CAN PASS THE WHOLE FEATURE MATRIX IN OR JUST THE TRAIN MATRIX?
        edge_embeddings = rgcn.w_relation

        nsamples = len(train_triples)
        ncorrupt = nsamples//5
        neg_samples_idx = torch.from_numpy(np.random.choice(np.arange(nsamples),
                                                            ncorrupt,
                                                            replace=False))

        # creating corrupted triples

        ncorrupt_head = ncorrupt//2
        ncorrupt_tail = ncorrupt - ncorrupt_head
        corrupted_data = torch.empty((ncorrupt, 3), dtype=torch.int64)

        corrupted_data = train_triples[neg_samples_idx]
        corrupted_data[:ncorrupt_head, 0] = torch.from_numpy(np.random.choice(np.arange(data.num_entities),
                                                                              ncorrupt_head))
        corrupted_data[-ncorrupt_tail:, 2] = torch.from_numpy(np.random.choice(np.arange(data.num_entities),
                                                                               ncorrupt_tail))


        # compute score
        Y = torch.ones(nsamples+ncorrupt, dtype=torch.float32)
        Y[-ncorrupt:] = 0
        Y = Y.to(device)
        Y_hat = torch.empty((nsamples+ncorrupt), dtype=torch.float32).to(device)
        Y_hat[:nsamples] = score_distmult_bc((train_triples[:, 0],
                                              train_triples[:, 1],
                                              train_triples[:, 2]),
                                             node_embeddings,
                                             edge_embeddings).to(device)

        Y_hat[-ncorrupt:] = score_distmult_bc((corrupted_data[:, 0],
                                               corrupted_data[:, 1],
                                               corrupted_data[:, 2]),
                                              node_embeddings,
                                              edge_embeddings).to(device)

        loss = criterion(Y_hat, Y)

        loss.backward()
        nn.utils.clip_grad_norm_(rgcn.parameters(), 1.0)
        opt.step()

        # validate
        rgcn.eval()
        with torch.no_grad():
            node_embeddings = rgcn(features) 
            edge_embeddings = rgcn.w_relation
            ranks = compute_ranks_fast(val_triples, node_embeddings, edge_embeddings, batch_size=16)

            mrr_raw = torch.mean(1.0 / ranks.float()).item()

            hits_at_k = dict()
            for k in [1, 3, 10, 100]:
                hits_at_k[k] = float(torch.mean((ranks <= k).float())) # % of dataset which made it to top k positions
                
        if e> 500:
            if best_score < 0:
                best_score = mrr_raw
                model_dict = OrderedDict()
                for k, v in rgcn.state_dict().items():
                    if k not in ['hor_graph','ver_graph']:
                        model_dict[k] = v 

                best_state = model_dict
            if mrr_raw <= best_score - delta: # if mrr is lower than before
                patience_left -= 1
            else:
                best_score = mrr_raw
                model_dict = OrderedDict()
                for k, v in rgcn.state_dict().items():
                    if k not in ['hor_graph','ver_graph']: # cannot save sparse tensors
                        model_dict[k] = v 

                best_state = model_dict
                patience_left = patience
            if patience_left <= 0:
                torch.save(best_state, f'{model_path}/lp_{best_score:.2}.pt')
                print("Early stopping after no improvement for {} epoch".format(patience))
                break


        # clear gpu cache
        if device == torch.device('cuda'):
            del node_embeddings
            del edge_embeddings
            torch.cuda.empty_cache()

        if e%50==0:
            epochs_list.append(e)
            train_loss_list.append(loss.cpu().detach().numpy())
            valid_MRR_list.append(mrr_raw)
            for k,v in hits_at_k.items():
                if int(k) == 1:
                    hits1_list.append(v)
                elif int(k) == 3:
                    hits3_list.append(v)
                elif int(k) == 10:
                    hits10_list.append(v)
                elif int(k) == 100:
                    hits100_list.append(v)
            print("{:04d} ".format(e) \
                         + "| train loss {:.4f} ".format(loss)
                         + "| valid MRR (raw) {:.4f} ".format(mrr_raw)
                         + "/ " + " / ".join(["H@{} {:.4f}".format(k,v)
                                             for k,v in hits_at_k.items()]))
            
    results_dict = {'epochs':epochs_list,'train_loss':train_loss_list, 'valid_MRR': valid_MRR_list, 'hits1':hits1_list, 'hits3':hits3_list, 'hits10':hits10_list, 'hits100':hits100_list}
    df = pd.DataFrame(data=results_dict)
    df.to_csv(os.path.join(results_path, "lp_train_val_logs.csv"), index=False)
    
    result = test_link_pred(config_yaml, emb, best_state, features, test_triples, rgcn, device, results_path)
    

In [4]:
def test_link_pred(config_yaml, emb, best_state, features, test_triples, rgcn, device, results_path):
    bases = config_yaml["rgcn"]["bases"]
    
    features = torch.cat([trainable], dim=0)
    rgcn = RGCN(data.triples.long().to(device), n=data.num_entities, r=data.num_relations, insize=emb, hidden=emb, numcls=data.num_classes, device = device, link_prediction = False, bases=bases)
    rgcn.to(device)
    rgcn.load_state_dict(best_state, strict=False)
    
    rgcn.eval()
    mrr = 0.0
    hits_at_k = {1: 0.0, 3: 0.0, 10: 0.0, 100:0.0}
    ranks = None
    with torch.no_grad():
        node_embeddings = rgcn(features) 
        edge_embeddings = rgcn.w_relation

        ranks = compute_ranks_fast(test_triples, node_embeddings, edge_embeddings, batch_size=16)

        mrr = torch.mean(1.0 / ranks.float()).item()
        for k in [1, 3, 10, 100]:
            hits_at_k[k] = float(torch.mean((ranks <= k).float()))

    rank_type = "raw"
    print("Performance on test set: MRR ({}) {:.4f}".format(rank_type, mrr) 
          + " / " + " / ".join(["H@{} {:.4f}".format(k,v) for k,v in hits_at_k.items()]))

    test_results_dict = {'test_MRR': mrr, 'hits1':hits_at_k[1], 'hits3':hits_at_k[3], 'hits10':hits_at_k[10], 'hits100':hits_at_k[100]}
    with open(os.path.join(results_path, "lp_test_logs.csv"), 'w') as file:
        file.write(json.dumps(test_results_dict)) # use `json.loads` to do the reverse

    return (mrr, hits_at_k, ranks.numpy())


In [5]:
class Args():
    config = "config.yaml"

args = Args()
with open(args.config) as file:
    config_yaml = yaml.load(file)
    
    print("DATA PREPARATION")
    prune = config_yaml["data_prep"]["prune"]
    data_path = config_yaml["data_prep"]['data_path']
    final = config_yaml["data_prep"]["final"]
    prune_dist = config_yaml["data_prep"]["prune_dist"]
    device = config_yaml["enivron"]["device"]
    emb = config_yaml["data_prep"]["emb"]
    
    data = load.Data(data_path, final=final, use_torch=True)
    if prune_dist is not None:
        data = load.prune(data, n=prune_dist)
    data = load.group(data)
    
    print("PREPARE EMBEDDINGS")
    use_saved_embeddings = config_yaml["embeddings"]["use_saved_embeddings"]
    imagebatch = config_yaml["embeddings"]["imagebatch"]
    stringbatch = config_yaml["embeddings"]["stringbatch"]
    sample_size = config_yaml["embeddings"]["sample_size"]
    sample_duration = config_yaml["embeddings"]["sample_duration"]
    embedding_path = config_yaml["embeddings"]["embedding_path"]
    cnn_type = config_yaml["embeddings"]["cnn_type"]

    if not use_saved_embeddings:
        with torch.no_grad():

            embeddings = [] # number of nodes

            for datatype in data.datatypes():
                # categorical variables - the relations are just for me to pick out categorical variables without needing to label them beforehand
                if datatype in ['hasName','hasSubjectSex','hasAPOEA1','hasAPOEA2','hasMMSCORE','hasGDTOTAL','hasCDGLOBAL','hasMoCa','hasESS','hasUPSIT','hasSTAI']:
                    print(f'Initializing embedding for datatype {datatype}.')
                    n = len(data.get_strings(dtype=datatype)) # initialise the dictionary
                    embedding = nn.Embedding(n, emb)
                    embedding = embedding.to(device)
                    # create indices for all the values to feed into embedding
                    x = torch.LongTensor([idx for idx,i in enumerate(data.get_strings(dtype=datatype))])
                    x = x.to(device)
                    categorical_embedding = embedding(x)

                    embeddings.append(categorical_embedding.float())

                #numerical variables
                if datatype in ['hasWeightKg','hasSubjectAge']:
                    print(f'Initializing embedding for datatype {datatype}.')
                    n = len(data.get_strings(dtype=datatype))
                    linear = nn.Linear(1, emb)
                    linear = linear.to(device)
                    array = [float(x) for x in data.get_strings(dtype=datatype)]
                    x = torch.unsqueeze(torch.FloatTensor(array), dim=1)
                    x = x.to(device)
                    linear_embedding = linear(x)

                    embeddings.append(linear_embedding.float())

                # images 
                elif datatype == 'hasImage':
                    # squeezenet for 3d mri images
                    print(f'Computing embeddings for images.')
                    if cnn_type == "squeezenet":
                        image_embeddings = squeeze_emb(data.get_images(), sample_size, sample_duration, device, bs=imagebatch)
                    elif cnn_type == "SFCNnet":
                        image_embeddings = sfcn_emb(data.get_images(), device, bs=imagebatch)
                    image_embeddings = pca(image_embeddings, target_dim=emb, device=device)
                    embeddings.append(image_embeddings.float())


                # strings    
                elif datatype == "hasCONDTERM":
                    # embed medical conditions with bio-clinical bert
                    print(f'Computing embeddings for datatype {datatype}.')
                    string_embeddings = bert_emb(data.get_strings(dtype=datatype), device, bs_chars=stringbatch)
                    string_embeddings = pca(string_embeddings, target_dim=emb, device= device) #emb is input size dimensions
                    embeddings.append(string_embeddings.float())

            # data loader clusters nodes by data type, and in order given by data._datasets
            embeddings = torch.cat(embeddings, dim=0).to(torch.float)
        if cnn_type == "squeezenet":
            torch.save(embeddings, os.path.join(embedding_path, 'squeezenet_embeddings.pt'))
        elif cnn_type == "SFCNnet":
            torch.save(embeddings, os.path.join(embedding_path, 'sfcn_embeddings.pt'))
            
        # all our embeddings are trainable
        trainable = nn.Parameter(embeddings)
    else:
        if cnn_type == "squeezenet":
            embeddings = torch.load(os.path.join(embedding_path, 'squeezenet_embeddings.pt'))
        elif cnn_type == "SFCNnet":
            embeddings = torch.load(os.path.join(embedding_path, 'sfcn_embeddings.pt'))
        # all our embeddings are trainable
        trainable = nn.Parameter(embeddings)
        
    print("PEFORM TASK")
    
    link_prediction = config_yaml["general"]["link_prediction"]
    node_classification = config_yaml["general"]["node_classification"]
    
    if link_prediction == True:
        link_prediction_task(config_yaml, data, emb, trainable, device)
    if node_classification == True:
        node_classification_task(config_yaml, data, emb, trainable, device)
        
        

  


DATA PREPARATION
PREPARE EMBEDDINGS
Initializing embedding for datatype hasAPOEA1.
Initializing embedding for datatype hasCDGLOBAL.
Initializing embedding for datatype hasGDTOTAL.
Computing embeddings for images.


100%|██████████| 880/880 [30:40<00:00,  2.09s/it]


Initializing embedding for datatype hasMMSCORE.
Initializing embedding for datatype hasName.
Initializing embedding for datatype hasSubjectAge.
Initializing embedding for datatype hasSubjectSex.
Initializing embedding for datatype hasWeightKg.
PEFORM TASK
epoch 00: loss 1.9e+01, train acc 0.69, 	 val_loss 2.7e+01, withheld acc 0.53 	 (0.012019s)
epoch 01: loss 2.8, train acc 0.67, 	 val_loss 4.8, withheld acc 0.53 	 (0.0113s)
epoch 02: loss 2.9e+01, train acc 0.31, 	 val_loss 2.3e+01, withheld acc 0.44 	 (0.010592s)
epoch 03: loss 2.3e+01, train acc 0.32, 	 val_loss 1.9e+01, withheld acc 0.44 	 (0.010193s)
epoch 04: loss 8.3, train acc 0.32, 	 val_loss 7.3, withheld acc 0.42 	 (0.0098672s)
epoch 05: loss 4.9, train acc 0.69, 	 val_loss 7.9, withheld acc 0.53 	 (0.0097623s)
epoch 06: loss 1e+01, train acc 0.69, 	 val_loss 1.5e+01, withheld acc 0.53 	 (0.0097127s)
epoch 07: loss 1.3e+01, train acc 0.69, 	 val_loss 1.9e+01, withheld acc 0.53 	 (0.0096254s)
epoch 08: loss 1.4e+01, train ac