In [None]:
import os
from pathlib import Path
import sys
import matplotlib.pyplot as plt

sys.path.append(os.path.abspath(".."))

import torch
from src.models.ddpm import *
from src.models.unet import Unet
from src.models.ddpm_classifier_free import Unet as Unet_class
from src.utils.image_utils import save_image_to_dir, save_patches_to_dir
from src.utils.model_utils import (load_model, load_classifier_free_model, generate_whole_image, 
                                   create_lcl_ctx_channels, create_inputs, generate_patches, stitch_patches,
                                   create_patch_channels)
from src.config import IS_COND, OVERLAP

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# load model for generating whole images in resolution 256x256
whole_img_model_path = '../models/artifacts/vindr_healthy_256:v82/model_124499.pt'
if IS_COND:
    whole_image_model = load_classifier_free_model(whole_img_model_path, channels=1, num_classes=3)
    img_class = 0
else:
    whole_image_model = load_model(whole_img_model_path, channels=1)
    img_class = ''

# load model for generating local contexts (mid-resolution images)
local_context_model_path = '../models/artifacts/vindr_lcl_ctx_3072:v37/model_56999.pt'
local_context_model = load_model(local_context_model_path, channels=3)

patch_model_path = '../models/artifacts/vindr_3c_256_v2:v84/model_169999.pt'
patch_model = load_model(patch_model_path, channels=3)

In [None]:
img = generate_whole_image(whole_image_model, device, batch_size=1, img_class=img_class)
save_image_to_dir(img, '../images/whole_small.png')
plt.imshow(img[0], 'gray')

In [None]:
img_channels, patch_coords = create_lcl_ctx_channels(img, overlap=OVERLAP)
print(len(img_channels))
inputs, black_idx = create_inputs(img, img_channels, patch_coords, mask_shape=1024)
print(len(inputs), len(black_idx))
local_contexts = generate_patches(local_context_model, inputs, black_idx, timesteps=timesteps, overlap=OVERLAP, device=device)

In [None]:
mid_img = stitch_patches(local_contexts, overlap=0.125)
plt.imshow(mid_img, 'gray')

In [None]:
save_image_to_dir(mid_img, '../images/mid_img.png')
save_patches_to_dir(local_contexts, '../images/local_contexts')

In [None]:
img_channels, patch_coords = create_patch_channels(torch.from_numpy(mid_img).unsqueeze(0), img, overlap=OVERLAP)
inputs, black_idx = create_inputs(img, img_channels, patch_coords, mask_shape=3072)
print(len(inputs), len(black_idx))

In [None]:
patches = generate_patches(patch_model, inputs, black_idx, timesteps=timesteps, overlap=OVERLAP, device=device)
final_img = stitch_patches(patches, overlap=OVERLAP)
plt.imshow(final_img, 'gray')

In [None]:
save_image_to_dir(final_img,  '../images/final_img.png')
save_patches_to_dir(patches, '../images/patches')