# Prerequisites

In [None]:
import os
import torch

os.environ['TORCH'] = torch.__version__
print(torch.__version__)
!pip install PyTDC
!pip install torcheval
!pip install hyppo
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.ticker as ticker
import numpy as np
import networkx as nx
import seaborn as sns

import pandas as pd

from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem, MACCSkeys

from hyppo.independence import Dcorr, MGC

from scipy.spatial.distance import pdist, squareform
#from scipy.stats import multiscale_graphcorr as MGC

from time import time

from google.colab import files

from torch import nn
from torch.nn import Linear, ReLU
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import BatchNorm, GCN
from torch_geometric.nn import global_mean_pool
import torcheval
from torcheval.metrics.functional import binary_auroc

from copy import deepcopy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def logistic_loss(output, target):
  return -F.logsigmoid((2*target-1)*output) #since label is 0 and 1
loss_function = logistic_loss

In [None]:
#datasets HIV, HERG, CYP, TOX, BBB, HIA
from tdc.single_pred import HTS
data = HTS(name = 'HIV')
HIV = data.get_split() #41,127 drugs (3.5% labeled as 1)

from tdc.utils import retrieve_label_name_list
label_list = retrieve_label_name_list('herg_central')
from tdc.single_pred import Tox
data = Tox(name = 'herg_central', label_name = label_list[2])
HERG = data.get_split() #306,893 drugs (4.48% labeled as 1)

from tdc.single_pred import ADME
data = ADME(name = 'CYP2D6_Veith')
CYP = data.get_split() #13,130 drugs (19.15% labeled as 1)

from tdc.utils import retrieve_label_name_list
label_list = retrieve_label_name_list('Tox21')
from tdc.single_pred import Tox
data = Tox(name = 'Tox21', label_name = label_list[0])
TOX = data.get_split() #5086 drugs (4.253% labeled as 1)

from tdc.single_pred import ADME
data = ADME(name = 'BBB_Martins')
BBB = data.get_split() #2030 drugs (76.4% labeled as 1)

from tdc.single_pred import ADME
data = ADME(name = 'HIA_Hou')
HIA = data.get_split() #578 drugs (86.5% labeled as 1)

In [None]:
#changed to get mol
from typing import Any, Dict, List
x_map: Dict[str, List[Any]] = {
    'atomic_num':
    list(range(0, 119)),
    'chirality': [
        'CHI_UNSPECIFIED',
        'CHI_TETRAHEDRAL_CW',
        'CHI_TETRAHEDRAL_CCW',
        'CHI_OTHER',
        'CHI_TETRAHEDRAL',
        'CHI_ALLENE',
        'CHI_SQUAREPLANAR',
        'CHI_TRIGONALBIPYRAMIDAL',
        'CHI_OCTAHEDRAL',
    ],
    'degree':
    list(range(0, 11)),
    'formal_charge':
    list(range(-5, 7)),
    'num_hs':
    list(range(0, 9)),
    'num_radical_electrons':
    list(range(0, 5)),
    'hybridization': [
        'UNSPECIFIED',
        'S',
        'SP',
        'SP2',
        'SP3',
        'SP3D',
        'SP3D2',
        'OTHER',
    ],
    'is_aromatic': [False, True],
    'is_in_ring': [False, True],
}

e_map: Dict[str, List[Any]] = {
    'bond_type': [
        'UNSPECIFIED',
        'SINGLE',
        'DOUBLE',
        'TRIPLE',
        'QUADRUPLE',
        'QUINTUPLE',
        'HEXTUPLE',
        'ONEANDAHALF',
        'TWOANDAHALF',
        'THREEANDAHALF',
        'FOURANDAHALF',
        'FIVEANDAHALF',
        'AROMATIC',
        'IONIC',
        'HYDROGEN',
        'THREECENTER',
        'DATIVEONE',
        'DATIVE',
        'DATIVEL',
        'DATIVER',
        'OTHER',
        'ZERO',
    ],
    'stereo': [
        'STEREONONE',
        'STEREOANY',
        'STEREOZ',
        'STEREOE',
        'STEREOCIS',
        'STEREOTRANS',
    ],
    'is_conjugated': [False, True],
}

def from_smiles(smiles: str, with_hydrogen: bool = False,
                kekulize: bool = False) -> 'torch_geometric.data.Data':
    r"""Converts a SMILES string to a :class:`torch_geometric.data.Data`
    instance.

    Args:
        smiles (str): The SMILES string.
        with_hydrogen (bool, optional): If set to :obj:`True`, will store
            hydrogens in the molecule graph. (default: :obj:`False`)
        kekulize (bool, optional): If set to :obj:`True`, converts aromatic
            bonds to single/double bonds. (default: :obj:`False`)
    """
    from rdkit import Chem, RDLogger

    from torch_geometric.data import Data

    RDLogger.DisableLog('rdApp.*')  # type: ignore

    mol = Chem.MolFromSmiles(smiles)

    if mol is None:
        mol = Chem.MolFromSmiles('')
    if with_hydrogen:
        mol = Chem.AddHs(mol)
    if kekulize:
        Chem.Kekulize(mol)

    xs: List[List[int]] = []
    for atom in mol.GetAtoms():  # type: ignore
        row: List[int] = []
        row.append(x_map['atomic_num'].index(atom.GetAtomicNum()))
        row.append(x_map['chirality'].index(str(atom.GetChiralTag())))
        row.append(x_map['degree'].index(atom.GetTotalDegree()))
        row.append(x_map['formal_charge'].index(atom.GetFormalCharge()))
        row.append(x_map['num_hs'].index(atom.GetTotalNumHs()))
        row.append(x_map['num_radical_electrons'].index(
            atom.GetNumRadicalElectrons()))
        row.append(x_map['hybridization'].index(str(atom.GetHybridization())))
        row.append(x_map['is_aromatic'].index(atom.GetIsAromatic()))
        row.append(x_map['is_in_ring'].index(atom.IsInRing()))
        xs.append(row)

    x = torch.tensor(xs, dtype=torch.long).view(-1, 9)

    edge_indices, edge_attrs = [], []
    for bond in mol.GetBonds():  # type: ignore
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()

        e = []
        e.append(e_map['bond_type'].index(str(bond.GetBondType())))
        e.append(e_map['stereo'].index(str(bond.GetStereo())))
        e.append(e_map['is_conjugated'].index(bond.GetIsConjugated()))

        edge_indices += [[i, j], [j, i]]
        edge_attrs += [e, e]

    edge_index = torch.tensor(edge_indices)
    edge_index = edge_index.t().to(torch.long).view(2, -1)
    edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3)

    if edge_index.numel() > 0:  # Sort indices.
        perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort()
        edge_index, edge_attr = edge_index[:, perm], edge_attr[perm]

    return [Data(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles), mol]

In [None]:
def prepare_datasets (split):
  #without hydrogen
  train_dataset=[]
  train_mol = []
  for i in range(len(split['train']['Drug'])):
    data, mol = from_smiles(split['train']['Drug'][i])
    data.x = data.x.float()
    data.y = torch.tensor(float(split['train']['Y'][i]))
    train_dataset.append(data)
    train_mol.append(mol)

  test_dataset=[]
  test_mol = []
  for i in range(len(split['test']['Drug'])):
    data, mol = from_smiles(split['test']['Drug'][i])
    data.x = data.x.float()
    data.y = torch.tensor(float(split['test']['Y'][i]))
    test_dataset.append(data)
    test_mol.append(mol)

  valid_dataset=[]
  valid_mol = []
  for i in range(len(split['valid']['Drug'])):
    data, mol = from_smiles(split['valid']['Drug'][i])
    data.x = data.x.float()
    data.y = torch.tensor(float(split['valid']['Y'][i]))
    valid_dataset.append(data)
    valid_mol.append(mol)

  print(f'Train : {len(train_dataset)}, Test : {len(test_dataset)}, Valid : {len(valid_dataset)}')
  return [train_dataset, test_dataset, valid_dataset, train_mol, test_mol, valid_mol]

In [None]:
batch_size = 1024
enc ={}
dec ={}
cmp_dec = {}
train_duration = {}
cmp_train_duration = {}
opt_epochs = {}
cmp_opt_epochs = {}
test_auroc = {}
cmp_test_auroc = {}
test_loss = {}
cmp_test_loss = {}

In [None]:
class classifier(nn.Module):
  def __init__(self):
    super(classifier, self).__init__()
    self.lin = Linear(32,8)
    self.lin2 = Linear(8,1)

  def forward(self, x):
    x=self.lin(x).relu()
    x=self.lin2(x)
    return x

class classifier2(nn.Module):
  def __init__(self):
    super(classifier2, self).__init__()
    self.lin = Linear(32,16)
    self.lin2 = Linear(16,8)
    self.lin3 = Linear(8,4)
    self.lin4 = Linear(4,2)
    self.lin5 = Linear(2,1)

  def forward(self, x):
    x=self.lin(x).relu()
    x=self.lin2(x).relu()
    x=self.lin3(x).relu()
    x=self.lin4(x).relu()
    x=self.lin5(x)
    return x

class classifier3(nn.Module):
  def __init__(self):
    super(classifier3, self).__init__()
    self.lin = Linear(2048,2)
    self.lin2 = Linear(2,2)
    self.lin3 = Linear(2,1)

  def forward(self, x):
    x=self.lin(x).relu()
    x=self.lin2(x).relu()
    x=self.lin3(x)
    return x

class classifier4(nn.Module):
  def __init__(self):
    super(classifier4, self).__init__()
    self.lin = Linear(167,24)
    self.lin2 = Linear(24,2)
    self.lin3 = Linear(2,1)

  def forward(self, x):
    x=self.lin(x).relu()
    x=self.lin2(x).relu()
    x=self.lin3(x)
    return x

class classifier5(nn.Module):
  def __init__(self):
    super(classifier5, self).__init__()
    self.lin = Linear(167,4)
    self.lin2 = Linear(4,4)
    self.lin3 = Linear(4,4)
    self.lin4 = Linear(4,2)
    self.lin5 = Linear(2,1)

  def forward(self, x):
    x=self.lin(x).relu()
    x=self.lin2(x).relu()
    x=self.lin3(x).relu()
    x=self.lin4(x).relu()
    x=self.lin5(x)
    return x

class GCNencoder(nn.Module):
  def __init__(self, nlayer=1, dropout=0.):
    super(GCNencoder, self).__init__()
    bn = nn.BatchNorm1d(32)
    self.gcn = GCN(in_channels=9,
                   hidden_channels = 32,
                   num_layers = nlayer,
                   norm=bn,
                   dropout=0.
                   )

  def forward(self, x, edge_index, batch, edge_attr = None):
    x = self.gcn(x, edge_index)
    x = global_mean_pool(x, batch)

    return x

In [None]:
def decide_dataset(key_dataset):
  global key, train_loader, valid_loader, test_loader, opt_epochs, test_auroc, test_loss, train_duration, cmp_opt_epochs, cmp_test_auroc, cmp_test_loss, cmp_train_duration, cmp_dec
  key = key_dataset
  train_loader = DataLoader(dataset[key]['train'], batch_size=batch_size, shuffle=True)
  valid_loader = DataLoader(dataset[key]['valid'], batch_size=len(dataset[key]['valid']), shuffle=False)
  test_loader = DataLoader(dataset[key]['test'], batch_size=len(dataset[key]['test']), shuffle=False)

  enc[key]={}
  dec[key]={}
  opt_epochs[key] = {}
  test_auroc[key] = {}
  test_loss[key] = {}
  train_duration[key] = {}

  cmp_dec[key]={}
  cmp_opt_epochs[key] = {}
  cmp_test_auroc[key] = {}
  cmp_test_loss[key] = {}
  cmp_train_duration[key] = {}

In [None]:
def decide_cmp_dataset(cmp_key_dataset, reset=False):
  global cmp_key, cmp_train_loader, cmp_valid_loader, cmp_test_loader, cmp_opt_epochs, test_auroc, test_loss
  global de, optimizer, loss_function, train_loss, valid_loss, train_auroc, valid_auroc, cmp_train_duration, batch_size

  cmp_key = cmp_key_dataset
  batch_size =1024

  if reset is False:
    en = enc[key][nlayer].to(device)

    ldr = DataLoader(dataset[cmp_key]['train'], batch_size=len(dataset[cmp_key]['train']), shuffle=False)
    data = next(iter(ldr)).to(device)
    out = en(data.x, data.edge_index, data.batch).detach()
    cmp_train_loader = DataLoader(TensorDataset(out,data.y), batch_size=batch_size, shuffle=True)

    ldr = DataLoader(dataset[cmp_key]['valid'], batch_size=len(dataset[cmp_key]['valid']), shuffle=False)
    data = next(iter(ldr)).to(device)
    out = en(data.x, data.edge_index, data.batch).detach()
    out.y = data.y
    cmp_valid_loader = DataLoader(TensorDataset(out,data.y), batch_size=out.shape[0], shuffle=False)

    ldr = DataLoader(dataset[cmp_key]['test'], batch_size=len(dataset[cmp_key]['test']), shuffle=False)
    data = next(iter(ldr)).to(device)
    out = en(data.x, data.edge_index, data.batch).detach()
    out.y = data.y
    cmp_test_loader = DataLoader(TensorDataset(out,data.y), batch_size=out.shape[0], shuffle=False)

    enc[key][nlayer] = en.cpu()
    en = enc[key][nlayer]

  de = classifier2().to(device)
  loss_function = logistic_loss

  train_loss = []
  valid_loss = []
  train_auroc = []
  valid_auroc = []
  cmp_train_duration[key][nlayer][cmp_key] = 0

  print(f'# of parameters in decoder : {sum(p.numel() for p in de.parameters())}')

In [None]:
def decide_fp(fp_key_in, cmp_key_dataset, reset=False):
  global fp_key, cmp_key, cmp_train_loader, cmp_valid_loader, cmp_test_loader
  global de, optimizer, loss_function, train_loss, valid_loss, train_auroc, valid_auroc,batch_size

  fp_key = fp_key_in
  cmp_key = cmp_key_dataset
  batch_size =1024

  fpgen = {'top':AllChem.GetRDKitFPGenerator(fpSize=2048),
           'ap':AllChem.GetAtomPairGenerator(fpSize=2048),
           'tt':AllChem.GetTopologicalTorsionGenerator(fpSize=2048),
           'morgan':AllChem.GetMorganGenerator(fpSize=2048),
           }.get(fp_key, None)
  if fpgen is not None:
    func = fpgen.GetFingerprint
    de = classifier3().to(device)
  elif fp_key == 'maccs':
    func = MACCSkeys.GenMACCSKeys
    de = classifier5().to(device)



  if reset is False:
    out = torch.tensor([(func(m)) for m in mol[cmp_key]['train']]).float()
    lbls = torch.tensor([data.y for data in dataset[cmp_key]['train']])
    cmp_train_loader = DataLoader(TensorDataset(out, lbls), batch_size = batch_size, shuffle=True)

    out = torch.tensor([(func(m)) for m in mol[cmp_key]['valid']]).float()
    lbls = torch.tensor([data.y for data in dataset[cmp_key]['valid']])
    cmp_valid_loader = DataLoader(TensorDataset(out, lbls), batch_size = len(mol[cmp_key]['valid']), shuffle=False)

    out = torch.tensor([(func(m)) for m in mol[cmp_key]['test']]).float()
    lbls = torch.tensor([data.y for data in dataset[cmp_key]['test']])
    cmp_test_loader = DataLoader(TensorDataset(out, lbls), batch_size = len(mol[cmp_key]['test']), shuffle=False)


  loss_function = logistic_loss

  train_loss = []
  valid_loss = []
  train_auroc = []
  valid_auroc = []

  print(f'# of parameters in decoder : {sum(p.numel() for p in de.parameters())}')

In [None]:
def decide_nlayer(n=None, dropout=0.):
  global batch_size, en, de, nlayer, optimizer, loss_function, train_loss, valid_loss, train_auroc, valid_auroc, train_duration
  if n is not None:
    nlayer = n
  else:
    n = nlayer
  en = GCNencoder(nlayer=n, dropout=dropout).to(device)
  de = classifier().to(device)
  loss_function = logistic_loss
  batch_size = 1024

  train_loss = []
  valid_loss = []
  train_auroc = []
  valid_auroc = []
  train_duration[key][n] = 0

  cmp_dec[key][n]={}
  cmp_opt_epochs[key][n] = {}
  cmp_test_auroc[key][n] = {}
  cmp_test_loss[key][n] = {}
  cmp_train_duration[key][n] = {}

  print(f'# of parameters in encoder : {sum(p.numel() for p in en.parameters())}')
  print(f'# of parameters in decoder : {sum(p.numel() for p in de.parameters())}')

In [None]:
def train_loss_auroc(epochs, lr_val = None, cont = False, load_dir = None):
  global lr, train_loss, valid_loss, train_auroc, valid_auroc, optimizer, final_state, en, de
  if lr_val is not None:
    lr = lr_val

  if load_dir is not None:
    en.load_state_dict(torch.load('encoder_state_'+key+'_'+str(nlayer)+'_ '+str(load_dir)+'.pt'))
    de.load_state_dict(torch.load('decoder_state_'+key+'_'+str(nlayer)+'_ '+str(load_dir)+'.pt'))
    optimizer.load_state_dict(torch.load('optimizer_state_'+key+'_'+str(nlayer)+'_ '+str(load_dir)+'.pt'))
    [train_loss, train_auroc, valid_loss, valid_auroc] = torch.load('training mem_'+key+'_'+str(nlayer)+'_ '+str(load_dir)+'.pt')

  else:
    if cont is False:
      optimizer = optim.Adam(list(en.parameters())+list(de.parameters()), lr=lr)
    else:
      en.load_state_dict(final_state[0])
      de.load_state_dict(final_state[1])
  valid_loss_min = 100
  printed_valid_loss_min = 100
  en = en.to(device)
  de = de.to(device)
  start=time()
  for it in range(epochs):
    en.train()
    de.train()
    loss_sum = 0
    auroc_sum = 0
    train_len = 0
    for data in train_loader:
      data = data.to(device)
      out = de(en(data.x, data.edge_index, data.batch))
      optimizer.zero_grad()
      loss = loss_function(out, data.y.view(-1,1)).sum()
      loss_sum += loss.item()
      loss.backward()
      optimizer.step()
      auroc_sum += binary_auroc(out.view(-1), data.y).item() * data.y.shape[0]
      train_len += data.y.shape[0]
    train_loss.append(loss_sum/train_len)
    train_auroc.append(auroc_sum/train_len)

    en.eval()
    de.eval()
    data = next(iter(valid_loader)).to(device)
    out = de(en(data.x, data.edge_index, data.batch))
    loss = loss_function(out, data.y.view(-1,1)).sum()
    loss = loss.item()/data.y.shape[0]

    if it == 0:
      print(f'Estimated training time : {(time()-start)*epochs} secs.')

    #print(valid_loss_min, printed_valid_loss_min, loss)
    if valid_loss_min > loss:
      if loss<0.6 and printed_valid_loss_min*0.95 > loss:
        printed_valid_loss_min = loss
        print(f'Find new local minima of valid loss : {loss : .3f}')
      valid_loss_min = loss
      save_state = [deepcopy(en.state_dict()), deepcopy(de.state_dict())]

    valid_loss.append(loss)
    valid_auroc.append(binary_auroc(out.view(-1), data.y).item())

  train_duration[key][nlayer] += time()-start
  final_state=[deepcopy(en.state_dict()), deepcopy(de.state_dict())]
  en.load_state_dict(save_state[0])
  de.load_state_dict(save_state[1])
  plot_loss_auroc()

In [None]:
def cmp_train_loss_auroc(epochs, lr_val = None, cont=False):
  global train_loss, valid_loss, train_auroc, valid_auroc, lr, optimizer, final_state, de
  if lr_val is not None:
    lr = lr_val
  if cont is False:
    optimizer = optim.Adam(list(de.parameters()), lr=lr)
  else:
    de.load_state_dict(final_state)
  valid_loss_min = 100
  printed_valid_loss_min = 100
  de = de.to(device)
  start=time()
  for it in range(epochs):
    de.train()
    loss_sum = 0
    auroc_sum = 0
    train_len = 0
    for data, label in cmp_train_loader:
      data = data.to(device)
      label = label.to(device)
      out = de(data)
      optimizer.zero_grad()
      loss = loss_function(out, label.view(-1,1)).sum()
      loss_sum += loss.item()
      loss.backward() #retain_graph=True needed?
      optimizer.step()
      auroc_sum += binary_auroc(out.view(-1), label).item() * label.shape[0]
      train_len += label.shape[0]
    train_loss.append(loss_sum/train_len)
    train_auroc.append(auroc_sum/train_len)

    de.eval()
    data, label = next(iter(cmp_valid_loader))
    data, label = data.to(device), label.to(device)
    out = de(data)
    loss = loss_function(out, label.view(-1,1)).sum()
    loss = loss.item()/label.shape[0]

    if it == 0:
      print(f'Estimated training time : {(time()-start)*epochs} secs.')

    if valid_loss_min > loss:
      if loss<0.6 and printed_valid_loss_min*0.9 > loss:
        printed_valid_loss_min = loss
        print(f'Find new local minima of valid loss : {loss : .3f}')
      valid_loss_min = loss
      save_state = deepcopy(de.state_dict())
    valid_loss.append(loss)
    valid_auroc.append(binary_auroc(out.view(-1), label).item())

  final_state=deepcopy(de.state_dict())
  de.load_state_dict(save_state)
  plot_loss_auroc()

In [None]:
def plot_loss_auroc():
  fig, ax1 = plt.subplots()

  color = 'tab:red'
  ax1.set_xlabel('epochs')
  ax1.set_ylabel('Mean Loss', color=color)
  l1 = ax1.plot(range(len(train_loss)), train_loss, color = color, label='Train set-loss')
  l2 = ax1.plot(range(len(valid_loss)), valid_loss, color = 'tab:orange', linestyle='dashed', label='Valid set-loss')
  ax1.tick_params(axis='y', labelcolor=color)

  ax2 = ax1.twinx()
  color='tab:blue'
  ax2.set_ylabel('AUROC', color=color)
  l3 = ax2.plot(range(len(train_auroc)), train_auroc, color = color, label='Train set-AUROC')
  l4 = ax2.plot(range(len(valid_auroc)), valid_auroc, color = 'tab:cyan', linestyle='dashed', label='Valid set-AUROC')
  ax2.tick_params(axis='y', labelcolor=color)

  lns = l1 + l2 + l3 + l4
  labs = [l.get_label() for l in lns]
  ax1.legend(lns, labs, loc=7)

  fig.tight_layout()
  plt.show()

In [None]:
def plot_nlayer_auroc():
  nlayers = [i for i in enc[key]]
  plt.plot(nlayers, [test_auroc[key][i] for i in nlayers], color='black', label = key)
  plt.xlabel('# of GCN layers')
  plt.xticks(range(min(nlayers), max(nlayers)+1, 1))
  plt.ylabel('AUROC of the embedding trained with '+key)
  for k in ['HIV', 'HERG', 'CYP', 'TOX', 'BBB', 'HIA']:
    nlayerst = []
    for i in nlayers:
      if k in cmp_test_auroc[key][i]:
        nlayerst.append(i)
    if len(nlayerst)>0:
      plt.plot(nlayerst, [cmp_test_auroc[key][i][k] for i in nlayerst], linestyle='dashed', label='cmp: '+k)
  plt.legend()
  plt.show()

In [None]:
def get_eval():
  global test_loss, test_auroc, enc, dec
  en.eval()
  de.eval()
  data = next(iter(test_loader)).to(device)
  out = de(en(data.x, data.edge_index, data.batch))
  test_loss[key][nlayer] = loss_function(out, data.y.view(-1,1)).mean().item()
  test_auroc[key][nlayer] = binary_auroc(out.view(-1), data.y).item()
  enc[key][nlayer] = en.cpu()
  dec[key][nlayer] = de.cpu()
  torch.save(enc[key][nlayer].state_dict(), 'encoder_state_'+key+'_'+str(nlayer)+'_'+f'{test_auroc[key][nlayer] : .3f}'+'.pt')
  files.download('encoder_state_'+key+'_'+str(nlayer)+'_'+f'{test_auroc[key][nlayer] : .3f}'+'.pt')
  torch.save(dec[key][nlayer].state_dict(), 'decoder_state_'+key+'_'+str(nlayer)+'_'+f'{test_auroc[key][nlayer] : .3f}'+'.pt')
  files.download('decoder_state_'+key+'_'+str(nlayer)+'_'+f'{test_auroc[key][nlayer] : .3f}'+'.pt')
  torch.save(optimizer.state_dict(), 'optimizer_state_'+key+'_'+str(nlayer)+'_'+f'{test_auroc[key][nlayer] : .3f}'+'.pt')
  files.download('optimizer_state_'+key+'_'+str(nlayer)+'_'+f'{test_auroc[key][nlayer] : .3f}'+'.pt')

  torch.save([train_loss, train_auroc, valid_loss, valid_auroc], 'training mem_'+key+'_'+str(nlayer)+'_'+f'{test_auroc[key][nlayer] : .3f}'+'.pt')
  files.download('training mem_'+key+'_'+str(nlayer)+'_'+f'{test_auroc[key][nlayer] : .3f}'+'.pt')

  print(f'Test loss : {test_loss[key][nlayer] : .3f}, Test auroc : {test_auroc[key][nlayer] : .3f}')

In [None]:
def fp_get_eval():
  global de
  de.eval()
  data, label = next(iter(cmp_test_loader))
  data, label = data.to(device), label.to(device)
  out = de(data)
  loss = loss_function(out, label.view(-1,1)).mean().item()
  auroc = binary_auroc(out.view(-1), label).item()

  de = de.cpu()
  torch.save(de, 'decoder_'+fp_key+'_to_'+cmp_key+'_'+f'{auroc : .3f}'+'.pt')
  files.download('decoder_'+fp_key+'_to_'+cmp_key+'_'+f'{auroc : .3f}'+'.pt')

  torch.save([train_loss, train_auroc, valid_loss, valid_auroc], 'training mem_'+fp_key+'_to_'+cmp_key+'_'+f'{auroc : .3f}'+'.pt')
  files.download('training mem_'+fp_key+'_to_'+cmp_key+'_'+f'{auroc : .3f}'+'.pt')

  print(f'Cmp Test loss : {loss : .3f}, Cmp Test auroc : {auroc : .3f}')

In [None]:
def cmp_get_eval():
  global cmp_test_loss, cmp_test_auroc
  de.eval()
  data, label = next(iter(cmp_test_loader))
  data, label = data.to(device), label.to(device)
  out = de(data)
  cmp_test_loss[key][nlayer][cmp_key] = loss_function(out, label.view(-1,1)).mean().item()
  cmp_test_auroc[key][nlayer][cmp_key] = binary_auroc(out.view(-1), label).item()

  cmp_dec[key][nlayer][cmp_key] = de.cpu()
  torch.save(cmp_dec[key][nlayer][cmp_key], 'decoder_'+key+'_'+str(nlayer)+'_to_'+cmp_key+'_'+f'{cmp_test_auroc[key][nlayer][cmp_key] : .3f}'+'.pt')
  files.download('decoder_'+key+'_'+str(nlayer)+'_to_'+cmp_key+'_'+f'{cmp_test_auroc[key][nlayer][cmp_key] : .3f}'+'.pt')

  torch.save([train_loss, train_auroc, valid_loss, valid_auroc], 'training mem_'+key+'_'+str(nlayer)+'_to_'+cmp_key+'_'+f'{cmp_test_auroc[key][nlayer][cmp_key] : .3f}'+'.pt')
  files.download('training mem_'+key+'_'+str(nlayer)+'_to_'+cmp_key+'_'+f'{cmp_test_auroc[key][nlayer][cmp_key] : .3f}'+'.pt')

  print(f'Cmp Test loss : {cmp_test_loss[key][nlayer][cmp_key] : .3f}, Cmp Test auroc : {cmp_test_auroc[key][nlayer][cmp_key] : .3f}')

In [None]:
def prepare_distance(key_val, data_type, fp_type, emb_metric='euclidean'):
  #Should set en before calling this function
  global key, fp, d_fp, d_emb, d_arr_fp, darr_fp00, darr_fp01, darr_fp11, darr_emb, darr_emb00, darr_emb01, darr_emb11
  key = key_val
  fpgen = {'top':AllChem.GetRDKitFPGenerator(fpSize=2048),
           'ap':AllChem.GetAtomPairGenerator(fpSize=2048),
           'tt':AllChem.GetTopologicalTorsionGenerator(fpSize=2048),
           'morgan':AllChem.GetMorganGenerator(fpSize=2048),
           }.get(fp_type, None)
  if fpgen is not None:
    fp = [fpgen.GetFingerprint(m) for m in mol[key][data_type]]
  elif fp_type == 'maccs':
    fp = [MACCSkeys.GenMACCSKeys(m) for m in mol[key][data_type]]
  else:
    fp = None
  darr_fp = []
  for i in range(len(fp)):
    darr_fp += [1-DataStructs.TanimotoSimilarity(fp[i], fp[j]) for j in range(i+1, len(fp))]
  darr_fp = np.array(darr_fp)
  d_fp = squareform(darr_fp)
  #d_fp = np.matrix([[1-DataStructs.TanimotoSimilarity(fp[i], fp[j]) for j in range(i)] + [0]*(len(fp)-i) for i in range(len(fp))])
  #d_fp += d_fp.T

  en = en.to(device)
  ldr = DataLoader(dataset[key][data_type], batch_size=len(dataset[key][data_type]), shuffle=False)
  data = next(iter(ldr)).to(device)
  out = enc(data.x, data.edge_index, data.batch)
  darr_emb = pdist(out.detach.numpy(), metric=emb_metric)
  d_emb = squareform(darr_emb)
  #d_emb = np.matrix([[sum((out[i]-out[j])**2).item() for j in range(i)] + [0]*(out.shape[0]-i) for i in range(out.shape[0])])
  #d_emb += d_emb.T

  x,y = np.meshgrid(data.y, data.y)
  d_label = np.matrix((x!=y).astype(int))

  zero = np.where(data.y==0)
  one = np.where(data.y==1)

  x,y=np.meshgrid(zero,zero)
  x,y=x.flatten(), y.flatten()
  darr_fp00 = np.squeeze(np.asarray(d_fp[(x,y)]))
  darr_emb00 = np.squeeze(np.asarray(d_emb[(x,y)]))
  x,y=np.meshgrid(zero,one)
  x,y=x.flatten(), y.flatten()
  darr_fp01 = np.squeeze(np.asarray(d_fp[(x,y)]))
  darr_emb01 = np.squeeze(np.asarray(d_emb[(x,y)]))
  x,y=np.meshgrid(one,one)
  x,y=x.flatten(), y.flatten()
  darr_fp11 = np.squeeze(np.asarray(d_fp[(x,y)]))
  darr_emb11 = np.squeeze(np.asarray(d_emb[(x,y)]))

  # darr_fp = np.squeeze(np.asarray(d_morgan[np.triu_indices(len(dataset[key][data_type]))]))
  # darr_emb = np.squeeze(np.asarray(d_herg[np.triu_indices(len(dataset[key][data_type]))]))

In [None]:
#HIV, HERG, CYP, TOX, BBB, HIA
dataset={}
mol={}
for keys in ['HIV', 'HERG', 'CYP', 'TOX', 'BBB', 'HIA']:
  dataset[keys]={}
  mol[keys]={}
#[dataset['HIV']['train'], dataset['HIV']['test'], dataset['HIV']['valid'], mol['HIV']['train'], mol['HIV']['test'], mol['HIV']['valid']] = prepare_datasets(HIV)
# [dataset['HERG']['train'], dataset['HERG']['test'], dataset['HERG']['valid'], mol['HERG']['train'], mol['HERG']['test'], mol['HERG']['valid']] = prepare_datasets(HERG)
[dataset['CYP']['train'], dataset['CYP']['test'], dataset['CYP']['valid'], mol['CYP']['train'], mol['CYP']['test'], mol['CYP']['valid']] = prepare_datasets(CYP)
#[dataset['TOX']['train'], dataset['TOX']['test'], dataset['TOX']['valid'], mol['TOX']['train'], mol['TOX']['test'], mol['TOX']['valid']] = prepare_datasets(TOX)
[dataset['BBB']['train'], dataset['BBB']['test'], dataset['BBB']['valid'], mol['BBB']['train'], mol['BBB']['test'], mol['BBB']['valid']] = prepare_datasets(BBB)
[dataset['HIA']['train'], dataset['HIA']['test'], dataset['HIA']['valid'], mol['HIA']['train'], mol['HIA']['test'], mol['HIA']['valid']] = prepare_datasets(HIA)

In [None]:
fp={}
d_fp = {}
darr_fp = {}
darr_fp00 = {}
darr_fp01 = {}
darr_fp11 = {}
d_emb = {}
darr_emb = {}
darr_emb00 = {}
darr_emb01 = {}
darr_emb11 = {}
d_label = {}
#for k in ['CYP','BBB','HIA']:
for k in ['CYP']:
  fp[k]={}
  d_fp[k]={}
  darr_fp[k]={}
  darr_fp00[k]={}
  darr_fp01[k]={}
  darr_fp11[k]={}



  for fpk in ['top','ap','tt','morgan','maccs']:
    fpgen = {'top':AllChem.GetRDKitFPGenerator(fpSize=2048),
            'ap':AllChem.GetAtomPairGenerator(fpSize=2048),
            'tt':AllChem.GetTopologicalTorsionGenerator(fpSize=2048),
            'morgan':AllChem.GetMorganGenerator(fpSize=2048),
            }.get(fpk, None)
    if fpk != 'maccs':
      fps = [fpgen.GetFingerprint(m) for m in mol[k]['test']]
    else:
      fps = [MACCSkeys.GenMACCSKeys(m) for m in mol[k]['test']]
    darr_fp[k][fpk] = []
    for i in range(len(fps)):
      darr_fp[k][fpk] += [1-DataStructs.TanimotoSimilarity(fps[i], fps[j]) for j in range(i+1, len(fps))]
    darr_fp[k][fpk] = np.array(darr_fp[k][fpk])
    d_fp[k][fpk] = squareform(darr_fp[k][fpk])

  decide_dataset(k)
  decide_nlayer({'CYP' : 4, 'BBB' :6, 'HIA':7}.get(k, None))
  en.load_state_dict(torch.load({
      'CYP' : 'encoder_state_CYP_4_ 0.847.pt',
      'BBB' : 'encoder_state_BBB_6_ 0.871.pt',
      'HIA' : 'encoder_state_HIA_7_ 0.973.pt'
  }.get(k,None)))
  en = en.to(device)
  ldr = DataLoader(dataset[k]['test'], batch_size=len(dataset[k]['test']), shuffle=False)
  data = next(iter(ldr)).to(device)
  out = en(data.x, data.edge_index, data.batch)
  darr_emb[k] = pdist(out.detach().numpy(), metric='euclidean')
  d_emb[k] = squareform(darr_emb[k])

  x,y = np.meshgrid(data.y, data.y)
  d_label[k] = np.matrix((x!=y).astype(int))

  zero = np.where(data.y==0)
  one = np.where(data.y==1)

  x,y=np.meshgrid(zero,zero)
  x,y=x.flatten(), y.flatten()
  for fpk in ['top','ap','tt','morgan','maccs']:
    darr_fp00[k][fpk] = np.squeeze(np.asarray(d_fp[k][fpk][(x,y)]))
  darr_emb00[k] = np.squeeze(np.asarray(d_emb[k][(x,y)]))

  x,y=np.meshgrid(zero,one)
  x,y=x.flatten(), y.flatten()
  for fpk in ['top','ap','tt','morgan','maccs']:
    darr_fp01[k][fpk] = np.squeeze(np.asarray(d_fp[k][fpk][(x,y)]))
  darr_emb01[k] = np.squeeze(np.asarray(d_emb[k][(x,y)]))

  x,y=np.meshgrid(one,one)
  x,y=x.flatten(), y.flatten()
  for fpk in ['top','ap','tt','morgan','maccs']:
    darr_fp11[k][fpk] = np.squeeze(np.asarray(d_fp[k][fpk][(x,y)]))
  darr_emb11[k] = np.squeeze(np.asarray(d_emb[k][(x,y)]))

In [None]:
def khist(k):
  xname = 'Distance'
  lname = 'Embedding w/ '+k
  data = {
      xname : np.concatenate([darr_emb00[k],darr_emb11[k],darr_emb01[k]]),
      lname : ['same label']*(len(darr_emb00[k]) + len(darr_emb11[k])) + ['different label']* len(darr_emb01[k])
  }
  df = pd.DataFrame(data)
  df=df.replace(.0,np.nan)
  plt.figure()
  ax = sns.displot(df, x= xname,hue=lname,kind="kde",common_norm=False)
  sns.move_legend(ax, 'center right')
  plt.axvline((np.concatenate([darr_emb00[k],darr_emb11[k]])).mean(),c='blue',ls=':',lw=2.5)
  plt.axvline((darr_emb01[k]).mean(),c='orange',ls=':',lw=2.5)
  plt.savefig('khist_'+k+'.png')
  plt.close()
  files.download('khist_'+k+'.png')

def fphist(k,fpk):
  xname = 'Distance'
  lname = fpk+' fingerprint w/ '+k
  data = {
      xname : np.concatenate([darr_fp00[k][fpk],darr_fp11[k][fpk],darr_fp01[k][fpk]]),
      lname : ['same label']*(len(darr_fp00[k][fpk]) + len(darr_fp11[k][fpk])) + ['different label']* len(darr_fp01[k][fpk])
  }
  df = pd.DataFrame(data)
  df=df.replace(.0,np.nan)
  plt.figure()
  ax = sns.displot(df, x= xname,hue=lname,kind="kde",common_norm=False)
  sns.move_legend(ax, 'center right')
  plt.axvline((np.concatenate([darr_fp00[k][fpk],darr_fp11[k][fpk]])).mean(),c='blue',ls=':',lw=2.5)
  plt.axvline((darr_fp01[k][fpk]).mean(),c='orange',ls=':',lw=2.5)
  plt.savefig('fphist_'+k+'_'+fpk+'.png')
  plt.close()
  files.download('fphist_'+k+'_'+fpk+'.png')

def dcorrplot(k,fpk,fpk2=None):
  if fpk2 is None:
    xname = 'Distance from embedding trained on '+k
    yname = 'Distance from '+fpk+' fingerprint'
    lname = k
    logpair = (False,False)
    savename = 'dcorrplot_'+k+'_'+fpk+'.png'
    data = {
        xname : np.concatenate([darr_emb00[k],darr_emb11[k],darr_emb01[k]]),
        yname : np.concatenate([darr_fp00[k][fpk],darr_fp11[k][fpk],darr_fp01[k][fpk]]),
        lname : ['same label']*(len(darr_fp00[k][fpk]) + len(darr_fp11[k][fpk])) + ['different label']* len(darr_fp01[k][fpk])
    }
  else:
    xname = 'Distance from '+fpk+' fingerprint'
    yname = 'Distance from '+fpk2+' fingerprint'
    lname = k
    logpair=None
    savename = 'dcorrplot of two fps_'+k+'_'+fpk+'_'+fpk2+'.png'
    data = {
        xname : np.concatenate([darr_fp00[k][fpk],darr_fp11[k][fpk],darr_fp01[k][fpk]]),
        yname : np.concatenate([darr_fp00[k][fpk2],darr_fp11[k][fpk2],darr_fp01[k][fpk2]]),
        lname : ['same label']*(len(darr_fp00[k][fpk]) + len(darr_fp11[k][fpk])) + ['different label']* len(darr_fp01[k][fpk])
    }
  df = pd.DataFrame(data)
  df=df.replace(.0,np.nan)
  plt.figure()
  ax = sns.displot(df, x= xname,y=yname,hue=lname,kind="kde",rug=True,rug_kws={"alpha" : 0.01},common_norm=False,fill=True,levels=15,alpha=.5,log_scale=logpair)
  sns.move_legend(ax, 'center right')
  plt.savefig(savename)
  plt.close()
  files.download(savename)

dcorr = Dcorr(compute_distance=None)
mgc = MGC(compute_distance=None)
def getstat(k, fpk=None, fpk2=None):
  if fpk is None:
    d1 =d_emb[k]
    d2 = d_label[k]
  elif fpk2 is None:
    d1 = d_emb[k]
    d2 = d_fp[k][fpk]
    savename = 'MGC_'+k+'_'+fpk+'.png'
  else:
    d1 = d_fp[k][fpk]
    d2 = d_fp[k][fpk2]
    savename = 'MGC of two fps_'+k+'_'+fpk+'_'+fpk2+'.png'

  stat,pvalue = dcorr.test(d1, d2)
  print('[Dcorr] stat : %.5f, pvalue : %.5f'%(stat,pvalue))

  if fpk is not None:
    stat, pvalue, mgc_dict = mgc.test(d1,d2, reps=0)
    print('[MGC] stat : %.5f, pvalue : %.5f'%(stat,pvalue))

    plt.figure()
    # make plots look pretty
    sns.set(color_codes=True, style="white", context="talk", font_scale=1)

    mgc_map = mgc_dict["mgc_map"]
    opt_scale = mgc_dict["opt_scale"]  # i.e. maximum smoothed test statistic
    print("Optimal Scale:", opt_scale)

    # create figure
    fig, (ax, cax) = plt.subplots(
        ncols=2, figsize=(9, 8.5), gridspec_kw={"width_ratios": [1, 0.05]}
    )

    # draw heatmap and colorbar
    ax = sns.heatmap(mgc_map, cmap="YlGnBu", ax=ax, cbar=False)
    fig.colorbar(ax.get_children()[0], cax=cax, orientation="vertical")
    ax.invert_yaxis()

    # optimal scale
    ax.scatter(opt_scale[1], opt_scale[0], marker="X", s=200, color="red")

    # make plots look nice
    ax.set_title("MGC Map")
    ax.xaxis.set_major_locator(ticker.MultipleLocator(10))
    ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
    ax.yaxis.set_major_locator(ticker.MultipleLocator(10))
    ax.yaxis.set_major_formatter(ticker.ScalarFormatter())
    ax.set_xlabel("Neighbors for x")
    ax.set_ylabel("Neighbors for y")
    N=d_emb[k].shape[0]
    ax.set_xticks([0, N/2, N])
    ax.set_yticks([0, N/2, N])
    ax.xaxis.set_tick_params()
    ax.yaxis.set_tick_params()
    cax.xaxis.set_tick_params()
    cax.yaxis.set_tick_params()

    plt.savefig(savename)
    plt.close()
    files.download(savename)

# Training

## Train encoder & decoder

In [None]:
decide_dataset('HERG')

In [None]:
decide_nlayer(5)

In [None]:
train_loss_auroc(10,0.01, True)
plot_loss_auroc()

In [None]:
get_eval()

In [None]:
#You can also upload en, de that was saved at get_eval()
en=torch.load('encoder_HERG_3.pt')
de=torch.load('decoder_HERG_3.pt')

In [None]:
train_loss_auroc(10,0.01, True)
plot_loss_auroc()

In [None]:
get_eval()

In [None]:
decide_nlayer(7)

In [None]:
train_loss_auroc(10,0.01, True)
plot_loss_auroc()

In [None]:
get_eval()

In [None]:
decide_dataset('HIV')

In [None]:
decide_nlayer(5)

In [None]:
train_loss_auroc(10,0.01, True)
plot_loss_auroc()

In [None]:
get_eval()

## Train decoder (with pre-trained encoder)

In [None]:
decide_cmp_dataset('BBB')

In [None]:
cmp_train_loss_auroc(1000, 0.001)

In [None]:
plot_loss_auroc()
cmp_opt_epochs[key][nlayer][cmp_key] = 800

You can re-train after getting the optimal epoch to stop.

In [None]:
decide_cmp_dataset('BBB', reset=True)
cmp_train_loss_auroc(100,0.001)
plot_loss_auroc()

In [None]:
cmp_train_loss_auroc(cmp_opt_epochs[key][nlayer][cmp_key] -100, 0.001)
plot_loss_auroc()
cmp_get_eval()

# Characterization

In [None]:
khist('HIA')

In [None]:
getstat('HIA','maccs','top')
getstat('HIA','maccs')

In [None]:
dcorrplot('BBB','top')

In [None]:
dcorrplot('HIA','top','maccs')

In [None]:
fphist('HIA','maccs')

In [None]:
plt.figure()
plt.hist(darr_emb00['HIA'],bins=50,alpha=1,label='00')
plt.hist(darr_emb01['HIA'],bins=50,alpha=0.67,label='01')
plt.hist(darr_emb11['HIA'],bins=50,alpha=0.33,label='11')
plt.legend()
plt.show()

In [None]:
k = 'HIA'
khist(k)
for fpk in ['top','ap','tt','morgan','maccs']:
  fphist(k,fpk)

In [None]:
k='HIA'
for fpk in ['top','ap','tt','morgan','maccs']:
  print(k, fpk)
  dcorrplot(k,fpk)
  getstat(k,fpk)

In [None]:
getstat('HIA')
getstat('BBB')
getstat('CYP')

In [None]:
decide_dataset('CYP')
ldr = DataLoader(dataset['CYP']['test'], batch_size=len(dataset['CYP']['test']), shuffle=False)
data = next(iter(ldr)).to(device)
darr_cyp={}
darr_cyp00={}
darr_cyp01={}
darr_cyp11={}
d_cyp = {}
for n in [2,3,4,5,6,7,8,9,11,13,15,20,30]:
  decide_nlayer(n)
  tmp = {
      2:'0.824',
      3:'0.843',
      4:'0.847',
      5:'0.845',
      6:'0.823',
      7:'0.830',
      8:'0.840',
      9:'0.832',
      11:'0.835',
      13:'0.810',
      15:'0.826',
      20:'0.814',
      30:'0.793'
  }.get(n,None)
  en.load_state_dict(torch.load('encoder_state_CYP_'+str(n)+'_ '+tmp+'.pt'))
  en = en.to(device)
  out = en(data.x, data.edge_index, data.batch)
  darr_cyp[n] = pdist(out.detach().numpy(), metric='euclidean')
  d_cyp[n] = squareform(darr_cyp[n])

  zero = np.where(data.y==0)
  one = np.where(data.y==1)

  x,y=np.meshgrid(zero,zero)
  x,y=x.flatten(), y.flatten()
  darr_cyp00[n] = np.squeeze(np.asarray(d_cyp[n][(x,y)]))

  x,y=np.meshgrid(zero,one)
  x,y=x.flatten(), y.flatten()
  darr_cyp01[n] = np.squeeze(np.asarray(d_cyp[n][(x,y)]))

  x,y=np.meshgrid(one,one)
  x,y=x.flatten(), y.flatten()

  darr_cyp11[n] = np.squeeze(np.asarray(d_cyp[n][(x,y)]))

In [None]:
for n in [2,3,4,5,6,7,8,9,11,13,15,20,30]:
  xname = 'Distance'
  lname = 'Embedding w/ '+k+' and #layer = '+str(n)
  data = {
      xname : np.concatenate([darr_cyp00[n],darr_cyp11[n],darr_cyp01[n]]),
      lname : ['same label']*(len(darr_cyp00[n]) + len(darr_cyp11[n])) + ['different label']* len(darr_cyp01[n])
  }
  df = pd.DataFrame(data)
  df=df.replace(.0,np.nan)
  plt.figure()
  ax = sns.displot(df, x= xname,hue=lname,kind="kde",common_norm=False)
  sns.move_legend(ax, 'center right')
  plt.axvline((np.concatenate([darr_cyp00[n],darr_cyp11[n]])).mean(),c='blue',ls=':',lw=2.5)
  plt.axvline((darr_cyp01[n]).mean(),c='orange',ls=':',lw=2.5)
  plt.savefig('khist_'+k+'_layer'+str(n)+'.png')
  plt.close()
  files.download('khist_'+k+'_layer'+str(n)+'.png')

In [None]:
for n in [2,3,4,5,6,7,8,9,11,13,15,20,30]:
  print('Layer #: ',n)
  stat,pvalue = dcorr.test(d_cyp[n], d_label['CYP'])
  print('[Dcorr-label] stat : %.5f, pvalue : %.5f'%(stat,pvalue))
  for fpk in ['top','ap','tt','morgan','maccs']:
    print(fpk)
    stat,pvalue = dcorr.test(d_cyp[n], d_fp['CYP'][fpk])
    print('[Dcorr] stat : %.5f, pvalue : %.5f'%(stat,pvalue))

In [None]:
for n in [2,3,4,5,6,7,8,9,11,13,15,20,30]:
  print('Layer #: ',n)
  for fpk in ['top','ap','tt','morgan','maccs']:
    savename = 'MGC_CYP_'+str(n)+'_'+fpk+'.png'
    print(fpk)

    stat, pvalue, mgc_dict = mgc.test(d_cyp[n], d_fp['CYP'][fpk], reps=10)
    print('[MGC] stat : %.5f, pvalue : %.5f'%(stat,pvalue))

    plt.figure()
    # make plots look pretty
    sns.set(color_codes=True, style="white", context="talk", font_scale=1)

    mgc_map = mgc_dict["mgc_map"]
    opt_scale = mgc_dict["opt_scale"]  # i.e. maximum smoothed test statistic
    print("Optimal Scale:", opt_scale)

    # create figure
    fig, (ax, cax) = plt.subplots(
        ncols=2, figsize=(9, 8.5), gridspec_kw={"width_ratios": [1, 0.05]}
    )

    # draw heatmap and colorbar
    ax = sns.heatmap(mgc_map, cmap="YlGnBu", ax=ax, cbar=False)
    fig.colorbar(ax.get_children()[0], cax=cax, orientation="vertical")
    ax.invert_yaxis()

    # optimal scale
    ax.scatter(opt_scale[1], opt_scale[0], marker="X", s=200, color="red")

    # make plots look nice
    ax.set_title("MGC Map")
    ax.xaxis.set_major_locator(ticker.MultipleLocator(10))
    ax.xaxis.set_major_formatter(ticker.ScalarFormatter())
    ax.yaxis.set_major_locator(ticker.MultipleLocator(10))
    ax.yaxis.set_major_formatter(ticker.ScalarFormatter())
    ax.set_xlabel("Neighbors for x")
    ax.set_ylabel("Neighbors for y")
    N=d_emb[k].shape[0]
    ax.set_xticks([0, N/2, N])
    ax.set_yticks([0, N/2, N])
    ax.xaxis.set_tick_params()
    ax.yaxis.set_tick_params()
    cax.xaxis.set_tick_params()
    cax.yaxis.set_tick_params()

    plt.savefig(savename)
    plt.close()
    files.download(savename)