In [1]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import ipywidgets as widgets
from ipywidgets import interact
from IPython.display import display

from sklearn.metrics import precision_score, recall_score, f1_score

from src.dataloaders import FPathDataset, FPathLazyDataset

In [None]:
dataset = FPathLazyDataset("/home/work/joono/VTFSketch/dataset/val.yaml")

In [None]:
def draw_gray_img(img):
    plt.figure(figsize=(10.24, 10.24))
    plt.imshow(img, cmap='gray')
    plt.axis('off')
    
def draw_img(img):
    plt.figure(0, figsize=(10.24, 10.24))
    plt.imshow(img)
    plt.axis('off')

In [None]:
def save_vtf_gif(vtf):
    from IPython.display import Image as Img
    from IPython.display import display
    
    infodraw = Image.fromarray(vtf[10] * 255)
    infodraw.save("out.gif", save_all=True, append_images=[Image.fromarray(vtf[i] * 255) for i in range(21)], loop=0xff, duration=500)

In [None]:
vtf, img, infodraw, target = dataset[6]
vtf, img, infodraw, target = np.array(vtf), np.array(img), np.array(infodraw), np.array(target)

print(vtf.shape, img.shape, infodraw.shape, target.shape)
draw_gray_img(np.concatenate([infodraw[0], target[0]], axis=1))

In [None]:
@interact(threshold_w=widgets.FloatSlider(min=0.0, max=1.0, step=0.01, value=0.99, description='Threshold W'),
          threshold_b=widgets.FloatSlider(min=0.0, max=1.0, step=0.01, value=0.0, description='Threshold B'))
def display_visualizations(threshold_w, threshold_b):
    global infodraw, target
    
    # Create an RGB image by stacking the grayscale image
    infodraw_rgb = np.stack([infodraw[0], infodraw[0], infodraw[0]], axis=2)
    infodraw_rgb_origin = infodraw_rgb.copy()
    
    # Apply thresholds
    infodraw_rgb[np.where(infodraw[0] < threshold_w)] = [1.0, 0.0, 0.0]
    # infodraw_rgb[np.where(target[0] == 0)] = [0.0, 0.0, 1.0]
    
    # print precision and recall
    # precision_b = precision_score((1-target).flatten(), (infodraw[0] < threshold_b).flatten())
    # recall_b    = recall_score((1-target).flatten(), (infodraw[0] < threshold_b).flatten())
    # f1score_b   = f1_score((1-target).flatten(), (infodraw[0] < threshold_b).flatten())
    # precision_w = precision_score((target).flatten(), (infodraw[0] > threshold_w).flatten())
    # recall_w    = recall_score((target).flatten(), (infodraw[0] > threshold_w).flatten())
    # f1score_w   = f1_score((target).flatten(), (infodraw[0] > threshold_w).flatten())
    # print(f"precision_b: {precision_b}, recall_b: {recall_b}, f1_score_b: {f1score_b}")
    # print(f"precision_w: {precision_w}, recall_w: {recall_w}, f1_score_w: {f1score_w}")
    
    # Display the image
    plt.figure(figsize=(30.24, 10.48))
    
    plt.subplot(1, 3, 1)
    plt.imshow(infodraw_rgb_origin)
    plt.axis('off')
    
    plt.subplot(1, 3, 2)
    plt.imshow(infodraw_rgb)
    plt.axis('off')
    
    plt.subplot(1, 3, 3)
    plt.imshow(target[0], cmap='gray')
    plt.axis('off')
    
    # Display the histogram
    # plt.subplot(2, 3, 4)
    # plt.hist(infodraw[0][np.where(infodraw[0] < threshold_w)], bins=300)
    # plt.axvline(x=threshold_w, color='r', linestyle='--')
    # plt.axvline(x=threshold_b, color='b', linestyle='--')
    # plt.ylim(0.0, 10000)
    # plt.xlim(0.0, 1.0)
    plt.show()

In [None]:
precision = precision_score((1-target).flatten(), (infodraw[0] < 0.8).flatten())

In [None]:
precision