In [75]:
import torch
import numpy as np
import time
from torch import nn

### RevNet 연산 검증

In [35]:
x1 = torch.rand((3, 3))
x2 = torch.rand((3, 3))

In [36]:
x1, x2

(tensor([[0.7386, 0.9690, 0.3168],
         [0.1200, 0.2633, 0.3560],
         [0.7594, 0.8729, 0.0847]]),
 tensor([[0.2724, 0.2673, 0.1191],
         [0.3291, 0.4657, 0.3513],
         [0.1449, 0.0705, 0.1540]]))

In [37]:
# 편의상 f와 g를 모두 Linear 함수로 정의함.
f = nn.Linear(3, 3)
g = nn.Linear(3, 3)

In [38]:
y1 = x1 + f(x2)
y2 = x2 + g(y1)

In [39]:
y1, y2

(tensor([[ 0.3169,  0.6543, -0.3394],
         [-0.3314, -0.2503, -0.4693],
         [ 0.2402,  0.6344, -0.4916]], grad_fn=<AddBackward0>),
 tensor([[0.5671, 0.4740, 0.3522],
         [0.8245, 0.9934, 0.7682],
         [0.4194, 0.3748, 0.4075]], grad_fn=<AddBackward0>))

In [40]:
r2 = y2 - g(y1)
r1 = y1 - f(r2)

In [41]:
r1, r2

(tensor([[0.7386, 0.9690, 0.3168],
         [0.1200, 0.2633, 0.3560],
         [0.7594, 0.8729, 0.0847]], grad_fn=<SubBackward0>),
 tensor([[0.2724, 0.2673, 0.1191],
         [0.3291, 0.4657, 0.3513],
         [0.1449, 0.0705, 0.1540]], grad_fn=<SubBackward0>))

### chunking

In [9]:
x1 = torch.rand((3, 100))
x2 = torch.rand((3, 100))

In [10]:
f = nn.Linear(100, 100)
g = nn.Linear(100, 100)

In [11]:
y1 = x1 + f(x2)
y2 = x2 + g(y1)

In [12]:
y1

tensor([[ 0.5320, -0.2551,  0.2358,  0.6311, -0.0697,  0.9891, -0.0793,  0.3213,
          0.8193, -0.1551,  0.4677, -0.0682,  0.4574,  0.7949,  0.2727,  0.8404,
          1.3370,  0.6841, -0.1828,  0.8333,  0.8891, -0.2493,  0.2731,  0.9432,
          0.5357,  0.1138, -0.0174, -0.1830,  0.2823,  0.1392,  0.9663,  0.4649,
          1.0841,  0.7859, -0.2072,  0.6600, -0.4842, -0.0533, -0.5340,  0.0516,
          1.3811,  0.1737,  0.7980,  0.2084,  1.4415,  0.2268, -0.2614,  0.4700,
          0.7914,  0.0059, -0.0081,  0.6814,  0.4616,  0.0745,  0.3274,  0.3978,
          0.6806,  0.3912,  0.8550,  0.0399,  0.1423,  0.6980,  0.0220,  0.6722,
          0.6643,  0.5073,  0.6355,  0.5778,  0.9340,  0.4726,  0.6476,  0.5052,
          1.0267,  0.2037,  0.3170,  0.9718,  0.8004,  0.3965,  0.1203,  0.2867,
         -0.0635,  0.2141,  0.2622,  0.8858, -0.1250,  0.4489,  1.0582,  0.1382,
          0.5098,  0.0148,  0.5242,  0.5343,  0.4896,  0.1776,  0.4125, -0.0335,
          1.3366, -0.3500,  

In [13]:
# y2 = [y2_1, y2_2, ..., y2_c]

In [14]:
y1_c = x1[0] + f(x2[0])
y2_c = x1[1] + f(x2[1])
y3_c = x1[2] + f(x2[2])

In [15]:
torch.stack([y1_c, y2_c, y3_c])

tensor([[ 0.5320, -0.2551,  0.2358,  0.6311, -0.0697,  0.9891, -0.0793,  0.3213,
          0.8193, -0.1551,  0.4677, -0.0682,  0.4574,  0.7949,  0.2727,  0.8404,
          1.3370,  0.6841, -0.1828,  0.8333,  0.8891, -0.2493,  0.2731,  0.9432,
          0.5357,  0.1138, -0.0174, -0.1830,  0.2823,  0.1392,  0.9663,  0.4649,
          1.0841,  0.7859, -0.2072,  0.6600, -0.4842, -0.0533, -0.5340,  0.0516,
          1.3811,  0.1737,  0.7980,  0.2084,  1.4415,  0.2268, -0.2614,  0.4700,
          0.7914,  0.0059, -0.0081,  0.6814,  0.4616,  0.0745,  0.3274,  0.3978,
          0.6806,  0.3912,  0.8550,  0.0399,  0.1423,  0.6980,  0.0220,  0.6722,
          0.6643,  0.5073,  0.6355,  0.5778,  0.9340,  0.4726,  0.6476,  0.5052,
          1.0267,  0.2037,  0.3170,  0.9718,  0.8004,  0.3965,  0.1203,  0.2867,
         -0.0635,  0.2141,  0.2622,  0.8858, -0.1250,  0.4489,  1.0582,  0.1382,
          0.5098,  0.0148,  0.5242,  0.5343,  0.4896,  0.1776,  0.4125, -0.0335,
          1.3366, -0.3500,  

### LSH 어탠션 성능

In [48]:
from transformers import ReformerModelWithLMHead
from transformers import ReformerModel

In [49]:
model = ReformerModel.from_pretrained("google/reformer-enwik8")

In [72]:
def make_random_inputs(batch_size, sequence_length):
    x = np.random.randint(0, 258, (batch_size, sequence_length))
    inputs = torch.from_numpy(x)
    return inputs

In [82]:
batch_size = 32
sequence_length = 64

In [83]:
for _ in range(6):
    inputs = make_random_inputs(batch_size, sequence_length)
    
    start = time.time()
    o = model(inputs)
    end = time.time()
    
    print(f'{end-start:.2f} seconds for input size of({batch_size},{sequence_length})')
    
    batch_size = batch_size // 2
    sequence_length = sequence_length * 2

1.04 seconds for input size of(32,64)
0.85 seconds for input size of(16,128)
0.96 seconds for input size of(8,256)
1.71 seconds for input size of(4,512)
1.83 seconds for input size of(2,1024)
1.85 seconds for input size of(1,2048)
