In [None]:
# Start by importing necessary libraries
import torch
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
import onnxruntime as ort


TEST_IMAGE_PATH = r'H:\my_files\my_programs\cat_upscaler\datasets\raw_cat_images\0dc281ea-1f96-4001-a14f-08d1f4df6f3e.jpg'
ONNX_MODEL_PATH = r'H:\my_files\my_programs\cat_upscaler\cat_downscale_4th_500_count_2025_01_06_17_37_26_epochs_10.onnx'
INFER_COUNT = 100

onnx_session = ort.InferenceSession(ONNX_MODEL_PATH)

transform = transforms.Compose(
    [
        transforms.ToTensor(),  # Convert image to tensor
    ]
)


def load_image(image_path):
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0)  # Add batch dimension
    image = image.rot90(1, [2, 3])
    return image


def infer(model, input_tensor):
    inputs = {model.get_inputs()[0].name: input_tensor.numpy()}

    outputs = model.run(None, inputs)

    output_tensor = torch.tensor(outputs[0])
    return output_tensor


def get_higher_res_image(lr_image):
    with torch.no_grad():
        hr_image = infer(onnx_session, lr_image)
        return hr_image


def show_image_comparison(lr_image, hr_image):
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.imshow(lr_image)
    plt.title("Low-Resolution Input")
    plt.axis("off")
    plt.subplot(1, 2, 2)
    plt.imshow(hr_image)
    plt.title("Super-Resolution Output")
    plt.axis("off")
    plt.show()


def show_image_as_matplotlib(image):
    plt.imshow(image)
    plt.axis("off")
    plt.show()


def upscale_image(image_path, upscale_count, w, h):
    image = load_image(image_path)

    # resize input image to w and h
    image = torch.nn.functional.interpolate(
        image, size=(w, h), mode="bicubic", align_corners=False
    )

    lr_image = image

    images = [image]

    for i in range(upscale_count):
        print(f"doing upscale {i+1} / {upscale_count}", end="\r")
        image = get_higher_res_image(image)
        images.append(image)
    print("\nDone upscaling")
    hr_image = image.squeeze(0).clamp(0, 1).permute(1, 2, 0).numpy()
    lr_image = lr_image.squeeze(0).permute(1, 2, 0).numpy()
    return lr_image, hr_image, images


lr_image, hr_image, images = upscale_image(TEST_IMAGE_PATH, INFER_COUNT, 2560, 2560)
show_image_comparison(lr_image, hr_image)
for image in images:
    show_image_as_matplotlib(image.squeeze(0).permute(1, 2, 0).numpy())


doing upscale 62 / 100