In [1]:
import torch
import time
import numpy as np
from linear_atomic import *
from utilities import *

def timing_fwd(layer, x):
    eval_times = []
    for i in range(10000):
        start = time.time()
        y=layer(x)
        stop = time.time()
        eval_times.append(stop-start)
    eval_times = np.array(eval_times)[100:]*1_000
    print(f"{np.mean(eval_times)} +/- {np.std(eval_times)} ms")

def timing_bwd(layer, x):
    criterion = torch.nn.CrossEntropyLoss()
    eval_times = []
    for i in range(10000):
        start = time.time()
        y=layer(x)
        loss = criterion(y, torch.tensor([0,0,0,0,0]))
        loss.backward()
        stop = time.time()
        eval_times.append(stop-start)
    eval_times = np.array(eval_times)[100:]*1_000
    print(f"{np.mean(eval_times)} +/- {np.std(eval_times)} ms")

def verify(m1, m2, x):
    for i in range(100):
        with torch.no_grad():
            # one order of magnitude smaller then default
            assert torch.allclose(m1(x),m2(x))

In [2]:
x = torch.stack([torch.rand((50,)),torch.rand((50,)),torch.rand((50,)),torch.rand((50,)),torch.rand((50,))])
torchlinear = torch.nn.Linear(50,5)
assign_fixed_params(torchlinear)
#tatomiclinear = AtomicLinearTorch(50,5)
atomiclinear = AtomicLinear(50,5)
assign_fixed_params(atomiclinear)
#verify(torchlinear, tatomiclinear, x)
verify(torchlinear, atomiclinear, x)

In [3]:
torchlinear.train()
timing_fwd(torchlinear, x)
timing_fwd(torchlinear, x)
timing_fwd(torchlinear, x)
timing_fwd(torchlinear, x)

0.01210056170068606 +/- 0.006539036947144636 ms
0.012019523466476287 +/- 0.00419139083020768 ms
0.012140081386373501 +/- 0.007454046388700754 ms
0.011664737354625355 +/- 0.0024179359891095906 ms


In [4]:
atomiclinear.train()
timing_fwd(atomiclinear, x)
timing_fwd(atomiclinear, x)
timing_fwd(atomiclinear, x)
timing_fwd(atomiclinear, x)

0.0506510879054214 +/- 0.008208702605867724 ms
0.052216968151054 +/- 0.009567621694343578 ms
0.05474940694943823 +/- 0.018164062326832413 ms
0.051467587249447604 +/- 0.009452837395259926 ms


In [5]:
torchlinear.train()
timing_bwd(torchlinear, x)
timing_bwd(torchlinear, x)
timing_bwd(torchlinear, x)
timing_bwd(torchlinear, x)

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


0.132609015763408 +/- 0.02686646575764254 ms
0.1340988910559452 +/- 0.019778505688054143 ms
0.13396706243958137 +/- 0.033465481558104984 ms
0.14896775736953272 +/- 0.10048259670002613 ms


In [6]:
atomiclinear.train()
timing_bwd(atomiclinear, x)
timing_bwd(atomiclinear, x)
timing_bwd(atomiclinear, x)
timing_bwd(atomiclinear, x)

0.18504224642358644 +/- 0.025913288807389006 ms
0.18506413758403123 +/- 0.029830971279655907 ms
0.18878874152597755 +/- 0.029151304595182576 ms
0.1903938524650805 +/- 0.03263770293720678 ms


In [4]:
base = atomiclinear(x)

In [5]:
for i in range(100):
    print(torch.sum(atomiclinear(x)-base))

tensor(9.2387e-07, grad_fn=<SumBackward0>)
tensor(8.1956e-08, grad_fn=<SumBackward0>)
tensor(-1.6093e-06, grad_fn=<SumBackward0>)
tensor(1.6689e-06, grad_fn=<SumBackward0>)
tensor(2.9802e-07, grad_fn=<SumBackward0>)
tensor(5.9605e-07, grad_fn=<SumBackward0>)
tensor(3.4571e-06, grad_fn=<SumBackward0>)
tensor(1.7881e-07, grad_fn=<SumBackward0>)
tensor(1.6093e-06, grad_fn=<SumBackward0>)
tensor(1.6093e-06, grad_fn=<SumBackward0>)
tensor(-2.3842e-07, grad_fn=<SumBackward0>)
tensor(4.1723e-07, grad_fn=<SumBackward0>)
tensor(2.6226e-06, grad_fn=<SumBackward0>)
tensor(-3.5763e-07, grad_fn=<SumBackward0>)
tensor(3.4273e-06, grad_fn=<SumBackward0>)
tensor(-5.3644e-07, grad_fn=<SumBackward0>)
tensor(2.5928e-06, grad_fn=<SumBackward0>)
tensor(6.5565e-07, grad_fn=<SumBackward0>)
tensor(2.1458e-06, grad_fn=<SumBackward0>)
tensor(1.5795e-06, grad_fn=<SumBackward0>)
tensor(-2.6226e-06, grad_fn=<SumBackward0>)
tensor(1.7285e-06, grad_fn=<SumBackward0>)
tensor(5.9605e-07, grad_fn=<SumBackward0>)
tensor

In [7]:
x = torch.stack([torch.rand((10,)),torch.rand((10,))])
nam = Classifier(False, 10, 2, 40)
am = Classifier(True, 10, 2, 40)

In [8]:
nam(x)

tensor([[-1.0121, -3.1367],
        [ 6.2277, 11.6180]], grad_fn=<AddmmBackward0>)

In [9]:
am(x)

tensor([[-1.0121, -3.1367],
        [ 6.2277, 11.6180]], grad_fn=<AddBackward0>)

In [10]:
base = am(x)
for i in range(20):
    print(torch.sum(am(x) - base))

tensor(7.1526e-06, grad_fn=<SumBackward0>)
tensor(8.1062e-06, grad_fn=<SumBackward0>)
tensor(9.2983e-06, grad_fn=<SumBackward0>)
tensor(6.4373e-06, grad_fn=<SumBackward0>)
tensor(6.1989e-06, grad_fn=<SumBackward0>)
tensor(7.1526e-07, grad_fn=<SumBackward0>)
tensor(4.7684e-06, grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
tensor(3.8147e-06, grad_fn=<SumBackward0>)
tensor(2.3842e-06, grad_fn=<SumBackward0>)
tensor(6.4373e-06, grad_fn=<SumBackward0>)
tensor(6.6757e-06, grad_fn=<SumBackward0>)
tensor(-3.8147e-06, grad_fn=<SumBackward0>)
tensor(7.6294e-06, grad_fn=<SumBackward0>)
tensor(8.5831e-06, grad_fn=<SumBackward0>)
tensor(5.4836e-06, grad_fn=<SumBackward0>)
tensor(6.4373e-06, grad_fn=<SumBackward0>)
tensor(8.8215e-06, grad_fn=<SumBackward0>)
tensor(3.0994e-06, grad_fn=<SumBackward0>)
tensor(9.0599e-06, grad_fn=<SumBackward0>)


## Try

[:, None, :]

vs 

.unsqueeze(dim=1) 

timing difference?