In [1]:
import torch
import math
import time
from cacheLSTM.cacheLSTM import cacheLSTM

assert torch.cuda.is_available()
cuda_device = torch.device("cuda")


B = 16 # batch-size
Di = 64 # input-features
Dh = 32 # hidden-state size
T = 100 # input sequence len

X = torch.randn(T, B, Di, device=cuda_device)
h = torch.randn(B, Dh, device=cuda_device)
C = torch.randn(B, Dh, device=cuda_device)

Wii = torch.randn(Dh, Di, device=cuda_device)
Wif = torch.randn(Dh, Di, device=cuda_device)
Wig = torch.randn(Dh, Di, device=cuda_device)
Wio = torch.randn(Dh, Di, device=cuda_device)

Whi = torch.randn(Dh, Dh, device=cuda_device)
Whf = torch.randn(Dh, Dh, device=cuda_device)
Whg = torch.randn(Dh, Dh, device=cuda_device)
Who = torch.randn(Dh, Dh, device=cuda_device)



In [2]:
# ========================================
# test cache cuda lstm on gpu

clstm = cacheLSTM(Di, Dh)
clstm_uu = torch.cat([Wii.clone(), Wif.clone(), Wio.clone(), Wig.clone()], 0).transpose(0,1).contiguous()
clstm_ww = torch.cat([Whi.clone(), Whf.clone(), Who.clone(), Whg.clone()], 0).contiguous()
print(clstm_uu.shape, clstm_ww.shape)

clstm.setUW(clstm_uu,clstm_ww)
clstm = clstm.to(cuda_device)

clstm_x = X.clone()
clstm_h = h.clone()
clstm_c = C.clone()
torch.cuda.synchronize()

# forward
start = time.time()
clstm_out_h, clstm_out_c = clstm(clstm_x, clstm_h, clstm_c)    
torch.cuda.synchronize()
clstm_forward = time.time() - start

clstm_out_h0 = clstm_out_h[0, :, :]
clstm_out_x = clstm_out_h[1:,:,:]
print(clstm_out_x.shape)
print('clstm cuda on gpu: Forward: {:.3f} us '.format(clstm_forward * 1e6/T))


torch.Size([64, 128]) torch.Size([128, 32])
torch.Size([100, 16, 128])
torch.Size([100, 16, 32])
clstm cuda on gpu: Forward: 20.590 us 


In [8]:
# ========================================
# test torch lstm on gpu

lstm = torch.nn.LSTM(Di, Dh, bias=False)
lstm_uu = torch.cat([Wii.clone(), Wif.clone(), Wig.clone(), Wio.clone()], 0).contiguous()
lstm_ww = torch.cat([Whi.clone(), Whf.clone(), Whg.clone(), Who.clone()], 0).contiguous()
lstm.weight_ih_l0 = torch.nn.Parameter(lstm_uu)
lstm.weight_hh_l0 = torch.nn.Parameter(lstm_ww)
lstm = lstm.to(cuda_device)

lstm_x = X.clone()
lstm_h = h.unsqueeze(0).clone()
lstm_c = C.unsqueeze(0).clone()
torch.cuda.synchronize()

# forward
start = time.time()
lstm_out_x, (lstm_out_h, lstm_out_c) = lstm(lstm_x, (lstm_h, lstm_c))    
print(lstm_out_x.shape)
lstm_gpu_forward = time.time() - start

print('pytorch lstm on gpu: Forward: {:.3f} us '.format(lstm_gpu_forward * 1e6/T))


torch.Size([100, 16, 32])
pytorch lstm on gpu: Forward: 14.484 us 


In [4]:
# print all results
print('cache lstm cuda on gpu: Forward: {:.3f} us '.format(clstm_forward * 1e6/T))
print('pytorch lstm on gpu: Forward: {:.3f} us '.format(lstm_gpu_forward * 1e6/T))

cache lstm cuda on gpu: Forward: 20.590 us 
pytorch lstm on gpu: Forward: 20.566 us 


In [40]:
torch.allclose(clstm_out_x, lstm_out_x, rtol=0, atol=10)
lstm_out_x[9,3,:], clstm_out_x[9,3,:]

(tensor([ 6.7095e-01, -5.2552e-02,  2.3361e-03, -2.3130e-01,  5.2138e-04,
         -7.0899e-01, -8.7565e-07,  3.7662e-02, -1.2957e-07,  7.4257e-05,
         -3.1100e-04,  2.3549e-04, -8.4916e-03, -7.6990e-02, -8.9720e-01,
         -5.6277e-02,  4.7381e-01, -1.0210e-04, -1.1794e-02,  4.7863e-02,
          1.4177e-04,  3.4258e-06,  5.7679e-01, -9.0748e-04,  3.2574e-05,
          5.6050e-01, -7.8106e-01, -5.3330e-01,  3.0315e-02, -5.8880e-07,
          1.8275e-01, -5.8799e-01], device='cuda:0', grad_fn=<SliceBackward>),
 tensor([ 6.7357e-01, -5.3687e-02,  1.8194e-03, -2.3104e-01,  7.8410e-04,
         -7.0561e-01, -7.2859e-07,  4.1269e-02, -1.2753e-07,  6.5641e-05,
         -3.5516e-04,  2.1308e-04, -8.9914e-03, -8.3167e-02, -8.5797e-01,
         -5.5491e-02,  4.8048e-01, -1.0232e-04, -1.2082e-02,  4.6014e-02,
          1.3134e-04,  3.0148e-06,  5.7085e-01, -6.4880e-04,  2.5527e-05,
          5.7099e-01, -7.8045e-01, -5.1147e-01,  3.2279e-02, -5.7538e-07,
          1.0764e-01, -5.8434e-01

In [36]:
lstm_out_x[10,1,:] - clstm_out_x[10,1,:]

tensor([ 4.6566e-10,  5.9605e-08,  0.0000e+00,  0.0000e+00,  7.2760e-12,
         1.8626e-09, -5.9605e-08, -5.9605e-08, -1.7881e-07,  6.4393e-15,
        -2.9104e-10,  5.9605e-08,  0.0000e+00, -2.3283e-09, -2.7285e-11,
         0.0000e+00, -4.5475e-12,  5.9605e-08,  2.2352e-08,  1.7764e-15,
        -7.2760e-11, -1.1642e-10,  1.6007e-10, -1.3411e-07,  9.7145e-17,
         2.5466e-11,  4.5475e-13, -2.9559e-12, -1.1921e-07,  2.9843e-13,
         1.0800e-12,  2.3283e-10], device='cuda:0', grad_fn=<SubBackward0>)

In [23]:
# manually reproducing one time step using the same input and parameters
# see if numbers match
one_x = lstm_x[0,0,:].squeeze() # (Di)
one_h = lstm_h[0,0,:].squeeze() # (Dh)
one_c = lstm_c[0,0,:].squeeze() # (Dh)

sig = torch.nn.Sigmoid()
tanh = torch.nn.Tanh()
it = sig(torch.matmul(one_x, Wii.transpose(0,1)) + torch.matmul(one_h, Whi.transpose(0,1)))
ft = sig(torch.matmul(one_x, Wif.transpose(0,1)) + torch.matmul(one_h, Whf.transpose(0,1)))
gt = tanh(torch.matmul(one_x, Wig.transpose(0,1)) + torch.matmul(one_h, Whg.transpose(0,1)))
ot = sig(torch.matmul(one_x, Wio.transpose(0,1)) + torch.matmul(one_h, Who.transpose(0,1)))

new_c = ft * one_c + it * gt
new_h = tanh(new_c) * ot

new_h

tensor([ 4.3733e-10,  1.2713e-02,  2.1779e-04,  3.4586e-07,  2.4343e-05,
        -7.1491e-01, -9.4858e-01, -3.4728e-01, -9.1524e-04, -1.0786e-02,
         7.4157e-01,  9.7296e-01,  3.7717e-05,  2.4591e-01, -7.4742e-01,
        -6.9213e-01,  4.2613e-04, -5.1385e-07,  2.2045e-02,  7.0434e-01,
        -9.9542e-01, -2.5549e-01,  8.5770e-04, -9.1106e-01, -4.0740e-05,
         8.0394e-03, -7.5224e-01,  9.4008e-05, -2.2697e-01,  9.1121e-02,
         6.5521e-02, -7.7733e-06], device='cuda:0')