### Packages

In [None]:
!pip install pytorch_lightning
import time
import torch
import torch.nn as nn
import numpy as np
import os
from torchvision import transforms
from matplotlib import pyplot as plt
from PIL import Image, ImageOps, ImageEnhance
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import gc
import pandas as pd



### Paths

In [None]:
from google.colab import drive
drive.mount('/content/drive')

path = '/content/drive/MyDrive/SMAI_Project/'
image_path = 'train_images'
cddd_path = '/content/drive/MyDrive/SMAI_Project/embeddings.csv'
image_tensor_file = 'tensor.pt'
model_file = '/content/drive/MyDrive/SMAI_Project/50000_images_100_epochs.pt'
img2mol_cddd_path = '/content/drive/MyDrive/SMAI_Project/img2mol_benchmark_embeddings.csv'
img2mol_image_path = 'Img2MolVal'
staker_cddd_path = '/content/drive/MyDrive/SMAI_Project/staker_benchmark_embeddings.csv'
staker_image_path = 'STAKER'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!unzip "drive/My Drive/train_images.zip" -d "./"

Archive:  drive/My Drive/train_images.zip
replace ./train_images/20037.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
opt = {
    'image_size': 234,
    'is_grayscale': False,
    'image_parsed': False,
    'train': True,
    'load': True
}

In [None]:
## Load images, adjust orientations, size and contrast 

In [None]:
def load_image(path):
    # Opening image
    im = Image.open(path).convert('L' if opt['is_grayscale'] else 'RGB')

    ratio = float(224) / max(im.size)
    ns = tuple([int(x * ratio) for x in im.size])
    im = im.resize(ns, Image.BICUBIC)
    ni = Image.new("L", (224, 224), "white")
    ni.paste(im, ((224 - ns[0]) // 2,
                        (224 - ns[1]) // 2))
    ni = ImageOps.expand(ni, int(np.random.randint(5, 25, size=1)), "white")
    im = ni
  
    # Enhancing image
    im = ImageEnhance.Contrast(ImageOps.autocontrast(im)).enhance(2.0)
    # Contrast adjustment
    im = ImageOps.autocontrast(im)
    im = im.resize((opt['image_size'],opt['image_size']))
    im = transforms.ToTensor()(im)
    return im

In [None]:
def imshow(img):
    npimg = img.numpy()
    plt.figure(figsize=(5,5))
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

### Loading cddd embeddings from file

In [None]:
cddd_df = pd.read_csv(cddd_path, delimiter=',',  header=None)
cddd = torch.from_numpy(cddd_df.values).float().to(device)
cddd = cddd[:50000]

In [None]:
cddd.shape

torch.Size([50000, 512])

In [None]:
dir_path = image_path

In [None]:
indices = range(cddd_df.shape[0])
image_list = []
for i in indices:
    image_list.append(os.path.join(dir_path,"{}.png".format(i)))

In [None]:
torch.cuda.get_device_name(0)

'Tesla T4'

### Image to tensor

In [None]:
if opt['image_parsed']:
  images = torch.load(image_tensor_file)
else:
  images = torch.cat([torch.unsqueeze(load_image(image), 0) for image in image_list[:50000]], dim=0)
  # torch.save(images, image_tensor_file)

In [None]:
images.shape

### MODEL

In [None]:
import pytorch_lightning as pl
class Image2CDDD(pl.LightningModule):
    def __init__(self):
        super(Image2CDDD, self).__init__()
        # Input size: [batch, 3, 500, 500]

        self.network = nn.Sequential(
            nn.Conv2d(1, 128, kernel_size = 7, stride=3, padding = 4),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size = 5, stride=1, padding = 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 384, kernel_size = 5, stride=1, padding = 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2),

            nn.Conv2d(384, 384, kernel_size = 3, stride=1, padding = 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 384, kernel_size = 3, stride=1, padding = 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2),

            nn.Conv2d(384, 512, kernel_size = 3, stride=1, padding = 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size = 3, stride=1, padding = 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size = 3, stride=1, padding = 1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2),

            nn.Flatten(),

            nn.Linear(512*9*9, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.0),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.0),
            nn.Linear(4096, 512),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.network(x)

### Model parameters

In [49]:
model = Image2CDDD()
model.eval()
model.to(device)

learning_rate = 1e-4
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(params=model.parameters(), lr=learning_rate)

In [None]:
if opt["load"] == True:
  model.load_state_dict(torch.load(model_file))

### Train, Test Split

In [None]:
dataset = []
for i in range(len(images)):
  dataset.append([images[i], cddd[i]])
train_size = int(0.8 * len(images))
test_size = len(images) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [int(train_size*0.9), int(train_size*0.1)])

### Training

In [None]:
n_epochs = 80
gc.collect()
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=128,shuffle=True)

if opt["train"] == True:
  for epoch in tqdm(range(n_epochs)):
      running_loss = 0.0
      st = time.time()
      for i, (batch_x, batch_y) in enumerate(trainloader, 0):
          batch_x = batch_x.to(device)
          batch_y = batch_y.to(device)

          outputs = model(batch_x)
          loss = criterion(outputs, batch_y)

          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

          running_loss += loss.data
          if i % 50 == 49:
              print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 50))
              running_loss = 0.0
          print("Done with Batch {}, epoch {}".format(i, epoch))
      et = time.time()
      torch.save(model.state_dict(), model_file)
      print(et-st, "Epoch", epoch)

  print('Finished Training')
  print('Saving Model...')
  torch.save(model.state_dict(), model_file)

  0%|          | 0/80 [00:00<?, ?it/s]

Done with Batch 0, epoch 0
Done with Batch 1, epoch 0
Done with Batch 2, epoch 0
Done with Batch 3, epoch 0
Done with Batch 4, epoch 0
Done with Batch 5, epoch 0
Done with Batch 6, epoch 0
Done with Batch 7, epoch 0
Done with Batch 8, epoch 0
Done with Batch 9, epoch 0
Done with Batch 10, epoch 0
Done with Batch 11, epoch 0
Done with Batch 12, epoch 0
Done with Batch 13, epoch 0
Done with Batch 14, epoch 0
Done with Batch 15, epoch 0
Done with Batch 16, epoch 0
Done with Batch 17, epoch 0
Done with Batch 18, epoch 0
Done with Batch 19, epoch 0
Done with Batch 20, epoch 0
Done with Batch 21, epoch 0
Done with Batch 22, epoch 0
Done with Batch 23, epoch 0
Done with Batch 24, epoch 0
Done with Batch 25, epoch 0
Done with Batch 26, epoch 0
Done with Batch 27, epoch 0
Done with Batch 28, epoch 0
Done with Batch 29, epoch 0
Done with Batch 30, epoch 0
Done with Batch 31, epoch 0
Done with Batch 32, epoch 0
Done with Batch 33, epoch 0
Done with Batch 34, epoch 0
Done with Batch 35, epoch 0
Do

  1%|▏         | 1/80 [07:00<9:13:07, 420.10s/it]

416.3992943763733 Epoch 0
Done with Batch 0, epoch 1
Done with Batch 1, epoch 1
Done with Batch 2, epoch 1
Done with Batch 3, epoch 1
Done with Batch 4, epoch 1
Done with Batch 5, epoch 1
Done with Batch 6, epoch 1
Done with Batch 7, epoch 1
Done with Batch 8, epoch 1
Done with Batch 9, epoch 1
Done with Batch 10, epoch 1
Done with Batch 11, epoch 1
Done with Batch 12, epoch 1
Done with Batch 13, epoch 1
Done with Batch 14, epoch 1
Done with Batch 15, epoch 1
Done with Batch 16, epoch 1
Done with Batch 17, epoch 1
Done with Batch 18, epoch 1
Done with Batch 19, epoch 1
Done with Batch 20, epoch 1
Done with Batch 21, epoch 1
Done with Batch 22, epoch 1
Done with Batch 23, epoch 1
Done with Batch 24, epoch 1
Done with Batch 25, epoch 1
Done with Batch 26, epoch 1
Done with Batch 27, epoch 1
Done with Batch 28, epoch 1
Done with Batch 29, epoch 1
Done with Batch 30, epoch 1
Done with Batch 31, epoch 1
Done with Batch 32, epoch 1
Done with Batch 33, epoch 1
Done with Batch 34, epoch 1
Done

  2%|▎         | 2/80 [14:00<9:06:17, 420.23s/it]

416.6819143295288 Epoch 1
Done with Batch 0, epoch 2
Done with Batch 1, epoch 2
Done with Batch 2, epoch 2
Done with Batch 3, epoch 2
Done with Batch 4, epoch 2
Done with Batch 5, epoch 2
Done with Batch 6, epoch 2
Done with Batch 7, epoch 2
Done with Batch 8, epoch 2
Done with Batch 9, epoch 2
Done with Batch 10, epoch 2
Done with Batch 11, epoch 2
Done with Batch 12, epoch 2
Done with Batch 13, epoch 2
Done with Batch 14, epoch 2
Done with Batch 15, epoch 2
Done with Batch 16, epoch 2
Done with Batch 17, epoch 2
Done with Batch 18, epoch 2
Done with Batch 19, epoch 2
Done with Batch 20, epoch 2
Done with Batch 21, epoch 2
Done with Batch 22, epoch 2
Done with Batch 23, epoch 2
Done with Batch 24, epoch 2
Done with Batch 25, epoch 2
Done with Batch 26, epoch 2
Done with Batch 27, epoch 2
Done with Batch 28, epoch 2
Done with Batch 29, epoch 2
Done with Batch 30, epoch 2
Done with Batch 31, epoch 2
Done with Batch 32, epoch 2
Done with Batch 33, epoch 2
Done with Batch 34, epoch 2
Done

  4%|▍         | 3/80 [21:01<8:59:55, 420.72s/it]

416.75495648384094 Epoch 2
Done with Batch 0, epoch 3
Done with Batch 1, epoch 3
Done with Batch 2, epoch 3
Done with Batch 3, epoch 3
Done with Batch 4, epoch 3
Done with Batch 5, epoch 3
Done with Batch 6, epoch 3
Done with Batch 7, epoch 3
Done with Batch 8, epoch 3
Done with Batch 9, epoch 3
Done with Batch 10, epoch 3
Done with Batch 11, epoch 3
Done with Batch 12, epoch 3
Done with Batch 13, epoch 3
Done with Batch 14, epoch 3
Done with Batch 15, epoch 3
Done with Batch 16, epoch 3
Done with Batch 17, epoch 3
Done with Batch 18, epoch 3
Done with Batch 19, epoch 3
Done with Batch 20, epoch 3
Done with Batch 21, epoch 3
Done with Batch 22, epoch 3
Done with Batch 23, epoch 3
Done with Batch 24, epoch 3
Done with Batch 25, epoch 3
Done with Batch 26, epoch 3
Done with Batch 27, epoch 3
Done with Batch 28, epoch 3
Done with Batch 29, epoch 3
Done with Batch 30, epoch 3
Done with Batch 31, epoch 3
Done with Batch 32, epoch 3
Done with Batch 33, epoch 3
Done with Batch 34, epoch 3
Don

  5%|▌         | 4/80 [28:01<8:52:38, 420.51s/it]

416.6169629096985 Epoch 3
Done with Batch 0, epoch 4
Done with Batch 1, epoch 4
Done with Batch 2, epoch 4
Done with Batch 3, epoch 4
Done with Batch 4, epoch 4
Done with Batch 5, epoch 4
Done with Batch 6, epoch 4
Done with Batch 7, epoch 4
Done with Batch 8, epoch 4
Done with Batch 9, epoch 4
Done with Batch 10, epoch 4
Done with Batch 11, epoch 4
Done with Batch 12, epoch 4
Done with Batch 13, epoch 4
Done with Batch 14, epoch 4
Done with Batch 15, epoch 4
Done with Batch 16, epoch 4
Done with Batch 17, epoch 4
Done with Batch 18, epoch 4
Done with Batch 19, epoch 4
Done with Batch 20, epoch 4
Done with Batch 21, epoch 4
Done with Batch 22, epoch 4
Done with Batch 23, epoch 4
Done with Batch 24, epoch 4
Done with Batch 25, epoch 4
Done with Batch 26, epoch 4
Done with Batch 27, epoch 4
Done with Batch 28, epoch 4
Done with Batch 29, epoch 4
Done with Batch 30, epoch 4
Done with Batch 31, epoch 4
Done with Batch 32, epoch 4
Done with Batch 33, epoch 4
Done with Batch 34, epoch 4
Done

  6%|▋         | 5/80 [35:02<8:45:32, 420.43s/it]

416.6848359107971 Epoch 4
Done with Batch 0, epoch 5
Done with Batch 1, epoch 5
Done with Batch 2, epoch 5
Done with Batch 3, epoch 5
Done with Batch 4, epoch 5
Done with Batch 5, epoch 5
Done with Batch 6, epoch 5
Done with Batch 7, epoch 5
Done with Batch 8, epoch 5
Done with Batch 9, epoch 5
Done with Batch 10, epoch 5
Done with Batch 11, epoch 5
Done with Batch 12, epoch 5
Done with Batch 13, epoch 5
Done with Batch 14, epoch 5
Done with Batch 15, epoch 5
Done with Batch 16, epoch 5
Done with Batch 17, epoch 5
Done with Batch 18, epoch 5
Done with Batch 19, epoch 5
Done with Batch 20, epoch 5
Done with Batch 21, epoch 5
Done with Batch 22, epoch 5
Done with Batch 23, epoch 5
Done with Batch 24, epoch 5
Done with Batch 25, epoch 5
Done with Batch 26, epoch 5
Done with Batch 27, epoch 5
Done with Batch 28, epoch 5
Done with Batch 29, epoch 5
Done with Batch 30, epoch 5
Done with Batch 31, epoch 5
Done with Batch 32, epoch 5
Done with Batch 33, epoch 5
Done with Batch 34, epoch 5
Done

  8%|▊         | 6/80 [42:02<8:38:20, 420.28s/it]

416.35153818130493 Epoch 5
Done with Batch 0, epoch 6
Done with Batch 1, epoch 6
Done with Batch 2, epoch 6
Done with Batch 3, epoch 6
Done with Batch 4, epoch 6
Done with Batch 5, epoch 6
Done with Batch 6, epoch 6
Done with Batch 7, epoch 6
Done with Batch 8, epoch 6
Done with Batch 9, epoch 6
Done with Batch 10, epoch 6
Done with Batch 11, epoch 6
Done with Batch 12, epoch 6
Done with Batch 13, epoch 6
Done with Batch 14, epoch 6
Done with Batch 15, epoch 6
Done with Batch 16, epoch 6
Done with Batch 17, epoch 6
Done with Batch 18, epoch 6
Done with Batch 19, epoch 6
Done with Batch 20, epoch 6
Done with Batch 21, epoch 6
Done with Batch 22, epoch 6
Done with Batch 23, epoch 6
Done with Batch 24, epoch 6
Done with Batch 25, epoch 6
Done with Batch 26, epoch 6
Done with Batch 27, epoch 6
Done with Batch 28, epoch 6
Done with Batch 29, epoch 6
Done with Batch 30, epoch 6
Done with Batch 31, epoch 6
Done with Batch 32, epoch 6
Done with Batch 33, epoch 6
Done with Batch 34, epoch 6
Don

  9%|▉         | 7/80 [49:02<8:31:19, 420.27s/it]

416.63417077064514 Epoch 6
Done with Batch 0, epoch 7
Done with Batch 1, epoch 7
Done with Batch 2, epoch 7
Done with Batch 3, epoch 7
Done with Batch 4, epoch 7
Done with Batch 5, epoch 7
Done with Batch 6, epoch 7
Done with Batch 7, epoch 7
Done with Batch 8, epoch 7
Done with Batch 9, epoch 7
Done with Batch 10, epoch 7
Done with Batch 11, epoch 7
Done with Batch 12, epoch 7
Done with Batch 13, epoch 7
Done with Batch 14, epoch 7
Done with Batch 15, epoch 7
Done with Batch 16, epoch 7
Done with Batch 17, epoch 7
Done with Batch 18, epoch 7
Done with Batch 19, epoch 7
Done with Batch 20, epoch 7
Done with Batch 21, epoch 7
Done with Batch 22, epoch 7
Done with Batch 23, epoch 7
Done with Batch 24, epoch 7
Done with Batch 25, epoch 7
Done with Batch 26, epoch 7
Done with Batch 27, epoch 7
Done with Batch 28, epoch 7
Done with Batch 29, epoch 7
Done with Batch 30, epoch 7
Done with Batch 31, epoch 7
Done with Batch 32, epoch 7
Done with Batch 33, epoch 7
Done with Batch 34, epoch 7
Don

### Testing

In [None]:
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=128,shuffle=True)
test_iter = iter(testloader)
test_images, test_labels = test_iter.next()
test_images = test_images.to(device)

In [None]:
test_images.shape

torch.Size([128, 1, 234, 234])

In [None]:
# Calculate MSE loss
for i, (batch_x, test_labels) in enumerate(testloader, 0):
  batch_x = batch_x.to(device)
  test_labels = test_labels.to(device)
  predictions = model(batch_x)
  loss = criterion(predictions[0], test_labels[0])
  # print(i, loss)
  pred_np = predictions.cpu().detach().numpy()
  test_np = test_labels.cpu().detach().numpy()
  file1 = open("test_results/predictions_"+str(i)+".pkl", "wb")
  np.save(file1, pred_np, allow_pickle=False)
  file1.close()
  file2 = open("test_results/test_labels_"+str(i)+".pkl", "wb")
  np.save(file2, test_np, allow_pickle=False)
  file2.close()

In [None]:
!zip -r test_results.zip test_results

  adding: test_results/ (stored 0%)
  adding: test_results/predictions_1.pkl (deflated 7%)
  adding: test_results/predictions_67.pkl (deflated 7%)
  adding: test_results/test_labels_10.pkl (deflated 7%)
  adding: test_results/test_labels_16.pkl (deflated 7%)
  adding: test_results/test_labels_40.pkl (deflated 7%)
  adding: test_results/predictions_57.pkl (deflated 7%)
  adding: test_results/test_labels_3.pkl (deflated 7%)
  adding: test_results/predictions_74.pkl (deflated 7%)
  adding: test_results/test_labels_70.pkl (deflated 7%)
  adding: test_results/predictions_23.pkl (deflated 7%)
  adding: test_results/test_labels_35.pkl (deflated 7%)
  adding: test_results/predictions_63.pkl (deflated 7%)
  adding: test_results/test_labels_4.pkl (deflated 7%)
  adding: test_results/test_labels_62.pkl (deflated 7%)
  adding: test_results/predictions_58.pkl (deflated 7%)
  adding: test_results/predictions_71.pkl (deflated 7%)
  adding: test_results/test_labels_36.pkl (deflated 7%)
  adding: test_

## Results on Benchmark dataset
## ImageMol

In [50]:
!tar -xvf "drive/My Drive/SMAI_Project/Img2Mol.tgz" 

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Img2Mol/19999.png
Img2Mol/20000.png
Img2Mol/20001.png
Img2Mol/20002.png
Img2Mol/20003.png
Img2Mol/20004.png
Img2Mol/20005.png
Img2Mol/20006.png
Img2Mol/20007.png
Img2Mol/20008.png
Img2Mol/20009.png
Img2Mol/20010.png
Img2Mol/20011.png
Img2Mol/20012.png
Img2Mol/20013.png
Img2Mol/20014.png
Img2Mol/20015.png
Img2Mol/20016.png
Img2Mol/20017.png
Img2Mol/20018.png
Img2Mol/20019.png
Img2Mol/20020.png
Img2Mol/20021.png
Img2Mol/20022.png
Img2Mol/20023.png
Img2Mol/20024.png
Img2Mol/20025.png
Img2Mol/20026.png
Img2Mol/20027.png
Img2Mol/20028.png
Img2Mol/20029.png
Img2Mol/20030.png
Img2Mol/20031.png
Img2Mol/20032.png
Img2Mol/20033.png
Img2Mol/20034.png
Img2Mol/20035.png
Img2Mol/20036.png
Img2Mol/20037.png
Img2Mol/20038.png
Img2Mol/20039.png
Img2Mol/20040.png
Img2Mol/20041.png
Img2Mol/20042.png
Img2Mol/20043.png
Img2Mol/20044.png
Img2Mol/20045.png
Img2Mol/20046.png
Img2Mol/20047.png
Img2Mol/20048.png
Img2Mol/20049.png
Img2Mol/20050.png

In [None]:
indices = range(5000)
image_list = []
for i in indices:
    image_list.append(os.path.join(img2mol_image_path,"{}.png".format(i)))
img2mol_images = torch.cat([torch.unsqueeze(load_image(image), 0) for image in image_list], dim=0)

In [None]:
img2mol_images.shape

torch.Size([5000, 1, 234, 234])

In [None]:
img2mol_cddd_df = pd.read_csv(img2mol_cddd_path, delimiter=',',  header=None)
img2mol_cddd = torch.from_numpy(img2mol_cddd_df.values).float().to(device)
img2mol_cddd = img2mol_cddd[:5000]

In [None]:
img2mol_cddd.shape

torch.Size([5000, 512])

In [None]:
b_dataset = []
for i in range(len(img2mol_images)):
  b_dataset.append([img2mol_images[i], img2mol_cddd[i]])

benchmarkloader = torch.utils.data.DataLoader(b_dataset, batch_size=128,shuffle=True)

In [None]:
# Calculate MSE loss
for i, (batch_x, test_labels) in enumerate(benchmarkloader, 0):
  batch_x = batch_x.to(device)
  test_labels = test_labels.to(device)
  predictions = model(batch_x)
  loss = criterion(predictions[0], test_labels[0])
  print(i, loss)
  pred_np = predictions.cpu().detach().numpy()
  test_np = test_labels.cpu().detach().numpy()
  file1 = open("benchmark_predictions.pkl", "wb")
  np.save(file1, pred_np, allow_pickle=False)
  file1.close()
  file2 = open("benchmark_test_labels.pkl", "wb")
  np.save(file2, test_np, allow_pickle=False)
  file2.close()
  break

0 tensor(0.1618, device='cuda:0', grad_fn=<MseLossBackward0>)
