In [1]:
import os
import torch
import pickle
import numpy as np
from tqdm import tqdm
from time import time
from common import *
from model import HALOModel
from config import HALOConfig

# Circuit imports
import sys
sys.path.append(os.path.join(sys.path[0],'hmc-utils'))
sys.path.append(os.path.join(sys.path[0],'hmc-utils', 'pypsdd'))
from GatingFunction import DenseGatingFunction
from compute_mpe import CircuitMPE

RUNS = 1

config = HALOConfig()
NUM_GENERATIONS = 10000
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Create circuit object
cmpe = CircuitMPE('constraints/inpatient.vtree', 'constraints/inpatient.sdd')

# Create gating function
gate = DenseGatingFunction(cmpe.beta, gate_layers=[config.n_embd] + [256]*config.num_gates, num_reps=config.num_reps).to(device)

# Create the model
model = HALOModel(config).to(device)

state = torch.load('../../save/spl_model_old')
model.load_state_dict(state['model'])
gate.load_state_dict(state['gate'])

[('linear1', Linear(in_features=768, out_features=256, bias=True)), ('relu1', ReLU())]


  torch.nn.init.xavier_uniform(m.weight)


<All keys matched successfully>

In [3]:
def sample_sequence(model, length, context, batch_size, device='cuda', sample=True):
  context = torch.tensor(context, device=device, dtype=torch.float32).unsqueeze(0).repeat(batch_size, 1)
  prev = context.unsqueeze(1)
  context = None
  with torch.no_grad():
    for _ in range(length-1):
      next = model.sample(prev)
      theta = gate(next)
      cmpe.set_params(theta)
      res = cmpe.get_mpe_inst(next.size(0))
      next = (res > 0).float()
      prev = torch.cat((prev, next.unsqueeze(1)), dim=1)
      
      if torch.sum(torch.sum(prev[:,:,config.code_vocab_size+config.label_vocab_size+1], dim=1).bool().int(), dim=0).item() == batch_size:
        break
  ehr = prev.cpu().detach().numpy()
  prev = None
  return ehr

def convert_ehr(ehrs, index_to_code=None):
  ehr_outputs = []
  for i in range(len(ehrs)):
    ehr = ehrs[i]
    ehr_output = []
    labels_output = ehr[1][config.code_vocab_size:config.code_vocab_size+config.label_vocab_size]
    if index_to_code is not None:
      labels_output = [index_to_code[idx + config.code_vocab_size] for idx in np.nonzero(labels_output)[0]]
    for j in range(2, len(ehr)):
      visit = ehr[j]
      visit_output = []
      indices = np.nonzero(visit)[0]
      end = False
      for idx in indices:
        if idx < config.code_vocab_size: 
          visit_output.append(index_to_code[idx] if index_to_code is not None else idx)
        elif idx == config.code_vocab_size+config.label_vocab_size+1:
          end = True
      if visit_output != []:
        ehr_output.append(visit_output)
      if end:
        break
    ehr_outputs.append({'visits': ehr_output, 'labels': labels_output})
  ehr = None
  ehr_output = None
  labels_output = None
  visit = None
  visit_output = None
  indices = None
  return ehr_outputs

# Generate Synthetic EHR dataset
speeds = []
stoken = np.zeros(config.total_vocab_size)
stoken[config.code_vocab_size+config.label_vocab_size] = 1
for run in tqdm(range(RUNS)):
  SEED = run
  random.seed(SEED)
  np.random.seed(SEED)
  torch.manual_seed(SEED)
  if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    
  synthetic_ehr_dataset = []
  start = time()
  for i in tqdm(range(0, NUM_GENERATIONS, config.sample_batch_size), leave=False):
    bs = min([NUM_GENERATIONS-i, config.sample_batch_size])
    batch_synthetic_ehrs = sample_sequence(model, config.n_ctx, stoken, batch_size=bs, device=device, sample=True)
    batch_synthetic_ehrs = convert_ehr(batch_synthetic_ehrs)
    synthetic_ehr_dataset += batch_synthetic_ehrs
  end = time()

  generationTime = end - start
  secondsPerPatient = generationTime / NUM_GENERATIONS
  speeds.append(secondsPerPatient)
  pickle.dump(secondsPerPatient, open(f'../../results/generationSpeeds/splSpeed_{run}.pkl', 'wb'))
  pickle.dump(synthetic_ehr_dataset, open(f'../../results/splDataset_{run}.pkl', 'wb'))
print(f"Seconds Per Patient: {np.mean(speeds)} +/- {np.std(speeds) / np.sqrt(RUNS) * 1.96}")

In [2]:
stoken = np.zeros(config.total_vocab_size)
stoken[config.code_vocab_size+config.label_vocab_size] = 1
synthetic_ehr_dataset = []
bs = config.sample_batch_size
# batch_synthetic_ehrs = sample_sequence(model, config.n_ctx, stoken, batch_size=bs, device=device, sample=True)
# batch_synthetic_ehrs = convert_ehr(batch_synthetic_ehrs)
# synthetic_ehr_dataset += batch_synthetic_ehrs

In [3]:
batch_size = bs
context = stoken
length = config.n_ctx
context = torch.tensor(context, device=device, dtype=torch.float32).unsqueeze(0).repeat(batch_size, 1)
prev = context.unsqueeze(1)
context = None

In [None]:

with torch.no_grad():
    next = model.sample(prev)
    theta = gate(next)
    cmpe.set_params(theta)
    res = cmpe.get_mpe_inst(next.size(0))
    next = (res > 0).float()
    prev = torch.cat((prev, next.unsqueeze(1)), dim=1)
    
    print(torch.sum(torch.sum(prev[:,:,config.code_vocab_size+config.label_vocab_size+1], dim=1).bool().int(), dim=0).item() == batch_size)


In [8]:
with torch.no_grad():
    next = model.sample(prev)
    theta = gate(next)
    cmpe.set_params(theta)
    res = cmpe.get_mpe_inst(32)
    

In [27]:
res.sigmoid()[0][100]

tensor(0., device='cuda:0')

In [14]:
next = (res > 0).float()

In [15]:
next.nonzero()

tensor([[   0, 1611],
        [   0, 1614],
        [   1, 1611],
        [   1, 1614],
        [   2, 1611],
        [   2, 1614],
        [   3, 1611],
        [   3, 1614],
        [   4, 1611],
        [   4, 1614],
        [   5, 1611],
        [   5, 1614],
        [   6, 1611],
        [   6, 1614],
        [   7, 1611],
        [   7, 1614],
        [   8, 1611],
        [   8, 1614],
        [   9, 1611],
        [   9, 1614],
        [  10, 1611],
        [  10, 1614],
        [  11, 1611],
        [  11, 1614],
        [  12, 1611],
        [  12, 1614],
        [  13, 1611],
        [  13, 1614],
        [  14, 1611],
        [  14, 1614],
        [  15, 1611],
        [  15, 1614],
        [  16, 1611],
        [  16, 1614],
        [  17, 1611],
        [  17, 1614],
        [  18, 1611],
        [  18, 1614],
        [  19, 1611],
        [  19, 1614],
        [  20, 1611],
        [  20, 1614],
        [  21, 1611],
        [  21, 1614],
        [  22, 1611],
        [ 

In [None]:
ehr = prev.cpu().detach().numpy()

In [9]:
def evaluateDataset(dataset, rules):
  violationsPerRule = []
  for (past_visits, past_pos_codes, past_neg_codes, curr_pos_codes, curr_neg_codes, output_code, output_value) in tqdm(rules, leave=False):
    violations = 0
    for p in tqdm(dataset, leave=False):
      visits = [[]] + [[l + config.code_vocab_size for l in p['labels'].nonzero()[0]]] + p['visits']
      for i, v in enumerate(visits):
        pastSatisfied = False
        currSatisfied = False
        if not past_visits:
          pastSatisfied = True
        else:
          if past_visits == -1:
            past_codes = set([c for v in p['visits'][:i] for c in v])
          else:
            visit_past_visits = [pi for pi in past_visits if (i > pi if pi >= 0 else i+pi >= 0)]
            past_codes = set([c for pi in visit_past_visits for c in (visits[pi] if pi >= 0 else visits[i+pi])])
            
          if all([c in past_codes for c in past_pos_codes] + [c not in past_codes for c in past_neg_codes]):
            pastSatisfied = True
        
        if all([c in v for c in curr_pos_codes] + [c not in v for c in curr_neg_codes]):
          currSatisfied = True
          
        if pastSatisfied and currSatisfied:
          if (output_value and output_code not in v) or (not output_value and output_code in v): 
            violations += 1
    
    violationsPerRule.append(violations)
  results = {'Per Rule': violationsPerRule, 'Total Number': sum(violationsPerRule)}
  return results

In [10]:
violations = evaluateDataset(synthetic_ehr_dataset, config.rules)

                                               

In [11]:
(end - start) / 10000

0.0014319023609161378

In [62]:
violations

{'Per Rule': [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  2,
  0,
  5,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  8,
  1,
  1,
  27,
  25,
  62,
  2,
  0,
  10,
  10,
  0,
  15,
  0,
  2,
  137,
  75,
  21,
  2,
  4,
  3,
  3,
  0,
  0,
  0,
  16,
  1,
  5,
  47,
  0,
  0,
  167,
  127,
  131,
  0,
  0,
  1,
  114,
  0,
  8,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 'Total Number': 1032}

In [63]:
config.rules

[([1], [1610], [], [], [], 972, 0),
 ([1], [1610], [], [], [], 89, 0),
 ([1], [1610], [], [], [], 92, 0),
 ([1], [1610], [], [], [], 93, 0),
 ([1], [1610], [], [], [], 94, 0),
 ([1], [1610], [], [], [], 95, 0),
 ([1], [1610], [], [], [], 162, 0),
 ([1], [1610], [], [], [], 163, 0),
 ([1], [1610], [], [], [], 208, 0),
 ([1], [1610], [], [], [], 230, 0),
 ([1], [1610], [], [], [], 294, 0),
 ([1], [1611], [], [], [], 89, 0),
 ([1], [1611], [], [], [], 90, 0),
 ([1], [1611], [], [], [], 92, 0),
 ([1], [1611], [], [], [], 93, 0),
 ([1], [1611], [], [], [], 94, 0),
 ([1], [1611], [], [], [], 95, 0),
 ([1], [1611], [], [], [], 162, 0),
 ([1], [1611], [], [], [], 163, 0),
 ([1], [1611], [], [], [], 208, 0),
 ([1], [1611], [], [], [], 230, 0),
 ([1], [1611], [], [], [], 294, 0),
 ([1], [1612], [], [], [], 972, 0),
 ([1], [1612], [], [], [], 89, 0),
 ([1], [1612], [], [], [], 92, 0),
 ([1], [1612], [], [], [], 93, 0),
 ([1], [1612], [], [], [], 94, 0),
 ([1], [1612], [], [], [], 95, 0),
 ([1], [

In [13]:
violations['Per Rule'][-11:]

[0, 0, 0, 0, 0, 10000, 0, 0, 0, 0, 0]

In [None]:
def evaluate_circuit(model, gate, cmpe, epoch, data_loader, data_split, prefix):

    test_val_t = perf_counter()

    for i, (x,y) in enumerate(data_loader):

        model.eval()
        gate.eval()
                
        x = x.to(device)
        y = y.to(device)

        # Parameterize circuit using nn
        emb = model(x.float())
        thetas = gate(emb)

        # negative log likelihood and map
        cmpe.set_params(thetas)
        nll = cmpe.cross_entropy(y, log_space=True).mean()

        cmpe.set_params(thetas)
        pred_y = (cmpe.get_mpe_inst(x.shape[0]) > 0).long()

        pred_y = pred_y.to('cpu')
        y = y.to('cpu')

        num_correct = (pred_y == y.byte()).all(dim=-1).sum()

        if i == 0:
            test_correct = num_correct
            predicted_test = pred_y
            y_test = y
        else:
            test_correct += num_correct
            predicted_test = torch.cat((predicted_test, pred_y), dim=0)
            y_test = torch.cat((y_test, y), dim=0)

    dt = perf_counter() - test_val_t
    y_test = y_test[:,data_split.to_eval]
    predicted_test = predicted_test[:,data_split.to_eval]
    
    accuracy = test_correct / len(y_test)
    nll = nll.detach().to("cpu").numpy() / (i+1)
    jaccard = jaccard_score(y_test, predicted_test, average='micro')
    hamming = hamming_loss(y_test, predicted_test)

    print(f"Evaluation metrics on {prefix} \t {dt:.4f}")
    print(f"Num. correct: {test_correct}")
    print(f"Accuracy: {accuracy}")
    print(f"Hamming Loss: {hamming}")
    print(f"Jaccard Score: {jaccard}")
    print(f"nll: {nll}")


    return {
        f"{prefix}/accuracy": (accuracy, epoch, dt),
        f"{prefix}/hamming": (hamming, epoch, dt),
        f"{prefix}/jaccard": (jaccard, epoch, dt),
        f"{prefix}/nll": (nll, epoch, dt),
    }