In [None]:
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline
import argparse
import os
import numpy as np
import torch
import torch.optim
from torch import nn
from torchvision import transforms
from model import UNet
from metrics import Metrics
import torch.functional as F
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from torch.utils.data import Dataset, DataLoader

In [None]:
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(DEVICE)

In [None]:
class DIP:
    def __init__(
        self, 
        low_res: torch.Tensor, 
        high_res: torch.Tensor,
        input: torch.Tensor,
        reg_noise_std: float, 
        num_iterations: int,
        criterion: nn.MSELoss,
        optimizer: torch.optim.Optimizer,
        device: torch.device,
        metrics: Metrics,
        model: nn.Module,
        summary_writer: SummaryWriter,
    ):
        self.low_res = low_res
        self.low_res_size = self.low_res.shape[-1]
        self.high_res = high_res.unsqueeze(0)
        self.input = input
        self.reg_noise_std = 0.05
        self.num_iterations = num_iterations
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.metrics = metrics
        self.model = model
        self.summary_writer = summary_writer
        self.step = 0
        self.history = []
    
    def closure(self, input_saved: torch.Tensor, noise: torch.Tensor):        
        if self.reg_noise_std > 0:
            self.input = input_saved + (noise.normal_() * self.reg_noise_std)
    
        self.input = self.input.to(self.device)
        self.low_res = self.low_res.to(self.device)
        
        output_hr = self.model(self.input)
        output_lr = downsample(output_hr, (self.low_res_size, self.low_res_size))
        
        loss = self.criterion(output_lr, self.low_res)
        loss.backward()

        if self.step <= 100 and self.step % 10 == 0 or \
           self.step % 100 == 0 or \
           self.step == self.num_iterations-1:
        
            mse, psnr, ssim, lpips = self.evaluate(output_hr)     
            print(f"{self.step} Loss: {loss} MSE: {mse} PSNR: {psnr} SSIM: {ssim} LPIPS: {lpips}")
            self.log(loss, mse, psnr, ssim, lpips, self.step)
            if self.step == self.num_iterations-1:
                self.mse = mse
                self.psnr = psnr
                self.ssim = ssim
                self.lpips = lpips
            
            output_hr_numpy = output_hr.squeeze(0).permute(1,2,0).cpu().detach().numpy()
            plt.imshow(output_hr_numpy)
            plt.show()
            self.history.append(output_hr_numpy)
        self.step += 1

    def train(self, input: torch.Tensor) -> (float, float, float, float) :
        input_saved = input.detach().clone()
        noise = input.detach().clone()
        
        for i in range(self.num_iterations):
            self.optimizer.zero_grad()
            self.closure(input_saved, noise)
            self.optimizer.step()

        return self.mse, self.psnr, self.ssim, self.lpips

    def evaluate(self, prediction):
        prediction = prediction.to(self.device)
        self.high_res = self.high_res.to(self.device)
        mse = self.metrics.calculate_mse(prediction, self.high_res)
        psnr = self.metrics.calculate_psnr(prediction, self.high_res)
        ssim = self.metrics.calculate_ssim(prediction, self.high_res)
        lpips = self.metrics.calculate_lpips(prediction, self.high_res)
        return mse, psnr, ssim, lpips

    def log(self, loss, mse, psnr, ssim, lpips, step):
        self.summary_writer.add_scalars(
            "loss",
            {"train": float(loss)},
            self.step
        )
        self.summary_writer.add_scalars(
            "mse",
            {"train": float(mse)},
            self.step
        )
        self.summary_writer.add_scalars(
            "psnr",
            {"train": float(psnr)},     
            self.step
        )
        self.summary_writer.add_scalars(
            "ssim",
            {"train": float(ssim)},
            self.step
        )
        self.summary_writer.add_scalars(
            "lpips",
            {"train": float(lpips)},
            self.step
        )

    def get_history(self) -> np.ndarray:
        return self.history

In [None]:
def downsample(image: torch.Tensor, size: (int, int)):
    transform = transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC)
    return transform(image)

In [None]:
def display_tensor(tensor: torch.Tensor):
    tensor = tensor.permute(1,2,0).detach().cpu().numpy()
    tensor = (tensor - np.min(tensor)) / (np.max(tensor) - np.min(tensor))
    plt.imshow(tensor)

In [None]:
def min_max_norm(tensor: torch.Tensor) -> torch.Tensor:
    tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
    return tensor

In [None]:
def main(low_res: torch.Tensor, high_res: torch.Tensor, id: int, num_iterations: int, reg_noise_std: float, histories, log_id: str):
    input = torch.randn_like(high_res).unsqueeze(0)
    input *= 0.1
    
    model = UNet(input_channels=3, output_channels=3)
    model = model.to(DEVICE)
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    metrics = Metrics()
    
    log_dir = "./logs"
    os.makedirs(log_dir, exist_ok=True)
    log_path = f"{log_dir}/run_{log_id}/{id}" 
    print(f"Writing logs to {log_path}")
    summary_writer = SummaryWriter(
        str(log_path),
        flush_secs=5,
    )

    dip = DIP(
        low_res = low_res, 
        high_res = high_res,
        input = input,
        reg_noise_std = reg_noise_std, 
        num_iterations = num_iterations,
        criterion = criterion,
        optimizer = optimizer,
        device = DEVICE,
        metrics = metrics,
        model = model,
        summary_writer = summary_writer,      
    )

    histories[id] = dip.get_history()

    return dip.train(input)

In [None]:
def plot_history(history: [np.ndarray]):
    m, n = 4, 5
    fig, axes = plt.subplots(m, n)
    for j in range(m):
        for i in range(n):
            if m == 1:
                axes[i].imshow(history[i])
                axes[i].axis("off")
            else:
                axes[j][i].imshow(history[j*n + i])
                axes[j][i].axis("off")
    plt.subplots_adjust(wspace=0, hspace=0)
    plt.show()
    

In [None]:
NOISE_LEVEL = 0
DOWN_FACTOR = 8
LOOPS = 1
lr_input_dir = f"./input/noise_{NOISE_LEVEL}_down_{DOWN_FACTOR}"
hr_input_dir = f"./input/original_high"

In [None]:
image_indices = [(0,1), (1,0), (1,3), (3,7)]
overall_histories = {}
time = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
total_mse = 0
total_psnr = 0
total_ssim = 0
total_lpips = 0
for j in range(LOOPS):
    overall_histories[j] = {}
    histories = overall_histories[j]
    for i, indices in enumerate(image_indices):
        low_res = Image.open(f"{lr_input_dir}/original_low-{indices[0]}-{indices[1]}.jpg")
        high_res = Image.open(f"{hr_input_dir}/original_high-{indices[0]}-{indices[1]}.jpg")
        
        low_res = transforms.ToTensor()(low_res)
        low_res = min_max_norm(low_res)
        assert (low_res.shape[-1] % 8 == 0)
        down_scale_size = low_res.shape[-1] // 8
        low_res = downsample(low_res, (down_scale_size, down_scale_size))
    
        high_res = transforms.ToTensor()(high_res)
        high_res = min_max_norm(high_res)
        
        mse, psnr, ssim, lpips = main(
            low_res=low_res, 
            high_res=high_res, 
            id=i,
            num_iterations = 1000,
            reg_noise_std = 0.05,
            histories = histories,
            log_id = time, 
        )
        total_mse += mse
        total_psnr += psnr
        total_ssim += ssim
        total_lpips += lpips
avg_mse = total_mse / len(image_indices) / LOOPS
avg_psnr = total_psnr / len(image_indices) / LOOPS
avg_ssim = total_ssim / len(image_indices) / LOOPS
avg_lpips = total_lpips / len(image_indices) / LOOPS
print(f"Test Evaluaton: Avg MSE: {avg_mse} Avg PSNR {avg_psnr} Avg SSIM {avg_ssim} Avg LPIPS {avg_lpips}")

In [None]:
history = overall_histories[0][0]
plot_history(history)

In [None]:
import os
output_dir = f"./output/noise_{NOISE_LEVEL}_down_{DOWN_FACTOR}"
os.makedirs(output_dir, exist_ok = True)

def save_images(histories):
    fig, axes = plt.subplots(LOOPS,len(image_indices))
    for j in range(LOOPS):
        for i in range(len(image_indices)):
            history = overall_histories[j][i]
            output_path = f"{output_dir}/output_{i}.jpg"
            image = history[-1]
            axes[j][i].imshow(image)
            axes[j][i].axis("off")
    
            image = 255 * (image - np.min(image)) / (np.max(image) - np.min(image))
            image = image.astype('uint8')
            image = Image.fromarray(image)
            #image.save(output_path)
    


In [None]:
save_images(histories)