In [2]:
import torch
import torch.nn as nn
import numpy as np

In [275]:
class HierarchicalSoftmax(nn.Module):
    
    def __init__(self, ntokens, nhid, ntokens_per_class = None):
        super(HierarchicalSoftmax, self).__init__()

        # Parameters
        self.ntokens = ntokens
        self.nhid = nhid

        if ntokens_per_class is None:
            ntokens_per_class = int(np.ceil(np.sqrt(ntokens)))

        self.ntokens_per_class = ntokens_per_class

        self.nclasses = int(np.ceil(self.ntokens * 1. / self.ntokens_per_class))
        self.ntokens_actual = self.nclasses * self.ntokens_per_class

        self.layer_top_W = nn.Parameter(torch.FloatTensor(self.nhid, self.nclasses), requires_grad=True)
        print(self.layer_top_W.shape)
        self.layer_top_b = nn.Parameter(torch.FloatTensor(self.nclasses), requires_grad=True)
        

        self.layer_bottom_W = nn.Parameter(torch.FloatTensor(self.nclasses, self.nhid, self.ntokens_per_class), requires_grad=True)
        print(self.layer_bottom_W.shape)
        self.layer_bottom_b = nn.Parameter(torch.FloatTensor(self.nclasses, self.ntokens_per_class), requires_grad=True)

        self.softmax = nn.Softmax(dim=1)

        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.layer_top_W.data.uniform_(-initrange, initrange)
        self.layer_top_b.data.fill_(0)
        self.layer_bottom_W.data.uniform_(-initrange, initrange)
        self.layer_bottom_b.data.fill_(0)


    def forward(self, inputs, labels):

        batch_size, d = inputs.size()
    
        label_position_top = (labels / self.ntokens_per_class).long()
        label_position_bottom = (labels % self.ntokens_per_class).long()

        layer_top_logits = torch.matmul(inputs, self.layer_top_W) + self.layer_top_b
        layer_top_probs = self.softmax(layer_top_logits)

        layer_bottom_logits = torch.squeeze(torch.bmm(torch.unsqueeze(inputs, dim=1), self.layer_bottom_W[label_position_top]), dim=1) + self.layer_bottom_b[label_position_top]
        layer_bottom_probs = self.softmax(layer_bottom_logits)

        target_probs = layer_top_probs[torch.arange(batch_size).long(), label_position_top] * layer_bottom_probs[torch.arange(batch_size).long(), label_position_bottom]

        # print(f"top {layer_top_probs.shape} {layer_top_probs}")
        # print(f"bottom {layer_bottom_probs.shape} {layer_bottom_probs}")
        top_indx = torch.argmax(layer_top_probs, dim=1)
        botton_indx = torch.argmax(layer_bottom_probs, dim=1)

        real_indx = (top_indx * self.ntokens_per_class) + botton_indx
        # print(top_indx, self.nclasses, botton_indx)
        # print(f"target {target_probs.shape} {target_probs}")
        
        loss = -torch.mean(torch.log(target_probs))

        return loss, target_probs, layer_top_probs, layer_bottom_probs, top_indx, botton_indx, real_indx


In [302]:
s = HierarchicalSoftmax(ntokens = 8, nhid = 256, ntokens_per_class=2)
s.nclasses, s.ntokens_per_class, s.ntokens_actual, s.nhid

torch.Size([256, 4])
torch.Size([4, 256, 2])


(4, 2, 8, 256)

In [303]:
x = torch.rand(10, 256).float()
y = torch.ones(10).long()
# y = torch.tensor([1, 2, 79000, 3]).long()
l, p, _, _, ti, bi, ri = s(x, y)

In [304]:
ti, bi, ri, y

(tensor([3, 1, 3, 3, 3, 1, 1, 1, 3, 3]),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
 tensor([7, 3, 7, 7, 7, 3, 3, 3, 7, 7]),
 tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]))

In [248]:
import torchmetrics

train_acc = torchmetrics.Accuracy()
train_acc(ri, y)

tensor(0.7000)

In [182]:
l

tensor(12.9449, grad_fn=<NegBackward>)

In [16]:
import timm

ModuleNotFoundError: No module named 'timm'

In [44]:
rnn = nn.RNN(input_size =10, hidden_size=20, num_layers=2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = rnn(input, h0)
output.shape, hn.shape

(torch.Size([5, 3, 20]), torch.Size([2, 3, 20]))

In [31]:
output

tensor([[[-0.7175,  0.7147,  0.1411,  0.9045, -0.3927, -0.5797, -0.8769,
          -0.3946,  0.4277,  0.0444,  0.9158,  0.5043, -0.3930,  0.1022,
           0.0757,  0.9047,  0.2502,  0.6870,  0.3141, -0.0672],
         [-0.9026,  0.2158,  0.5263, -0.1289,  0.6552, -0.3140, -0.8748,
           0.4386,  0.3600, -0.4555,  0.8233,  0.2973, -0.5606, -0.8467,
          -0.2112,  0.4747, -0.4633,  0.1217, -0.3112, -0.2013],
         [-0.4059,  0.4937, -0.1862,  0.2150, -0.4308,  0.6778,  0.1412,
          -0.0387, -0.2283, -0.6684, -0.8446,  0.8311,  0.2351,  0.8425,
          -0.4483,  0.0921, -0.7801, -0.1060,  0.0533,  0.3374]],

        [[-0.3310, -0.0913,  0.0705, -0.0304,  0.6011,  0.1184, -0.0502,
          -0.5929, -0.5561,  0.6226,  0.2480,  0.2018,  0.2935,  0.5392,
           0.2178, -0.3965, -0.5496,  0.1775, -0.5934, -0.0623],
         [-0.7821, -0.1966,  0.3740,  0.0997,  0.2173, -0.0241, -0.3716,
          -0.2147, -0.6032,  0.1225,  0.7787,  0.2232,  0.3798,  0.1481,
        

In [42]:
output.view(-1, output.size(2)).shape

torch.Size([15, 20])