In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt


device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'

from upscaling_model import UpscalingModel
from PIL import Image


In [2]:
import os
from hashlib import md5

from PIL import ExifTags
from pillow_heif import register_heif_opener
register_heif_opener()

os.makedirs('test_images', exist_ok=True)

for filename in os.listdir("Test dataset photos"):
    path = os.path.join("Test dataset photos", filename)
    hash = md5(bytearray(path, 'utf-8')).hexdigest()
    img = Image.open(path)

    # https://stackoverflow.com/a/26928142
    for orientation in ExifTags.TAGS.keys():
        if ExifTags.TAGS[orientation]=='Orientation':
            break
    if hasattr(img, '_getexif'):
        exif = img._getexif()
        if exif is not None and orientation in exif:
            if exif[orientation] == 3:
                img=img.rotate(180, expand=True)
            elif exif[orientation] == 6:
                img=img.rotate(270, expand=True)
            elif exif[orientation] == 8:
                img=img.rotate(90, expand=True)

    img = img.convert("RGB")
    img0 = img
    scale = 256/(img.width*img.height)**0.5
    if scale < 1:
        img = img.resize((int(img.width*scale+0.5), int(img.height*scale+0.5)), Image.LANCZOS)
    print(img.width, img.height, filename)
    img.save(os.path.join('test_images', hash+'.png'))
    gt = img0.resize((img.width*3, img.height*3), Image.LANCZOS)
    gt.save(os.path.join('test_images', hash+'-gt.png'))
    bicubic = img.resize((img.width*3, img.height*3), Image.BICUBIC)
    bicubic.save(os.path.join('test_images', hash+'-bicubic.png'))


225 292 IMG_6286 2.jpg
222 295 PXL_20240329_161801625.RAW-01.COVER.jpg
296 222 20240301_131843.jpg
222 296 IMG-20240213-WA0000.jpg
319 206 1212939746330935396-image.png




296 222 53568265220_59468e4d3a_o.jpg
295 222 PXL_20240329_161646005.RAW-01.COVER.jpg
174 377 IMG_5365.PNG
316 207 d793bfe8-5a9f-48cc-9527-d890c28b99cb.jpg
286 229 IMG-20240205-WA0002.jpg
296 222 20231214_153842.jpg
222 295 PXL_20240329_161349020.RAW-01.COVER.jpg
222 295 PXL_20240329_161429285.RAW-01.COVER.jpg
293 223 image.png
296 222 20230904_163004.jpg
296 222 20240328_112527.jpg
222 296 IMG_2180 2.heic
222 296 20240130_205627.jpg
296 222 20230904_112951.jpg
222 296 IMG_1100 2.HEIC
317 207 Screenshot from 2024-03-27 15-43-05.png
295 222 PXL_20240329_161847168.RAW-01.COVER.jpg
222 296 IMG_1631 2.HEIC


In [4]:
def compute_W_and_B(batch1, batch2):
    # Reshape batches to (batch_size, num_channels, num_pixels)
    batch1 = batch1.view(batch1.size(0), batch1.size(1), -1)
    batch2 = batch2.view(batch2.size(0), batch2.size(1), -1)
    
    # Compute means of the batches
    mean_batch1 = torch.mean(batch1, dim=2, keepdim=True)
    mean_batch2 = torch.mean(batch2, dim=2, keepdim=True)
    
    # Compute centered batches
    centered_batch1 = batch1 - mean_batch1
    centered_batch2 = batch2 - mean_batch2
    
    # Compute covariance matrix
    xTy = torch.matmul(centered_batch1, centered_batch2.transpose(1, 2))
    xTx = torch.matmul(centered_batch1, centered_batch1.transpose(1, 2))
    
    # Compute weight matrix
    W = torch.matmul(torch.linalg.inv(xTx), xTy)
    W = W.transpose(1, 2)

    # Compute bias vector
    B = mean_batch2.squeeze() - torch.matmul(W, mean_batch1).squeeze()

    return W, B

def correct_color_shift(x, y):
    # W, B = compute_W_and_B(y, x)
    # y1 = torch.matmul(W, y.reshape(1, 3, -1)).view_as(y) + B.reshape(y.shape[0], y.shape[1], 1, 1)

    def get_ms_1(x):
        u = torch.mean(x, (2, 3)).reshape((len(x), 3, 1, 1))
        s = torch.std(x, (2, 3)).reshape((len(x), 3, 1, 1))
        return u, s
    ux, sx = get_ms_1(x)
    uy, sy = get_ms_1(y)
    y1 = (y-uy)*(sx/sy)+ux

    return torch.clamp(y1, 0.0, 1.0)


model = torch.load("final_models/model_8_64_10_12_51_g.pth").to(device)

for filename in os.listdir("test_images"):
    if '-' in filename:
        continue
    filename = os.path.join("test_images", filename)
    print(filename)
    img = Image.open(filename)
    x = np.array(img, dtype=np.float32) / 255.0
    x = np.transpose(x, (2, 0, 1))
    with torch.no_grad():
        x = torch.tensor(x, device=device).unsqueeze(0)
        y = model(x)
        x3 = F.interpolate(x, scale_factor=3)
        y = correct_color_shift(x3, y)
    y = y[0].cpu().numpy()
    y = np.transpose(y, (1, 2, 0))
    y = (y*255).astype(np.uint8)
    filename = filename[:filename.rfind('.')] + '-x3.png'
    Image.fromarray(y).save(filename)

test_images/92de6702244303a86415e73c226aad40.png
test_images/335c9330fb4039d0dad6eac436d41870.png
test_images/79e4d548c008d6ddd65936694bd39947.png
test_images/86fa6d32cd5c1c1bbc1edae7c0582736.png
test_images/705439935649698038a4a9f3965642b8.png
test_images/46bbb3abe03c2165f800a4c23a1d7979.png
test_images/dbce7fc2513c93ce70dff1331b5e88eb.png
test_images/d0ac30fa1232f2647ae66e391f63b75f.png
test_images/6a464ccb3298d7ca722edd0860a68115.png
test_images/2968097a7600848ac4d355f64252e87c.png
test_images/2d20003fe50beaf62366be76c2d8d4d7.png
test_images/5d80e7926fefbe3cd9e1076749d27c0d.png
test_images/ff1e886199fcfc5e2d780b954426030b.png
test_images/1561c24c73923e5d45f087824775d11c.png
test_images/b78d1af230e05ae4777bb2aef97af047.png
test_images/fcf7ad0d4786868c27b71c03322bd48e.png
test_images/9c13d6c16cf68e590b58188b94b9b7d1.png
test_images/1f7f2a0ffe106dcb848f26e366faccc7.png
test_images/0cdc33ecfb38841606fca6b434668bee.png
test_images/a2bb97a1c93d12945215774ce445dd6d.png
test_images/15a87487