In [1]:
import torch
import torch.nn as nn

torch.manual_seed(1)
rnn_layer = nn.RNN(input_size=5, hidden_size=2, num_layers=1, batch_first=True)
W_xh = rnn_layer.weight_ih_l0
print(f'W_xh {W_xh}\n')
b_xh = rnn_layer.bias_ih_l0
print(f'b_xh {b_xh}\n')

W_hh = rnn_layer.weight_hh_l0
print(f'W_hh {W_hh}\n')
b_hh = rnn_layer.bias_hh_l0
print(f'b_hh {b_hh}\n')

W_xh Parameter containing:
tensor([[ 0.3643, -0.3121, -0.1371,  0.3319, -0.6657],
        [ 0.4241, -0.1455,  0.3597,  0.0983, -0.0866]], requires_grad=True)

b_xh Parameter containing:
tensor([-0.0516, -0.0637], requires_grad=True)

W_hh Parameter containing:
tensor([[ 0.1961,  0.0349],
        [ 0.2583, -0.2756]], requires_grad=True)

b_hh Parameter containing:
tensor([ 0.1025, -0.0028], requires_grad=True)



In [35]:
x_seq = torch.tensor([[1.0] * 5, [2.0] * 5, [3.0] * 5]).float()
print(f'x_seq:\n{x_seq}\n')

out_seq, hidden_states = rnn_layer(x_seq)

seq_length = len(x_seq)  # T
output = []
for t in range(seq_length):
    x_t = x_seq[t]
    print(f'Time step {t} =>')
    print(f'    Input           : {x_t.numpy()}')
    
    h_t = W_xh.matmul(x_t) + b_xh
    print(f'    Hidden          : {h_t.detach().numpy()}')

    h_prev = output[t - 1] if t > 0 else torch.zeros_like(h_t)
    h_t += W_hh.matmul(h_prev) + b_hh
    o_t = torch.tanh(h_t)
    output.append(o_t)
    print(f'    Output          : {o_t.detach().numpy()}')
    print(f'    RNN output      : {out_seq[t].detach().numpy()}')

x_seq:
tensor([[1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3.]])

Time step 0 =>
    Input           : [1. 1. 1. 1. 1.]
    Hidden          : [-0.4701929  0.5863904]
    Output          : [-0.3519801   0.52525216]
    RNN output      : [-0.3519801   0.52525216]
Time step 1 =>
    Input           : [2. 2. 2. 2. 2.]
    Hidden          : [-0.88883156  1.2364397 ]
    Output          : [-0.68424344  0.76074266]
    RNN output      : [-0.68424344  0.76074266]
Time step 2 =>
    Input           : [3. 3. 3. 3. 3.]
    Hidden          : [-1.3074701  1.886489 ]
    Output          : [-0.8649416   0.90466356]
    RNN output      : [-0.8649416   0.90466356]


In [33]:
hidden_states

tensor([[-0.8649,  0.9047]], grad_fn=<SqueezeBackward1>)

In [34]:
out_seq

tensor([[-0.3520,  0.5253],
        [-0.6842,  0.7607],
        [-0.8649,  0.9047]], grad_fn=<SqueezeBackward1>)

In [7]:
mat = torch.rand(15, 8)
print(f'mat:\n{mat}')
arr = torch.rand(8)
print(f'arr:\n{arr}')
row = arr.unsqueeze(0)
print(f'row:\n{row}')
col = arr.unsqueeze(1)
print(f'col:\n{col}')

mat:
tensor([[0.9967, 0.5107, 0.8953, 0.6219, 0.9482, 0.7885, 0.2451, 0.8512],
        [0.5032, 0.0986, 0.3952, 0.0487, 0.2409, 0.1458, 0.8420, 0.3950],
        [0.4136, 0.7396, 0.7035, 0.1700, 0.6843, 0.7124, 0.1024, 0.9347],
        [0.1958, 0.3045, 0.0882, 0.9570, 0.5952, 0.9210, 0.5327, 0.7718],
        [0.3404, 0.1011, 0.5803, 0.4554, 0.4495, 0.2544, 0.5104, 0.6436],
        [0.6804, 0.0972, 0.7416, 0.5514, 0.8449, 0.8534, 0.1062, 0.9802],
        [0.0083, 0.7874, 0.5352, 0.7366, 0.2296, 0.8006, 0.2526, 0.0581],
        [0.6675, 0.7737, 0.9956, 0.4477, 0.9810, 0.8212, 0.2520, 0.1143],
        [0.7804, 0.2261, 0.7293, 0.0718, 0.8648, 0.6499, 0.4425, 0.7293],
        [0.1440, 0.3907, 0.5049, 0.2111, 0.3722, 0.2844, 0.9765, 0.9248],
        [0.9312, 0.9087, 0.0211, 0.6673, 0.6843, 0.1333, 0.7027, 0.2287],
        [0.5795, 0.6656, 0.9314, 0.2377, 0.1066, 0.9626, 0.7099, 0.7430],
        [0.6800, 0.3230, 0.8008, 0.0635, 0.5696, 0.7260, 0.8416, 0.2503],
        [0.1621, 0.6874, 0.1029, 

In [14]:
c = mat.mm(col)
print(f'c:\n{c}')
r = row.mm(mat.t())
print(f'r:\n{r}')

c:
tensor([[1.7639],
        [0.6409],
        [1.1158],
        [1.0865],
        [0.7411],
        [1.5113],
        [0.7082],
        [1.4187],
        [1.3678],
        [0.5774],
        [1.1225],
        [1.2687],
        [1.2112],
        [0.3440],
        [0.8876]])
r:
tensor([[1.7639, 0.6409, 1.1158, 1.0865, 0.7411, 1.5113, 0.7082, 1.4187, 1.3678,
         0.5774, 1.1225, 1.2687, 1.2112, 0.3440, 0.8876]])


In [11]:
mat.matmul(arr)

tensor([1.7639, 0.6409, 1.1158, 1.0865, 0.7411, 1.5113, 0.7082, 1.4187, 1.3678,
        0.5774, 1.1225, 1.2687, 1.2112, 0.3440, 0.8876])

In [12]:
mat.matmul(row)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (15x8 and 1x8)

In [26]:
in_dim, out_dim = 32, 512
token = torch.rand(in_dim)
trans = torch.rand(out_dim, in_dim)
r1 = trans.matmul(token)
r1.shape

torch.Size([512])

In [30]:
batch = 3333
X = torch.rand(batch, in_dim)
y = X.matmul(trans.t())
y.shape

torch.Size([3333, 512])

In [31]:
linear = nn.Linear(in_dim, out_dim)
y1 = linear(X)
y1.shape

torch.Size([3333, 512])

In [34]:
linear.weight.shape

torch.Size([512, 32])

In [35]:
trans.shape

torch.Size([512, 32])