In [1]:
import math

import torch
import torch.nn as nn

from diffusers.optimization import get_scheduler
from tqdm import tqdm

import pandas as pd
import numpy as np

import random

import os
device = "cuda"
def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

set_seed(0)


  from .autonotebook import tqdm as notebook_tqdm


Random seed set as 0


In [2]:
save_location = "/media/bhux/alpha/xsd_mvp/test/"

# uniform, lg, gaussian

NUMERICAL = 12
CATEGORICAL = 0

INFILLING_TYPE = ''
NOISE_TYPE = ''

In [3]:


def m(v):
    nonnull = v[np.isnan(v) == False]
    return np.mean(nonnull)

def std(v):
    nonnull = v[np.isnan(v) == False]
    return np.std(nonnull)

def infill_null(v):
    v[np.isnan(v)] = 0
    return v

def remove_outliers(lists):
    b = np.ones(lists[0].shape)

    for l in lists:
        q1 = np.nanquantile(l,0.25)
        q3 = np.nanquantile(l,0.75)

        iqr = q3 - q1
        lower = q1 - 1.5*iqr
        upper = q3 + 1.5*iqr

        b = np.logical_and(b, np.logical_or(l > lower , np.isnan(l)))
        b = np.logical_and(b, np.logical_or(l < upper , np.isnan(l)))
    
    return b

def norm(v):
    nonnull = v[np.isnan(v) == False]
    max = np.nanmax(nonnull)
    min = np.nanmin(nonnull)

    return 2*((v - min) / (max - min))-1

def get_local_gaussian(ys, numbins=50):
    max = np.nanmax(ys)
    min = np.nanmin(ys)
    bins = np.linspace(min, max, retstep=numbins)[0]
    s_ys = np.array(sorted(ys, reverse=True))

    d, m, s = [], [], []
    for i in range(numbins-1):
        low = bins[i]
        high = bins[i+1]
        tbool = np.logical_and(low<=s_ys, s_ys<=high)
        data = s_ys[tbool]
        d.append(len(data))
        m.append(data.mean() if not np.isnan(data.mean()) else low)
        s.append(data.std() if not np.isnan(data.std()) else 0.01)
    d = np.array(d) / sum(d)

    return d, np.array(m), np.array(s)

def convert_categorical(ys, numbins=50):
    max = np.nanmax(ys)
    min = np.nanmin(ys)
    bins = np.linspace(min, max, retstep=numbins)[0]
    s = np.array(sorted(enumerate(ys), key=lambda x:x[1]))
    s_inds = s[:,0].astype(int)
    s_ys = s[:,1]
    
    
    nys = ys
    for i in range(numbins-1):
        low = bins[i]
        high = bins[i+1]
        tbool = np.logical_and(low<=s_ys, s_ys<=high)
        
        if sum(tbool) > 1:
            nys[np.array(s_inds[tbool])] = i+1

    print(i+1)
    nys[np.isnan(nys)] = 0
    return nys

def sample_local_gaussian(v, numbins=50):
    d,m,s = get_local_gaussian(v, numbins=numbins)
    num = sum(np.isnan(v))

    samples = np.random.choice(numbins-1, num, p=d)
    rand_n = np.random.randn(num)

    adjust = m[samples] + 1.2 * rand_n *s[samples]
    
    # override 
    #adjust = np.zeros(adjust.shape)
    
    if INFILLING_TYPE == "zeros":
        adjust = np.zeros(adjust.shape)

    if INFILLING_TYPE == "uniform":
        adjust = np.random.uniform(low=-1.0, high=1.0, size=adjust.shape)

    if INFILLING_TYPE == "gaussian":
        adjust = np.random.normal(size=adjust.shape)

    v[np.isnan(v)] = adjust
    return v, (d,m,s)


def sample_noise(b, dmss, numbins=50):
    if NOISE_TYPE == "uniform":
        return np.random.uniform(low=-1.0, high=1.0, size=(b,NUMERICAL))
    
    if NOISE_TYPE == "gaussian":
        return np.random.normal(size=(b,NUMERICAL))

    vs = []
    for d,m,s in dmss:
        samples = np.random.choice(numbins-1, b, p=d)
        rand_n = np.random.randn(b)
        vs.append(m[samples] + 1.2 * rand_n *s[samples])
    return np.stack(vs, axis=-1)

def unison_shuffled_copies(a, b):
    #assert len(a) == len(b)
    p = np.random.permutation(len(a))
    return np.array(a)[p], np.array(b)[p]

In [4]:
PATH = "./data/xtended_data_all.csv"
EMB_PATH = "./data/xtended_emb_all_deberta_pubchem.npy"

f = pd.read_csv(PATH)
drug_embeddings = np.load(EMB_PATH)
smiles = f['Drug'].values
vlists = {
    col: f[col].values for col in f.drop(labels=['Drug'], axis=1).columns[:NUMERICAL] 
}

inmask = remove_outliers([v for _,v in vlists.items()])
print(sum(inmask))
smiles = smiles[inmask]
vlists = {
    k: v[inmask] for k,v in vlists.items()
}

vlists = {
    k: norm(v) for k,v in vlists.items()
}

# for col in f.drop(labels=['Drug'], axis=1).columns[:NUMERICAL]:
#     vlists[col+"_cat"] = vlists[col]

nullmask = np.stack([
    np.isnan(v)==False for _,v in vlists.items()
    ], axis=-1)

dmss = []
for k,v in vlists.items():
    vlists[k], dms = sample_local_gaussian(v, numbins=15)
    dmss.append(dms)

for col in f.drop(labels=['Drug'], axis=1).columns[:NUMERICAL]:
    vlists[col+"_cat"] = convert_categorical(vlists[col], numbins=15)

# for col in f.drop(labels=['Drug'], axis=1).columns[NUMERICAL:]:
#     nan = np.isnan(vlists[col])
#     vlists[col] += 1
#     vlists[col][nan] = 0

# dmss = []
# for k,v in vlists.items():
#     dms = get_local_gaussian(v, numbins=50)
#     dmss.append(dms)

dataset = []
for i, gt in enumerate(zip(*[v for _,v in vlists.items()])):
    dataset.append({
        "sm": smiles[i],
        "ft": drug_embeddings[i],
        "ma": nullmask[i],
        "gt": np.array(gt),
        "od": np.array(gt[NUMERICAL:]),
    })
    # print(gt)
    # print(nullmask[i])
    # break

valCount = np.sum(nullmask, axis=0)*0.1
dataset, rcomb = unison_shuffled_copies(dataset, nullmask)
trdataset = []
valdataset = []
for c, d in zip(rcomb, dataset):
    inc = False
    for i, j in enumerate(list(c)):
        if j and valCount[i] > 0:
            valCount[i] -= 1
            inc = True
    if inc:
        valdataset.append(d)
    else:
        trdataset.append(d)

print(len(trdataset))
print(len(valdataset))
print(len(list(vlists.keys())))

28422
14
14


  m.append(data.mean() if not np.isnan(data.mean()) else low)
  ret = ret.dtype.type(ret / rcount)
  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  arrmean = um.true_divide(arrmean, div, out=arrmean,
  ret = ret.dtype.type(ret / rcount)


14
14
14
14
14
14
14
14
14
14
25531
2891
24


In [5]:
from torch.utils.data import Dataset

class GaucamolDataset(Dataset):
    def __init__(self, dataset) -> None:
        self.dataset = dataset
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]
    
    def update(self, idx, delta):
        item = self.dataset[idx]["gt"]
        self.dataset[idx]["gt"] = item + delta

trainset = GaucamolDataset(trdataset)
valset = GaucamolDataset(valdataset)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=128, shuffle=True)
steps_per_epoch = len(trainset)
DMSS = dmss[:NUMERICAL]

In [6]:
def train(diffusion, ema, gamma, dataloader, optimizer, lr_scheduler, two_noise=False):
    diffusion.train()
    running_loss = 0
    global_step = 0
    for i, batch in enumerate(tqdm(dataloader)):
        ft = batch['ft'].to(device).float()
        gt = batch['gt'].to(device).float()
        od = batch['od'].to(device).long()
        mask = batch['ma'].to(device)
        bs = ft.shape[0]

        optimizer.zero_grad()
        loss_multi, loss_gauss = diffusion.mixed_loss(ft, gt, od, mask, DMSS)

        loss = loss_multi + loss_gauss
        
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        ema.update_params(gamma)
        gamma = ema.update_gamma(global_step)

        running_loss += loss.item()
        global_step += 1
    return running_loss/global_step

In [7]:
from sklearn.metrics import mean_squared_error
import csv
from utils import ohe_to_categories

def evaluate(e, ema, dataloader):
    ema.ema_model.eval()
    before_mse = 0
    running_mse = 0
    global_step = 0
    vals = {}
    device = 'cuda'
    ema.ema_model.to(device)
    with torch.no_grad():
        for i, batch in enumerate(tqdm(dataloader)):
            sm = batch['sm']
            mask = batch['ma'].repeat(1,2)
            ft = batch['ft'].to(device).float()
            gt = batch['gt'].to(device).float()
            od = batch['od'].to(device).long()
            bs = ft.shape[0]

            x_in, generated_ys = ema.ema_model.sample(ft, bs, od, DMSS, clip_sample=True)

            raw_mse = mean_squared_error(gt[mask].flatten().cpu(), x_in[mask].flatten().cpu())
            mse = mean_squared_error(gt[mask].flatten().cpu(), generated_ys[mask].flatten().cpu())

            for s, g in zip(sm, list(generated_ys.cpu().numpy())):
                vals[s] = g
            
            before_mse += raw_mse
            running_mse += mse
            global_step += 1

    with open(save_location+'{}_dict.csv'.format(e), 'w') as csv_file:  
        writer = csv.writer(csv_file)
        for key, value in vals.items():
            writer.writerow([key, value])

    return running_mse / global_step, before_mse / global_step
            

In [8]:
from ema import EMA
lr = 0.0005
wd = 1e-4
warmup = 200
n_timesteps = 2000
n_inference_timesteps = 150
num_epochs = 3000
update_epochs = 500
update_timesteps = int(num_epochs/update_epochs)
gamma = 0.994

In [9]:


from sdt import SDT
from diffusion import GaussianMultinomialDiffusion
torch.set_printoptions(profile="full")


total_num_steps = (steps_per_epoch * num_epochs)

model = SDT(
    time_dim = 128,
    cond_size = 768,
    patch_size = 8,
    y_dim = NUMERICAL+15*(NUMERICAL+CATEGORICAL),
    dim = 768,
    depth = 12,
    heads = 12,
    mlp_dim = 1024,
    dropout = 0.1,
    emb_dropout = 0.1,
    num_classes = 15,
)
model.to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {total_params}")

diffusion = GaussianMultinomialDiffusion(
    num_classes = np.array([15 for _ in range(NUMERICAL+CATEGORICAL)]),
    num_numerical_features = NUMERICAL,
    denoise_fn = model,
    device = device,
)
diffusion.to(device)

ema = EMA(diffusion, gamma, total_num_steps)

optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=wd,
    )

lr_scheduler = get_scheduler(
        "cosine",
        optimizer=optimizer,
        num_warmup_steps=warmup,
        num_training_steps=total_num_steps,
    )

Number of parameters: 62563024
torch.Size([180])


In [19]:
l = ""
best_mse = 1
loss = 0
for e in range(num_epochs):
    loss = train(diffusion, ema, gamma, trainloader, optimizer, lr_scheduler)
    if (e % 10 == 0) and (e > 0):
        mse, bmse = evaluate(e, ema, valloader)
        print(e, "avgloss {}, avgvalmse {}, beforemse: {}".format(loss, mse, bmse))
        l += "{} avgloss {}, avgvalmse {}, beforemse: {}\n".format(e, loss, mse, bmse)

        if mse < best_mse:
            best_mse = mse
            torch.save({
                'e': e,
                'ema_model': ema.ema_model.state_dict(),
                'model': diffusion.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, save_location+"best_model.pt")
    else:
        print(e, "avgloss {}".format(loss))
        l += "{} avgloss {}\n".format(e, loss)

    # if ((e % update_epochs  == 0) and e > 500):
    #     trainloader = update(int((num_epochs-e) / update_epochs), ema, updateloader, trainset, ns, update_timesteps)

    with open(save_location+'output.txt', 'w') as file:
        file.write(l)

 47%|████▋     | 94/200 [00:23<00:25,  4.12it/s]