In [None]:
!pip install kora

In [None]:
import tensorflow as tf
import torch
from torchvision import transforms
import kora.install.rdkit
from rdkit import Chem
from typing import Optional
import random
import numpy as np
from PIL import Image, ImageOps, ImageEnhance
from typing import Union, List, Optional
from torch import nn
import torch.nn.functional as F

In [None]:
!pip install pytorch_lightning

In [None]:
import pytorch_lightning as pl

In [None]:
LR = 1e-4

In [None]:
configs = {"unit" : [128,256,384,384,384,512,512,512],
           "kernel_size" : [7,5,5,3,3,3,3,3],
           "stride" : [3,1,1,1,1,1,1,1],
           "padding" : [4,1,1,1,1,1,1,1]}
layer_list = ["conv2d", "conv2d", "conv2d" "maxpool", "conv2d", "conv2d", "maxpool", "conv2d", "conv2d", "conv2d", "maxpool"]

In [None]:
def getCNN(configs, layer_list):
    layers = []
    in_channels = 1
    i = 0
    for layer in layer_list:
        if layer == "conv2d":
            layers.append(nn.Conv2d(in_channels, configs["unit"][i], kernel_size = configs["kernel_size"][i], stride = configs["stride"][i], padding = configs["padding"][i]))
            layers.append(nn.ReLU(inplace=True))
            in_channels = configs["unit"][i]
            i += 1
        elif layer == "maxpool":
            layers.append(nn.MaxPool2d(kernel_size=2, stride=2))

    model = nn.Sequential(*layers)
    return model


In [None]:
def getFCN():
  layers = []
  layers.append(nn.Linear(512*9*9, 4096))
  layers.append(nn.ReLU(True))
  layers.append(nn.Dropout(p=0))
  layers.append(nn.Linear(4096, 4096))
  layers.append(nn.ReLU(True))
  layers.append(nn.Dropout(p=0))
  layers.append(nn.Linear(4096, 512))
  layers.append(nn.Tanh())
  model = nn.Sequential(*layers)

  return model

In [None]:
class Img2MolModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.learning_rate = LR
        self.features = getCNN(configs, layer_list)
        self.classifier = getFCN()
        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        self.log('train_loss', loss, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        self.log('valid_loss', loss, on_epoch=True, prog_bar=True, logger=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        self.log('test_loss', loss)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)

In [None]:
model = Img2MolModel()