In [None]:
import numpy as np
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader

from model.utils.transfroms import get_train_transform, get_valid_transform, get_test_transform
from model.dataset import CustomDataset

HOME = ''
HOME_DIR = f'../Data/ChEMBL/OCR'
DATAFRAME_LIST = dict(train=f'data/chembl_31_smiles_train.csv',
                        val=f'data/chembl_31_smiles_val.csv',
                        test=f'data/chembl_31_smiles_test.csv')

def collate_fn(batch):
    return tuple(zip(*batch))

ocr_dataset = CustomDataset(
        data_df=DATAFRAME_LIST['val'],
        # mode='val',
        mode='test',
        # transforms=get_valid_transform(),
        # transforms=get_train_transform(),
        transforms=get_test_transform(),
        dir_path=HOME_DIR,
    )


In [None]:
image, gt_shr, gt_shr_mask, gt_thr, gt_thr_mask = ocr_dataset[14]

fig, ax = plt.subplots(1, 5, figsize=(24, 6))
ax[0].imshow(image[0])
ax[1].imshow(gt_shr)
ax[2].imshow(gt_shr_mask)
ax[3].imshow(gt_thr)
ax[4].imshow(gt_thr_mask)

In [None]:
import torch
from torch import Tensor

from model.dbnet import DBNet
from model.loss import DBLoss

model = DBNet(
        inner_channels=128,
        out_channels=64,
        head_in_channels=320,
        test=True,
    )

model.load_state_dict(torch.load('model_weights.v9.mbv3s.final.pth'), strict=False)
# model.load_state_dict(torch.load('model_weights.v9.mbv3s.15.pth'), strict=False)
# model.cpu()
model.eval()

idx = 0

In [None]:
# image, gt_shr, gt_shr_mask, gt_thr, gt_thr_mask = ocr_dataset[idx]
image = ocr_dataset[idx]
# image = ocr_dataset[54]
image = image[None,]

x = model(image)
print(idx)

fig, ax = plt.subplots(2, 4, figsize=(24, 10))
ax[0, 0].imshow(image[0, 0].detach().numpy())
# ax[0, 1].imshow(x[0, 0].detach().numpy())
# ax[0, 2].imshow(x[0, 1].detach().numpy())
# ax[0, 3].imshow(x[0, 2].detach().numpy())
ax[0, 1].imshow(x[0, 1].detach().numpy())
ax[0, 2].imshow(x[0, 2].detach().numpy())
ax[0, 3].imshow(x[0, 3].detach().numpy())
ax[1, 0].imshow(x[0, 4].detach().numpy())
ax[1, 1].imshow(x[0, 5].detach().numpy())
ax[1, 2].imshow(x[0, 6].detach().numpy())
ax[1, 3].imshow(x[0, 7].detach().numpy())
# ax[1, 0].imshow(gt_shr.detach().numpy())
# ax[1, 1].imshow(gt_shr.detach().numpy() == 5)
# ax[1, 2].imshow(gt_shr.detach().numpy() == 6)
# ax[1, 3].imshow(gt_shr.detach().numpy() == 7)

idx += 1

In [None]:
from PIL import Image
from datetime import datetime

def save_image(img, pred, idx):
    now = datetime.now()
    cur_time_str = now.strftime("%d%m%Y_%H%M%S")

    img = np.array(img*255, dtype=np.uint8)
    pil_image = Image.fromarray(img)
    pil_image.save(f"tmp_img/{cur_time_str}_{idx}_{pred}.png")


In [None]:
from utils.parser import get_mol_conn_info, get_mol

In [None]:
_idx = 200

In [None]:
# image = ocr_dataset[_idx]
# image, gt_shr, gt_shr_mask, gt_thr, gt_thr_mask = ocr_dataset[2]
image = image[None,]

neck_out = model.neck(image)
out = model.head(neck_out)
out = out.detach().cpu().numpy()

contours, b_pair, pred_heavy_char_list, pred_char_list, pred_img_char_list = get_mol_conn_info(out)

for idx, img in enumerate(pred_img_char_list):
    save_image(img, pred_char_list[idx], idx)

_idx += 1

get_mol(contours, pred_heavy_char_list, b_pair)

In [None]:
fig, ax = plt.subplots(1, 1, dpi=150)
img = image[0, 0].detach().numpy().copy()
_contours, _ = cv2.findContours(255*np.array(out[0][1] > 0.2, dtype=np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
tmp_contours = []
for _polygon in _contours:
    if _polygon.shape.__len__() > 1:
        rect = cv2.minAreaRect(_polygon)
        (x, y), (w, h), ang = rect
        if w * h > 9:
            tmp_contours.append(_polygon)
ax.imshow(cv2.drawContours(img, tmp_contours, -1, (5), 1))

In [None]:
fig, ax = plt.subplots(2, 4, figsize=(24, 10))
ax[0, 0].imshow(image[0, 0].detach().numpy())
# ax[0, 1].imshow(x[0, 0].detach().numpy())
# ax[0, 2].imshow(x[0, 3].detach().numpy())
ax[0, 1].imshow(out[0, 1])
ax[0, 2].imshow(out[0, 2])
ax[0, 3].imshow(out[0, 3])
ax[1, 0].imshow(out[0, 4])
ax[1, 1].imshow(out[0, 5])
ax[1, 2].imshow(out[0, 6])
ax[1, 3].imshow(out[0, 7])

idx += 1

In [None]:
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.drawOptions.addAtomIndices = False

In [None]:
# for i in contours[1]: plt.scatter(i[0], i[1], s=4200, c='none', edgecolors='k', alpha=0.2)
k = 0
for i in contours[1]: plt.scatter(i[0], i[1], c='k', alpha=0.5)
for i in contours[1]: plt.text(i[0], i[1], k); k += 1
k = 0
# for i in contours[2]: plt.scatter(i[0], i[1], s=4200, c='none', edgecolors='b', alpha=0.2)
for i in contours[2]: plt.scatter(i[0], i[1], c='b', alpha=0.5)
for i in contours[2]: plt.text(i[0], i[1], k); k += 1
k = 0
for i in contours[6]: plt.scatter(i[0], i[1], c='gray', alpha=0.5)
for i in contours[6]: plt.text(i[0], i[1], k, c='gray'); k += 1