In [None]:
# default_exp layers

In [None]:
#export
from local.imports import *
from local.test import *
from local.core import *

# Layers
> Custom fastai layers and basic functions to grab them.

## Basic manipulations and resize

In [None]:
# export
class Lambda(nn.Module):
    "An easy way to create a pytorch layer for a simple `func`"
    def __init__(self, func):
        super().__init__()
        self.func=func

    def forward(self, x): return self.func(x)

> Warning: In the tests below, we use lambda functions for convenience, but you shouldn't do this when building a real modules as it would make models that won't pickle (so you won't be able to save/export them).

In [None]:
tst = Lambda(lambda x:x+2)
x = torch.randn(10,20)
test_eq(tst(x), x+2)

In [None]:
# export
class View(nn.Module):
    "Reshape `x` to `size`"
    def __init__(self, *size):
        super().__init__()
        self.size = size

    def forward(self, x): return x.view(self.size)

In [None]:
tst = View(10,5,4)
test_eq(tst(x).shape, [10,5,4])

In [None]:
# export
class ResizeBatch(nn.Module):
    "Reshape `x` to `size`, keeping batch dim the same size"
    def __init__(self, *size):
        super().__init__()
        self.size = size

    def forward(self, x):
        size = (x.size(0),) + self.size
        return x.view(size)

In [None]:
tst = ResizeBatch(5,4)
test_eq(tst(x).shape, [10,5,4])

In [None]:
# export
class Flatten(nn.Module):
    "Flatten `x` to a single dimension, often used at the end of a model. `full` for rank-1 tensor"
    def __init__(self, full=False):
        super().__init__()
        self.full = full

    def forward(self, x):
        return x.view(-1) if self.full else x.view(x.size(0), -1)

In [None]:
tst = Flatten()
x = torch.randn(10,5,4)
test_eq(tst(x).shape, [10,20])
tst = Flatten(full=True)
test_eq(tst(x).shape, [200])

In [None]:
class PoolFlatten(nn.Sequential):
    "Combine `nn.AdaptiveAvgPool2d` and `Flatten`."
    def __init__(self): super().__init__(nn.AdaptiveAvgPool2d(1), Flatten())

In [None]:
tst = PoolFlatten()
x = torch.randn(10,5,4,4)
test_eq(tst(x).shape, [10,5])
test_eq(tst(x), x.mean(dim=[2,3]))

## BatchNorm