In [1]:
import os
import pickle
import json
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary
from torchviz import make_dot

In [2]:
# set seed
seed = 24
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x20158f9ba30>

In [3]:
from pyhealth.datasets import MIMIC3Dataset

mimic3_ds = MIMIC3Dataset(
        root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/",
        tables=["DIAGNOSES_ICD"])

# we show the statistics below.
mimic3_ds.stat()


Statistics of base dataset (dev=False):
	- Dataset: MIMIC3Dataset
	- Number of patients: 49993
	- Number of visits: 52769
	- Number of visits per patient: 1.0555
	- Number of events per visit in DIAGNOSES_ICD: 9.1038



'\nStatistics of base dataset (dev=False):\n\t- Dataset: MIMIC3Dataset\n\t- Number of patients: 49993\n\t- Number of visits: 52769\n\t- Number of visits per patient: 1.0555\n\t- Number of events per visit in DIAGNOSES_ICD: 9.1038\n'

In [4]:
from pyhealth.medcode import InnerMap
icd9cm = InnerMap.load("ICD9CM")
icd9cm.stat()
from pyhealth.medcode import CrossMap
codemap = CrossMap.load("ICD9CM", "CCSCM")
# ccscm = InnerMap.load("CCSCM")
# ccscm.stat()


Statistics for ICD9CM:
	- Number of nodes: 17736
	- Number of edges: 17733
	- Available attributes: ['name']



In [5]:
def proc_code(code):
    if code[0] == 'E':
        if len(code) > 4 and len(code) < 8:
            return (code[:4] + '.' + code[4:])
        else:
            return code
    else:
        if len(code) > 3 and len(code) < 7:
            return (code[:3] + '.' + code[3:])
        else:
            return code

In [6]:
diag_codes = set()
pa = pc = 0
codes = []
visits = 0
leaf_codes = set()
ancester_codes = set()
ccs_codes = set()
ccs_codes_count = []
ancester = None
max_codes_per_visit = 0
max_visits = 0
for pid in mimic3_ds.patients:
    pa +=1
    max_visits = max(max_visits, len(mimic3_ds.patients[pid].visits))
    if len(mimic3_ds.patients[pid].visits) > 1:
        pc += 1
#        print("patients:", pid)
        for vid in mimic3_ds.patients[pid].visits:
            visits += 1
#            print("visit:", vid)
            max_codes_per_visit = max(max_codes_per_visit, len(mimic3_ds.patients[pid].visits[vid].get_code_list("DIAGNOSES_ICD")))
            for code in mimic3_ds.patients[pid].visits[vid].get_code_list("DIAGNOSES_ICD"):
                code = proc_code(code)
#                print("diagnoses:", code)
                if code in icd9cm:
                    diag_codes.add(code)
                    leaf_codes.add(code)
                    ancesters = icd9cm.get_ancestors(code)
#                    print("ancesters:", ancesters)
                    for acode in (ancesters):
                        ancester_codes.add(acode)
                    codes.append(code)
#                     print(code, codemap.map(code))
                    ccs_codes.add(codemap.map(code)[0])
                    ccs_codes_count.append(codemap.map(code)[0])
print("leaf codes:", len(leaf_codes))
print("ancester codes:", len(ancester_codes))
print("# of patients:", pc)
print("# of codes:", len(codes))
print("code per visit:", len(codes)/visits)
print("visit per patient", visits/pc)
print("max visit per patient:", max_visits)
print("max code per visit:", max_codes_per_visit)
print("# of ccs codes", len(ccs_codes))

leaf codes: 2987
ancester codes: 1612
# of patients: 2339
# of codes: 38767
code per visit: 7.579081133919844
visit per patient 2.1868319794784097
max visit per patient: 7
max code per visit: 63
# of ccs codes 259


In [7]:
from collections import Counter
code_freq = dict(Counter(code for code in codes))
code_freq = sorted(code_freq.items(), key=lambda kv: kv[1])
n_code = len(code_freq)
n_code, code_freq
code_freq_dict = {}
for i in range(len(code_freq)):
    code_freq_dict[code_freq[i][0]] = int((i/len(code_freq))*5)

In [8]:
from collections import Counter
ccs_code_freq = dict(Counter(code for code in ccs_codes_count))
ccs_code_freq = sorted(ccs_code_freq.items(), key=lambda kv: kv[1])
n_ccs_code = len(ccs_code_freq)
ccs_code_freq_dict = {}
for i in range(len(ccs_code_freq)):
    ccs_code_freq_dict[ccs_code_freq[i][0]] = int((i/len(ccs_code_freq))*5)

In [9]:
leaf_ancester_codes = leaf_codes & ancester_codes
pure_leaf_codes = leaf_codes - leaf_ancester_codes
all_codes = leaf_codes | ancester_codes
n_leaf_codes = len(leaf_codes)
n_all_codes = len(all_codes)

In [10]:
all_codes_ordered = list(pure_leaf_codes) + list(leaf_ancester_codes) + list(ancester_codes - leaf_ancester_codes)

In [11]:
all_code_dict = {all_codes_ordered[i]:i for i in range(len(all_codes_ordered))}
diag_code_dict = {list(diag_codes)[i]:i for i in range(len(diag_codes))}
n_diag_codes = len(diag_code_dict)
ccs_codes_list = list(ccs_codes)
ccs_codes_dict = {ccs_codes_list[i]:i for i in range(len(ccs_codes_list))}


In [12]:
#
# Now build the attention mask: the mask has a value of 1 if there is an ancester relationship, 0 otherwise.
#
seqs = []
labels = []
n_ancesters = 6
att_mask = torch.zeros((len(leaf_codes), n_ancesters))
att_mask_1 = torch.zeros((n_leaf_codes, n_all_codes))
n_labels = 0
ancesters_dict = {}
for pid in mimic3_ds.patients:
    seq = []
    label = []
    if len(mimic3_ds.patients[pid].visits) > 1:
        for vid in mimic3_ds.patients[pid].visits:
            icd_codes = mimic3_ds.patients[pid].visits[vid].get_code_list("DIAGNOSES_ICD")
            processed_icd_codes = []
            for code in icd_codes:
                code = proc_code(code)
                if code in icd9cm:
                    ancesters = icd9cm.get_ancestors(code)
                    ancesters_dict[code] = [code] + ancesters
                    att_mask_1[all_code_dict[code], all_code_dict[code]] = 1
                    for idx_ancester in range(len(ancesters)):
                        att_mask[all_code_dict[code], idx_ancester] = 1
                        att_mask_1[all_code_dict[code], all_code_dict[ancesters[idx_ancester]]] = 1
                    processed_icd_codes.append(all_code_dict[code])
            if len(processed_icd_codes) > 0:
                seq.append(processed_icd_codes)
        if len(seq) > 1:
            labels.append([[ccs_codes_dict[codemap.map(all_codes_ordered[code])[0]] for code in codes] for codes in seq[1:]])
            seqs.append(seq[:-1])

print("average number of label codes:", n_labels/len(labels))

average number of label codes: 0.0


In [13]:
from torch.utils.data import Dataset
class CustomDataset(Dataset):   
    def __init__(self, seqs, labels):
        self.x = seqs
        self.y = labels
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, index):
        return self.x[index], self.y[index]

dataset = CustomDataset(seqs, labels)
print(len(dataset))

2085


In [14]:
def collate_fn(data):
    """
    TODO: Collate the the list of samples into batches. For each patient, you need to pad the diagnosis
        sequences to the sample shape (max # visits, max # diagnosis codes). The padding infomation
        is stored in `mask`.
    
    Arguments:
        data: a list of samples fetched from `CustomDataset`
        
    Outputs:
        x: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.long
        masks: a tensor of shape (# patiens, max # visits, max # diagnosis codes) of type torch.bool
        rev_x: same as x but in reversed time. This will be used in our RNN model for masking 
        rev_masks: same as mask but in reversed time. This will be used in our RNN model for masking
        y: a tensor of shape (# patiens) of type torch.float
        
    Note that you can obtains the list of diagnosis codes and the list of hf labels
        using: `sequences, labels = zip(*data)`
    """

    sequences, labels = zip(*data)


    num_patients = len(sequences)
    num_visits = [len(patient) for patient in sequences]
    num_codes = [len(visit) for patient in sequences for visit in patient]

    max_num_visits = max(num_visits)
    max_num_codes = max(num_codes)
    
    x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    rev_x = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.long)
    masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    rev_masks = torch.zeros((num_patients, max_num_visits, max_num_codes), dtype=torch.bool)
    for i_patient, patient in enumerate(sequences):
        for j_visit, visit in enumerate(patient):
            x[i_patient][j_visit][range(len(sequences[i_patient][j_visit]))] = torch.LongTensor(sequences[i_patient][j_visit])
            masks[i_patient][j_visit][range(len(sequences[i_patient][j_visit]))] = True
            rev_x[i_patient][num_visits[i_patient] - 1 - j_visit][range(len(sequences[i_patient][j_visit]))] = torch.LongTensor(sequences[i_patient][j_visit])
            rev_masks[i_patient][num_visits[i_patient] - 1 - j_visit][range(len(sequences[i_patient][j_visit]))] = True

                        
    y = torch.zeros((num_patients, max_num_visits, n_ccs_code), dtype = torch.float) # n_leaf_codes), dtype=torch.float)
    for i_patient, visits in enumerate(labels):
        for i_visit, codes in enumerate(visits):
            for code in codes:
                y[i_patient][i_visit][code] = 1
            
    return x, masks, rev_x, rev_masks, y

In [15]:
from torch.utils.data import DataLoader

loader = DataLoader(dataset, batch_size=10, collate_fn=collate_fn)
loader_iter = iter(loader)
x, masks, rev_x, rev_masks, y = next(loader_iter)

In [16]:
from torch.utils.data.dataset import random_split
split = int(len(dataset)*0.8)
lengths = [split, len(dataset) - split]
train_dataset, val_dataset = random_split(dataset, lengths)

print("Length of train dataset:", len(train_dataset))
print("Length of val dataset:", len(val_dataset))

Length of train dataset: 1668
Length of val dataset: 417


In [17]:
from torch.utils.data import DataLoader

def load_data(train_dataset, val_dataset, collate_fn):
    
    '''
    TODO: Implement this function to return the data loader for  train and validation dataset. 
    Set batchsize to 32. Set `shuffle=True` only for train dataloader.
    
    Arguments:
        train dataset: train dataset of type `CustomDataset`
        val dataset: validation dataset of type `CustomDataset`
        collate_fn: collate function
        
    Outputs:
        train_loader, val_loader: train and validation dataloaders
    
    Note that you need to pass the collate function to the data loader `collate_fn()`.
    '''
    
    batch_size = 2000
    train_loader = DataLoader(train_dataset, batch_size = batch_size, collate_fn = collate_fn, shuffle = True)
    val_loader = DataLoader(val_dataset, batch_size = batch_size, collate_fn = collate_fn, shuffle = False)
        
    return train_loader, val_loader


train_loader, val_loader = load_data(train_dataset, val_dataset, collate_fn)

In [18]:
def sum_embeddings_with_mask(x, masks):
    """
    TODO: mask select the embeddings for true visits (not padding visits) and then
        sum the embeddings for each visit up.

    Arguments:
        x: the embeddings of diagnosis sequence of shape (batch_size, # visits, # diagnosis codes, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        sum_embeddings: the sum of embeddings of shape (batch_size, # visits, embedding_dim)
        
    NOTE: Do NOT use for loop.

    """
    masked_x = torch.mul(x, masks[:,:,:,None])
    return torch.sum(masked_x, 2, keepdim=False)

In [19]:
def get_last_visit(hidden_states, masks):
    """
    TODO: obtain the hidden state for the last true visit (not padding visits)

    Arguments:
        hidden_states: the hidden states of each visit of shape (batch_size, # visits, embedding_dim)
        masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

    Outputs:
        last_hidden_state: the hidden state for the last true visit of shape (batch_size, embedding_dim)
        
    NOTE: DO NOT use for loop.
    
    HINT: First convert the mask to a vector of shape (batch_size,) containing the true visit length; 
          and then use this length vector as index to select the last visit.
    """
    
    # your code here
    last_visit = torch.sum((torch.sum(masks, 2, keepdim = False) > 0).int(), 1)
    return hidden_states[range(hidden_states.shape[0]),last_visit-1,:]

In [20]:
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, top_k_accuracy_score

def eval_model(model, val_loader):
    
    """
    TODO: evaluate the model.
    
    Arguments:
        model: the RNN model
        val_loader: validation dataloader
        
    Outputs:
        precision: overall precision score
        recall: overall recall score
        f1: overall f1 score
        roc_auc: overall roc_auc score
        
    Note that please pass all four arguments to the model so that we can use this function for both 
    models. (Use `model(x, masks, rev_x, rev_masks)`.)
        
    HINT: checkout https://scikit-learn.org/stable/modules/classes.html#module-sklearn.metrics
    """
    k = 20
    model.eval()
    y_pred = torch.LongTensor()
    y_score = torch.Tensor()
    y_true = torch.LongTensor()
    model.eval()
    found_all = 0
    actual_all = 0 
    samples = 0
    n_freq_band = 5
    top_k_tp = torch.zeros(n_freq_band, dtype=torch.float)
    top_k_fn = torch.zeros(n_freq_band, dtype=torch.float)
    top_k_fp = torch.zeros(n_freq_band, dtype=torch.float)
    top_k_tn = torch.zeros(n_freq_band, dtype=torch.float)
    for x, masks, rev_x, rev_masks, y in val_loader:
        y_hat = model(x, masks, rev_x, rev_masks).detach()
        for i in range(len(y_hat)):
            for m in range(len(y_hat[0])):
                top_k = np.argpartition(y_hat[i][m].detach(), -k)[-k:]
                for j in range(n_ccs_code): # n_leaf_codes):
                    if(y[i][m][j] > 0.5):
                        # code = all_codes_ordered[j]
                        # print("actual code", code)
                        if j in top_k:
                            # print(code, "in top ", code_freq_dict[code], "percentile")
                            # top_k_tp[code_freq_dict[code]] += 1
                            top_k_tp[ccs_code_freq_dict[ccs_codes_list[j]]] += 1
                        else:
                            # top_k_fn[code_freq_dict[code]] += 1
                            top_k_fn[ccs_code_freq_dict[ccs_codes_list[j]]] += 1
    print(top_k_tp, top_k_fn)
    return (top_k_tp/(top_k_tp+top_k_fn))

In [21]:
import time

def train(model, train_loader, val_loader, n_epochs):
    """
    TODO: train the model.
    
    Arguments:
        model: the RNN model
        train_loader: training dataloder
        val_loader: validation dataloader
        n_epochs: total number of epochs
        
    You need to call `eval_model()` at the end of each training epoch to see how well the model performs 
    on validation data.
        
    Note that please pass all four arguments to the model so that we can use this function for both 
    models. (Use `model(x, masks, rev_x, rev_masks)`.)
    """
    
    for epoch in range(n_epochs):
        start_time = time.time()
        model.train()
        train_loss = 0
        for x, masks, rev_x, rev_masks, y in train_loader:
            optimizer.zero_grad()
            y_hat = model(x, masks, rev_x, rev_masks)
#             print("y_hat size", y_hat.size())
#             print("y size", y.size())
            loss = criterion(y_hat, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        train_loss = train_loss / len(train_loader)
        print('Epoch: {} \t Training Loss: {:.6f}'.format(epoch+1, train_loss))
       # make_dot(y_hat, params=dict(model.named_parameters()))
        end_time = time.time()
        print("time in this epoch", end_time-start_time)
        top_k_acc = eval_model(model, val_loader)
        print("top k acc ", top_k_acc)

# RNN Model

In [22]:
class NaiveRNN(nn.Module):
    
    """
    TODO: implement the naive RNN model above.
    """
    
    def __init__(self, num_codes):
        super().__init__()
        embDimSize = 128
        hidden_size = 128
        self.embedding = nn.Embedding(num_codes, embDimSize)
        self.rnn = nn.GRU(input_size = embDimSize, hidden_size = hidden_size, batch_first = True)
        self.rev_rnn = nn.GRU(input_size = embDimSize, hidden_size = hidden_size, batch_first = True)
        self.fc = nn.Linear(hidden_size * 2, n_ccs_code) # n_diag_codes)
        self.sigmoid = nn.Sigmoid()
        
    
    def forward(self, x, masks, rev_x, rev_masks):
        """
        Arguments:
            x: the diagnosis sequence of shape (batch_size, # visits, # diagnosis codes)
            masks: the padding masks of shape (batch_size, # visits, # diagnosis codes)

        Outputs:
            probs: probabilities of shape (batch_size, # diagnosis codes)
        """
        
        batch_size = x.shape[0]
        
        x = self.embedding(x)
        x = sum_embeddings_with_mask(x, masks)
        output, _ = self.rnn(x)
        true_h_n = output # get_last_visit(output, masks)
        
        """
        TODO:
            5. Do the step 1-4 again for the reverse order (rev_x), and concatenate the hidden
               states for both directions;
        """
    
        rev_x = self.embedding(rev_x)
        # 5b. Sum the embeddings for each diagnosis code up for a visit of a patient.
        rev_x = sum_embeddings_with_mask(rev_x, rev_masks)
        # 5c. Pass the embegginds through the RNN layer;
        rev_output, _ = self.rev_rnn(rev_x)
        # 5d. Obtain the hidden state at the last visit.
        true_h_n_rev = rev_output # get_last_visit(rev_output, rev_masks)        
        # 6. Pass the hidden state through the linear and activation layers.
        logits = self.fc(torch.cat([true_h_n, true_h_n_rev], 2))
        probs = self.sigmoid(logits)
        return probs
    

# load the model here
naive_rnn = NaiveRNN(num_codes = n_leaf_codes)
naive_rnn

NaiveRNN(
  (embedding): Embedding(2987, 128)
  (rnn): GRU(128, 128, batch_first=True)
  (rev_rnn): GRU(128, 128, batch_first=True)
  (fc): Linear(in_features=256, out_features=259, bias=True)
  (sigmoid): Sigmoid()
)

In [23]:
summary(naive_rnn)

Layer (type:depth-idx)                   Param #
NaiveRNN                                 --
├─Embedding: 1-1                         382,336
├─GRU: 1-2                               99,072
├─GRU: 1-3                               99,072
├─Linear: 1-4                            66,563
├─Sigmoid: 1-5                           --
Total params: 647,043
Trainable params: 647,043
Non-trainable params: 0

In [29]:
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(naive_rnn.parameters(), lr = 0.01)
# number of epochs to train the model
n_epochs = 10
train(naive_rnn, train_loader, val_loader, n_epochs)

Epoch: 1 	 Training Loss: 0.034411
time in this epoch 1.9839236736297607
tensor([0.0000e+00, 0.0000e+00, 1.0000e+00, 1.6000e+01, 1.1030e+03]) tensor([  17.,   67.,  183.,  452., 1115.])
top k acc  tensor([0.0000, 0.0000, 0.0054, 0.0342, 0.4973])
Epoch: 2 	 Training Loss: 0.028249
time in this epoch 1.8102672100067139
tensor([   0.,    0.,    0.,   13., 1142.]) tensor([  17.,   67.,  184.,  455., 1076.])
top k acc  tensor([0.0000, 0.0000, 0.0000, 0.0278, 0.5149])
Epoch: 3 	 Training Loss: 0.024460
time in this epoch 1.9080932140350342
tensor([   0.,    0.,    0.,    5., 1169.]) tensor([  17.,   67.,  184.,  463., 1049.])
top k acc  tensor([0.0000, 0.0000, 0.0000, 0.0107, 0.5271])
Epoch: 4 	 Training Loss: 0.023011
time in this epoch 1.8099751472473145
tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 1.2620e+03]) tensor([ 17.,  67., 184., 467., 956.])
top k acc  tensor([0.0000, 0.0000, 0.0000, 0.0021, 0.5690])
Epoch: 5 	 Training Loss: 0.023128
time in this epoch 1.974869489669799

# GRAM Model

In [39]:
class GRAM(torch.nn.Module):
    def __init__(self, n_leaf_codes, n_all_codes, n_diag_codes, n_emb=128, l=64):
        super(GRAM, self).__init__()
        self.n_emb = n_emb # size of code embedding, m in the paper
        self.l = l # size of hidden layer
        self.n_ancesters = 6
        self.n_hidden = 128
        self.n_leaf_codes = n_leaf_codes
        self.n_all_codes = n_all_codes
        #
        # initialzie embedding matrix to random.  Future work: use Glove to train embedding.
        #
        self.E = nn.Parameter(torch.rand(self.n_all_codes, self.n_emb))

        embDimSize = self.n_emb
        hidden_size = self.n_hidden
       
        self.fc1 = nn.Linear(2*self.n_emb, self.l)  # Wa, b_a
        self.fc2 = nn.Linear(self.l, 1, bias = False)   # ua
        
        self.rnn = nn.GRU(input_size = embDimSize, hidden_size = hidden_size, batch_first = True)
        self.rev_rnn = nn.GRU(input_size = embDimSize, hidden_size = hidden_size, batch_first = True)
        self.fc = nn.Linear(hidden_size * 2, n_ccs_code) # n_diag_codes)  # W, b
        self.sigmoid = nn.Sigmoid()
    
    def update_G(self):
        a1 = torch.zeros((self.n_leaf_codes, self.n_all_codes))
        for child in range(self.n_leaf_codes):
            if child % 400 == 0:
                print("processing leaf node", child)
            for parent in range(self.n_all_codes):
                if att_mask_1[child][parent]:
                    eij = torch.cat([self.E[child], self.E[parent]])
                    a1[child][parent] = self.fc2(torch.tanh(self.fc1(eij)))        
        a2 = torch.mul(nn.Softmax(dim=1)(a1), att_mask_1)
        a = (a2/torch.sum(a2, dim=1, keepdim=True))
        G = torch.matmul(a, self.E)  # G shape: n_leaf_node x n_emb
        return G
    
    def embedding(self, x, G):
        v = torch.tanh(torch.matmul(x.float(), G))
        return v
                                                         
    def forward(self, x, masks, rev_x, rev_masks):
        batch_size = x.shape[0]
        G = self.update_G()
        
        x = nn.functional.one_hot(x, num_classes = n_diag_codes).sum(dim=2)
        x_emb = self.embedding(x, G)
        output, _ = self.rnn(x_emb)
        true_h_n = output # get_last_visit(output, masks)
        
        rev_x = nn.functional.one_hot(rev_x, num_classes = n_diag_codes).sum(dim=2)
        rev_x_emb = self.embedding(rev_x, G)
        rev_output, _ = self.rev_rnn(rev_x_emb)
        true_h_n_rev = rev_output # get_last_visit(rev_output, rev_masks)
        
        logits = self.fc(torch.cat([true_h_n, true_h_n_rev], 2))
        probs = self.sigmoid(logits)
        return probs 

In [40]:
print(n_leaf_codes, n_all_codes, n_diag_codes)
gram = GRAM(n_leaf_codes, n_all_codes, n_diag_codes)
print(gram)
summary(gram)

2987 4543 2987
GRAM(
  (fc1): Linear(in_features=256, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=1, bias=False)
  (rnn): GRU(128, 128, batch_first=True)
  (rev_rnn): GRU(128, 128, batch_first=True)
  (fc): Linear(in_features=256, out_features=259, bias=True)
  (sigmoid): Sigmoid()
)


Layer (type:depth-idx)                   Param #
GRAM                                     581,504
├─Linear: 1-1                            16,448
├─Linear: 1-2                            64
├─GRU: 1-3                               99,072
├─GRU: 1-4                               99,072
├─Linear: 1-5                            66,563
├─Sigmoid: 1-6                           --
Total params: 862,723
Trainable params: 862,723
Non-trainable params: 0

In [41]:
torch.autograd.set_detect_anomaly(False)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(gram.parameters(), lr = 0.01)
n_epochs = 10
train(gram, train_loader, val_loader, n_epochs)

processing leaf node 0
processing leaf node 400
processing leaf node 800
processing leaf node 1200
processing leaf node 1600
processing leaf node 2000
processing leaf node 2400
processing leaf node 2800
Epoch: 1 	 Training Loss: 0.694918
time in this epoch 554.800961971283
processing leaf node 0
processing leaf node 400
processing leaf node 800
processing leaf node 1200
processing leaf node 1600
processing leaf node 2000
processing leaf node 2400
processing leaf node 2800
tensor([  1.,   7.,   3.,  62., 271.]) tensor([  16.,   60.,  181.,  406., 1947.])
top k acc  tensor([0.0588, 0.1045, 0.0163, 0.1325, 0.1222])
processing leaf node 0
processing leaf node 400
processing leaf node 800
processing leaf node 1200
processing leaf node 1600
processing leaf node 2000
processing leaf node 2400
processing leaf node 2800
Epoch: 2 	 Training Loss: 0.366427
time in this epoch 574.2729349136353
processing leaf node 0
processing leaf node 400
processing leaf node 800
processing leaf node 1200
proces

In [None]:
gram.E[0,]

# Ablation Test

In [None]:
pure_ancester_codes = list(all_codes - leaf_ancester_codes)
n_pure_ancester_code = len(pure_ancester_codes)
                           
def random_code():
    indexes = random.sample(range(n_leaf_codes, n_all_codes), n_ancesters-1)
    return indexes

att_mask_1 = torch.zeros((n_leaf_codes, n_all_codes))

for child in range(n_leaf_codes):
    ancesters = random_code()
    att_mask_1[child, child] = 1
    for ancester in ancesters:
        att_mask_1[child,ancester] = 1

In [None]:
att_mask_1[0].sum()

In [None]:
gram_ablation = GRAM(n_leaf_codes, n_all_codes, n_diag_codes)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(gram_ablation.parameters(), lr = 0.01)
n_epochs = 10
train(gram_ablation, train_loader, val_loader, n_epochs)

In [None]:
gram_ablation.E

In [None]:
gram.E

In [None]:
g