In [26]:
import distrib_zoo as dz
import torch
from layers import *
from func import *
import torch
from torch.nn import functional as F
import numpy as np
import math

In [2]:
f = SimpleLinear(5,3)

In [9]:
def mlp_fn(mlp_func, in_feats, out_feats, chunk_sizes):
    def mlp(x, cond):
        x_ = mlp_func(torch.cat((x, cond), dim=-1))
        mu, var = torch.split(x_, chunk_sizes, -1)
        var = F.softplus(var)/math.log(2)
        return mu, var
    return mlp

In [10]:
mlp_fn(f, 5, 3, [1,2])(a, b)

(tensor([[0.7437],
         [0.3418],
         [0.6956],
         [0.5973],
         [0.4781],
         [0.8379],
         [0.4216]], grad_fn=<SplitWithSizesBackward>),
 tensor([[0.8899, 1.2222],
         [0.9214, 1.2993],
         [0.9226, 1.2301],
         [0.8311, 1.3665],
         [0.8873, 1.4396],
         [0.7287, 1.5651],
         [0.7332, 1.7231]], grad_fn=<DivBackward0>))

In [13]:
a = torch.rand([7,4])
b = torch.rand([7,1])

In [8]:
f(torch.cat((a,b), dim=-1))

tensor([[ 0.7437, -0.1589,  0.2875],
        [ 0.3418, -0.1120,  0.3792],
        [ 0.6956, -0.1104,  0.2970],
        [ 0.5973, -0.2496,  0.4564],
        [ 0.4781, -0.1629,  0.5379],
        [ 0.8379, -0.4199,  0.6725],
        [ 0.4216, -0.4120,  0.8335]], grad_fn=<EluBackward>)

In [39]:
torch.split(a,[2,5],-1)

(tensor([[1., 1.],
         [1., 1.],
         [1., 1.]]), tensor([[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]]))

[0;31mSignature:[0m [0mtorch[0m[0;34m.[0m[0msplit[0m[0;34m([0m[0mtensor[0m[0;34m,[0m [0msplit_size_or_sections[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;36m0[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Splits the tensor into chunks.

If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will
be split into equally sized chunks (if possible). Last chunk will be smaller if
the tensor size along the given dimension :attr:`dim` is not divisible by
:attr:`split_size`.

If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split
into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according
to :attr:`split_size_or_sections`.

Arguments:
    tensor (Tensor): tensor to split.
    split_size_or_sections (int) or (list(int)): size of a single chunk or
        list of sizes for each chunk
    dim (int): dimension along which to split the tensor.
[0;31mFile:[0m      /opt/conda/lib/python3.7/site-packages/t

In [46]:
torch.split(a, a.size(1) // 2, -1)

(tensor([[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]), tensor([[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]), tensor([[1.],
         [1.],
         [1.]]))

In [49]:
a.shape

torch.Size([3, 7])

In [48]:
from torch.nn import functional as F

In [50]:
?F.linear

[0;31mSignature:[0m [0mF[0m[0;34m.[0m[0mlinear[0m[0;34m([0m[0minput[0m[0;34m,[0m [0mweight[0m[0;34m,[0m [0mbias[0m[0;34m=[0m[0;32mNone[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.

Shape:

    - Input: :math:`(N, *, in\_features)` where `*` means any number of
      additional dimensions
    - Weight: :math:`(out\_features, in\_features)`
    - Bias: :math:`(out\_features)`
    - Output: :math:`(N, *, out\_features)`
[0;31mFile:[0m      /opt/conda/lib/python3.7/site-packages/torch/nn/functional.py
[0;31mType:[0m      function


In [51]:
b = dz.InvLinear(7, 5)

In [53]:
b.flow(a)

tensor([[-0.5545, -1.5532, -0.2023,  1.3605,  0.9206, -0.9471,  0.8023],
        [-0.5545, -1.5532, -0.2023,  1.3605,  0.9206, -0.9471,  0.8023],
        [-0.5545, -1.5532, -0.2023,  1.3605,  0.9206, -0.9471,  0.8023]],
       grad_fn=<MmBackward>)

In [58]:
F.linear(a, torch.rand([5,7]))

tensor([[3.1820, 4.1477, 2.7423, 3.8388, 3.6860],
        [3.1820, 4.1477, 2.7423, 3.8388, 3.6860],
        [3.1820, 4.1477, 2.7423, 3.8388, 3.6860]])

In [56]:
a.shape

torch.Size([3, 7])

In [3]:
a

tensor([[1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1.]])

In [4]:
torch.sum(a, dim=-1)

tensor([7., 7., 7.])

In [5]:
torch.cat((a, a), dim=-1)

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])

tensor([0., 0., 0., 0., 0.])

In [14]:
nn.ReLU()(a)

tensor([[0.6776, 0.3710, 0.8945, 0.8784],
        [0.4639, 0.4442, 0.4395, 0.0650],
        [0.8493, 0.2488, 0.9899, 0.1586],
        [0.7743, 0.2697, 0.1206, 0.9328],
        [0.9514, 0.9723, 0.8638, 0.8053],
        [0.4224, 0.0471, 0.0620, 0.1813],
        [0.8320, 0.7917, 0.9346, 0.3507]])

In [15]:
F.relu(a)

tensor([[0.6776, 0.3710, 0.8945, 0.8784],
        [0.4639, 0.4442, 0.4395, 0.0650],
        [0.8493, 0.2488, 0.9899, 0.1586],
        [0.7743, 0.2697, 0.1206, 0.9328],
        [0.9514, 0.9723, 0.8638, 0.8053],
        [0.4224, 0.0471, 0.0620, 0.1813],
        [0.8320, 0.7917, 0.9346, 0.3507]])

In [22]:
from torch_scatter import *

In [23]:
seg_ids = [0,0,1,1,1,2,2]

In [24]:
scatter_add(F.relu(a), torch.tensor(seg_ids), dim=0)

tensor([[1.1415, 0.8152, 1.3340, 0.9433],
        [2.5749, 1.4907, 1.9743, 1.8967],
        [1.2544, 0.8388, 0.9966, 0.5320]])

In [19]:
?scatter_add

[0;31mSignature:[0m [0mscatter_add[0m[0;34m([0m[0msrc[0m[0;34m,[0m [0mindex[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;34m-[0m[0;36m1[0m[0;34m,[0m [0mout[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m [0mdim_size[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m [0mfill_value[0m[0;34m=[0m[0;36m0[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
|

.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
        master/docs/source/_figures/add.svg?sanitize=true
    :align: center
    :width: 400px

|

Sums all values from the :attr:`src` tensor into :attr:`out` at the indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
each value in :attr:`src`, its output index is specified by its index in
:attr:`input` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`. If
multiple indices reference the same location, their **contributions add**.

Formally, if :attr:`src` a

In [21]:
?torch.index_add

[0;31mDocstring:[0m <no docstring>
[0;31mType:[0m      builtin_function_or_method


In [29]:
np.arange(len([2,4])).repeat([2,4])

array([0, 0, 1, 1, 1, 1])