In [None]:
# Importing important libraries
import random
import numpy as np
import torch
import torchvision as tv
import PIL
import rdkit

In [None]:
# Auxiliary Functions

def initialize (model, checkpoint) :
  if checkpoint:
    model = model.load_from_checkpoint(checkpoint)

  model.eval()
  model.to("cpu")
  return model


def fit (image):
  factor = 224.0 / max(image.size)
  size = ([int(s * factor) for s in image.size])
  image = image.resize(size, PIL.Image.BICUBIC)

  newImage = PIL.Image.new("L", (224, 224), "white")
  newImage.paste(image, ((224 - size[0]) // 2, (224 - size[1]) // 2))

  return PIL.ImageOps.expand(newImage, int(np.random.randint(5, 25, size=1)), "white")


def transform (cls, image):
  image = cls.fit_image(image)

  shearVal = np.random.uniform(0.1, 7)

  PILImage = tv.transforms.RandomRotation(degrees = (-15, 15), expand = True, center = None, fill = 255, resample = 3)(image)
  PILImage = tv.transforms.ColorJitter(brightness=[0.75, 2.0], contrast=0, saturation=0, hue=0)(PILImage)
  shear = random.choice([[0, 0, -1 * shearVal, shearVal], [-1 * shearVal, shearVal, 0, 0], [-1* shearVal, shearVal, -1 * shearVal, shearVal]])
  PILImage = tv.transforms.RandomAffine(degrees = 0, translate = None, scale = None, shear = shear, resample = 3, fillcolor = 255)(PILImage)
  PILImage = tv.ImageEnhance.Contrast(tv.ImageOps.autocontrast(PILImage)).enhance(2.0)
  PILImage = tv.transforms.Resize(size = (224, 224), interpolation = 3)(PILImage)
  PILImage = tv.ImageOps.autocontrast(PILImage)
  PILImage = tv.transforms.ToTensor()(PILImage)
  return PILImage

def read (path):
  fileType = path.split('.')[-1]
  if fileType == 'jpg' or fileType == 'jpeg' or fileType == 'png':
    image = PIL.Image.open(path, 'r')

    if image.mode == 'RGBA':
      background = PIL.Image.new('RGB', image.size, (255, 255, 255))
      background.paste(image, (0, 0), image)
      image = background.convert('L')
    else:
      image = image.convert('L')

    images = torch.cat([torch.unsqueeze(transform(image), 0) for i in range(50)], dim=0)
    return images.to("cpu")

  else:
    return "Incorrect file type"

In [None]:
def predict (path, model, server):
  images = read(path)

  with torch.no_grad():
    cddd = model(images).detach().cpu().numpy()
    cddd = np.median(cddd, axis=0)
    smiles = server.cddd_to_smiles(cddd.tolist())
    mol = rdkit.Chem.MolFromSmiles(smiles, sanitize=True)
    if mol:
        isSmiles = rdkit.Chem.MolToSmiles(mol)
        isMol = rdkit.Chem.MolFromSmiles(isSmiles)
    else:
        can_smiles = None
        can_mol = None

    ret = {"filepath": path, "smiles": isSmiles, "mol": isMol}
    return ret