In [2]:
import numpy as np

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pack_padded_sequence
from torch.autograd import Variable
from diffusers import DDPMScheduler
import os

import sys
sys.path.insert(1,"../scripts")
from get_voxels import get_mol_voxels, smile_to_sstring, collate_batch
from networks import EncoderCNN, DecoderRNN, UNet3D, Encoder

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

Using device: cuda


In [8]:
if "models" not in os.listdir("../"):
    os.mkdir("../models")

out_dir = "../models/"

In [9]:
smiles = []
with open("../datasets/raw/zinc15_druglike_clean_canonical_max60.smi") as f:
    i=0
    for i, line in enumerate(f):
        smiles.append(line[:-1])
        if i > 20000000:
            break

# smiles = smiles[112320:]

In [5]:
class CustomImageDataset(Dataset):
    def __init__(self, smiles):
        self.smiles = smiles

    def __len__(self):
        return len(self.smiles)

    def __getitem__(self, idx):
        smile = self.smiles[idx]
        return smile

smile_DS = CustomImageDataset(smiles)

In [7]:
# Define the networks
encoderCNN = EncoderCNN(5)
decoderRNN = DecoderRNN(512, 1024, 29, 1)
encoder = Encoder()
net = UNet3D(5,5)
encoderCNN.to(device)
decoderRNN.to(device)
net.to(device)
encoder.to(device)

#Encoder Optimizer
criterionEncoder = nn.BCELoss()
#Encoder optimizer
optimizerEncoder = torch.optim.Adam(encoder.parameters(), lr = 0.001)

# Our loss finction
criterionNet = nn.BCELoss()
# The optimizer
optimizerNet = torch.optim.Adam(net.parameters(), lr=0.001) 

# Caption optimizer
criterionCaption = nn.CrossEntropyLoss()
caption_params = list(decoderRNN.parameters()) + list(encoderCNN.parameters())
caption_optimizer = torch.optim.Adam(caption_params, lr=0.001)

# # How many runs through the data should we do?
# n_epochs = 1

#Other training stuff
n_epochs = 1
train_dataloader = DataLoader(smile_DS, batch_size=32, collate_fn=collate_batch)

scheduler = DDPMScheduler(num_train_timesteps=1000)

In [8]:
for i, (x, captions, pharm, lengths) in enumerate(train_dataloader):
    if (i+1) % 10 == 0:
        print("Batch {} of {}.".format(i+1,np.int64(np.ceil(len(train_dataloader.dataset)/train_dataloader.batch_size))))

    #Train Encoder and Unet
    ##Unet
    timesteps = torch.randint(
        0,
        scheduler.num_train_timesteps,
        (x.shape[0],),
        device=x.device,
    ).long()

    noise = torch.randn(x.shape).to(x.device)
    noisy_x = scheduler.add_noise(x, noise, timesteps)    
    noisy_x = noisy_x.type(torch.FloatTensor).to(device)
    x = x.to(device)

    pred = net(noisy_x)
    net_loss = criterionNet(pred, x)

    # Backprop and update the params:
    optimizerNet.zero_grad()
    net_loss.backward()
    optimizerNet.step()
    net_loss = net_loss.cpu()

    ##Encoder
    pharm = pharm.to(device)
    # Forward pass
    encoded_tensor = encoder(x)
    enc_loss=criterionEncoder(encoded_tensor, pharm)
    # Backward and optimize
    optimizerEncoder.zero_grad()
    enc_loss.backward()
    optimizerEncoder.step()


    ##Train Captioning Networks after ~100,000 compounds 
    if i > 3200:
        captions = Variable(captions.to(device))
        targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]

        decoderRNN.zero_grad()
        encoderCNN.zero_grad()
        features = encoderCNN(pred.detach())
        outputs = decoderRNN(features, captions, lengths)
        cap_loss = criterionCaption(outputs, targets)
        cap_loss.backward()
        caption_optimizer.step()


    if (i + 1) % 500 == 0:
        torch.save(net.state_dict(),"../models/net_weights_{}.pkl".format(i+1))
        torch.save(encoder.state_dict(),"../models/net_weights_{}.pkl".format(i+1))
        torch.save(encoderCNN.state_dict(),"../models/net_weights_{}.pkl".format(i+1))
        torch.save(decoderRNN.state_dict(),"../models/net_weights_{}.pkl".format(i+1))
        if i > 3200:
            print("Net Loss: {}\nEncoder Loss: {}\nCaptioning Loss: {}".format(net_loss,enc_loss,cap_loss))
        else:
            print("Net Loss: {}\nEncoder Loss: {}".format(net_loss,enc_loss))

Batch 10 of 8991.
Batch 20 of 8991.
Batch 30 of 8991.
Batch 40 of 8991.
Batch 50 of 8991.
Batch 60 of 8991.
Batch 70 of 8991.
Batch 80 of 8991.
Batch 90 of 8991.
Batch 100 of 8991.
Batch 110 of 8991.
Batch 120 of 8991.
Batch 130 of 8991.
Batch 140 of 8991.
Batch 150 of 8991.
Batch 160 of 8991.
Batch 170 of 8991.
Batch 180 of 8991.
Batch 190 of 8991.
Batch 200 of 8991.
Batch 210 of 8991.
Batch 220 of 8991.
Batch 230 of 8991.
Batch 240 of 8991.
Batch 250 of 8991.
Batch 260 of 8991.
Batch 270 of 8991.
Batch 280 of 8991.
Batch 290 of 8991.
Batch 300 of 8991.
Batch 310 of 8991.
Batch 320 of 8991.
Batch 330 of 8991.
Batch 340 of 8991.
Batch 350 of 8991.
Batch 360 of 8991.
Batch 370 of 8991.
Batch 380 of 8991.
Batch 390 of 8991.
Batch 400 of 8991.
Batch 410 of 8991.
Batch 420 of 8991.
Batch 430 of 8991.
Batch 440 of 8991.
Batch 450 of 8991.
Batch 460 of 8991.
Batch 470 of 8991.
Batch 480 of 8991.
Batch 490 of 8991.
Batch 500 of 8991.
Net Loss: 0.017858007922768593
Encoder Loss: 0.009011551737