In [1]:
import torch
from tqdm import tqdm
from network.conv_node import NODE
from misc import *
import os
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from torchmetrics.multimodal import CLIPImageQualityAssessment

Model and Dataset (LOL train 485 images)

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = NODE(device, (3, 256, 256), 32, augment_dim=0, time_dependent=True, adjoint=True)
model.eval()
model.to(device)
model.load_state_dict(torch.load(f'pth/universal.pth', weights_only=True), strict=False)

file_path = Path('/data/soom/lol_dataset/our485')
img_labels = sorted(os.listdir(file_path / 'low'))

def load_image(idx):
    lq_img = image_tensor(file_path / 'low' / img_labels[idx], size=(256, 256))
    gt_img = image_tensor(file_path / 'high' / img_labels[idx], size=(256, 256))
    
    return lq_img.to(device), gt_img.to(device)


Try find best T values by PSNR

In [4]:
T_values = np.linspace(2, 5, 30)

results = []
with torch.no_grad():
    for idx in tqdm(range(len(img_labels))):
        lq_img, gt_img = load_image(idx)

        high_psnr = 0.0
        best_T = 2.0
        for T in tqdm(T_values, leave=False):
            integration_time = torch.tensor([0, T]).float().cuda()
            pred = model(lq_img, integration_time, inference=True)['output'][0]
            
            _psnr = calculate_psnr(pred, gt_img).item()
            if high_psnr < _psnr:
                high_psnr = _psnr
                best_T = T
        results.append([best_T, high_psnr])

  9%|▉         | 46/485 [37:57<6:02:18, 49.52s/it]


KeyboardInterrupt: 

Adapt $\alpha$, $\beta$, $\gamma$ so that weighted IQA score approximates score with optimal T

In [None]:
prompts = ['brightness', 'noisiness', 'quality']
weights = [1.0, 1.0, 1.0]
clip_iqa = CLIPImageQualityAssessment(prompts=prompts).to(device)

learning_rate = 0.1
asc, desc = learning_rate * len(prompts) / (len(prompts) -1), learning_rate / (len(prompts) - 1)

def adjust_clip_weights(pred, weights):
    score = clip_iqa(pred.unsqueeze(0))
    scores = [score[prompt].item() for prompt in prompts]
    max_idx = np.argmax(scores)
    
    weights[max_idx] += asc
    weights -= desc
    
    return weights

with torch.no_grad():
    for idx in tqdm(range(len(img_labels))):
        lq_img, gt_img = load_image(idx)

        T = results[idx][0]
        integration_time = torch.tensor([0, T]).float().cuda()
        pred = model(lq_img, integration_time, inference=True)['output'][0]
            
        weights = adjust_clip_weights(pred, weights)
