In [6]:
import torch
import torch.nn as nn


class ClassificationLayer(nn.Module):
    def __init__(self,input_dim,output_dim) -> None:
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.ln1 = nn.Linear(input_dim,input_dim*2)
        self.ln2 = nn.Linear(input_dim*2,output_dim)
        self.act = nn.ReLU()
        self.drop =nn.Dropout(0.2)

    def forward(self,x):
        '''

        :param x: (N,L,D)
        :return: prob. distribution (N,L,dout)
        '''
        x = self.ln1(x) # (N,L,D*2)
        x = self.act(x)
        logits = self.ln2(x) # (N,L,dout)
        logits = self.drop(logits)
        probs = nn.Softmax(-1)(logits)
        return probs

In [7]:
clf = ClassificationLayer(768,4000)
x = torch.rand(2,30,768)
x

tensor([[[0.1482, 0.3409, 0.7603,  ..., 0.5581, 0.6056, 0.0717],
         [0.8745, 0.2910, 0.6281,  ..., 0.3196, 0.6573, 0.2279],
         [0.5815, 0.7399, 0.7406,  ..., 0.9813, 0.5961, 0.7176],
         ...,
         [0.9861, 0.0863, 0.4987,  ..., 0.8114, 0.2574, 0.0361],
         [0.8659, 0.9893, 0.1339,  ..., 0.8401, 0.3677, 0.4428],
         [0.4787, 0.5477, 0.8247,  ..., 0.9343, 0.7346, 0.1771]],

        [[0.3019, 0.8349, 0.9631,  ..., 0.8688, 0.6517, 0.3198],
         [0.9515, 0.3885, 0.1247,  ..., 0.7233, 0.2207, 0.6691],
         [0.6933, 0.1910, 0.7373,  ..., 0.3218, 0.2361, 0.0115],
         ...,
         [0.6495, 0.9091, 0.0282,  ..., 0.5070, 0.8919, 0.5179],
         [0.0026, 0.2610, 0.5566,  ..., 0.6671, 0.1637, 0.9657],
         [0.1130, 0.6909, 0.4184,  ..., 0.5202, 0.2963, 0.6070]]])

In [8]:
out = clf(x)
out.shape, out

(torch.Size([2, 30, 4000]),
 tensor([[[0.0003, 0.0003, 0.0002,  ..., 0.0003, 0.0002, 0.0002],
          [0.0002, 0.0002, 0.0003,  ..., 0.0002, 0.0002, 0.0002],
          [0.0002, 0.0003, 0.0002,  ..., 0.0003, 0.0002, 0.0002],
          ...,
          [0.0002, 0.0002, 0.0003,  ..., 0.0002, 0.0002, 0.0002],
          [0.0003, 0.0003, 0.0002,  ..., 0.0003, 0.0002, 0.0002],
          [0.0002, 0.0003, 0.0003,  ..., 0.0002, 0.0003, 0.0002]],
 
         [[0.0002, 0.0003, 0.0002,  ..., 0.0002, 0.0002, 0.0002],
          [0.0003, 0.0002, 0.0003,  ..., 0.0002, 0.0002, 0.0002],
          [0.0003, 0.0003, 0.0003,  ..., 0.0002, 0.0002, 0.0002],
          ...,
          [0.0002, 0.0002, 0.0003,  ..., 0.0002, 0.0002, 0.0002],
          [0.0002, 0.0003, 0.0003,  ..., 0.0003, 0.0002, 0.0002],
          [0.0002, 0.0003, 0.0002,  ..., 0.0002, 0.0002, 0.0002]]],
        grad_fn=<SoftmaxBackward0>))