In [11]:
import torch
import math
import time
from cacheLSTM.cacheLSTM import cacheLSTM
from refLSTM import refLSTM

assert torch.cuda.is_available()
cuda_device = torch.device("cuda")

# init dimensions
B = 32 # batch-size
Di = 64 # input-features
Dh = 32 # hidden-state size
T = 1000 # input sequence len

# init inputs
X = torch.randn(T, B, Di, device=cuda_device)
h = torch.randn(B, Dh, device=cuda_device)
C = torch.randn(B, Dh, device=cuda_device)

# init weights
Wii = torch.randn(Dh, Di, device=cuda_device)
Wif = torch.randn(Dh, Di, device=cuda_device)
Wig = torch.randn(Dh, Di, device=cuda_device)
Wio = torch.randn(Dh, Di, device=cuda_device)

Whi = torch.randn(Dh, Dh, device=cuda_device)
Whf = torch.randn(Dh, Dh, device=cuda_device)
Whg = torch.randn(Dh, Dh, device=cuda_device)
Who = torch.randn(Dh, Dh, device=cuda_device)

# test each algorithm multiple iters
test_iters = 100

In [18]:
# ========================================
# test cache cuda lstm on gpu

clstm = cacheLSTM(Di, Dh)
clstm_uu = torch.cat([Wii.clone(), Wif.clone(), Wio.clone(), Wig.clone()], 0).transpose(0,1).contiguous()
clstm_ww = torch.cat([Whi.clone(), Whf.clone(), Who.clone(), Whg.clone()], 0).contiguous()

clstm.setUW(clstm_uu,clstm_ww)
clstm = clstm.to(cuda_device)

clstm_x = X.clone()
clstm_h = h.clone()
clstm_c = C.clone()
torch.cuda.synchronize()

# forward
clstm_forward = 0
for _ in range(test_iters):
    # prepare input
    clstm_x = X.clone()
    clstm_h = h.clone()
    clstm_c = C.clone()
    
    # one forward pass
    start = time.time()
    clstm_out_h, clstm_out_c = clstm(clstm_x, clstm_h, clstm_c)    
    torch.cuda.synchronize()
    clstm_forward += time.time() - start

clstm_forward_us = clstm_forward * 1e6 / T / test_iters
print('c-lstm cuda on gpu: Forward: {:.3f} us '.format(clstm_forward_us))

# get rid of the first in hout which is h0
clstm_out_x = clstm_out_h[1:,:,:] 

c-lstm cuda on gpu: Forward: 2.030 us 


In [20]:
# ========================================
# test torch lstm on gpu

lstm = torch.nn.LSTM(Di, Dh, bias=False)
lstm_uu = torch.cat([Wii.clone(), Wif.clone(), Wig.clone(), Wio.clone()], 0).contiguous()
lstm_ww = torch.cat([Whi.clone(), Whf.clone(), Whg.clone(), Who.clone()], 0).contiguous()
lstm.weight_ih_l0 = torch.nn.Parameter(lstm_uu)
lstm.weight_hh_l0 = torch.nn.Parameter(lstm_ww)
lstm = lstm.to(cuda_device)

h0 = h.unsqueeze(0)
c0 = C.unsqueeze(0)

# forward
lstm_gpu_forward = 0
for _ in range(test_iters):
    # prepare input
    lstm_x = X.clone()
    lstm_h = h0.clone()
    lstm_c = c0.clone()
    
    # forward pass
    start = time.time()
    lstm_out_x, (lstm_out_h, lstm_out_c) = lstm(lstm_x, (lstm_h, lstm_c)) 
    torch.cuda.synchronize()
    lstm_gpu_forward += time.time() - start

lstm_gpu_forward_us = lstm_gpu_forward * 1e6/T/test_iters
print('pytorch cuda-optimized lstm on gpu: Forward: {:.3f} us '.format(lstm_gpu_forward_us))


pytorch cuda-optimized lstm on gpu: Forward: 7.085 us 


In [21]:
# ========================================
# test naive python lstm on gpu

ref_lstm = refLSTM(B, Di, Dh, [Wii.clone(), Wif.clone(), Wig.clone(), Wio.clone(), Whi.clone(), Whf.clone(), Whg.clone(), Who.clone()])

ref_lstm_gpu_forward = 0
for _ in range(test_iters):
    ref_X = X.clone()
    ref_h0 = h.clone()
    ref_c0 = C.clone()
    
    start = time.time()
    reflstm_out_x, ref_cout = ref_lstm.forward(ref_X, ref_h0, ref_c0) 
    torch.cuda.synchronize()
    ref_lstm_gpu_forward += time.time() - start

ref_lstm_gpu_forward_us = ref_lstm_gpu_forward * 1e6/T/test_iters
print('simple ref lstm on gpu: Forward: {:.3f} us '.format(ref_lstm_gpu_forward_us))

simple ref lstm on gpu: Forward: 199.717 us 


In [22]:
# check correctness
print(torch.allclose(clstm_out_x, lstm_out_x, rtol=0, atol=1e-5))
# print all results
print('simple ref lstm on gpu: Forward: {:.3f} us '.format(ref_lstm_gpu_forward_us))
print('pytorch cuda lstm on gpu: Forward: {:.3f} us '.format(lstm_gpu_forward_us))
print('cache lstm cuda on gpu: Forward: {:.3f} us '.format(clstm_forward_us))

True
simple ref lstm on gpu: Forward: 199.717 us 
pytorch cuda lstm on gpu: Forward: 7.085 us 
cache lstm cuda on gpu: Forward: 2.030 us 
