In [1]:
import math

import torch
import torch.nn as nn

from diffusers.optimization import get_scheduler
from tqdm import tqdm

import pandas as pd
import numpy as np

import random

import os
device = "cuda"
def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

set_seed(0)


Random seed set as 0


In [None]:
save_location = "/media/bhux/alpha/xsd_mvp/3000e_normalized/"

# uniform, lg, gaussian

NUMERICAL = 12
CATEGORICAL = 0

NUM_BINS = 30

INFILLING_TYPE = ''
NOISE_TYPE = ''

In [None]:
from utils import remove_outliers, norm, sample_local_gaussian, convert_categorical, unison_shuffled_copies, categorical_norm

PATH = "./data/xtended_data_all.csv"
EMB_PATH = "./data/xtended_emb_all_deberta_pubchem.npy"

f = pd.read_csv(PATH)
drug_embeddings = np.load(EMB_PATH)
smiles = f['Drug'].values
vlists = {
    col: f[col].values for col in f.drop(labels=['Drug'], axis=1).columns[:NUMERICAL] 
}

inmask = remove_outliers([v for _,v in vlists.items()])
print(sum(inmask))
smiles = smiles[inmask]
vlists = {
    k: v[inmask] for k,v in vlists.items()
}

vlists = {
    k: norm(v) for k,v in vlists.items()
}

# for col in f.drop(labels=['Drug'], axis=1).columns[:NUMERICAL]:
#     vlists[col+"_cat"] = vlists[col]

nullmask = np.stack([
    np.isnan(v)==False for _,v in vlists.items()
    ], axis=-1)

for col in f.drop(labels=['Drug'], axis=1).columns[:NUMERICAL]:
    vlists[col+"_cat"], vlists[col] = categorical_norm(vlists[col], numbins=NUM_BINS)

dmss = []
for k,v in vlists.items():
    vlists[k], dms = sample_local_gaussian(v, numbins=NUM_BINS)
    dmss.append(dms)


# for col in f.drop(labels=['Drug'], axis=1).columns[NUMERICAL:]:
#     nan = np.isnan(vlists[col])
#     vlists[col] += 1
#     vlists[col][nan] = 0

# dmss = []
# for k,v in vlists.items():
#     dms = get_local_gaussian(v, numbins=50)
#     dmss.append(dms)

dataset = []
for i, gt in enumerate(zip(*[v for _,v in vlists.items()])):
    dataset.append({
        "sm": smiles[i],
        "ft": drug_embeddings[i],
        "ma": nullmask[i],
        "gt": np.array(gt),
        "od": np.array(gt[NUMERICAL:]),
    })
    # print(gt)
    # print(nullmask[i])
    # break

valCount = np.sum(nullmask, axis=0)*0.1
dataset, rcomb = unison_shuffled_copies(dataset, nullmask)
trdataset = []
valdataset = []
for c, d in zip(rcomb, dataset):
    inc = False
    for i, j in enumerate(list(c)):
        if j and valCount[i] > 0:
            valCount[i] -= 1
            inc = True
    if inc:
        valdataset.append(d)
    else:
        trdataset.append(d)

print(len(trdataset))
print(len(valdataset))
print(len(list(vlists.keys())))

28443
-1.0 -0.8571428571428572
[0.72844186 0.         0.32770072 0.90040047 0.43479552 0.33093674
 0.43549484 0.17768169 0.22859833 0.50619097 0.18490945 0.43479552
 0.72037842 0.17955583]
-0.8571428571428572 -0.7142857142857143
[0.37051335 0.98325722 0.42995929 0.21552874 0.11305576 0.50573167
 0.68670864 0.47003448 0.54143029 0.68670864 0.37051335 0.01038715
 0.22316488 0.12060337 0.43162924 0.36294004 0.82555101]
-0.7142857142857143 -0.5714285714285714
[0.06998814 0.18399191 0.18399191 0.75515986 0.61236823 0.46113726
 0.61420667 0.15398211 0.38033147 0.86225251 0.61604511 0.14350656
 0.95252217 0.89832917 0.07550526 0.79085669]
-0.5714285714285714 -0.4285714285714286
[0.82609601 0.5719544  0.19624463 0.68698127 0.47637861 0.26521977
 0.10812292 0.18353403 0.61152232 0.36202429 0.79658991 0.96888728
 0.82609601 0.68946655 0.68698127 0.00504378 0.47637861 0.68330438
 0.01022533 0.64950703 0.25492985 0.24979256 0.29305771 0.18353403
 0.01093358 0.35819604 0.4334201  0.79039882 0.00504

  m.append(data.mean() if not np.isnan(data.mean()) else low)
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)


25590
2853
24


In [4]:
from torch.utils.data import Dataset

class GaucamolDataset(Dataset):
    def __init__(self, dataset) -> None:
        self.dataset = dataset
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]
    
    def update(self, idx, delta):
        item = self.dataset[idx]["gt"]
        self.dataset[idx]["gt"] = item + delta

trainset = GaucamolDataset(trdataset)
valset = GaucamolDataset(valdataset)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=True)
steps_per_epoch = len(trainset)
DMSS = dmss[:NUMERICAL]

In [5]:
def train(diffusion, ema, gamma, dataloader, optimizer, lr_scheduler, two_noise=False):
    diffusion.train()
    running_loss = 0
    global_step = 0
    for i, batch in enumerate(tqdm(dataloader)):
        ft = batch['ft'].to(device).float()
        gt = batch['gt'].to(device).float()
        od = batch['od'].to(device).long()
        mask = batch['ma'].to(device)
        bs = ft.shape[0]

        optimizer.zero_grad()
        loss_multi, loss_gauss = diffusion.mixed_loss(ft, gt, od, mask, DMSS)

        loss = loss_multi + loss_gauss
        
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        ema.update_params(gamma)
        gamma = ema.update_gamma(global_step)

        running_loss += loss.item()
        global_step += 1
    return running_loss/global_step

In [6]:
from sklearn.metrics import mean_squared_error
import csv
from utils import ohe_to_categories

def evaluate(e, ema, dataloader):
    ema.ema_model.eval()
    before_mse = 0
    running_mse = 0
    global_step = 0
    vals = {}
    device = 'cuda'
    ema.ema_model.to(device)
    with torch.no_grad():
        for i, batch in enumerate(tqdm(dataloader)):
            sm = batch['sm']
            mask = batch['ma'].repeat(1,2)
            ft = batch['ft'].to(device).float()
            gt = batch['gt'].to(device).float()
            od = batch['od'].to(device).long()
            bs = ft.shape[0]

            x_in, generated_ys = ema.ema_model.sample(ft, bs, od, DMSS, clip_sample=True)

            raw_mse = mean_squared_error(gt[mask].flatten().cpu(), x_in[mask].flatten().cpu())
            mse = mean_squared_error(gt[mask].flatten().cpu(), generated_ys[mask].flatten().cpu())

            for s, g in zip(sm, list(generated_ys.cpu().numpy())):
                vals[s] = g
            
            before_mse += raw_mse
            running_mse += mse
            global_step += 1

    with open(save_location+'{}_dict.csv'.format(e), 'w') as csv_file:  
        writer = csv.writer(csv_file)
        for key, value in vals.items():
            writer.writerow([key, value])

    return running_mse / global_step, before_mse / global_step
            

In [7]:
from ema import EMA
lr = 0.009980944985851876
wd = 0.0002561887904737375
warmup = 150
n_timesteps = 2000
n_inference_timesteps = 150
num_epochs = 3000
update_epochs = 500
update_timesteps = int(num_epochs/update_epochs)
gamma = 0.9793712795623954

In [None]:


from sdt import SDT
from diffusion import GaussianMultinomialDiffusion
torch.set_printoptions(profile="full")


total_num_steps = (steps_per_epoch * num_epochs)

model = SDT(
    time_dim = 16,
    cond_size = 768,
    patch_size = 32,
    y_dim = NUMERICAL+(NUM_BINS+1)*(NUMERICAL+CATEGORICAL),
    dim = 768,
    depth = 5,
    heads = 7,
    mlp_dim = 768,
    dropout =  0.1423146531023044,
    emb_dropout =  0.1423146531023044,
    num_classes = NUM_BINS+1,
)
model.to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {total_params}")

diffusion = GaussianMultinomialDiffusion(
    num_classes = np.array([(NUM_BINS+1) for _ in range(NUMERICAL+CATEGORICAL)]),
    num_numerical_features = NUMERICAL,
    denoise_fn = model,
    device = device,
)
diffusion.to(device)

ema = EMA(diffusion, gamma, total_num_steps)

optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=wd,
    )

lr_scheduler = get_scheduler(
        "constant",
        optimizer=optimizer,
        num_warmup_steps=warmup,
        num_training_steps=total_num_steps,
    )

Number of parameters: 19151756
torch.Size([192])


In [None]:
l = ""
best_mse = 1
loss = 0
for e in range(num_epochs):
    loss = train(diffusion, ema, gamma, trainloader, optimizer, lr_scheduler)
    if (e % 10 == 0) and (e > 0):
        mse, bmse = evaluate(e, ema, valloader)
        print(e, "avgloss {}, avgvalmse {}, beforemse: {}".format(loss, mse, bmse))
        l += "{} avgloss {}, avgvalmse {}, beforemse: {}\n".format(e, loss, mse, bmse)

        if mse < best_mse:
            best_mse = mse
            torch.save({
                'e': e,
                'ema_model': ema.ema_model.state_dict(),
                'model': diffusion.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, save_location+"best_model.pt")
    else:
        print(e, "avgloss {}".format(loss))
        l += "{} avgloss {}\n".format(e, loss)

    # if ((e % update_epochs  == 0) and e > 500):
    #     trainloader = update(int((num_epochs-e) / update_epochs), ema, updateloader, trainset, ns, update_timesteps)

    with open(save_location+'output.txt', 'w') as file:
        file.write(l)

 48%|████▊     | 95/200 [00:03<00:02, 37.28it/s]