In [2]:
import math
import torch as th
import torch.nn as nn
from torch.nn.parameter import Parameter
from torch.nn import init
from torch.nn import functional as F

In [3]:
c = th.randn(2)
cs = c.shape[0]
n = th.randn(4)
ns = n.shape[0]
l = th.randn(3)
r = th.randn(3)
ss = l.shape[0]
assert l.shape[0] == r.shape[0]

In [45]:
c

tensor([-0.4235,  0.2918])

In [88]:
class MyLayer(nn.Module):
    __constants__ = ['sbias', 'cbias']
    __weights__ = ['c2s', 'c2c', 's2c', 's1s', 's2s']

    def __init__(self, in_const, in_side, out_const, out_side):
        """
        """
        super().__init__()
        
        self.in_const = in_const
        self.in_side = in_side
        self.out_const = out_const
        self.out_side = out_side
        
        self.c2s = Parameter(th.Tensor(out_side, in_const))
        self.s1s = Parameter(th.Tensor(out_side, in_side))
        self.s2s = Parameter(th.Tensor(out_side, in_side))
        self.sbias = Parameter(th.Tensor(out_side))
        
        self.s2c = Parameter(th.Tensor(out_const, in_side))
        self.c2c = Parameter(th.Tensor(out_const, in_const))
        self.cbias = Parameter(th.Tensor(out_const))
        self.reset_parameters()
        
    
    def forward(self, c, l, r):
        c2s = F.linear(c, self.c2s, self.sbias)
        s2c = F.linear((l+r)/2, self.s2c, self.cbias)
        return (
            s2c + F.linear(c, self.c2c),                         # constant part
            c2s + F.linear(l, self.s1s) + F.linear(r, self.s2s), # left
            c2s + F.linear(r, self.s1s) + F.linear(l, self.s2s), # right
        )
        
#         self.in_features = in_features
#         self.out_features = out_features
#         self.weight = Parameter(torch.Tensor(out_features, in_features))
        
#         if bias:
#             self.bias = Parameter(torch.Tensor(out_features))
#         else:
#             self.register_parameter('bias', None)

    def reset_parameters(self):
        for wname in self.__weights__:
            init.kaiming_uniform_(getattr(self, wname), a=math.sqrt(5))
            
        bound = 1 / math.sqrt(self.in_const + 2 * self.in_side)
        init.uniform_(self.sbias, -bound, bound)
        init.uniform_(self.cbias, -bound, bound)
#         if self.bias is not None:
#             fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
#             bound = 1 / math.sqrt(fan_in)
#             init.uniform_(self.bias, -bound, bound)

#     @weak_script_method
#     def forward(self, input):
#         return F.linear(input, self.weight, self.bias)

#     def extra_repr(self):
#         return 'in_features={}, out_features={}, bias={}'.format(
#             self.in_features, self.out_features, self.bias is not None
#         )

In [89]:
class MyNet(nn.Module):
    def __init__(self, c_in, n_in, s_in, c_out, s_out, num_layers=3, num_hidden=16):
        super().__init__()
        self.c_in = c_in
        self.s_in = s_in
        self.n_in = n_in
        self.layers = []
        last_cin = c_in + n_in
        last_sin = s_in
        for i in range(num_layers-1):
            self.layers.append(MyLayer(last_cin, last_sin, num_hidden, num_hidden))
            self.add_module("layer%d" % i, self.layers[i])
            last_cin, last_sin = num_hidden, num_hidden
        self.layers.append(MyLayer(last_cin, last_sin, c_out, s_out))
        self.add_module("final", self.layers[-1])
    
    def _forward_one_side(self, c, l, r):
        for layer in self.layers:
            c, l, r = layer(c, l, r)
        return th.cat([c, l, r], -1)
        
    def forward(self, obs):
        # TODO: fix for batch or multi
        c = obs[         :self.c_in]
        n = obs[self.c_in:self.c_in+self.n_in]
        l = obs[self.c_in+self.n_in:self.c_in+self.n_in+self.s_in]
        r = obs[-self.s_in:]
        return (
              self._forward_one_side(th.cat([c, n], -1), l, r)
            + self._forward_one_side(th.cat([c, -n], -1), l, r)
        )


In [90]:
net = MyNet(cs, ns, ss, 0, 4)

In [91]:
net(th.cat([c, n, l, r], -1))

tensor([-2.3242,  0.0867,  0.1860,  2.3088,  0.9979,  0.3886, -0.3321, -0.7505],
       grad_fn=<AddBackward0>)

In [94]:
net(th.cat([c, -n, r, l], -1))

tensor([ 0.9979,  0.3886, -0.3321, -0.7505, -2.3242,  0.0867,  0.1860,  2.3088],
       grad_fn=<AddBackward0>)

### Negatives
This doesn't handle them correctly since we have $f(c, n, l, r) = f(c, -n, l, r)$ as well

Correct way:
 - negatives on sides: just invert them in input
 - fixed negatives: add to layer

In [46]:

class SymmetricLayer(nn.Module):
    __constants__ = ["sbias", "cbias"]
    __weights__ = ["c2s", "n2s", "c2c", "s2c", "s1s", "s2s", "n2n", "s2n"]

    def __init__(self, in_const, in_neg, in_side, out_const, out_neg, out_side):
        """
        """
        super().__init__()

        self.in_const = in_const
        self.in_neg = in_neg
        self.in_side = in_side
        self.out_const = out_const
        self.out_neg = out_neg
        self.out_side = out_side

        self.c2s = Parameter(th.Tensor(out_side, in_const))
        self.s1s = Parameter(th.Tensor(out_side, in_side))
        self.s2s = Parameter(th.Tensor(out_side, in_side))
        self.n2s = Parameter(th.Tensor(out_side, in_neg))
        self.sbias = Parameter(th.Tensor(out_side))

        self.s2c = Parameter(th.Tensor(out_const, in_side))
        self.c2c = Parameter(th.Tensor(out_const, in_const))
        self.cbias = Parameter(th.Tensor(out_const))

        self.n2n = Parameter(th.Tensor(out_neg, in_neg))
        self.s2n = Parameter(th.Tensor(out_neg, in_side))

        self.reset_parameters()

    def forward(self, c, n, l, r):
        c2s = F.linear(c, self.c2s, self.sbias)
        n2s = F.linear(n, self.n2s)
        s2c = F.linear((l + r) / 2, self.s2c, self.cbias)
        s2n = F.linear(l - r, self.s2n)
        return (
            s2c + F.linear(c, self.c2c),  # constant part
            s2n + F.linear(n, self.n2n),  # negative part
            c2s + n2s + F.linear(l, self.s1s) + F.linear(r, self.s2s),  # left
            c2s - n2s + F.linear(r, self.s1s) + F.linear(l, self.s2s),  # right
        )

    #         self.in_features = in_features
    #         self.out_features = out_features
    #         self.weight = Parameter(torch.Tensor(out_features, in_features))

    #         if bias:
    #             self.bias = Parameter(torch.Tensor(out_features))
    #         else:
    #             self.register_parameter('bias', None)

    def reset_parameters(self):
        for wname in self.__weights__:
            init.kaiming_uniform_(getattr(self, wname), a=math.sqrt(5))

        bound = 1 / math.sqrt(self.in_const + self.in_neg + 2 * self.in_side)
        init.uniform_(self.sbias, -bound, bound)
        init.uniform_(self.cbias, -bound, bound)


In [74]:
class SymmetricNet(nn.Module):
    def __init__(self, c_in, n_in, s_in, c_out, n_out, s_out, num_layers=3, num_hidden=16, tanh_finish=True,
        varying_std=False,
        log_std=-1,
        deterministic=False,):
        super().__init__()
        self.c_in = c_in
        self.s_in = s_in
        self.n_in = n_in
        self.tanh_finish = tanh_finish
        self.deterministic = deterministic
        
        self.layers = []
        last_cin = c_in
        last_nin = n_in
        last_sin = s_in
        for i in range(num_layers-1):
            self.layers.append(SymmetricLayer(last_cin, last_nin, last_sin, num_hidden, num_hidden, num_hidden))
            self.add_module("layer%d" % i, self.layers[i])
            last_cin, last_nin, last_sin = num_hidden, num_hidden, num_hidden
        self.layers.append(SymmetricLayer(last_cin, last_nin, last_sin, c_out, n_out, s_out))
        self.add_module("final", self.layers[-1])
        
        if not self.deterministic:
            action_size = c_out + n_out + 2 * s_out
            if varying_std:
                self.log_std_param = nn.Parameter(
                    th.randn(action_size) * 1e-10 + log_std
                )
            else:
                self.log_std_param = log_std * th.ones(action_size)

        
    def forward(self, obs):
        # TODO: better fix than transpose?
        obs = obs.transpose(0, -1)
        c = obs[         :self.c_in].transpose(0, -1)
        n = obs[self.c_in:self.c_in+self.n_in].transpose(0, -1)
        l = obs[self.c_in+self.n_in:self.c_in+self.n_in+self.s_in].transpose(0, -1)
        r = obs[-self.s_in:].transpose(0, -1)

        for i, layer in enumerate(self.layers):
            if i != 0:
                n = th.tanh(n)   # TODO
                c = th.relu(c)
                r = th.relu(r)
                l = th.relu(l)

            c, n, l, r = layer(c, n, l, r)

        mean = th.cat([c, n, l, r], -1)
        if self.tanh_finish:
            mean = th.tanh(mean)
        
        if not self.deterministic:
            log_std = self.log_std_param.expand_as(mean)
            return mean, log_std
        else:
            return mean

In [75]:
net = SymmetricNet(cs, ns, ss, 0, 0, 4)

In [76]:
net(th.cat([c, n, l, r], -1))

(tensor([ 0.0825, -0.1993,  0.0163,  0.0147,  0.0537,  0.1788,  0.3346, -0.7536],
        grad_fn=<TanhBackward>),
 tensor([-1., -1., -1., -1., -1., -1., -1., -1.]))

In [77]:
net(th.cat([c, -n, r, l], -1))

(tensor([ 0.0537,  0.1788,  0.3346, -0.7536,  0.0825, -0.1993,  0.0163,  0.0147],
        grad_fn=<TanhBackward>),
 tensor([-1., -1., -1., -1., -1., -1., -1., -1.]))

### Thoughts
 - The final layer can have a bias for the negative part, but it wouldn't really be useful anyway
 - anything to replace `transpose`?
 - Investigate why I can't use `relu` for `n`

12