# Utils

> Utility functions


In [None]:
#| default_exp utils

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export

import os
import time
import datetime
import cv2
from math import log10, exp, log, pi
import random
import numpy as np

import torch 
from torch import nn
import torch.nn.functional as F

from torchmetrics.functional.image import structural_similarity_index_measure as structural_similarity
from torchmetrics.functional.image import peak_signal_noise_ratio

from fastai.vision.all import store_attr

import logging

### Compute attribute index

In [None]:
#| export
class ComputeIndex:
    def __init__(self, codes) -> None:
        self.codes = codes

    def _compute_index(self, b, **kwargs):
        # Dynamically detect device from any input tensor
        device = next(iter(kwargs.values())).device

        idx = torch.zeros([b], device=device, dtype=torch.float32)
        for key, value in self.codes.items():
            idx = idx * len(value)
            for i, v in enumerate(value):
                idx += torch.where(kwargs[key] == v, i, 0.0)
        return idx

    def __call__(self, b, **kwargs):
        return self._compute_index(b, **kwargs)


In [None]:
device = 'cuda:0'
codes = {
        'exposure-time': torch.tensor([10, 50, 100], dtype=torch.float32, device=device),
        'optical-setup': torch.tensor([0, 1], dtype=torch.float32).to(device),
        'camera': torch.tensor([0, 1], dtype=torch.float32).to(device)
    }
kwargs = {
        'exposure-time': torch.tensor([100], dtype=torch.float32).to(device),
        'optical-setup': torch.tensor([1], dtype=torch.float32).to(device),
        'camera': torch.tensor([0], dtype=torch.float32).to(device)
    }

fn = ComputeIndex(codes)

print('index: ', fn(1, **kwargs))


index:  tensor([10.], device='cuda:0')


### One-Hot encoding

In [None]:
#| export

class ComputeOneHot:
    def __init__(self, codes) -> None:
        self.codes = codes

    def _compute_one_hot(self, b, **kwargs):
        # get device dynamically
        device = next(iter(kwargs.values())).device

        embedding = torch.tensor([], device=device)
        for key, value in self.codes.items():
            idx = torch.zeros([b], device=device, dtype=torch.float32)
            for i, v in enumerate(value):
                idx += torch.where(kwargs[key] == v, i, 0.0)
            idx_one_hot = F.one_hot(idx.to(torch.int64), num_classes=len(value)).to(torch.float32)
            embedding = torch.cat((embedding, idx_one_hot), dim=1)
        return embedding

    def __call__(self, b, **kwargs):
        return self._compute_one_hot(b, **kwargs)


In [None]:
fn1hot = ComputeOneHot(codes)

print('one hot encoding: ', fn1hot(1, **kwargs))

one hot encoding:  tensor([[0., 0., 1., 0., 1., 1., 0.]], device='cuda:0')


### Normal Distribution

In [None]:
#| export

class StandardNormal(nn.Module):
    """A multivariate Normal with zero mean and unit covariance."""

    def __init__(self):
        super(StandardNormal, self).__init__()
        self.register_buffer('buffer', torch.zeros(1))

    def log_prob(self, x):
        # https://www.statlect.com/fundamentals-of-statistics/normal-distribution-maximum-likelihood
        log_base =  - 0.5 * log(2 * pi)
        log_inner = - 0.5 * x**2
        return sum_except_batch(log_base+log_inner)

    def sample(self, shape):
        return torch.randn(*shape, device=self.buffer.device, dtype=self.buffer.dtype)
        # return torch.rand(*shape, device=self.buffer.device, dtype=self.buffer.dtype)

def sum_except_batch(x, num_dims=1):
    '''
    Sums all dimensions except the first.
    Args:
        x: Tensor, shape (batch_size, ...)
        num_dims: int, number of batch dims (default=1)
    Returns:
        x_sum: Tensor, shape (batch_size,)
    '''
    return x.reshape(*x.shape[:num_dims], -1).sum(-1)

## Base Utilities

In [None]:
#| export

def np2tensor(n:np.array):
    '''
    transform numpy array (image) to torch Tensor
    BGR -> RGB
    (h,w,c) -> (c,h,w)
    '''
    # gray
    if len(n.shape) == 2:
        n = np.expand_dims(n, axis=2)
        return torch.from_numpy(np.ascontiguousarray(np.transpose(n, (2,0,1))))
    # RGB -> BGR
    elif len(n.shape) == 3:
        return torch.from_numpy(np.ascontiguousarray(np.transpose(np.flip(n, axis=2), (2,0,1))))
    else:
        raise RuntimeError('wrong numpy dimensions : %s'%(n.shape,))
    


In [None]:
a = np.random.rand(4, 4)
assert np2tensor(a).type() == 'torch.DoubleTensor'

In [None]:
a = np.random.randint(0, high=255, size=(4,4))
assert np2tensor(a).type() == 'torch.LongTensor'

In [None]:
#| export

def np2tensor_multi(n:np.array):
    t = None
    if len(n) <= 1: # single stacked image
        t = np2tensor(n[0].astype(np.float32)).unsqueeze(0).float()
    else: # multi stacked image
        for mat in n:
            if t is None: t = np2tensor(mat.astype(np.float32)).unsqueeze(0).float()
            else: t = torch.cat([t, np2tensor(mat.astype(np.float32)).unsqueeze(0).float()], dim=0)
    return t


In [None]:
#| export

def tensor2np(t:torch.Tensor):
    '''
    transform torch Tensor to numpy having opencv image form.
    RGB -> BGR
    (c,h,w) -> (h,w,c)
    '''
    t = t.cpu().detach()

    # gray
    if len(t.shape) == 2:
        return t.permute(1,2,0).numpy()
    # RGB -> BGR
    elif len(t.shape) == 3:
        return np.flip(t.permute(1,2,0).numpy(), axis=2)
    # image batch
    elif len(t.shape) == 4:
        return np.flip(t.permute(0,2,3,1).numpy(), axis=3)
    else:
        raise RuntimeError('wrong tensor dimensions : %s'%(t.shape,))


In [None]:
#| export

def imwrite_tensor(t, name='test.png'):
    cv2.imwrite('./%s'%name, tensor2np(t.cpu()))

def imread_tensor(name='test'):
    return np2tensor(cv2.imread('./%s'%name))


In [None]:
#| export

def rot_hflip_img(img:torch.Tensor, rot_times:int=0, hflip:int=0):
    '''
    rotate '90 x times degree' & horizontal flip image 
    (shape of img: b,c,h,w or c,h,w)
    '''
    b=0 if len(img.shape)==3 else 1
    # no flip
    if hflip % 2 == 0:
        # 0 degrees
        if rot_times % 4 == 0:    
            return img
        # 90 degrees
        elif rot_times % 4 == 1:  
            return img.flip(b+1).transpose(b+1,b+2)
        # 180 degrees
        elif rot_times % 4 == 2:  
            return img.flip(b+2).flip(b+1)
        # 270 degrees
        else:               
            return img.flip(b+2).transpose(b+1,b+2)
    # horizontal flip
    else:
        # 0 degrees
        if rot_times % 4 == 0:    
            return img.flip(b+2)
        # 90 degrees
        elif rot_times % 4 == 1:  
            return img.flip(b+1).flip(b+2).transpose(b+1,b+2)
        # 180 degrees
        elif rot_times % 4 == 2:  
            return img.flip(b+1)
        # 270 degrees
        else:               
            return img.transpose(b+1,b+2)
   

In [None]:
#| export
     
def psnr(x, y, mask=None, max_val=1.):
    if max_val is None : max_val = 1.
    if mask is None:
        mse = torch.mean((x - y) ** 2)
    else:
        mse = torch.sum(((x - y) ** 2) * mask) / mask.sum() 
    return 10 * log10(max_val**2 / mse.item())


In [None]:
#| export

def ssim(img1, img2, data_range):
    '''
    image value range : [0 - data_range]
    clipping for model output
    '''
    if len(img1.shape) == 4:
        img1 = img1[0]
    if len(img2.shape) == 4:
        img2 = img2[0]

    # tensor to numpy
    if isinstance(img1, torch.Tensor):
        img1 = tensor2np(img1)
    if isinstance(img2, torch.Tensor):
        img2 = tensor2np(img2)

    # numpy value cliping
    img2 = np.clip(img2, 0, data_range)
    img1 = np.clip(img1, 0, data_range)

    # https://forum.image.sc/t/how-to-calculate-ssim-of-muti-channel-images-since-the-function-structural-similarity-deprecate-the-parameter-multichannel/79693
    return structural_similarity(img1, img2, channel_axis=-1, data_range=data_range)


In [None]:
#| export

class AverageMeter(object):
    """
    Computes and stores the average and current value.
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
    


In [None]:
#| export

def setup_determinism(seed):
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)



In [None]:
#| export

def get_gaussian_2d_filter(window_size, sigma, channel=1, device=torch.device('cpu')):
    '''
    return 2d gaussian filter window as tensor form
    Arg:
        window_size : filter window size
        sigma : standard deviation
    '''
    gauss = torch.ones(window_size, device=device)
    for x in range(window_size): gauss[x] = exp(-(x - window_size//2)**2/float(2*sigma**2))
    gauss = gauss.unsqueeze(1)
    #gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)], device=device).unsqueeze(1)
    filter2d = gauss.mm(gauss.t()).float()
    filter2d = (filter2d/filter2d.sum()).unsqueeze(0).unsqueeze(0)
    return filter2d.expand(channel, 1, window_size, window_size)



In [None]:
#| export

def get_mean_2d_filter(window_size, channel=1, device=torch.device('cpu')):
    '''
    return 2d mean filter as tensor form
    Args:
        window_size : filter window size
    '''
    window = torch.ones((window_size, window_size), device=device)
    window = (window/window.sum()).unsqueeze(0).unsqueeze(0)
    return window.expand(channel, 1, window_size, window_size)


In [None]:
#| export

def mean_conv2d(x, window_size=None, window=None, filter_type='gau', sigma=None, keep_sigma=False, padd=True):
    '''
    color channel-wise 2d mean or gaussian convolution
    Args:
        x : input image
        window_size : filter window size
        filter_type(opt) : 'gau' or 'mean'
        sigma : standard deviation of gaussian filter
    '''
    b_x = x.unsqueeze(0) if len(x.shape) == 3 else x

    if window is None:
        if sigma is None: sigma = (window_size-1)/6
        if filter_type == 'gau':
            window = get_gaussian_2d_filter(window_size, sigma=sigma, channel=b_x.shape[1], device=x.device)
        else:
            window = get_mean_2d_filter(window_size, channel=b_x.shape[1], device=x.device)
    else:
        window_size = window.shape[-1]

    if padd:
        pl = (window_size-1)//2
        b_x = F.pad(b_x, (pl,pl,pl,pl), 'reflect')

    m_b_x = F.conv2d(b_x, window, groups=b_x.shape[1])

    if keep_sigma:
        m_b_x /= (window**2).sum().sqrt()

    if len(x.shape) == 4:
        return m_b_x
    elif len(x.shape) == 3:
        return m_b_x.squeeze(0)
    else:
        raise ValueError('input image shape is not correct')
    


In [None]:
#| export

def get_file_name_from_path(path):
    if '/' in path : name = path.split('/')[-1].split('.')[:-1]
    elif '\\' in path: name = path.split('\\')[-1].split('.')[:-1]
    else: assert False, f'Invalid path: {path}'

    if isinstance(name, list):
        merged = ""
        for token in name[:-1]: 
            merged += token + '.'
        merged += name[-1]
        name = merged
    return name



In [None]:
#| export

def get_histogram(data, bin_edges=None, cnt_regr=1):
    n = np.prod(data.shape)	
    hist, _ = np.histogram(data, bin_edges)	
    return (hist + cnt_regr)/(n + cnt_regr * len(hist))



In [None]:
#| export

def kl_div_forward(p, q):
    assert (~(np.isnan(p) | np.isinf(p) | np.isnan(q) | np.isinf(q))).all()	
    idx = (p > 0)
    p = p[idx]
    q = q[idx]
    return np.sum(p * np.log(p / q))



In [None]:
#| export

def kl_div_3_data(real_noise, gen_noise, bin_edges=None, left_edge=0.0, right_edge=1.0):
    # Kousha, Shayan, et al. "Modeling srgb camera noise with normalizing flows." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022.
    # ref) https://github.com/SamsungLabs/Noise2NoiseFlow
    noise_pats = (gen_noise, real_noise)	

    # histograms
    bw = 4
    bin_edges = np.arange(left_edge, right_edge, bw)

    cnt_regr = 1
    hists = [None] * len(noise_pats)	
    klds = np.ndarray([len(noise_pats)])	
    klds[:] = 0.0

    for h in reversed(range(len(noise_pats))):
        hists[h] = get_histogram(noise_pats[h], bin_edges=bin_edges, cnt_regr=cnt_regr)
        klds[h] = kl_div_forward(hists[-1], hists[h])	

    return klds[0]



In [None]:
#| export

def load_numpy_from_raw(path, dtype='float32'):
    fid = open(path, "rb")
    return np.fromfile(fid, dtype=dtype)



In [None]:
#| export

def make_predefiend_1d_to_2d(arr):
    predefined_sizes = [(3072,2560), (3072,3072), (9216,3072), (6144,3072)] # H, W
    assert len(arr.shape) == 1
    for predefined_size in predefined_sizes:
        if arr.shape[0] == (predefined_size[0] * predefined_size[1]):
            arr = np.reshape(arr, predefined_size)
    assert len(arr.shape) == 2, "Error: No matching predefined size exists."
    return arr 



In [None]:
#| export

def save_img(dir_name, file_name, img):
    path = os.path.join(dir_name, file_name)
    if 'raw' in path[-3:]:
        os.makedirs(dir_name, exist_ok=True)
        with open(path, 'w') as fid:
            img.tofile(fid)
    else:
        if len(img.shape) == 3 and img.shape[-1] != 3 and img.shape[-1] > 1:
            cv2.imwritemulti(path, img.transpose([2,0,1])) # multi stack image, convert to CHW
        elif len(img.shape) == 4 and img.shape[0] > 1: # batch image, only grey image is available
            img = img.squeeze(-1)
            cv2.imwritemulti(path, img) 
        elif len(img.shape) == 4 and img.shape[0] <= 1: # single batch image
            img = img.squeeze(0)
            cv2.imwrite(path, img)
        else:
            cv2.imwrite(path, img)


## File Manager

In [None]:
#| export

class FileManager:
    def __init__(self, session_name, output_path=None):
        if output_path is None:
            self.output_folder = "./output"
        else:
            self.output_folder = output_path
            
        if not os.path.isdir(self.output_folder):
            os.makedirs(self.output_folder)
            print("[WARNING] output folder does not exist, creating a new one")

        # init session
        self.session_name = session_name
        os.makedirs(os.path.join(self.output_folder, self.session_name), exist_ok=True)

        # mkdir
        for directory in ['checkpoint', 'img']:
            self.make_dir(directory)

    def is_dir_exist(self, dir_name:str) -> bool:
        return os.path.isdir(os.path.join(self.output_folder, self.session_name, dir_name))

    def make_dir(self, dir_name:str) -> str:
        os.makedirs(os.path.join(self.output_folder, self.session_name, dir_name), exist_ok=True) 

    def get_dir(self, dir_name:str) -> str:
        # -> './output/<session_name>/dir_name'
        return os.path.join(self.output_folder, self.session_name, dir_name)

    def save_img_tensor(self, dir_name:str, file_name:str, img:torch.Tensor, ext='png'):
        self.save_img_numpy(dir_name, file_name, tensor2np(img), ext)

    def save_img_numpy(self, dir_name:str, file_name:str, img:np.array, ext='png'):
        if np.shape(img)[2] == 1:
            save_img(self.get_dir(dir_name), '%s.%s'%(file_name, ext), np.squeeze(img, 2))
        else:
            save_img(self.get_dir(dir_name), '%s.%s'%(file_name, ext), img)
    

## Logging

### Progress Message

In [None]:
#| export

class ProgressMsg():
    def __init__(self, max_iter, min_time_interval=0.1):
        '''
        Args:
            max_iter : (max_epoch, max_data_length, ...)
            min_time_interval (second)
        '''
        self.max_iter = max_iter
        self.min_time_interval = min_time_interval

        self.start_time = time.time()
        self.progress_time = self.start_time

    def start(self, start_iter):

        assert len(self.max_iter) == len(start_iter), 'start_iter should have same length with max variable.'

        self.start_iter = start_iter
        self.current_iter = start_iter
        self.start_time = time.time()
        self.progress_time = self.start_time

    def calculate_progress(self, current_iter):
        self.progress_time = time.time()

        assert len(self.max_iter) == len(current_iter), 'current should have same length with max variable.'

        for i in range(len(self.max_iter)):
            assert current_iter[i] <= self.max_iter[i], 'current value should be less than max value.'

        start_per = 0
        for i in reversed(range(len(self.max_iter))):
            start_per += self.start_iter[i]
            start_per /= self.max_iter[i]
        start_per *= 100

        pg_per = 0
        for i in reversed(range(len(self.max_iter))):
            pg_per += current_iter[i]
            pg_per /= self.max_iter[i]
        pg_per *= 100

        pg_per = (pg_per-start_per) / (100-start_per) * 100

        if pg_per != 0:
            elapsed = time.time() - self.start_time
            total = 100*elapsed/pg_per
            remain = total - elapsed
            elapsed_str = str(datetime.timedelta(seconds=int(elapsed)))
            remain_str = str(datetime.timedelta(seconds=int(remain)))
            total_str = str(datetime.timedelta(seconds=int(total)))
        else:
            elapsed = time.time() - self.start_time
            elapsed_str = str(datetime.timedelta(seconds=int(elapsed)))
            remain_str = 'INF'
            total_str = 'INF'

        return pg_per, elapsed_str, remain_str, total_str

    def print_prog_msg(self, current_iter):
        if time.time() - self.progress_time >= self.min_time_interval:
            pg_per, elapsed_str, remain_str, total_str = self.calculate_progress(current_iter)

            txt = '\033[K>>> progress : %.2f%%, elapsed: %s, remaining: %s, total: %s \t\t\t\t\t' % (pg_per, elapsed_str, remain_str, total_str)

            print(txt, end='\r')

            return txt.replace('\t', '')
        return

    def get_start_msg(self):
        return 'Start >>>'

    def get_finish_msg(self):
        total = time.time() - self.start_time
        total_str = str(datetime.timedelta(seconds=int(total)))
        txt = 'Finish >>> (total elapsed time : %s)' % total_str
        return txt

        

In [None]:
logging.basicConfig(
        format='%(message)s',
        level=logging.INFO,
        handlers=[logging.StreamHandler()]
        )

min_time = 1
max_iter = 1

pp = ProgressMsg((max_iter,min_time))
ss = (0, 0)

pp.start(ss)

for i in range(0, max_iter):
    for j in range(max_iter):
        for k in range(max_iter):
            time.sleep(0.5)
            pp.print_prog_msg((i, j))
        logging.info('ttt')
            

ttt


[K>>> progress : 0.00%, elapsed: 0:00:00, remaining: INF, total: INF 					

### Logger

In [None]:
#| export

class Logger(ProgressMsg):
    def __init__(self, max_iter:tuple=None, log_dir:str=None, log_file_option:str='w', log_lvl:str='note', log_file_lvl:str='info', log_include_time:bool=True):
        '''
        Args:
            session_name (str)
            max_iter (tuple) : max iteration for progress
            log_dir (str) : if None, no file out for logging
            log_file_option (str) : 'w' or 'a'
            log_lvl (str) : 'debug' < 'note' < 'info' < 'highlight' < 'val'
            log_include_time (bool)
        '''
        self.lvl_list = ['debug', 'note', 'info', 'highlight', 'val']
        self.lvl_color = [bcolors.FAIL, None, None, bcolors.WARNING, bcolors.OKGREEN]

        assert log_file_option in ['w', 'a']
        assert log_lvl in self.lvl_list
        assert log_file_lvl in self.lvl_list

        # init progress message class
        ProgressMsg.__init__(self, max_iter)

        # log setting
        self.log_dir = log_dir
        self.log_lvl      = self.lvl_list.index(log_lvl)
        self.log_file_lvl = self.lvl_list.index(log_file_lvl)
        self.log_include_time = log_include_time
        
        # init logging
        if self.log_dir is not None:
            logfile_time = datetime.datetime.now().strftime('%m-%d-%H-%M')
            self.log_file = open(os.path.join(log_dir, 'log_%s.log'%logfile_time), log_file_option, encoding='utf-8')
            self.val_file = open(os.path.join(log_dir, 'validation_%s.log'%logfile_time), log_file_option, encoding='utf-8')

    def _print(self, txt, lvl_n, end):
        txt = str(txt)
        if self.log_lvl <= lvl_n:
            if self.lvl_color[lvl_n] is not None:
                print('\033[K'+ self.lvl_color[lvl_n] + txt + bcolors.ENDC, end=end)
            else:
                print('\033[K'+txt, end=end)
        if self.log_file_lvl <= lvl_n:
            self.write_file(txt)

    def debug(self, txt, end=None):
        self._print(txt, self.lvl_list.index('debug'), end)
    
    def note(self, txt, end=None):
        self._print(txt, self.lvl_list.index('note'), end)

    def info(self, txt, end=None):
        self._print(txt, self.lvl_list.index('info'), end)

    def highlight(self, txt, end=None):
        self._print(txt, self.lvl_list.index('highlight'), end)

    def val(self, txt, end=None):
        self._print(txt, self.lvl_list.index('val'), end)
        if self.log_dir is not None:
            self.val_file.write(txt+'\n')
            self.val_file.flush()

    def write_file(self, txt):
        if self.log_dir is not None:
            if self.log_include_time:
                time = datetime.datetime.now().strftime('%H:%M:%S')
                txt = "[%s] "%time + txt
            self.log_file.write(txt+'\n')
            self.log_file.flush()

    def clear_screen(self):
        if os.name == 'nt': 
            os.system('cls') 
        else: 
            os.system('clear') 

# https://stackoverflow.com/questions/287871/how-to-print-colored-text-in-python
class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()