In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [27]:
class Controller(nn.Module):
    def __init__(self, dim_in, dim_lowrank, dim_hidden, num_blocks):
        super(Controller, self).__init__()
        self.dim_in = dim_in
        self.dim_lowrank = dim_lowrank
        self.dim_hidden = dim_hidden
        self.num_blocks = num_blocks
        assert self.dim_hidden % self.num_blocks == 0, "hidden vector must be divisible into N blocks"
        self.U = nn.Linear(dim_in, dim_lowrank, bias = False)
        self.V = nn.Linear(dim_lowrank, dim_hidden, bias = False)
    def forward(self, x):
        logits = self.V(self.U(x))
        original_shape = logits.shape
        logits = logits.reshape(*logits.shape[:-1], self.num_blocks, self.dim_hidden // self.num_blocks)
        if self.training:
            mask = F.gumbel_softmax(logits, tau=0.1, hard=True)
            return mask.reshape(original_shape)
        else:
            selected = torch.argmax(logits, dim=-1)
            mask = F.one_hot(selected, num_classes = self.dim_hidden // self.num_blocks)
            return mask.reshape(original_shape)
            
            

In [15]:
A = torch.rand(2,5,3,4)
am = torch.argmax(A,dim=-1)
onehot = F.one_hot(am, num_classes=4)
print(A)
print(onehot)
print(onehot.reshape(2,5,12))

tensor([[[[7.3692e-01, 7.9813e-01, 5.0458e-01, 5.3212e-01],
          [5.1928e-01, 1.1359e-01, 9.7737e-01, 6.4920e-01],
          [6.8741e-01, 4.5131e-01, 1.5995e-01, 8.3960e-01]],

         [[5.4992e-01, 7.4053e-04, 1.3080e-01, 5.8648e-01],
          [7.3517e-01, 5.8267e-01, 9.2709e-01, 6.9977e-01],
          [7.9869e-01, 1.7985e-02, 3.0743e-01, 8.6149e-01]],

         [[2.7280e-01, 6.8720e-01, 8.7931e-01, 6.1018e-01],
          [1.7987e-02, 6.3314e-01, 3.6387e-02, 8.4156e-01],
          [2.3171e-01, 3.3489e-02, 8.5126e-01, 8.1929e-01]],

         [[3.5983e-01, 4.8006e-01, 9.2413e-01, 7.2229e-01],
          [1.8599e-02, 5.9205e-02, 1.5740e-02, 1.5232e-01],
          [3.7424e-01, 6.5595e-01, 4.5042e-01, 2.0002e-01]],

         [[6.4676e-01, 7.0073e-01, 4.3187e-01, 4.0098e-01],
          [4.7939e-01, 3.3698e-01, 1.3873e-01, 1.9097e-01],
          [9.3650e-01, 7.9411e-01, 8.3972e-01, 4.2133e-01]]],


        [[[7.9493e-01, 5.3347e-01, 1.3234e-01, 2.0028e-01],
          [2.2333e-01, 8.089

In [26]:
A = torch.rand(2,5,3,4)
gs = F.gumbel_softmax(A, tau=0.1, hard=True)
print(gs)
print(gs.reshape(2,5,12))

tensor([[[[0., 0., 0., 1.],
          [1., 0., 0., 0.],
          [0., 0., 0., 1.]],

         [[0., 0., 1., 0.],
          [0., 0., 0., 1.],
          [0., 0., 0., 1.]],

         [[1., 0., 0., 0.],
          [1., 0., 0., 0.],
          [0., 0., 1., 0.]],

         [[0., 0., 1., 0.],
          [0., 1., 0., 0.],
          [0., 0., 1., 0.]],

         [[0., 0., 1., 0.],
          [0., 1., 0., 0.],
          [0., 1., 0., 0.]]],


        [[[0., 0., 0., 1.],
          [1., 0., 0., 0.],
          [0., 1., 0., 0.]],

         [[0., 1., 0., 0.],
          [0., 0., 0., 1.],
          [1., 0., 0., 0.]],

         [[0., 1., 0., 0.],
          [1., 0., 0., 0.],
          [1., 0., 0., 0.]],

         [[0., 1., 0., 0.],
          [0., 0., 1., 0.],
          [0., 0., 0., 1.]],

         [[1., 0., 0., 0.],
          [0., 1., 0., 0.],
          [0., 1., 0., 0.]]]])
tensor([[[0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 1.],
         [0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1.],
         [1., 0., 0

In [24]:
cnt = Controller(6,3,8,2)
print(cnt(torch.rand(100,50,6)).shape)

torch.Size([100, 50, 8])


In [18]:
class ControllerFFN(nn.Module):
    def __init__(self, dim_in, dim_lowrank, dim_hidden, num_blocks):
        super(ControllerFFN, self).__init__()
        self.dim_in = dim_in
        self.dim_lowrank = dim_lowrank
        self.dim_hidden = dim_hidden
        self.num_blocks = num_blocks
        assert self.dim_hidden % self.num_blocks == 0, "hidden vector must be divisible into N blocks"
        self.controller = Controller(dim_in, dim_lowrank, dim_hidden, num_blocks)
        self.layer1 = nn.Linear(dim_in, dim_hidden)
        self.layer2 = nn.Linear(dim_hidden, dim_in)
    def forward(self, x):
        return self.layer2(self.controller(x)*self.layer1(x))

In [25]:
cntffn = ControllerFFN(6,3,8,2)
print(cntffn(torch.rand(100,50,6)).shape)

torch.Size([100, 50, 6])
