In [None]:
# export
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from fastcore.all import *

In [None]:
# default_exp torch_utils

# Torch Utils
> Some useful utils to extend pytorch functions

In [None]:
# export
class InfiniteDl():
    def __init__(self, dl):
        self.dl = dl
        self.it = iter(self.dl)
    def next(self):
        try:
            return self.it.next()
        except StopIteration:
            self.it = iter(self.dl)
            return self.it.next()

In [None]:
# export
def isin(t, ids):
    ''' Returns ByteTensor where True values are positions that contain ids. '''
    return (t[..., None] == torch.tensor(ids, device=t.device)).any(-1)

In [None]:
t = torch.tensor([[12, 11, 0, 0], 
                  [9, 1, 5, 0]])
mask = isin(t, [0, 1])
test_eq(mask, torch.tensor([[0, 0, 1, 1],
                            [0, 1, 0, 1]]).bool())

In [None]:
# export
def get_src_mask(cap_len, max_seq_len, device='cpu'):
    ''' cap_len: (bs,), max_seq_len: int '''
    return torch.arange(max_seq_len, device=device)[None, :] >= cap_len[:, None]

In [None]:
cap_len = torch.tensor([2, 1, 3])
max_seq_len = 5
src_mask = get_src_mask(cap_len, max_seq_len)
test_eq(src_mask, torch.tensor([[False, False,  True,  True,  True],
                                [False,  True,  True,  True,  True],
                                [False, False, False,  True,  True]]))

In [None]:
# export
class Normalizer():
    " normalize input image to -1 ~ 1 "
    def __init__(self, device='cpu'): 
        self.mean = torch.tensor([0.5, 0.5, 0.5], device=device)[None, ..., None, None] # (1, 3, 1, 1)
        self.std = torch.tensor([0.5, 0.5, 0.5], device=device)[None, ..., None, None]
    def set_device(device='cpu'):
        self.mean.to(device)
        self.std.to(device)
    def encode(self, x): 
        "x: (bs, 3, _, _)"
        return (x.float()/255-self.mean) / self.std
    def decode(self, x):
        x = x*self.std + self.mean
        return (x.clamp(0., 1.)*255).long()

In [None]:
normalizer = Normalizer()
img = torch.randint(0, 255, (2, 3, 16, 16))
img_encoded = normalizer.encode(img)
img_decoded = normalizer.decode(img_encoded)
test_close(img, img_decoded, eps=2)

# test encoded img is in range -1~1
test_eq((img_encoded>=-1).long() + (img_encoded<=1).long(), torch.ones(2, 3, 16, 16).long()*2 )
# test decoded img is in range 0~255
test_eq((img_decoded>=0).long() + (img_decoded<=255).long(), torch.ones(2, 3, 16, 16).long()*2 )

In [None]:
# export
def to_device(tensors, device='cpu'):
    return [t.to(device) for t in tensors]
def detach(tensors, is_to_cpu=False):
    return [t.cpu().detach() if is_to_cpu else t.detach() for t in tensors]
def is_models_equal(model_1, model_2):
    models_differ = 0
    for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
        if torch.equal(key_item_1[1], key_item_2[1]):
            pass
        else:
            models_differ += 1
            if (key_item_1[0] == key_item_2[0]):
                print('Mismtach found at', key_item_1[0])
                return False
            else:
                print('Oops somethings wrong')
                return False
    if models_differ == 0:
        return True
class MultiWrapper(nn.Module):
    def __init__(self, layer, n_returns=1):
        super().__init__()
        assert n_returns>=1
        self.layer = layer
        self.n_returns = n_returns
    def forward(self, x, *others):
        if self.n_returns==1: 
            return self.layer(x)
        else:
            return (self.layer(x), *others[:self.n_returns-1])
class MultiSequential(nn.Sequential):
    def forward(self, *inputs):
        for module in self._modules.values():
            if type(inputs) == tuple:
                inputs = module(*inputs)
            else:
                inputs = module(inputs)
        return inputs
class IdentityModule(nn.Module):
    def forward(self, x):
        return x

In [None]:
# export
noise_gen = torch.distributions.normal.Normal(0, torch.exp(torch.tensor(-1/np.pi)))

In [None]:
noise = noise_gen.sample((2, 100))
test_eq(noise.shape, (2, 100))

## Export -

In [None]:
# hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_torch_utils.ipynb.
Converted 02a_data_anime_heads.ipynb.
Converted 02b_data_birds.ipynb.
Converted 03a_model.ipynb.
Converted 04a_trainer_DAMSM.ipynb.
Converted 04b_trainer.ipynb.
Converted 05a_inference_anime_heads.ipynb.
Converted 05b_inference_birds.ipynb.
Converted index.ipynb.
