# Utils

> Utility functions


In [36]:
#| default_exp utils

In [37]:
#| hide
from nbdev.showdoc import *

In [38]:
#| export

import os
import cv2
from math import log10, exp
import random
import numpy as np

import torch 
import torch.nn.functional as F

from torchmetrics.functional.image import structural_similarity_index_measure as structural_similarity
from torchmetrics.functional.image import peak_signal_noise_ratio

In [39]:
#| export
def attributesFromDict(d):
    self = d.pop('self')
    for n, v in d.items():
        setattr(self, n, v)

### Compute attribute index

In [40]:
#| export

class compute_index():
    def __init__(self, codes, device='cpu') -> None:
        attributesFromDict(locals( ))
      
    def _compute_index(self, b, **kwargs):
        idx = torch.zeros([b], device=self.device, dtype=torch.float32)
        for key, value in self.codes.items():
            idx = idx * len(value)
            for i, v in enumerate(value):
                idx += torch.where(kwargs[key] == v, i, 0.0)

        return idx
    
    def __call__(self, b, **kwargs):
        return self._compute_index(b, **kwargs)

In [41]:
device = 'cpu'
codes = {
        'exposure-time': torch.tensor([10, 50, 100], dtype=torch.float32, device=device),
        'optical-setup': torch.tensor([0, 1], dtype=torch.float32).to(device),
        'camera': torch.tensor([0, 1], dtype=torch.float32).to(device)
    }
kwargs = {
        'exposure-time': torch.tensor([100], dtype=torch.float32).to(device),
        'optical-setup': torch.tensor([1], dtype=torch.float32).to(device),
        'camera': torch.tensor([0], dtype=torch.float32).to(device)
    }

fn = compute_index(codes)

print('index: ', fn(1, **kwargs))


index:  tensor([10.])


### One-Hot encoding

In [42]:
#| export

class compute_one_hot():
    def __init__(self, codes, device='cpu') -> None:
        attributesFromDict(locals( ))
      
    def _compute_one_hot(self, b, **kwargs):
        embedding = torch.tensor([])
        for key, value in self.codes.items():
            idx = torch.zeros([b], device=self.device, dtype=torch.float32)
            for i, v in enumerate(value):
                idx += torch.where(kwargs[key] == v, i, 0.0)
            idx_one_hot = F.one_hot(idx.to(torch.int64), num_classes=value.shape[0]).to(torch.float32)
            print(key, ': ', idx_one_hot)
            embedding = torch.cat((embedding, idx_one_hot), dim=1)

        return embedding
    
    def __call__(self, b, **kwargs):
        return self._compute_one_hot(b, **kwargs)

In [43]:
fn1hot = compute_one_hot(codes)

print('one hot encoding: ', fn1hot(1, **kwargs))

exposure-time :  tensor([[0., 0., 1.]])
optical-setup :  tensor([[0., 1.]])
camera :  tensor([[1., 0.]])
one hot encoding:  tensor([[0., 0., 1., 0., 1., 1., 0.]])


### Normal Distribution

In [44]:
#| export

import math
import torch
from torch import nn


In [45]:
#| export

class StandardNormal(nn.Module):
    """A multivariate Normal with zero mean and unit covariance."""

    def __init__(self):
        super(StandardNormal, self).__init__()
        self.register_buffer('buffer', torch.zeros(1))

    def log_prob(self, x):
        # https://www.statlect.com/fundamentals-of-statistics/normal-distribution-maximum-likelihood
        log_base =  - 0.5 * math.log(2 * math.pi)
        log_inner = - 0.5 * x**2
        return sum_except_batch(log_base+log_inner)

    def sample(self, shape):
        return torch.randn(*shape, device=self.buffer.device, dtype=self.buffer.dtype)

def sum_except_batch(x, num_dims=1):
    '''
    Sums all dimensions except the first.
    Args:
        x: Tensor, shape (batch_size, ...)
        num_dims: int, number of batch dims (default=1)
    Returns:
        x_sum: Tensor, shape (batch_size,)
    '''
    return x.reshape(*x.shape[:num_dims], -1).sum(-1)

## Base Utilities

In [46]:
#| export

def np2tensor(n:np.array):
    '''
    transform numpy array (image) to torch Tensor
    BGR -> RGB
    (h,w,c) -> (c,h,w)
    '''
    # gray
    if len(n.shape) == 2:
        n = np.expand_dims(n, axis=2)
        return torch.from_numpy(np.ascontiguousarray(np.transpose(n, (2,0,1))))
    # RGB -> BGR
    elif len(n.shape) == 3:
        return torch.from_numpy(np.ascontiguousarray(np.transpose(np.flip(n, axis=2), (2,0,1))))
    else:
        raise RuntimeError('wrong numpy dimensions : %s'%(n.shape,))
    


In [47]:
a = np.random.rand(4, 4)
assert np2tensor(a).type() == 'torch.DoubleTensor'

In [48]:
a = np.random.randint(0, high=255, size=(4,4))
assert np2tensor(a).type() == 'torch.LongTensor'

In [49]:
#| export

def np2tensor_multi(n:np.array):
    t = None
    if len(n) <= 1: # single stacked image
        t = np2tensor(n[0].astype(np.float32)).unsqueeze(0).float()
    else: # multi stacked image
        for mat in n:
            if t is None: t = np2tensor(mat.astype(np.float32)).unsqueeze(0).float()
            else: t = torch.cat([t, np2tensor(mat.astype(np.float32)).unsqueeze(0).float()], dim=0)
    return t


In [50]:
#| export

def tensor2np(t:torch.Tensor):
    '''
    transform torch Tensor to numpy having opencv image form.
    RGB -> BGR
    (c,h,w) -> (h,w,c)
    '''
    t = t.cpu().detach()

    # gray
    if len(t.shape) == 2:
        return t.permute(1,2,0).numpy()
    # RGB -> BGR
    elif len(t.shape) == 3:
        return np.flip(t.permute(1,2,0).numpy(), axis=2)
    # image batch
    elif len(t.shape) == 4:
        return np.flip(t.permute(0,2,3,1).numpy(), axis=3)
    else:
        raise RuntimeError('wrong tensor dimensions : %s'%(t.shape,))


In [51]:
#| export

def imwrite_tensor(t, name='test.png'):
    cv2.imwrite('./%s'%name, tensor2np(t.cpu()))

def imread_tensor(name='test'):
    return np2tensor(cv2.imread('./%s'%name))


In [52]:
#| export

def rot_hflip_img(img:torch.Tensor, rot_times:int=0, hflip:int=0):
    '''
    rotate '90 x times degree' & horizontal flip image 
    (shape of img: b,c,h,w or c,h,w)
    '''
    b=0 if len(img.shape)==3 else 1
    # no flip
    if hflip % 2 == 0:
        # 0 degrees
        if rot_times % 4 == 0:    
            return img
        # 90 degrees
        elif rot_times % 4 == 1:  
            return img.flip(b+1).transpose(b+1,b+2)
        # 180 degrees
        elif rot_times % 4 == 2:  
            return img.flip(b+2).flip(b+1)
        # 270 degrees
        else:               
            return img.flip(b+2).transpose(b+1,b+2)
    # horizontal flip
    else:
        # 0 degrees
        if rot_times % 4 == 0:    
            return img.flip(b+2)
        # 90 degrees
        elif rot_times % 4 == 1:  
            return img.flip(b+1).flip(b+2).transpose(b+1,b+2)
        # 180 degrees
        elif rot_times % 4 == 2:  
            return img.flip(b+1)
        # 270 degrees
        else:               
            return img.transpose(b+1,b+2)
   

In [53]:
#| export
     
def psnr(x, y, mask=None, max_val=1.):
    if max_val is None : max_val = 1.
    if mask is None:
        mse = torch.mean((x - y) ** 2)
    else:
        mse = torch.sum(((x - y) ** 2) * mask) / mask.sum() 
    return 10 * log10(max_val**2 / mse.item())


In [54]:
#| export

def ssim(img1, img2, data_range):
    '''
    image value range : [0 - data_range]
    clipping for model output
    '''
    if len(img1.shape) == 4:
        img1 = img1[0]
    if len(img2.shape) == 4:
        img2 = img2[0]

    # tensor to numpy
    if isinstance(img1, torch.Tensor):
        img1 = tensor2np(img1)
    if isinstance(img2, torch.Tensor):
        img2 = tensor2np(img2)

    # numpy value cliping
    img2 = np.clip(img2, 0, data_range)
    img1 = np.clip(img1, 0, data_range)

    # https://forum.image.sc/t/how-to-calculate-ssim-of-muti-channel-images-since-the-function-structural-similarity-deprecate-the-parameter-multichannel/79693
    return structural_similarity(img1, img2, channel_axis=-1, data_range=data_range)


In [55]:
#| export

class AverageMeter(object):
    """
    Computes and stores the average and current value.
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
    


In [56]:
#| export

def setup_determinism(seed):
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)



In [57]:
#| export

def get_gaussian_2d_filter(window_size, sigma, channel=1, device=torch.device('cpu')):
    '''
    return 2d gaussian filter window as tensor form
    Arg:
        window_size : filter window size
        sigma : standard deviation
    '''
    gauss = torch.ones(window_size, device=device)
    for x in range(window_size): gauss[x] = exp(-(x - window_size//2)**2/float(2*sigma**2))
    gauss = gauss.unsqueeze(1)
    #gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)], device=device).unsqueeze(1)
    filter2d = gauss.mm(gauss.t()).float()
    filter2d = (filter2d/filter2d.sum()).unsqueeze(0).unsqueeze(0)
    return filter2d.expand(channel, 1, window_size, window_size)



In [58]:
#| export

def get_mean_2d_filter(window_size, channel=1, device=torch.device('cpu')):
    '''
    return 2d mean filter as tensor form
    Args:
        window_size : filter window size
    '''
    window = torch.ones((window_size, window_size), device=device)
    window = (window/window.sum()).unsqueeze(0).unsqueeze(0)
    return window.expand(channel, 1, window_size, window_size)


In [59]:
#| export

def mean_conv2d(x, window_size=None, window=None, filter_type='gau', sigma=None, keep_sigma=False, padd=True):
    '''
    color channel-wise 2d mean or gaussian convolution
    Args:
        x : input image
        window_size : filter window size
        filter_type(opt) : 'gau' or 'mean'
        sigma : standard deviation of gaussian filter
    '''
    b_x = x.unsqueeze(0) if len(x.shape) == 3 else x

    if window is None:
        if sigma is None: sigma = (window_size-1)/6
        if filter_type == 'gau':
            window = get_gaussian_2d_filter(window_size, sigma=sigma, channel=b_x.shape[1], device=x.device)
        else:
            window = get_mean_2d_filter(window_size, channel=b_x.shape[1], device=x.device)
    else:
        window_size = window.shape[-1]

    if padd:
        pl = (window_size-1)//2
        b_x = F.pad(b_x, (pl,pl,pl,pl), 'reflect')

    m_b_x = F.conv2d(b_x, window, groups=b_x.shape[1])

    if keep_sigma:
        m_b_x /= (window**2).sum().sqrt()

    if len(x.shape) == 4:
        return m_b_x
    elif len(x.shape) == 3:
        return m_b_x.squeeze(0)
    else:
        raise ValueError('input image shape is not correct')
    


In [60]:
#| export

def get_file_name_from_path(path):
    if '/' in path : name = path.split('/')[-1].split('.')[:-1]
    elif '\\' in path: name = path.split('\\')[-1].split('.')[:-1]
    else: assert False, f'Invalid path: {path}'

    if isinstance(name, list):
        merged = ""
        for token in name[:-1]: 
            merged += token + '.'
        merged += name[-1]
        name = merged
    return name



In [62]:
#| export

def get_histogram(data, bin_edges=None, cnt_regr=1):
    n = np.prod(data.shape)	
    hist, _ = np.histogram(data, bin_edges)	
    return (hist + cnt_regr)/(n + cnt_regr * len(hist))



In [63]:
#| export

def kl_div_forward(p, q):
    assert (~(np.isnan(p) | np.isinf(p) | np.isnan(q) | np.isinf(q))).all()	
    idx = (p > 0)
    p = p[idx]
    q = q[idx]
    return np.sum(p * np.log(p / q))



In [61]:
#| export

def kl_div_3_data(real_noise, gen_noise, bin_edges=None, left_edge=0.0, right_edge=1.0):
    # Kousha, Shayan, et al. "Modeling srgb camera noise with normalizing flows." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022.
    # ref) https://github.com/SamsungLabs/Noise2NoiseFlow
    noise_pats = (gen_noise, real_noise)	

    # histograms
    bw = 4
    bin_edges = np.arange(left_edge, right_edge, bw)

    cnt_regr = 1
    hists = [None] * len(noise_pats)	
    klds = np.ndarray([len(noise_pats)])	
    klds[:] = 0.0

    for h in reversed(range(len(noise_pats))):
        hists[h] = get_histogram(noise_pats[h], bin_edges=bin_edges, cnt_regr=cnt_regr)
        klds[h] = kl_div_forward(hists[-1], hists[h])	

    return klds[0]



In [64]:
#| export

def load_numpy_from_raw(path, dtype='float32'):
    fid = open(path, "rb")
    return np.fromfile(fid, dtype=dtype)



In [65]:
#| export

def make_predefiend_1d_to_2d(arr):
    predefined_sizes = [(3072,2560), (3072,3072), (9216,3072), (6144,3072)] # H, W
    assert len(arr.shape) == 1
    for predefined_size in predefined_sizes:
        if arr.shape[0] == (predefined_size[0] * predefined_size[1]):
            arr = np.reshape(arr, predefined_size)
    assert len(arr.shape) == 2, "Error: No matching predefined size exists."
    return arr 



In [66]:
#| export

def save_img(dir_name, file_name, img):
    path = os.path.join(dir_name, file_name)
    if 'raw' in path[-3:]:
        os.makedirs(dir_name, exist_ok=True)
        with open(path, 'w') as fid:
            img.tofile(fid)
    else:
        if len(img.shape) == 3 and img.shape[-1] != 3 and img.shape[-1] > 1:
            cv2.imwritemulti(path, img.transpose([2,0,1])) # multi stack image, convert to CHW
        elif len(img.shape) == 4 and img.shape[0] > 1: # batch image, only grey image is available
            img = img.squeeze(-1)
            cv2.imwritemulti(path, img) 
        elif len(img.shape) == 4 and img.shape[0] <= 1: # single batch image
            img = img.squeeze(0)
            cv2.imwrite(path, img)
        else:
            cv2.imwrite(path, img)


## File Manager

In [67]:
#| export

class FileManager:
    def __init__(self, session_name, output_path=None):
        if output_path is None:
            self.output_folder = "./output"
        else:
            self.output_folder = output_path
            
        if not os.path.isdir(self.output_folder):
            os.makedirs(self.output_folder)
            print("[WARNING] output folder is not exist, create new one")

        # init session
        self.session_name = session_name
        os.makedirs(os.path.join(self.output_folder, self.session_name), exist_ok=True)

        # mkdir
        for directory in ['checkpoint', 'img']:
            self.make_dir(directory)

    def is_dir_exist(self, dir_name:str) -> bool:
        return os.path.isdir(os.path.join(self.output_folder, self.session_name, dir_name))

    def make_dir(self, dir_name:str) -> str:
        os.makedirs(os.path.join(self.output_folder, self.session_name, dir_name), exist_ok=True) 

    def get_dir(self, dir_name:str) -> str:
        # -> './output/<session_name>/dir_name'
        return os.path.join(self.output_folder, self.session_name, dir_name)

    def save_img_tensor(self, dir_name:str, file_name:str, img:torch.Tensor, ext='png'):
        self.save_img_numpy(dir_name, file_name, tensor2np(img), ext)

    def save_img_numpy(self, dir_name:str, file_name:str, img:np.array, ext='png'):
        if np.shape(img)[2] == 1:
            save_img(self.get_dir(dir_name), '%s.%s'%(file_name, ext), np.squeeze(img, 2))
        else:
            save_img(self.get_dir(dir_name), '%s.%s'%(file_name, ext), img)
    

In [68]:
#| hide
import nbdev; nbdev.nbdev_export()