In [1]:
from pathlib import Path

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

from PIL import ImageDraw, Image
from shapely.geometry import Polygon, MultiPoint

from model.utils.targets import DBNetTargets
from model.utils.chem import MolSVG

from multiprocessing import Pool

import cv2

In [2]:
def get_img_size(smi):
    try:
        _svg = MolSVG(smi, precision=1)
        return _svg.image_size
    except:
        print(smi)
        return (0, 0)
        
def get_mol_weight(smi):
    try:
        _mol = Chem.MolFromSmiles(smi)
        return Descriptors.MolWt(_mol)
    except:
        print(smi)
        return 0

def chk_ion_molecules(smi):
    if '.' in smi:
        return 1
    return 0

def rewrite_id(x):
    _length = len(x)
    _numbers = int(x[6:])
    return f"CHEMBL{_numbers:07d}"

def save_image(idx):
    _id = df.iloc[idx, 0]
    _smi = df.iloc[idx, 1]
    _dir = f"{HOME_DIR}/{_id[6]}/{_id[7]}/{_id[8]}/{_id[9]}/{_id[10]}"
    if not Path(f"{_dir}/{_id}.thr_mask.png").exists():
        Path(_dir).absolute().mkdir(parents=True, exist_ok=True)
        mol_svg = MolSVG(_smi)
        x = mol_svg.image
        y = targets.generate_targets(mol_svg)
        pil_image = Image.fromarray(x)
        pil_image.save(f"{_dir}/{_id}.png")
        pil_img_shr = Image.fromarray(y['gt_shr'])
        pil_img_shr.save(f"{_dir}/{_id}.shr.png")
        pil_img_shr_mask = Image.fromarray(y['gt_shr_mask'])
        pil_img_shr_mask.save(f"{_dir}/{_id}.shr_mask.png")
        pil_img_thr = Image.fromarray(y['gt_thr'])
        pil_img_thr.save(f"{_dir}/{_id}.thr.png")
        pil_img_thr_mask = Image.fromarray(y['gt_thr_mask'])
        pil_img_thr_mask.save(f"{_dir}/{_id}.thr_mask.png")
    # try:
    # except:
    #     print(idx)

In [None]:
df = pd.read_csv('~/Data/chembl_30_chemreps_eval.csv')

targets = DBNetTargets(shrink_ratio=0.5)

idx = 54
# _id = 'test'
_smi = df.iloc[idx, 1]
# _smi = 'CN(C)CCOc1cc(NC(=O)Nc2cccc([N+](=O)[O-])c2)ccc1I'
# _smi = 'O=C(OCc1ccccc1)c1sc2c([N+](=O)[O-])c(O)c(O)cc2c1Cl'
_dir = '.'
# Path(_dir).absolute().mkdir(parents=True, exist_ok=True)
mol_svg = MolSVG(_smi)
x = mol_svg.image
y = targets.generate_targets(mol_svg)
# pil_image = Image.fromarray(x)
# pil_image.save(f"{_dir}/{_id}.png")
# pil_img_shr = Image.fromarray(y['gt_shr'])
# pil_img_shr.save(f"{_dir}/{_id}.shr.png")
# pil_img_shr_mask = Image.fromarray(y['gt_shr_mask'])
# pil_img_shr_mask.save(f"{_dir}/{_id}.shr_mask.png")
# # y['gt_thr'] *= 255
# pil_img_thr = Image.fromarray(y['gt_thr'])
# pil_img_thr.save(f"{_dir}/{_id}.thr.png")
# pil_img_thr_mask = Image.fromarray(y['gt_thr_mask'])
# pil_img_thr_mask.save(f"{_dir}/{_id}.thr_mask.png")

# plt.imshow(y['gt_shr'])
fig, ax = plt.subplots(1, 4, figsize=(21, 6))
ax[0].imshow(y['gt_shr'])
ax[1].imshow(y['gt_shr_mask'])
ax[2].imshow(y['gt_thr'])
ax[3].imshow(y['gt_thr_mask'])

In [3]:
# df = pd.read_csv('~/Data/chembl_30_chemreps_mw600nonion.csv')

train = pd.read_csv('~/Data/chembl_30_chemreps_train.w.csv')
eval = pd.read_csv('~/Data/chembl_30_chemreps_eval.csv')
test = pd.read_csv('~/Data/chembl_30_chemreps_test.csv')

df = pd.concat([train, eval, test])

In [4]:
HOME_DIR = ''

targets = DBNetTargets(shrink_ratio=0.5)
_index = [_ for _ in range(len(df))]

with Pool(16) as pool:
    pool.map(save_image, _index, chunksize=8)

# for i in _index:
#     save_image(i)

In [None]:
idx = 0 
_id = df.iloc[idx, 0]
_smi = df.iloc[idx, 1]
mol_svg = MolSVG(_smi)
x = mol_svg.image
y = targets.generate_targets(mol_svg)

In [None]:
idx = 0
id = df.iloc[idx, 0]
image, mask = load_image(idx)

In [None]:
char_pos = get_char_pos(mask)

In [None]:
_max, _min = char_pos[0]
_image = image[_min[1]:_max[1], _min[0]:_max[0]]

plt.imshow(_image)