In [None]:
!pip install kora

In [36]:
import torch
import torch.nn as nn
import numpy as np
import os
from torchvision import transforms
from matplotlib import pyplot as plt
from PIL import Image, ImageOps, ImageEnhance
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import tensorflow as tf
import gc
from torchvision import transforms
import kora.install.rdkit
from rdkit import Chem
from typing import Optional
import random
from PIL import Image, ImageOps, ImageEnhance
from typing import Union, List, Optional
import torch.nn.functional as F

In [None]:
!pip install pytorch_lightning

In [38]:
import pytorch_lightning as pl

In [39]:
LR = 1e-4

In [40]:
configs = {"unit" : [128,256,384,384,384,512,512,512],
           "kernel_size" : [7,5,5,3,3,3,3,3],
           "stride" : [3,1,1,1,1,1,1,1],
           "padding" : [4,1,1,1,1,1,1,1]}
layer_list = ["conv2d", "conv2d", "conv2d" "maxpool", "conv2d", "conv2d", "maxpool", "conv2d", "conv2d", "conv2d", "maxpool"]

In [41]:
def getCNN(configs, layer_list):
    layers = []
    in_channels = 1
    i = 0
    for layer in layer_list:
        if layer == "conv2d":
            layers.append(nn.Conv2d(in_channels, configs["unit"][i], kernel_size = configs["kernel_size"][i], stride = configs["stride"][i], padding = configs["padding"][i]))
            layers.append(nn.ReLU(inplace=True))
            in_channels = configs["unit"][i]
            i += 1
        elif layer == "maxpool":
            layers.append(nn.MaxPool2d(kernel_size=2, stride=2))

    model = nn.Sequential(*layers)
    return model

In [42]:
def getFCN():
  layers = []
  layers.append(nn.Linear(512*9*9, 4096))
  layers.append(nn.ReLU(True))
  layers.append(nn.Dropout(p=0))
  layers.append(nn.Linear(4096, 4096))
  layers.append(nn.ReLU(True))
  layers.append(nn.Dropout(p=0))
  layers.append(nn.Linear(4096, 512))
  layers.append(nn.Tanh())
  model = nn.Sequential(*layers)

  return model

In [43]:
class Img2MolModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.learning_rate = LR
        self.features = getCNN(configs, layer_list)
        self.classifier = getFCN()
        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        self.log('train_loss', loss, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        self.log('valid_loss', loss, on_epoch=True, prog_bar=True, logger=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        self.log('test_loss', loss)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

In [44]:
model = Img2MolModel()

In [None]:
%cd drive/MyDrive

In [46]:
train_images = sorted(os.listdir("train_data"))

In [47]:
image_names = []
for train_image in train_images[:5000]:
    image_names.append(os.path.join("train_data",train_image))

In [None]:
def load_image(path):
    im = Image.open(path).convert('RGB')
    im = ImageEnhance.Contrast(ImageOps.autocontrast(im)).enhance(2.0)
    im = ImageOps.autocontrast(im)
    im = im.resize((234,234))
    im = np.array(im)
    im = transforms.ToTensor()(im)
    return im

In [None]:
images = torch.unsqueeze(load_image(image_names[0]), 0)
for image_name in tqdm(image_names):
  images = torch.cat((images, torch.unsqueeze(load_image(image_name), 0)), dim = 0)

In [58]:
import pandas as pd
chembl = pd.read_csv("chembl_cleaned.csv", delimiter = "\t")
chembl = chembl[:50000]
smiles = chembl["Smiles"]
smiles = list(smiles)

In [48]:
import json
import requests
requests.packages.urllib3.disable_warnings()
DEFAULT_HOST = "http://ec2-18-157-240-87.eu-central-1.compute.amazonaws.com"
class CDDDRequest:
    def __init__(self, host=DEFAULT_HOST, port=8892):
        self.host = host
        self.port = port
        self.headers = {'content-type': 'application/json'}

    def smiles_to_cddd(self, smiles):
        url = "{}:{}/smiles_to_cddd/".format(self.host, self.port)
        req = json.dumps({"smiles": smiles})
        response = requests.post(url, data=req, headers=self.headers, verify=False)
        return json.loads(response.content.decode("utf-8"))

In [59]:
CDDDserver = CDDDRequest()
cddd = []
for smile in tqdm(smiles):
  cddd.append(CDDDserver.smiles_to_cddd(smile))

100%|██████████| 1000/1000 [17:31<00:00,  1.05s/it]


In [60]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [61]:
cddd = np.array(cddd)
cddd = torch.from_numpy(cddd).to(device)

In [64]:
images = images[1:]

In [65]:
dataset = []
for i in range(len(images)):
  dataset.append([images[i], cddd[i]])
train_size = int(0.8 * len(images))
test_size = len(images) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [int(train_size*0.9), int(train_size*0.1)])

In [None]:
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=128)
valloader = torch.utils.data.DataLoader(val_dataset, batch_size = 128)
trainer = pl.Trainer(callbacks=[pl.callbacks.TQDMProgressBar(refresh_rate=10)])
trainer.fit(model, trainloader, valloader)
torch.save(model.save_dict(), "model.ckpt")