In [1]:
import torch
from torch.utils.data import DataLoader
import gentrl
import gentrl.lp
from gentrl.tokenizer import encode, get_vocab_size
from gentrl.new_dataloader import NewMolecularDataset
import time

## Set GPU

In [2]:
torch.cuda.set_device(3)

In [3]:
RANDOM_SEED = 42

In [4]:
enc = gentrl.RNNEncoder(latent_size=50)
dec = gentrl.DilConvDecoder(latent_input_size=50)
model = gentrl.GENTRL(enc, dec, latent_descr=50 * [('c', 20)], feature_descr=[('c', 20)], beta=0.001)
# this moves the model to GPU
model.to('cuda');

In [5]:
device = torch.device('cuda')

In [6]:
md = NewMolecularDataset(device=device,
    sources=[
    {'path':'train_subset_100_000.csv',
     'smiles': 'SMILES',
     'prob': 1,
     'plogP' : 'plogP',
    }], 
    props=['plogP'])

In [7]:
gpu_dataset = md.create_tensor_dataset()

In [8]:
BATCH_SIZE = 50
LR = 1e-4
NUM_EPOCHS = 1
NUM_WORKERS = 0
PIN_MEMORY= False

In [9]:
train_loader = DataLoader(gpu_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=True)

In [10]:
start_time = time.time()
model.train_as_vaelp(train_loader, lr=LR, num_epochs=NUM_EPOCHS)
end_time = time.time()
duration = end_time - start_time
print(f"Time Taken is {duration/60} min")

Epoch 0 :
loss: 2.488;rec: -2.4;kl: -64.57;log_p_y_by_z: -1.526;log_p_z_by_y: -70.97;
loss: 1.79;rec: -1.689;kl: -50.16;log_p_y_by_z: -1.512;log_p_z_by_y: -55.86;
loss: 1.53;rec: -1.417;kl: -37.85;log_p_y_by_z: -1.511;log_p_z_by_y: -45.38;
loss: 1.397;rec: -1.279;kl: -33.24;log_p_y_by_z: -1.513;log_p_z_by_y: -43.06;
loss: 1.299;rec: -1.188;kl: -30.75;log_p_y_by_z: -1.416;log_p_z_by_y: -42.78;
loss: 1.232;rec: -1.127;kl: -29.62;log_p_y_by_z: -1.345;log_p_z_by_y: -43.42;
loss: 1.168;rec: -1.068;kl: -27.41;log_p_y_by_z: -1.268;log_p_z_by_y: -43.46;
loss: 1.128;rec: -1.041;kl: -26.17;log_p_y_by_z: -1.126;log_p_z_by_y: -43.61;
loss: 1.09;rec: -1.006;kl: -24.57;log_p_y_by_z: -1.084;log_p_z_by_y: -44.68;
loss: 1.061;rec: -0.9886;kl: -24.18;log_p_y_by_z: -0.9637;log_p_z_by_y: -44.51;
loss: 1.032;rec: -0.9661;kl: -24.15;log_p_y_by_z: -0.9058;log_p_z_by_y: -44.81;
loss: 1.022;rec: -0.9567;kl: -23.36;log_p_y_by_z: -0.8832;log_p_z_by_y: -45.17;
loss: 0.9886;rec: -0.929;kl: -22.44;log_p_y_by_z: -0.

In [None]:
# from moses.metrics import mol_passes_filters, QED, SA, logP
# from moses.metrics.utils import get_n_rings, get_mol


# def get_num_rings_6(mol):
#     r = mol.GetRingInfo()
#     return len([x for x in r.AtomRings() if len(x) > 6])


# def penalized_logP(mol_or_smiles, masked=False, default=-5):
#     mol = get_mol(mol_or_smiles)
#     if mol is None:
#         return default
#     reward = logP(mol) - SA(mol) - get_num_rings_6(mol)
#     if masked and not mol_passes_filters(mol):
#         return default
#     return reward

In [None]:
# df = pd.read_csv('dataset_v1.csv')
# df = df[df['SPLIT'] == 'train']
# df['plogP'] = df['SMILES'].apply(penalized_logP)
# df.to_csv('train_plogp_plogpm.csv', index=None)

In [None]:
# ! mkdir -p saved_gentrl

In [None]:
# model.save('./saved_gentrl/')