In [None]:
import lightning as L
from model import BaseLineUnet
import torch
import torch.nn as nn
from dataset import GoProDataset
from torchvision.transforms import v2
import matplotlib.pyplot as plt
import random
from PIL import Image
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure, LearnedPerceptualImagePatchSimilarity
from tqdm import tqdm
import os

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = BaseLineUnet()
model.to(device)
checkpoint = torch.load("lightning_logs/version_14/checkpoints/best-checkpoint-epoch=267-train_loss=0.047806.ckpt")
model_weights = checkpoint["state_dict"]
for key in list(model_weights):
    model_weights[key.replace("model.", "")] = model_weights.pop(key)
for key in list(model_weights):
    if key.startswith("loss_fn."):
        model_weights.pop(key)
model.load_state_dict(model_weights)
model.eval()

In [None]:
dataset = GoProDataset(root_dir='A', mode='test')

In [None]:
def big_img_inference_to_file(img_tensor, sharp_tensor, index):
    to_PIL = v2.ToPILImage()
    to_PIL(img_tensor).save(f'output/{index}blur.png')
    img_tensor = img_tensor.unsqueeze(0)
    if sharp_tensor is not None:
        sharp_tensor = sharp_tensor.unsqueeze(0)
    B, C, H, W = img_tensor.shape
    pad_h = 256 - (H % 256) if H % 256 != 0 else 0
    pad_w = 256 - (W % 256) if W % 256 != 0 else 0
    img_tensor = F.pad(img_tensor, (0, pad_w, 0, pad_h))
    out_tensor = torch.zeros_like(img_tensor, device=device)
    with torch.inference_mode():
        out_tensor = model(img_tensor)
                
    out_tensor = torch.clamp(out_tensor, 0, 1)
    out_tensor = out_tensor[:, :, :out_tensor.shape[2]-pad_h, :out_tensor.shape[3]-pad_w]
    
    out_tensor = out_tensor.cpu().numpy().transpose(0, 2, 3, 1)
    
    img = to_PIL(out_tensor[0])
    img.save(f'output/{index}predicted.png')
    if sharp_tensor is not None:
        sharp_tensor = sharp_tensor.cpu().numpy().transpose(0, 2, 3, 1)
        sharp = to_PIL(sharp_tensor[0])
        sharp.save(f'output/{index}sharp.png')

In [None]:
def delete_folder_contents(folder_path):
    if not os.path.exists(folder_path):
        print(f"The folder {folder_path} does not exist.")
        return
    for item in os.listdir(folder_path):
        item_path = os.path.join(folder_path, item)
        if os.path.isfile(item_path) or os.path.islink(item_path):
            os.remove(item_path)
            print(f"Deleted file: {item_path}")
delete_folder_contents('output')
os.makedirs('output', exist_ok=True)

In [None]:
index = random.randint(0, len(dataset)-1)
x, y = dataset[index]
x = x.to(device)
y = y.to(device)
big_img_inference_to_file(x, y, index)

In [None]:
img = 'aaa.jpg'
if img != '':
    img = Image.open(img)
    img = img.convert('RGB')
    to_tensor = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])

    img = to_tensor(img)
    img = img.to(device)

    big_img_inference_to_file(img, None, 99999)