In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import cv2
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader

from glob import glob

from src.models import FPathPredictor, UNetFPathPredictor
from src.dataloaders import FPathDataset
from src.utils import inference
from src.preprocesses import VTFPreprocessor, VTFPreprocessorUNet, ImagePreprocessor

In [None]:
vtf_paths = glob("/home/joono/VTFSketch/dataset/simple_data/test/vtfs/*")
vtf_paths.sort()
vtfs      = [VTFPreprocessorUNet.get(vtf_path) for vtf_path in vtf_paths]

img_paths = glob("/home/joono/VTFSketch/dataset/simple_data/test/imgs/*")
img_paths.sort()
imgs      = [ImagePreprocessor.get(img_path) for img_path in img_paths]

In [None]:
weight_path = "/home/joono/VTFSketch/checkpoints/results_20240523_135551/best_model_loss_1_1141014397144318.pth"

model = UNetFPathPredictor()
model.load_state_dict(torch.load(weight_path))

model = model.to("cuda")
model = model.eval()

In [None]:
vtf, vtf_path = vtfs[1], vtf_paths[1]
img, img_path = imgs[1], img_paths[1]

print(f"{vtf_path=}, {img_path=}")

vtf = torch.tensor(vtf).to('cuda')
img = torch.tensor(img).to('cuda')
pred = inference(model, vtf.unsqueeze(0), img.unsqueeze(0))

result = pred.squeeze().detach().cpu().numpy()
result = result.transpose((1, 0))
# plt.imshow(result, cmap="gray")

cv2.imwrite(f"{os.path.basename(vtf_path).split('.')[0]}.png", result * 255)

In [None]:
import cv2
import numpy as np

import torch
from src.dataloaders import get_data_loaders, get_FPathUNetDataset, UNetFPathDataset

import matplotlib.pyplot as plt

In [None]:
dset = UNetFPathDataset(config_path="dataset/test.yaml")

In [None]:
vtf, img, target = dset[0]

In [None]:
print(vtf.shape, img.shape, target.shape)
print(np.min(vtf), np.max(vtf), torch.min(img), torch.max(img), torch.min(target), torch.max(target))

In [None]:
plt.imshow(vtf[10], cmap="gray")

In [None]:
plt.imshow(img.permute((1, 2, 0)))

In [None]:
mask = torch.tensor(vtf[10, :, :] != 1.0)
print(mask.shape)
W, B = torch.sum(target), torch.sum(1-target)
mask_W, mask_B = mask * target, mask * (1-target)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(vtf[10, :, :], cmap='gray')
plt.axis('off')
cv2.imwrite("infodraw.png", vtf[10, :, :] * 255)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(mask.numpy(), cmap='gray')
plt.axis('off')
cv2.imwrite("mask.png", mask.numpy() * 255)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(mask_W[0], cmap='gray')
plt.axis('off')
cv2.imwrite("mask_W.png", mask_W[0].numpy() * 255)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(mask_B[0], cmap='gray')
plt.axis('off')
cv2.imwrite("mask_B.png", mask_B[0].numpy() * 255)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(target[0], cmap='gray')
plt.axis('off')
cv2.imwrite("target_W.png", target[0].numpy() * 255)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(1-target[0], cmap='gray')
plt.axis('off')
cv2.imwrite("target_B.png", (1-target[0]).numpy() * 255)