In [1]:
import torch
import torch.nn as nn

In [2]:
class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, hid_dim, pf_dim, dropout):
        super().__init__()
        
        self.fc_1 = nn.Linear(hid_dim, pf_dim)
        self.fc_2 = nn.Linear(pf_dim, hid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        
        #x = [batch size, seq len, hid dim]
        
        x = self.dropout(torch.relu(self.fc_1(x)))
        
        #x = [batch size, seq len, pf dim]
        
        x = self.fc_2(x)
        
        #x = [batch size, seq len, hid dim]
        
        return x

In [3]:
batch_size, seq_len, hid_dim, pf_dim = 2, 3, 4, 5
x = torch.arange(24).view((batch_size, seq_len, hid_dim)).to(torch.float)

In [4]:
x

tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])

In [5]:
x.shape

torch.Size([2, 3, 4])

In [6]:
fc_1 = nn.Linear(hid_dim, pf_dim, bias = False)
weights1 = torch.ones(4,5).float()
weights0 = torch.zeros(4,5).float()
weights_same_column = torch.tensor([1,2,3,4,5]).repeat(4,1).float()
weights_same_row= torch.tensor([1,2,3,4]).view(4,1).repeat(1,4).float()
torch.manual_seed(0)
weights_rand = torch.randn(20).view(4,5)
fc_1.weight.data = weights_rand.T

In [7]:
fc_1.weight.data.shape

torch.Size([5, 4])

In [8]:
weights1

tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])

In [9]:
weights_same_column 

tensor([[1., 2., 3., 4., 5.],
        [1., 2., 3., 4., 5.],
        [1., 2., 3., 4., 5.],
        [1., 2., 3., 4., 5.]])

In [10]:
weights_same_row

tensor([[1., 1., 1., 1.],
        [2., 2., 2., 2.],
        [3., 3., 3., 3.],
        [4., 4., 4., 4.]])

In [11]:
torch.matmul(x, weights1)

tensor([[[ 6.,  6.,  6.,  6.,  6.],
         [22., 22., 22., 22., 22.],
         [38., 38., 38., 38., 38.]],

        [[54., 54., 54., 54., 54.],
         [70., 70., 70., 70., 70.],
         [86., 86., 86., 86., 86.]]])

In [12]:
torch.matmul(x, weights_same_column)

tensor([[[  6.,  12.,  18.,  24.,  30.],
         [ 22.,  44.,  66.,  88., 110.],
         [ 38.,  76., 114., 152., 190.]],

        [[ 54., 108., 162., 216., 270.],
         [ 70., 140., 210., 280., 350.],
         [ 86., 172., 258., 344., 430.]]])

In [13]:
torch.matmul(x, weights_same_row)

tensor([[[ 20.,  20.,  20.,  20.],
         [ 60.,  60.,  60.,  60.],
         [100., 100., 100., 100.]],

        [[140., 140., 140., 140.],
         [180., 180., 180., 180.],
         [220., 220., 220., 220.]]])

In [14]:
torch.matmul(x, weights0)

tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])

In [15]:
torch.matmul(x, weights_rand)

tensor([[[ -1.1988,   2.9862,   8.4088,   9.6557,   1.8961],
         [ -9.5226,   1.8028,  25.4120,  24.1578,   7.6607],
         [-17.8464,   0.6195,  42.4152,  38.6600,  13.4254]],

        [[-26.1702,  -0.5639,  59.4184,  53.1621,  19.1900],
         [-34.4940,  -1.7473,  76.4216,  67.6642,  24.9547],
         [-42.8178,  -2.9307,  93.4248,  82.1664,  30.7193]]])

In [16]:
x1 = fc_1(x)

In [17]:
fc_1.weight

Parameter containing:
tensor([[-1.1258, -1.5551,  1.4437, -0.8437],
        [-1.1524, -0.3414,  0.2660,  0.9318],
        [-0.2506,  1.8530,  1.3894,  1.2590],
        [-0.4339,  0.4681,  1.5863,  2.0050],
        [ 0.5988, -0.1577,  0.9463,  0.0537]], requires_grad=True)

In [18]:
fc_1.weight.shape

torch.Size([5, 4])

In [19]:
x1.shape

torch.Size([2, 3, 5])

In [20]:
x1

tensor([[[ -1.1988,   2.9862,   8.4088,   9.6557,   1.8961],
         [ -9.5226,   1.8028,  25.4120,  24.1578,   7.6607],
         [-17.8464,   0.6195,  42.4152,  38.6600,  13.4254]],

        [[-26.1702,  -0.5639,  59.4184,  53.1621,  19.1900],
         [-34.4940,  -1.7473,  76.4216,  67.6642,  24.9547],
         [-42.8178,  -2.9307,  93.4248,  82.1664,  30.7193]]],
       grad_fn=<UnsafeViewBackward0>)

In [21]:
src_len, batch_size = 10, 2
torch.arange(0, src_len).view(1,-1).repeat(batch_size,1)

tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])