In [1]:
import math

import torch
import torch.nn as nn

from diffusers.optimization import get_scheduler
from tqdm import tqdm

import pandas as pd
import numpy as np

import random

import os
device = "cuda"
def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

set_seed(0)


  from .autonotebook import tqdm as notebook_tqdm


Random seed set as 0


In [2]:
save_location = "/media/bhux/alpha/xsd_mvp/3000e_optimal_cat_normalized/"

# uniform, lg, gaussian

NUMERICAL = 12
CATEGORICAL = 0

NUM_BINS = 15

INFILLING_TYPE = ''
NOISE_TYPE = ''

In [3]:
from utils import remove_outliers, norm, sample_local_gaussian, convert_categorical, unison_shuffled_copies, categorical_norm

PATH = "./data/xtended_data_all.csv"
EMB_PATH = "./data/xtended_emb_all_deberta_pubchem.npy"

f = pd.read_csv(PATH)
drug_embeddings = np.load(EMB_PATH)
smiles = f['Drug'].values
vlists = {
    col: f[col].values for col in f.drop(labels=['Drug'], axis=1).columns[:NUMERICAL] 
}

inmask = remove_outliers([v for _,v in vlists.items()])

smiles = smiles[inmask]
vlists = {
    k: v[inmask] for k,v in vlists.items()
}

vlists = {
    k: norm(v) for k,v in vlists.items()
}

for col in f.drop(labels=['Drug'], axis=1).columns[NUMERICAL:NUMERICAL+CATEGORICAL]:
    vlists[col] = f[col].values[inmask]

nullmask = np.stack([
    np.isnan(v)==False for _,v in vlists.items()
    ], axis=-1)

for col in f.drop(labels=['Drug'], axis=1).columns[:NUMERICAL]:
    vlists[col+"_cat"], vlists[col] = categorical_norm(vlists[col], numbins=NUM_BINS)

dmss = []
for k,v in vlists.items():
    vlists[k], dms = sample_local_gaussian(v, numbins=NUM_BINS)
    dmss.append(dms)
    print(k)


# for col in f.drop(labels=['Drug'], axis=1).columns[NUMERICAL:]:
#     nan = np.isnan(vlists[col])
#     vlists[col] += 1
#     vlists[col][nan] = 0

# dmss = []
# for k,v in vlists.items():
#     dms = get_local_gaussian(v, numbins=50)
#     dmss.append(dms)

dataset = []
for i, gt in enumerate(zip(*[v for _,v in vlists.items()])):
    dataset.append({
        "sm": smiles[i],
        "ft": drug_embeddings[i],
        "ma": nullmask[i],
        "gt": np.array(gt),
        "od": np.array(gt[NUMERICAL:]),
    })
    # print(gt)
    # print(nullmask[i])
    # break

valCount = np.sum(nullmask, axis=0)*0.1
dataset, rcomb = unison_shuffled_copies(dataset, nullmask)
trdataset = []
valdataset = []
for c, d in zip(rcomb, dataset):
    inc = False
    for i, j in enumerate(list(c)):
        if j and valCount[i] > 0:
            valCount[i] -= 1
            inc = True
    if inc:
        valdataset.append(d)
    else:
        trdataset.append(d)

print(len(trdataset))
print(len(valdataset))
print(len(list(vlists.keys())))

Caco2_Wang
Lipophilicity_AstraZeneca
Solubility_AqSolDB
HydrationFreeEnergy_FreeSolv
PPBR_AZ
VDss_Lombardo
Half_Life_Obach
Clearance_Hepatocyte_AZ
Clearance_Microsome_AZ
LD50_Zhu
herg_central_hERG_at_1uM
herg_central_hERG_at_10uM
Caco2_Wang_cat
Lipophilicity_AstraZeneca_cat
Solubility_AqSolDB_cat
HydrationFreeEnergy_FreeSolv_cat
PPBR_AZ_cat
VDss_Lombardo_cat


  m.append(data.mean() if not np.isnan(data.mean()) else low)
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)


Half_Life_Obach_cat
Clearance_Hepatocyte_AZ_cat
Clearance_Microsome_AZ_cat
LD50_Zhu_cat
herg_central_hERG_at_1uM_cat
herg_central_hERG_at_10uM_cat
30264
2864
24


In [4]:
from torch.utils.data import Dataset

class GaucamolDataset(Dataset):
    def __init__(self, dataset) -> None:
        self.dataset = dataset
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]
    
    def update(self, idx, delta):
        item = self.dataset[idx]["gt"]
        self.dataset[idx]["gt"] = item + delta

trainset = GaucamolDataset(trdataset)
valset = GaucamolDataset(valdataset)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=True)
steps_per_epoch = len(trainset)
DMSS = dmss[:NUMERICAL]

In [5]:
def train(diffusion, ema, gamma, dataloader, optimizer, lr_scheduler, two_noise=False):
    diffusion.train()
    running_loss = 0
    global_step = 0
    for i, batch in enumerate(tqdm(dataloader)):
        ft = batch['ft'].to(device).float()
        gt = batch['gt'].to(device).float()
        od = batch['od'].to(device).long()
        mask = batch['ma'].to(device)
        bs = ft.shape[0]

        optimizer.zero_grad()
        loss_multi, loss_gauss = diffusion.mixed_loss(ft, gt, od, mask, DMSS)

        loss = loss_multi + loss_gauss
        
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        ema.update_params(gamma)
        gamma = ema.update_gamma(global_step)

        running_loss += loss.item()
        global_step += 1
    return running_loss/global_step

In [6]:
from sklearn.metrics import mean_squared_error
import csv
from utils import ohe_to_categories

def evaluate(e, ema, dataloader):
    ema.ema_model.eval()
    before_mse = 0
    running_mse = 0
    global_step = 0
    vals = {}
    device = 'cuda'
    ema.ema_model.to(device)
    with torch.no_grad():
        for i, batch in enumerate(tqdm(dataloader)):
            sm = batch['sm']
            mask = batch['ma'].repeat(1,2)
            ft = batch['ft'].to(device).float()
            gt = batch['gt'].to(device).float()
            od = batch['od'].to(device).long()
            bs = ft.shape[0]

            x_in, generated_ys = ema.ema_model.sample(ft, bs, od, DMSS, clip_sample=True)

            raw_mse = mean_squared_error(gt[mask].flatten().cpu(), x_in[mask].flatten().cpu())
            mse = mean_squared_error(gt[mask].flatten().cpu(), generated_ys[mask].flatten().cpu())

            for s, g in zip(sm, list(generated_ys.cpu().numpy())):
                vals[s] = g
            
            before_mse += raw_mse
            running_mse += mse
            global_step += 1

    with open(save_location+'{}_dict.csv'.format(e), 'w') as csv_file:  
        writer = csv.writer(csv_file)
        for key, value in vals.items():
            writer.writerow([key, value])

    return running_mse / global_step, before_mse / global_step
            

In [7]:
from ema import EMA
lr = 0.00022643741661680812
wd = 0.0004940511856069883
warmup = 50
n_timesteps = 2000
n_inference_timesteps = 150
num_epochs = 30000
update_epochs = 500
update_timesteps = int(num_epochs/update_epochs)
gamma = 0.9739783641481703

In [8]:


from sdt import SDT
from diffusion import GaussianMultinomialDiffusion
torch.set_printoptions(profile="full")


total_num_steps = (steps_per_epoch * num_epochs)

model = SDT(
    time_dim = 64,
    cond_size = 768,
    patch_size = 64,
    y_dim = NUMERICAL+(NUM_BINS+1)*(NUMERICAL+CATEGORICAL),
    dim = 768,
    depth = 12,
    heads = 7,
    mlp_dim = 512,
    dropout =  0.16614739878727047,
    emb_dropout =  0.16614739878727047,
)
model.to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {total_params}")

diffusion = GaussianMultinomialDiffusion(
    num_classes = np.array([(NUM_BINS+1) for _ in range(NUMERICAL+CATEGORICAL)]),
    num_numerical_features = NUMERICAL,
    denoise_fn = model,
    device = device,
)
diffusion.to(device)

ema = EMA(diffusion, gamma, total_num_steps)

optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=wd,
    )

lr_scheduler = get_scheduler(
        "cosine_with_restarts",
        optimizer=optimizer,
        num_warmup_steps=warmup,
        num_training_steps=total_num_steps,
    )

Number of parameters: 40865868
torch.Size([192])


In [None]:
l = ""
best_mse = 1
loss = 0
for e in range(num_epochs):
    loss = train(diffusion, ema, gamma, trainloader, optimizer, lr_scheduler)
    if (e % 10 == 0) and (e > 0):
        mse, bmse = evaluate(e, ema, valloader)
        print(e, "avgloss {}, avgvalmse {}, beforemse: {}".format(loss, mse, bmse))
        l += "{} avgloss {}, avgvalmse {}, beforemse: {}\n".format(e, loss, mse, bmse)

        if mse < best_mse:
            best_mse = mse
            torch.save({
                'e': e,
                'ema_model': ema.ema_model.state_dict(),
                'model': diffusion.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, save_location+"best_model.pt")
    else:
        print(e, "avgloss {}".format(loss))
        l += "{} avgloss {}\n".format(e, loss)

    # if ((e % update_epochs  == 0) and e > 500):
    #     trainloader = update(int((num_epochs-e) / update_epochs), ema, updateloader, trainset, ns, update_timesteps)

    with open(save_location+'output.txt', 'w') as file:
        file.write(l)

100%|██████████| 237/237 [00:11<00:00, 20.01it/s]


0 avgloss 0.6395241586719385


100%|██████████| 237/237 [00:10<00:00, 23.49it/s]


1 avgloss 0.6381343012117635


100%|██████████| 237/237 [00:10<00:00, 23.70it/s]


2 avgloss 0.6316374517694304


100%|██████████| 237/237 [00:09<00:00, 25.26it/s]


3 avgloss 0.6277973905897342


100%|██████████| 237/237 [00:10<00:00, 22.49it/s]


4 avgloss 0.6305067690112923


100%|██████████| 237/237 [00:10<00:00, 23.26it/s]


5 avgloss 0.6228757196337865


100%|██████████| 237/237 [00:11<00:00, 20.06it/s]


6 avgloss 0.6282753148410893


100%|██████████| 237/237 [00:09<00:00, 23.86it/s]


7 avgloss 0.6297401914113685


100%|██████████| 237/237 [00:10<00:00, 23.27it/s]


8 avgloss 0.6274231576718359


100%|██████████| 237/237 [00:10<00:00, 22.61it/s]


9 avgloss 0.623486944014513


100%|██████████| 237/237 [00:09<00:00, 24.23it/s]
100%|██████████| 23/23 [00:37<00:00,  1.61s/it]


10 avgloss 0.6285430082288975, avgvalmse 11.414409667894878, beforemse: 14.288568496704102


100%|██████████| 237/237 [00:10<00:00, 22.94it/s]


11 avgloss 0.6279820988449869


100%|██████████| 237/237 [00:09<00:00, 24.03it/s]


12 avgloss 0.6239260992168877


100%|██████████| 237/237 [00:11<00:00, 21.45it/s]


13 avgloss 0.6282374456461974


100%|██████████| 237/237 [00:10<00:00, 23.59it/s]


14 avgloss 0.62162687632605


100%|██████████| 237/237 [00:09<00:00, 23.74it/s]


15 avgloss 0.627450876095124


100%|██████████| 237/237 [00:10<00:00, 23.43it/s]


16 avgloss 0.6227635836802454


100%|██████████| 237/237 [00:10<00:00, 22.95it/s]


17 avgloss 0.625689890314255


100%|██████████| 237/237 [00:11<00:00, 21.16it/s]


18 avgloss 0.6255097893471456


100%|██████████| 237/237 [00:11<00:00, 20.64it/s]


19 avgloss 0.6234439375028329


100%|██████████| 237/237 [00:11<00:00, 20.16it/s]
100%|██████████| 23/23 [00:37<00:00,  1.61s/it]


20 avgloss 0.6297613578506663, avgvalmse 11.809320513444588, beforemse: 14.615535736083984


100%|██████████| 237/237 [00:10<00:00, 21.71it/s]


21 avgloss 0.6163036722674149


100%|██████████| 237/237 [00:10<00:00, 22.26it/s]


22 avgloss 0.6222974353701757


100%|██████████| 237/237 [00:09<00:00, 24.31it/s]


23 avgloss 0.6300021679089542


100%|██████████| 237/237 [00:10<00:00, 22.75it/s]


24 avgloss 0.6244052356808497


100%|██████████| 237/237 [00:11<00:00, 20.51it/s]


25 avgloss 0.6197955759265755


100%|██████████| 237/237 [00:09<00:00, 24.27it/s]


26 avgloss 0.6295668777031235


100%|██████████| 237/237 [00:11<00:00, 21.30it/s]


27 avgloss 0.625040755870473


100%|██████████| 237/237 [00:11<00:00, 20.67it/s]


28 avgloss 0.6311146733881552


100%|██████████| 237/237 [00:10<00:00, 22.43it/s]


29 avgloss 0.625118241908681


100%|██████████| 237/237 [00:10<00:00, 22.89it/s]
100%|██████████| 23/23 [00:37<00:00,  1.61s/it]


30 avgloss 0.6212587692315066, avgvalmse 11.822517651924088, beforemse: 14.729496002197266


100%|██████████| 237/237 [00:10<00:00, 23.38it/s]


31 avgloss 0.6179171218399258


100%|██████████| 237/237 [00:10<00:00, 21.93it/s]


32 avgloss 0.6296239008631888


100%|██████████| 237/237 [00:11<00:00, 20.56it/s]


33 avgloss 0.6230449645076623


100%|██████████| 237/237 [00:11<00:00, 20.70it/s]


34 avgloss 0.6280535493470445


100%|██████████| 237/237 [00:10<00:00, 22.48it/s]


35 avgloss 0.6246258569920616


100%|██████████| 237/237 [00:11<00:00, 20.66it/s]


36 avgloss 0.61397399467255


100%|██████████| 237/237 [00:10<00:00, 23.52it/s]


37 avgloss 0.6182851191562942


100%|██████████| 237/237 [00:09<00:00, 23.98it/s]


38 avgloss 0.6168110207163332


100%|██████████| 237/237 [00:10<00:00, 22.81it/s]


39 avgloss 0.6248307721021307


100%|██████████| 237/237 [00:10<00:00, 23.33it/s]
100%|██████████| 23/23 [00:37<00:00,  1.62s/it]


40 avgloss 0.6178818939858851, avgvalmse 11.606199131753561, beforemse: 14.700361251831055


100%|██████████| 237/237 [00:10<00:00, 21.88it/s]


41 avgloss 0.6252050885168309


100%|██████████| 237/237 [00:09<00:00, 24.59it/s]


42 avgloss 0.6312661749401173


100%|██████████| 237/237 [00:09<00:00, 23.84it/s]


43 avgloss 0.6219623189435226


100%|██████████| 237/237 [00:10<00:00, 23.69it/s]


44 avgloss 0.6230154636036997


100%|██████████| 237/237 [00:09<00:00, 24.05it/s]


45 avgloss 0.6271216693306774


100%|██████████| 237/237 [00:10<00:00, 22.67it/s]


46 avgloss 0.6212973914065945


100%|██████████| 237/237 [00:10<00:00, 22.42it/s]


47 avgloss 0.6133224815758975


100%|██████████| 237/237 [00:11<00:00, 20.17it/s]


48 avgloss 0.6282566989524455


100%|██████████| 237/237 [00:10<00:00, 22.67it/s]


49 avgloss 0.6209332520448709


100%|██████████| 237/237 [00:09<00:00, 23.92it/s]
100%|██████████| 23/23 [00:37<00:00,  1.63s/it]


50 avgloss 0.6189510569542269, avgvalmse 11.594048586754422, beforemse: 14.716779708862305


100%|██████████| 237/237 [00:11<00:00, 20.87it/s]


51 avgloss 0.6145745029177847


100%|██████████| 237/237 [00:10<00:00, 23.64it/s]


52 avgloss 0.6200942195920501


100%|██████████| 237/237 [00:09<00:00, 23.89it/s]


53 avgloss 0.6201061400180125


100%|██████████| 237/237 [00:09<00:00, 24.88it/s]


54 avgloss 0.6246274741138587


100%|██████████| 237/237 [00:10<00:00, 23.59it/s]


55 avgloss 0.626307441212457


100%|██████████| 237/237 [00:11<00:00, 20.32it/s]


56 avgloss 0.630282379403899


100%|██████████| 237/237 [00:11<00:00, 20.32it/s]


57 avgloss 0.6234408368038226


100%|██████████| 237/237 [00:11<00:00, 21.31it/s]


58 avgloss 0.6218680998444054


100%|██████████| 237/237 [00:09<00:00, 24.10it/s]


59 avgloss 0.6171599471870857


100%|██████████| 237/237 [00:10<00:00, 22.11it/s]
100%|██████████| 23/23 [00:37<00:00,  1.61s/it]


60 avgloss 0.6178791339387371, avgvalmse 11.396386847663127, beforemse: 14.63599967956543


100%|██████████| 237/237 [00:10<00:00, 21.91it/s]


61 avgloss 0.6206945002330506


100%|██████████| 237/237 [00:09<00:00, 23.73it/s]


62 avgloss 0.611452250541011


100%|██████████| 237/237 [00:11<00:00, 20.21it/s]


63 avgloss 0.6315647063124532


100%|██████████| 237/237 [00:10<00:00, 22.65it/s]


64 avgloss 0.6214591677178813


100%|██████████| 237/237 [00:10<00:00, 23.57it/s]


65 avgloss 0.6219578482179199


100%|██████████| 237/237 [00:10<00:00, 23.57it/s]


66 avgloss 0.6194070146305148


100%|██████████| 237/237 [00:09<00:00, 24.82it/s]


67 avgloss 0.6192451114392985


100%|██████████| 237/237 [00:09<00:00, 24.62it/s]


68 avgloss 0.6215079370178754


100%|██████████| 237/237 [00:10<00:00, 21.83it/s]


69 avgloss 0.6167141909589244


100%|██████████| 237/237 [00:11<00:00, 20.53it/s]
100%|██████████| 23/23 [00:37<00:00,  1.62s/it]


70 avgloss 0.635216951873232, avgvalmse 11.461121965976838, beforemse: 14.75210952758789


100%|██████████| 237/237 [00:10<00:00, 22.40it/s]


71 avgloss 0.6220005359579239


100%|██████████| 237/237 [00:10<00:00, 22.51it/s]


72 avgloss 0.6267822067948836


100%|██████████| 237/237 [00:11<00:00, 21.48it/s]


73 avgloss 0.6189477781957715


100%|██████████| 237/237 [00:10<00:00, 21.82it/s]


74 avgloss 0.6113576552032921


100%|██████████| 237/237 [00:10<00:00, 22.95it/s]


75 avgloss 0.6222772114136048


100%|██████████| 237/237 [00:10<00:00, 22.44it/s]


76 avgloss 0.6229729329231922


100%|██████████| 237/237 [00:09<00:00, 24.02it/s]


77 avgloss 0.6157593950943605


100%|██████████| 237/237 [00:10<00:00, 22.81it/s]


78 avgloss 0.6246429882975068


100%|██████████| 237/237 [00:10<00:00, 22.48it/s]


79 avgloss 0.6129164024244381


100%|██████████| 237/237 [00:10<00:00, 22.21it/s]
100%|██████████| 23/23 [00:37<00:00,  1.65s/it]


80 avgloss 0.622816378934474, avgvalmse 11.570579057838371, beforemse: 14.777853012084961


100%|██████████| 237/237 [00:10<00:00, 22.89it/s]


81 avgloss 0.6160459317235504


100%|██████████| 237/237 [00:11<00:00, 20.51it/s]


82 avgloss 0.6181443720930236


100%|██████████| 237/237 [00:10<00:00, 23.67it/s]


83 avgloss 0.6045368946302793


100%|██████████| 237/237 [00:10<00:00, 23.48it/s]


84 avgloss 0.6152875794388574


100%|██████████| 237/237 [00:10<00:00, 22.71it/s]


85 avgloss 0.6307994034219895


100%|██████████| 237/237 [00:10<00:00, 23.35it/s]


86 avgloss 0.6160101845294614


100%|██████████| 237/237 [00:10<00:00, 22.28it/s]


87 avgloss 0.6229154465067739


100%|██████████| 237/237 [00:10<00:00, 22.06it/s]


88 avgloss 0.6274985959006765


100%|██████████| 237/237 [00:10<00:00, 22.77it/s]


89 avgloss 0.616854998381329


100%|██████████| 237/237 [00:10<00:00, 23.28it/s]
100%|██████████| 23/23 [00:38<00:00,  1.67s/it]


90 avgloss 0.6210047849119967, avgvalmse 11.511149643098383, beforemse: 14.76513671875


100%|██████████| 237/237 [00:10<00:00, 21.74it/s]


91 avgloss 0.6194308831470425


100%|██████████| 237/237 [00:09<00:00, 24.30it/s]


92 avgloss 0.6254353181219302


100%|██████████| 237/237 [00:09<00:00, 23.93it/s]


93 avgloss 0.6220478136328202


100%|██████████| 237/237 [00:11<00:00, 20.72it/s]


94 avgloss 0.6172266313295324


100%|██████████| 237/237 [00:10<00:00, 21.91it/s]


95 avgloss 0.6273614995590242


100%|██████████| 237/237 [00:11<00:00, 21.17it/s]


96 avgloss 0.6174813902579279


100%|██████████| 237/237 [00:11<00:00, 20.06it/s]


97 avgloss 0.6148093950144852


100%|██████████| 237/237 [00:10<00:00, 22.34it/s]


98 avgloss 0.6147431077333442


100%|██████████| 237/237 [00:10<00:00, 21.73it/s]


99 avgloss 0.6155470764335198


100%|██████████| 237/237 [00:09<00:00, 24.49it/s]
100%|██████████| 23/23 [00:39<00:00,  1.70s/it]


100 avgloss 0.6202309620782795, avgvalmse 11.37491571187731, beforemse: 14.607922554016113


 47%|████▋     | 111/237 [00:05<00:06, 19.98it/s]