In [2]:
import torch
import torch.nn as nn

In [51]:
T = 120
batch_size = 3

# dataset of batch_size drawing, length T 
S = torch.rand(batch_size, T, 5)
print("input size:", S.shape)

input size: torch.Size([3, 120, 5])


In [92]:
H = 64
N_layers = 2
lstm_layer = nn.LSTM(input_size=5, hidden_size=H, num_layers=N_layers, batch_first=True) # dropout?

In [93]:
# layers = 2 does the following:
# lstm_A = nn.LSTM(input_size=5, hidden_size=H, num_layers=1, batch_first=True)
# lstm_B = nn.LSTM(input_size=H, hidden_size=H, num_layers=1, batch_first=True)
# lstm_B(lstm_A(S)[0])

In [94]:
output, (hidden, cell) = lstm_layer(S)

In [95]:
print("output size:", output.shape)

output size: torch.Size([3, 120, 64])


In [97]:
print("seq embedding (final state):", hidden.shape)

seq embedding (final state): torch.Size([4, 3, 64])


In [98]:
lstm_layer_bi = nn.LSTM(input_size=5, hidden_size=H, batch_first=True, num_layers=N_layers, bidirectional=True)
output, (hidden, cell) = lstm_layer_bi(S)

In [99]:
print("output size:", output.shape)

output size: torch.Size([3, 120, 128])


In [100]:
print("seq embedding ('final' state):", hidden.shape)

seq embedding ('final' state): torch.Size([4, 3, 64])


In [101]:
# from h-> and h<- to h
print(hidden.shape)
h = hidden.transpose(0, 1)
print(h.shape)
h = h.reshape(batch_size, -1)
print(h.shape)

torch.Size([4, 3, 64])
torch.Size([3, 4, 64])
torch.Size([3, 256])


In [102]:
# (note on what transpose does)
A = torch.FloatTensor([[1, 2, 3],[1,2,2]])
print(A)
print(A.T)
print(A.transpose(0, 1))

tensor([[1., 2., 3.],
        [1., 2., 2.]])
tensor([[1., 1.],
        [2., 2.],
        [3., 2.]])
tensor([[1., 1.],
        [2., 2.],
        [3., 2.]])


In [106]:
Nz = 100
mean_ = nn.Linear(2*N_layers*H, Nz)
log_std_ = nn.Linear(2*N_layers*H, Nz)

mean_

Linear(in_features=256, out_features=100, bias=True)

In [104]:
import torch.distributions as dist

# make sure to use rsample and not sample if you want to take gradients
z = dist.Normal(loc=mean_(h), scale=log_std_(h).exp()).rsample()

In [105]:
z.shape

torch.Size([3, 100])

### generation

In [110]:
# make an LSTM
# LSTM takes inputs (z, S_{t-1})

print("input", S[:,:-1].shape, z.shape) # need to concatenate these together at each timestep
print("target", S[:,1:].shape)

input torch.Size([3, 119, 5]) torch.Size([3, 100])
target torch.Size([3, 119, 5])


In [117]:
# e.g.
batch_index = 0
t = 80
print("from s (at step t) and z:", S[batch_index,t], z[batch_index].shape)
print("predict s (at t+1):", S[batch_index,t+1])

from s (at step t) and z: tensor([0.5226, 0.3804, 0.1333, 0.9342, 0.6991]) torch.Size([100])
predict s (at t+1): tensor([0.2794, 0.8601, 0.7867, 0.3570, 0.9223])


In [112]:
# need to "expand" z to be an input at every timestep
z.shape, z.unsqueeze(1).expand([batch_size, T-1, z.shape[-1]]).shape

(torch.Size([3, 100]), torch.Size([3, 119, 100]))

In [120]:
# inputs will be something like

torch.cat((S[:,:-1], z.unsqueeze(1).expand([batch_size, T-1, z.shape[-1]])), -1).shape

torch.Size([3, 119, 105])

In [121]:
# targets will be like
S[:, 1:].shape

torch.Size([3, 119, 5])

In [122]:
S.shape

torch.Size([3, 120, 5])

In [17]:
s = [torch.Tensor([0,0,1,0,0])]
#print(s)
m = s*100 # 1,100
print(m)
[100 tensors]

[tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0., 0., 1., 0., 0.]), tensor([0

In [20]:
q = torch.stack(m)
print(q.shape)
#tensor array[array*100]
r = q.unsqueeze(0)
print(r.shape)
[[[]]]

torch.Size([100, 5])
torch.Size([1, 100, 5])
