In [1]:
import torch
import numpy as np
import time
from torch import nn

### RevNet 연산 검증

In [2]:
x1 = torch.rand((3, 3))
x2 = torch.rand((3, 3))

In [3]:
x1, x2

(tensor([[0.9502, 0.6581, 0.6349],
         [0.4874, 0.4134, 0.5632],
         [0.1164, 0.2004, 0.7470]]),
 tensor([[0.6721, 0.9979, 0.9987],
         [0.6816, 0.1874, 0.1850],
         [0.4833, 0.5918, 0.3176]]))

In [4]:
# 편의상 f와 g를 모두 Linear 함수로 정의함.
f = nn.Linear(3, 3)
g = nn.Linear(3, 3)

In [5]:
y1 = x1 + f(x2)
y2 = x2 + g(y1)

In [6]:
y1, y2

(tensor([[ 2.0585,  0.2100,  0.6484],
         [ 0.8482, -0.0078,  0.2465],
         [ 0.6933, -0.2690,  0.5804]], grad_fn=<AddBackward0>),
 tensor([[-0.8666,  0.1399,  1.8037],
         [-0.1278, -0.0986,  0.8124],
         [-0.1854,  0.3420,  1.0780]], grad_fn=<AddBackward0>))

In [7]:
r2 = y2 - g(y1)
r1 = y1 - f(r2)

In [8]:
r1, r2

(tensor([[0.9502, 0.6581, 0.6349],
         [0.4874, 0.4134, 0.5632],
         [0.1164, 0.2004, 0.7470]], grad_fn=<SubBackward0>),
 tensor([[0.6721, 0.9979, 0.9987],
         [0.6816, 0.1874, 0.1850],
         [0.4833, 0.5918, 0.3176]], grad_fn=<SubBackward0>))

### chunking

In [9]:
x1 = torch.rand((3, 100))
x2 = torch.rand((3, 100))

In [10]:
f = nn.Linear(100, 100)
g = nn.Linear(100, 100)

In [11]:
y1 = x1 + f(x2)
y2 = x2 + g(y1)

In [12]:
y1

tensor([[ 0.8403,  0.5778,  0.0716,  1.3934,  0.5052,  0.6488,  0.3716,  1.0432,
         -0.2125,  0.3223,  0.0048,  0.3135,  0.9538,  0.8498,  0.1229,  0.2124,
          0.3272,  0.5906,  0.4368,  0.9378,  0.5013,  0.2088,  0.5233,  0.4851,
          0.1878,  0.6155, -0.2481,  0.0555,  0.0798,  1.3488,  0.4776,  0.7540,
          0.6649,  1.0658,  0.1558, -0.1444,  0.1998,  0.0879,  0.8397,  1.1685,
          0.7320,  1.4036,  0.3449,  0.3430,  0.7232,  0.3845,  0.8835,  0.6652,
          1.2258,  0.4545,  0.8888,  0.8005, -0.0577, -0.4609,  0.8082,  0.8837,
          0.1098, -0.2864,  1.1300,  0.7461,  0.8941, -0.1740,  0.5886,  0.9263,
          1.0437,  0.9485, -0.3407,  0.5569, -0.1521,  0.3088,  0.6934,  0.6455,
          0.5970,  0.8061,  0.3682,  0.6992, -0.1592,  0.2782,  1.0481, -0.0694,
          0.0120, -0.0049,  0.5984,  0.9994,  0.4521,  0.4095,  0.5784,  0.1074,
          0.6448,  0.9382, -0.0088, -0.1270,  0.3325,  0.4197,  0.8027, -0.1187,
          1.2479,  0.6108,  

In [13]:
# y2 = [y2_1, y2_2, ..., y2_c]

In [14]:
y1_c = x1[0] + f(x2[0])
y2_c = x1[1] + f(x2[1])
y3_c = x1[2] + f(x2[2])

In [15]:
torch.stack([y1_c, y2_c, y3_c])

tensor([[ 0.8403,  0.5778,  0.0716,  1.3934,  0.5052,  0.6488,  0.3716,  1.0432,
         -0.2125,  0.3223,  0.0048,  0.3135,  0.9538,  0.8498,  0.1229,  0.2124,
          0.3272,  0.5906,  0.4368,  0.9378,  0.5013,  0.2088,  0.5233,  0.4851,
          0.1878,  0.6155, -0.2481,  0.0555,  0.0798,  1.3488,  0.4776,  0.7540,
          0.6649,  1.0658,  0.1558, -0.1444,  0.1998,  0.0879,  0.8397,  1.1685,
          0.7320,  1.4036,  0.3449,  0.3430,  0.7232,  0.3845,  0.8835,  0.6652,
          1.2258,  0.4545,  0.8888,  0.8005, -0.0577, -0.4609,  0.8082,  0.8837,
          0.1098, -0.2864,  1.1300,  0.7461,  0.8941, -0.1740,  0.5886,  0.9263,
          1.0437,  0.9485, -0.3407,  0.5569, -0.1521,  0.3088,  0.6934,  0.6455,
          0.5970,  0.8061,  0.3682,  0.6992, -0.1592,  0.2782,  1.0481, -0.0694,
          0.0120, -0.0049,  0.5984,  0.9994,  0.4521,  0.4095,  0.5784,  0.1074,
          0.6448,  0.9382, -0.0088, -0.1270,  0.3325,  0.4197,  0.8027, -0.1187,
          1.2479,  0.6108,  

### LSH 어탠션 성능

In [16]:
from transformers import ReformerModelWithLMHead
from transformers import ReformerModel

In [17]:
model = ReformerModel.from_pretrained("google/reformer-enwik8")

In [18]:
def make_random_inputs(batch_size, sequence_length):
    x = np.random.randint(0, 258, (batch_size, sequence_length))
    inputs = torch.from_numpy(x)
    return inputs

In [19]:
batch_size = 32
sequence_length = 64

In [20]:
for _ in range(6):
    inputs = make_random_inputs(batch_size, sequence_length)
    
    start = time.time()
    o = model(inputs)
    end = time.time()
    
    print(f'{end-start:.2f} seconds for input size of({batch_size},{sequence_length})')
    
    batch_size = batch_size // 2
    sequence_length = sequence_length * 2

0.82 seconds for input size of(32,64)
0.86 seconds for input size of(16,128)
0.97 seconds for input size of(8,256)
1.56 seconds for input size of(4,512)
1.53 seconds for input size of(2,1024)
1.52 seconds for input size of(1,2048)
