In [1]:
from dataset import AIHubDataset,SceneTextDataset
from east_dataset import EASTDataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


import numpy as np
import cv2

def gray_mask_to_heatmap(x):
    x = cv2.cvtColor(cv2.applyColorMap(x, cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB)
    return x

def get_superimposed_image(image, score_map, heatmap=True, w_image=None, w_map=None):
    """
    Args:
        image (ndarray): (H, W, C) shaped, float32 or uint8 dtype is allowed.
        score_map (ndarray): (H, W) shaped, float32 or uint8 dtype is allowed.
        heatmap (boot): Wheather to convert `score_map` into a heatmap.
        w_image (float)
        w_map (float)

    Blending weights(`w_image` and `w_map`) are default to (0.4, 0.6).
    """

    assert w_image is None or (w_image > 0 and w_image < 1)
    assert w_map is None or (w_map > 0 and w_map < 1)

    image = cv2.resize(image, dsize=(score_map.shape[1], score_map.shape[0]))

    if image.dtype != np.uint8:
        image = (255 * np.clip(image, 0, 1)).astype(np.uint8)

    if score_map.dtype != np.uint8:
        score_map = (255 * np.clip(score_map, 0, 1)).astype(np.uint8)
    if heatmap:
        score_map = gray_mask_to_heatmap(score_map)
    elif score_map.ndim == 2 or score_map.shape[2] != 3:
        score_map = cv2.cvtColor(score_map, cv2.COLOR_GRAY2RGB)

    if w_image is None and w_map is None:
        w_image, w_map = 0.4, 0.6
    elif w_image is None:
        w_image = 1 - w_map
    elif w_map is None:
        w_map = 1 - w_image

    return cv2.addWeighted(image, w_image, score_map, w_map, 0)


In [2]:
#
dataset= SceneTextDataset('../input/data/ICDAR17_Korean')
dataset = EASTDataset(dataset)
train_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4)

In [4]:
for img, gt_score_map, gt_geo_map, roi_mask in train_loader:
    idx= np.random.randint(0,img.shape[0])
    print('idx :',idx)
    fig,axes = plt.subplots(1,4,figsize=(24,12))
    axes[0].imshow(img[idx].detach().cpu().numpy().transpose(1,2,0))
    axes[0].set_title('original image')
    axes[1].imshow(roi_mask[idx].detach().cpu().numpy().squeeze())
    axes[1].set_title('roi mask')
    axes[2].imshow(gt_score_map[idx].detach().cpu().numpy().squeeze())
    axes[2].set_title('gt_score_map')
    axes[3].imshow(
        get_superimposed_image(
            img[idx].detach().cpu().numpy().transpose(1,2,0),
            gt_score_map[idx].detach().cpu().numpy().squeeze()
        ))
    axes[3].set_title('superimposed_image')
    break

KeyboardInterrupt: 