# SLayer Tutorial

This tutorial gives you a brief insight in the functionalities offered by the `nn.SLayer` 
module. It assumes familarity with standard `PyTorch` functionality. 


In [24]:
from shared_code import check_chofer_torchex_availability
check_chofer_torchex_availability()

In [25]:
from chofer_torchex.nn import SLayer

# create an instance with 3 structure elements over \R^2
sl = SLayer(3, 2)

`nn.SLayer` is a `torch.nn.Module` ... 

In [26]:
import torch
isinstance(sl, torch.nn.Module)

True

... now we can do all the beautiful stuff which is inherited from `torch.nn.Module`, e.g.,

In [27]:
for p in sl.parameters():
    print(p)

Parameter containing:
 0.2816  0.4501
 0.9266  0.2779
 0.6380  0.5584
[torch.FloatTensor of size 3x2]

Parameter containing:
 3  3
 3  3
 3  3
[torch.FloatTensor of size 3x2]



The module has **two** parameters: 
1. `centers` : controls the centers of the structure elements. 
2. `sharpness`: controls how tight the used Gaussians are. The higher the value, the tighter. 

Both can be initialized using the `centers_init` and `sharpness_init` keyword arguments, respectively.

In [28]:
# here is an initialization example
centers_init = torch.Tensor([[0,0], [0.5, 0.5], [1,1]])
sharpness_init = torch.Tensor([[1,1], [2,2], [3,3]])

sl = SLayer(3, 2, 
            centers_init=centers_init, 
            sharpness_init=sharpness_init)

print(sl.centers)
print(sl.sharpness)

Parameter containing:
 0.0000  0.0000
 0.5000  0.5000
 1.0000  1.0000
[torch.FloatTensor of size 3x2]

Parameter containing:
 1  1
 2  2
 3  3
[torch.FloatTensor of size 3x2]



The simplest input form for `nn.SLayer` is a `list` of `torch.Tensor` objects which are treated as a *batch*. 

In [29]:
# As an example, we create a batch of multisets
mset_1 = [[0, 0]]
mset_2 = [[0, 0], [0, 0]]
mset_3 = [[1, 1], [0, 0]]
mset_4 = [[0, 0], [1, 1]]
batch = [mset_1, mset_2, mset_3,  mset_4]
batch = [torch.Tensor(x) for x in batch]
output = sl(batch)
print(output.size())

torch.Size([4, 3])


As we see the output dimensionality is `(4, 3)` since
we have a batch of size `4` and `3` structure elements. 

In other words, 
`output[i, j] =` "evaluation of structure element j on mset_i"

Lets take a look ... 

In [30]:
print(output)

Variable containing:
 1.0000e+00  1.3534e-01  1.5230e-08
 2.0000e+00  2.7067e-01  3.0460e-08
 1.1353e+00  2.7067e-01  1.0000e+00
 1.1353e+00  2.7067e-01  1.0000e+00
[torch.FloatTensor of size 4x3]



**We observe the following:**

1. The j-th stucture element approximates the multiplicity function of the given input at point `sl.centers[j]`. E.g., the output of mset_1, `output[0, :]`,  is approx. `(1, 0, 0)`. 
2. `sl.sharpness[j]` controls the amount of contribution of points not exactly on `sl.centers[j]` with respect to their distance to `sl.centers[j]`. 
3. The input is interpreted as set, i.e., it is permutation invariant, as mset_3 and mset_4 do not defer as multiset and also `output[2,:] == output[3, :]`. 

Maybe this becomes more clear if we increase the sharpness of our structure elements a "little" ...

In [31]:
sl = SLayer(3, 2, 
            centers_init=centers_init, 
            sharpness_init=10*sharpness_init)
print(sl(batch))

Variable containing:
 1  0  0
 2  0  0
 1  0  1
 1  0  1
[torch.FloatTensor of size 4x3]



**Performance tip:** 
`nn.SLayer` has a static method called `prepare_batch`. Here a lot of stuff is done 
which is faster on `cpu` than on `gpu` as there is looping involved. 
You can use this method in your training environment to separate batch preparation from 
the actual calculation of the output as `nn.SLayer` recognizes a prepared batch if it gets one...

In [32]:
batch = SLayer.prepare_batch(batch, point_dim=2)
sl(batch)

Variable containing:
 1  0  0
 2  0  0
 1  0  1
 1  0  1
[torch.FloatTensor of size 4x3]

**`nn.SLayer` is an input layer ONLY!**
This means you can only use `nn.SLayer` as input layer for your model as you can not differentiate w.r.t. its input. This is also the reason why it accepts `Tensors` and *not* `Variables` as input. <br>

An example of a model is shown below:

In [33]:
class MyModel(torch.nn.Module):
    def __init__(self):
        self.slayer = SLayer(50, 2)
        self.linear = torch.nn.Linear(50, 10)
        
    def forward(self, input):
        x = self.slayer(input)
        x = self.linear(x)
        return x 