In [1]:
import skimage
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import time
import argparse
import cv2
from scipy import io
from tqdm.notebook import tqdm

import torch
from torch import nn
import torch.nn.functional as F
import torchvision.models as models
import torch.optim.lr_scheduler as lr_scheduler
from pytorch_msssim import ssim

from modules import utils
from modules.models import INR

In [12]:
parser = argparse.ArgumentParser(description='INCODE')

# Shared Parameters
parser.add_argument('--input',type=str, default='./incode_data/Image/0882.png', help='Input image path')
parser.add_argument('--inr_model',type=str, default='parac', help='[gauss, mfn, relu, siren, wire, wire2d, ffn, incode, parac]')
parser.add_argument('--lr',type=float, default=9e-4, help='Learning rate')
parser.add_argument('--using_schedular', type=bool, default=True, help='Whether to use schedular')
parser.add_argument('--scheduler_b', type=float, default=0.1, help='Learning rate scheduler')
parser.add_argument('--maxpoints', type=int, default=16*16, help='Batch size')
parser.add_argument('--niters', type=int, default=501, help='Number if iterations')
parser.add_argument('--steps_til_summary', type=int, default=100, help='Number of steps till summary visualization')
parser.add_argument('--upscale_factor', type=int, default=4, help='Upscale factor for super-resolution (e.g., 4x larger output)')
parser.add_argument('--eval_epoch', type=int, default=400, help='HR evaluation epoch')

# INCODE Parameters
parser.add_argument('--a_coef',type=float, default=0.1993, help='a coeficient')
parser.add_argument('--b_coef',type=float, default=0.0196, help='b coeficient')
parser.add_argument('--c_coef',type=float, default=0.0588, help='c coeficient')
parser.add_argument('--d_coef',type=float, default=0.0269, help='d coeficient')


args = parser.parse_args(args=[])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

## Loading Data

In [13]:
im_hr = utils.normalize(plt.imread(args.input).astype(np.float32), True)
im_lr = cv2.resize(im_hr, None, fx=1/args.upscale_factor, fy=1/args.upscale_factor, interpolation=cv2.INTER_AREA)
H_hr, W_hr, _ = im_hr.shape
H_lr, W_lr, _ = im_lr.shape

## Defining Model

### Defining desired Positional Encoding

In [14]:
# Frequency Encoding
pos_encode_freq = {'type':'frequency', 'use_nyquist': True, 'mapping_input': int(max(H_lr, W_lr))}

# Gaussian Encoding
pos_encode_gaus = {'type':'gaussian', 'scale_B': 10, 'mapping_input': 256}

# No Encoding
pos_encode_no = {'type': None}

### Model Configureations

In [15]:
### Harmonizer Configurations
MLP_configs={'task': 'image',
             'model': 'resnet34',
             'truncated_layer':5,
             'in_channels': 64,             
             'hidden_channels': [64, 32, 4],
             'mlp_bias':0.3120,
             'activation_layer': nn.SiLU,
             'GT': torch.tensor(im_lr).to(device)[None,...].permute(0, 3, 1, 2)
            }

### Model Configurations
model = INR(args.inr_model).run(in_features=2,
                                out_features=3, 
                                hidden_features=256,
                                hidden_layers=3,
                                first_omega_0=30.0,
                                hidden_omega_0=30.0,
                                pos_encode_configs=pos_encode_no, 
                                MLP_configs = MLP_configs
                               ).to(device)

In [16]:
# ### Model Configurations for parac
# model = INR(args.inr_model).run(in_features=3,
#                                 out_features=3, 
#                                 hidden_features=256,
#                                 hidden_layers=3,
#                                 first_omega_0=30.0,
#                                 hidden_omega_0=30.0
#                                ).to(device)

## Training Code

In [17]:
# Optimizer setup
if args.inr_model == 'wire':
    args.lr = args.lr * min(1, args.maxpoints / (H * W))
optim = torch.optim.Adam(lr=args.lr, params=model.parameters())
scheduler = lr_scheduler.LambdaLR(optim, lambda x: args.scheduler_b ** min(x / args.niters, 1))

# Initialize lists for PSNR and SSIM
psnr_values_lr = []
psnr_values_hr = []
ssim_values_hr = []
mse_array = torch.zeros(args.niters, device=device)

# Initialize best loss value as positive infinity
best_loss = torch.tensor(float('inf'))

# Generate coordinate grid
coords_lr = utils.get_coords(H_lr, W_lr, dim=2)[None, ...]
coords_hr = utils.get_coords(H_hr, W_hr, dim=2)[None, ...]

# Convert input image to a tensor and reshape
gt_lr = torch.tensor(im_lr).reshape(H_lr * W_lr, 3)[None, ...].to(device)
gt_hr = torch.tensor(im_hr).reshape(H_hr * W_hr, 3)[None, ...].to(device)

# Initialize a tensor for reconstructed data
rec_lr = torch.zeros_like(gt_lr)
rec_hr = torch.zeros_like(gt_hr)

In [18]:
for step in tqdm(range(args.niters)):
    # Randomize the order of data points for each iteration
    indices = torch.randperm(H_lr*W_lr)

    # Process data points in batches
    for b_idx in range(0, H_lr*W_lr, args.maxpoints):
        b_indices = indices[b_idx:min(H_lr*W_lr, b_idx+args.maxpoints)]
        b_coords = coords_lr[:, b_indices, ...].to(device)
        b_indices = b_indices.to(device)
        
        # Calculate model output
        if args.inr_model == 'incode':
            model_output, coef = model(b_coords)  
        else:
            model_output = model(b_coords) 

        # Update the reconstructed data
        with torch.no_grad():
            rec_lr[:, b_indices, :] = model_output

        # Calculate the output loss
        output_loss = ((model_output - gt_lr[:, b_indices, :])**2).mean()
        
        if args.inr_model == 'incode':
            # Calculate regularization loss for 'incode' model
            a_coef, b_coef, c_coef, d_coef = coef[0]  
            reg_loss = args.a_coef * torch.relu(-a_coef) + \
                       args.b_coef * torch.relu(-b_coef) + \
                       args.c_coef * torch.relu(-c_coef) + \
                       args.d_coef * torch.relu(-d_coef)

            # Total loss for 'incode' model
            loss = output_loss + reg_loss 
        else: 
            # Total loss for other models
            loss = output_loss

        # Perform backpropagation and update model parameters
        optim.zero_grad()
        loss.backward()
        optim.step()
    
    
    # Calculate and log mean squared error (MSE) and PSNR
    with torch.no_grad():
        mse_array[step] = ((gt_lr - rec_lr)**2).mean().item()
        psnr_lr = -10*torch.log10(mse_array[step])
        psnr_values_lr.append(psnr_lr.item())
        
        #### HR Evaluation
        if step > args.eval_epoch:
            indices_hr = torch.randperm(H_hr*W_hr)
            for b_idx in range(0, H_hr*W_hr, args.maxpoints):
                b_indices_hr = indices_hr[b_idx:min(H_hr*W_hr, b_idx+args.maxpoints)]
                b_coords_hr = coords_hr[:, b_indices_hr, ...].to(device)
                b_indices_hr = b_indices_hr.to(device)

                if args.inr_model == 'incode':
                    model_eval, _ = model(b_coords_hr)  
                else:
                    model_eval = model(b_coords_hr) 
                    
                rec_hr[:, b_indices_hr, :] = model_eval
            
            loss_hr = ((gt_hr - rec_hr)**2).mean()
            psnr_hr = -10*torch.log10(loss_hr)
            psnr_values_hr.append(psnr_hr.item())
            hr_pred = rec_hr[0, ...].reshape(H_hr, W_hr, 3).detach().cpu().numpy()
            hr_pred = (hr_pred - hr_pred.min()) / (hr_pred.max() - hr_pred.min())

            # Check if the current iteration's HR image is the best so far
            if (loss_hr < best_loss) or (step == args.eval_epoch+1):
                best_loss = loss_hr
                best_img_hr = hr_pred
                best_img_lr = rec_lr[0, ...].reshape(H_lr, W_lr, 3).detach().cpu().numpy()
                best_img_lr = (best_img_lr - best_img_lr.min()) / (best_img_lr.max() - best_img_lr.min())
                
                ### Plot
                fig, axes = plt.subplots(1, 4, figsize=(9, 9))
                subplot_info = [
                    {'title': 'GT HR', 'image': im_hr},
                    {'title': 'HR Image', 'image': best_img_hr},
                    {'title': 'GT LR', 'image': im_lr},
                    {'title': 'LR Image', 'image': best_img_lr}
                ]

                for ax, info in zip(axes, subplot_info):
                    ax.set_title(info['title'])
                    ax.imshow(info['image'], cmap='gray')
                    ax.axis('off')
                plt.show()

            
            # SSIM
            ms_ssim_val = ssim(torch.tensor(im_hr[None,...]).permute(0, 3, 1, 2),
                                torch.tensor(hr_pred[None, ...]).permute(0, 3, 1, 2),
                                data_range=1, size_average=False)
            ssim_values_hr.append(ms_ssim_val[0].item())
            
            # Display intermediate results at specified intervals
            print("Epoch: {} | Total Loss: {:.5f} | PSNR LR: {:.4f} | PSNR HR: {:.4f} | SSIM: {:.4f}".format(step, 
                                                                                                  mse_array[step].item(),
                                                                                                  psnr_lr.item(),
                                                                                                  psnr_hr.item(),
                                                                                                  ms_ssim_val[0].item())) 
            
        # Display intermediate results at specified intervals
        if (step % args.steps_til_summary == 0) and step <= args.eval_epoch:
            print("Epoch: {} | Total Loss: {:.5f} | PSNR LR: {:.4f}".format(step, 
                                                                             mse_array[step].item(),
                                                                             psnr_lr.item())) 
    
    # Adjust learning rate using a sch.duler if applicable
    if args.using_schedular:
        scheduler.step()

        
# Print maximum PSNR achieved during training
print('--------------------')
print('Max PSNR LR:', max(psnr_values_lr))
print('Max PSNR HR:', max(psnr_values_hr))
print('--------------------')

  0%|          | 0/501 [00:00<?, ?it/s]

Epoch: 0 | Total Loss: 0.07180 | PSNR LR: 11.4389


KeyboardInterrupt: 

# Convergance Rate

In [None]:
## PSNR LR vs. #Epochs
## PSNR HR vs. #Epochs
## SSIM vs. #Epochs


# Define the font settings
font = {'font': 'Times New Roman', 'size': 12}
axfont = {'family': 'Times New Roman', 'weight': 'regular', 'size': 10}

# Create a figure with 3 subplots
fig, axes = plt.subplots(1, 3, figsize=(12, 3))

# Plot PSNR LR vs. #Epochs
axes[0].plot(np.arange(len(psnr_values_lr[:-1])), psnr_values_lr[:-1], label=f"{(args.inr_model).upper()}")
axes[0].set_xlabel('# Epochs', fontdict=font)
axes[0].set_ylabel('PSNR (dB)', fontdict=font)
axes[0].set_title('PSNR LR vs. #Epochs', fontdict={'family': 'Times New Roman', 'size': 12, 'weight': 'bold'})
axes[0].legend()
axes[0].grid(True, color='lightgray')

# Plot PSNR HR vs. #Epochs
axes[1].plot(np.arange(len(psnr_values_hr[:-1])), psnr_values_hr[:-1], label=f"{(args.inr_model).upper()}", color='black')
axes[1].set_xlabel('# Epochs', fontdict=font)
axes[1].set_ylabel('PSNR (dB)', fontdict=font)
axes[1].set_title('PSNR HR vs. #Epochs', fontdict={'family': 'Times New Roman', 'size': 12, 'weight': 'bold'})
axes[1].legend()
axes[1].grid(True, color='lightgray')

# Plot SSIM vs. #Epochs
axes[2].plot(np.arange(len(ssim_values_hr[:-1])), ssim_values_hr[:-1], label=f"{(args.inr_model).upper()}", color='red')
axes[2].set_xlabel('# Epochs', fontdict=font)
axes[2].set_ylabel('SSIM', fontdict=font)
axes[2].set_title('SSIM vs. #Epochs', fontdict={'family': 'Times New Roman', 'size': 12, 'weight': 'bold'})
axes[2].legend()
axes[2].grid(True, color='lightgray')

# Adjust spacing between subplots
plt.tight_layout()

# Show the plot
plt.show()