In [None]:
import lightning as L
from model import BaseLineUnet
import torch
import torch.nn as nn
from dataset import GoProDataset
from torchvision.transforms import v2
import matplotlib.pyplot as plt
import random
from PIL import Image
import torch.nn.functional as F

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = BaseLineUnet()
model.to(device)
checkpoint = torch.load("lightning_logs/baselineunetRandomPerceptualL1Loss/checkpoints/best-checkpoint-epoch=182-train_loss=0.048400.ckpt")

In [None]:
model_weights = checkpoint["state_dict"]
for key in list(model_weights):
    model_weights[key.replace("model.", "")] = model_weights.pop(key)
for key in list(model_weights):
    if key.startswith("loss_fn."):
        model_weights.pop(key)

In [None]:
model.load_state_dict(model_weights)
model.eval()

In [None]:
dataset = GoProDataset()

In [None]:
sample = dataset[random.randint(0, len(dataset))]
with torch.inference_mode():
    predict = model(sample[0].unsqueeze(0).to(device))

fig, axes = plt.subplots(1, 3, figsize=(10, 5))
axes[0].imshow(sample[0].numpy().transpose(1, 2, 0))
axes[1].imshow(sample[1].numpy().transpose(1, 2, 0))
axes[2].imshow(predict.squeeze().cpu().numpy().transpose(1, 2, 0))
plt.show()

In [None]:
image = v2.ToPILImage()(torch.clamp(sample[0], min=0, max=1))
image.save('blur.png')
image = v2.ToPILImage()(torch.clamp(sample[1], min=0, max=1))
image.save('sharp.png')
image = v2.ToPILImage()(torch.clamp(predict.squeeze(), min=0, max=1))
image.save('predict.png')

In [None]:
def sliding_window_inference(img_path, save_patch=False):
    img = Image.open(img_path).convert('RGB')
    to_tensor = v2.Compose([
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True)
    ])
    to_img = v2.ToPILImage()
    img_tensor = to_tensor(img)
    C, H, W = img_tensor.shape
    pad_h = 256 - (H % 256) if H % 256 != 0 else 0
    pad_w = 256 - (W % 256) if W % 256 != 0 else 0
    img_tensor = F.pad(img_tensor, (0, pad_w, 0, pad_h), mode='constant', value=0).to(device)
    out_tensor = torch.zeros_like(img_tensor).to(device)
    C, H, W = img_tensor.shape
    with torch.inference_mode():
        stride = 128
        for i in range(0, H-255, stride):
            for j in range(0, W-255, stride):
                temp = model(img_tensor[:, i:i+256, j:j+256].unsqueeze(0).to(device)).squeeze()
                temp_out_tensor = out_tensor[:, i:i+256, j:j+256]
                temp_out_tensor = torch.where(temp_out_tensor > 0, (temp_out_tensor + temp) / 2, temp)
                out_tensor[:, i:i+256, j:j+256] = temp_out_tensor
                if save_patch:
                    to_img(torch.clamp(temp_out_tensor, 0, 1)).save('z'+str(i)+'_'+str(j)+'.png')
                
    out_tensor = torch.clamp(out_tensor, 0, 1)
    out_tensor = out_tensor[:, :out_tensor.shape[1]-pad_h, :out_tensor.shape[2]-pad_w]
    output = to_img(out_tensor)
    output.save('y_hat.png')

In [None]:
sliding_window_inference('E:\\Downloads\\GOPRO_Large\\test\\GOPR0384_11_00\\blur_gamma\\000001.png')

In [None]:
def big_img_inference(img_path):
    img = Image.open(img_path).convert('RGB')
    to_tensor = v2.Compose([
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True)
    ])
    to_img = v2.ToPILImage()
    img_tensor = to_tensor(img)
    C, H, W = img_tensor.shape
    pad_h = 256 - (H % 256) if H % 256 != 0 else 0
    pad_w = 256 - (W % 256) if W % 256 != 0 else 0
    img_tensor = F.pad(img_tensor, (0, pad_w, 0, pad_h)).to(device)
    out_tensor = torch.zeros_like(img_tensor).to(device)
    with torch.inference_mode():
        out_tensor = model(img_tensor.unsqueeze(0).to(device)).squeeze()
                
    out_tensor = torch.clamp(out_tensor, 0, 1)
    out_tensor = out_tensor[:, :out_tensor.shape[1]-pad_h, :out_tensor.shape[2]-pad_w]
    output = to_img(out_tensor)
    output.save('y_hat2.png')

In [None]:
big_img_inference('E:\\Downloads\\GOPRO_Large\\test\\GOPR0384_11_00\\blur_gamma\\000001.png')