# Utils

> Utility functions


In [1]:
#| default_exp utils

In [2]:
#| hide
from nbdev.showdoc import *

In [3]:
#| export
from fastai.vision.all import torch, nn
import numpy as np
from torchmetrics.functional.image import structural_similarity_index_measure as structural_similarity
from torchmetrics.functional.image import peak_signal_noise_ratio
import torch 
import torch.nn.functional as F

In [4]:
#| export
def attributesFromDict(d):
    self = d.pop('self')
    for n, v in d.items():
        setattr(self, n, v)

In [5]:
#| export

def batch_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = 0
    for i in range(Img.shape[0]):
        PSNR += peak_signal_noise_ratio(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range)
    return (PSNR/Img.shape[0])

In [6]:
#| export

class compute_index():
    def __init__(self, codes, device='cpu') -> None:
        attributesFromDict(locals( ))
      
    def _compute_index(self, b, **kwargs):
        idx = torch.zeros([b], device=self.device, dtype=torch.float32)
        for key, value in self.codes.items():
            idx = idx * len(value)
            for i, v in enumerate(value):
                idx += torch.where(kwargs[key] == v, i, 0.0)

        return idx
    
    def __call__(self, b, **kwargs):
        return self._compute_index(b, **kwargs)

In [7]:
device = 'cpu'
codes = {
        'exposure-time': torch.tensor([10, 50, 100], dtype=torch.float32, device=device),
        'optical-setup': torch.tensor([0, 1], dtype=torch.float32).to(device),
        'camera': torch.tensor([0, 1], dtype=torch.float32).to(device)
    }
kwargs = {
        'exposure-time': torch.tensor([100], dtype=torch.float32).to(device),
        'optical-setup': torch.tensor([1], dtype=torch.float32).to(device),
        'camera': torch.tensor([0], dtype=torch.float32).to(device)
    }

fn = compute_index(codes)

print('index: ', fn(1, **kwargs))


index:  tensor([10.])


In [8]:
#| export

class compute_one_hot():
    def __init__(self, codes, device='cpu') -> None:
        attributesFromDict(locals( ))
      
    def _compute_one_hot(self, b, **kwargs):
        embedding = torch.tensor([])
        for key, value in self.codes.items():
            idx = torch.zeros([b], device=self.device, dtype=torch.float32)
            for i, v in enumerate(value):
                idx += torch.where(kwargs[key] == v, i, 0.0)
            idx_one_hot = F.one_hot(idx.to(torch.int64), num_classes=value.shape[0]).to(torch.float32)
            print(key, ': ', idx_one_hot)
            embedding = torch.cat((embedding, idx_one_hot), dim=1)

        return embedding
    
    def __call__(self, b, **kwargs):
        return self._compute_one_hot(b, **kwargs)

In [9]:
fn1hot = compute_one_hot(codes)

print('one hot encoding: ', fn1hot(1, **kwargs))

exposure-time :  tensor([[0., 0., 1.]])
optical-setup :  tensor([[0., 1.]])
camera :  tensor([[1., 0.]])
one hot encoding:  tensor([[0., 0., 1., 0., 1., 1., 0.]])


In [10]:
#| hide
import nbdev; nbdev.nbdev_export()