# Utils

> Utility functions


In [9]:
#| default_exp utils

In [10]:
#| hide
from nbdev.showdoc import *

In [11]:
#| export
from fastai.vision.all import torch, nn
import numpy as np
from torchmetrics.functional.image import structural_similarity_index_measure as structural_similarity
from torchmetrics.functional.image import peak_signal_noise_ratio
import math

In [12]:
#| export
def attributesFromDict(d):
    self = d.pop('self')
    for n, v in d.items():
        setattr(self, n, v)

In [13]:
#| export

def batch_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = 0
    for i in range(Img.shape[0]):
        PSNR += peak_signal_noise_ratio(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range)
    return (PSNR/Img.shape[0])

In [14]:
#| export

class compute_index():
    def __init__(self, codes, device='cpu') -> None:
        attributesFromDict(locals( ))
      
    def _compute_one_hot(self, b, **kwargs):
        idx = torch.zeros([b], device=self.device, dtype=torch.float32)
        for key, value in self.codes.items():
            idx = idx * len(value)
            for i, v in enumerate(value):
                idx += torch.where(kwargs[key] == v, i, 0.0)

        return idx
    
    def __call__(self, b, **kwargs):
        return self._compute_one_hot(b, **kwargs)

In [15]:
device = 'cpu'
codes = {
        'exposure-time': torch.tensor([10, 50, 100], dtype=torch.float32, device=device),
        'optical-setup': torch.tensor([0, 1], dtype=torch.float32).to(device),
        'camera': torch.tensor([0, 1], dtype=torch.float32).to(device)
    }
kwargs = {
        'exposure-time': torch.tensor([100], dtype=torch.float32).to(device),
        'optical-setup': torch.tensor([1], dtype=torch.float32).to(device),
        'camera': torch.tensor([0], dtype=torch.float32).to(device)
    }

fn1hot = compute_index(codes)

print('index from one hot encoding: ', fn1hot(1, **kwargs))


index from one hot encoding:  tensor([10.])


In [16]:
#| hide
import nbdev; nbdev.nbdev_export()