# Using Numpy:

In [1]:
import numpy as np
x = np.random.randn(1024)
w = np.random.randn(1024, 1024)
b = np.random.randn(1024)

w @ x + b

array([ -4.1467563 , -20.45317729,   8.27463481, ...,  32.17259803,
        11.65379888, -22.2881978 ])

In [2]:
def new_multiplication(x, w, b, aggregate=np.sum):
    return aggregate((x[np.newaxis, ...] * w), axis=1) + b

In [3]:
assert np.isclose(
    new_multiplication(x, w, b, aggregate=np.sum),
    w @ x + b
).all()

In [4]:
%%timeit
w @ x + b

71.7 µs ± 1.65 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [14]:
%%timeit
new_multiplication(x, w, b, aggregate=np.max)

1.08 ms ± 3.86 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


# Using pytorch

In [16]:
import torch
import tqdm

In [53]:
in_size = 128
out_size = 256
x = torch.randn(8, in_size)
w = torch.randn(in_size, out_size)
b = torch.randn(out_size)

def new_multiplication(x, w, b, aggregate=torch.sum):
    return aggregate((x.unsqueeze(1) * w.t()), dim=-1) + b


In [3]:
new_multiplication(x, w, b, aggregate=torch.mean)

tensor([[ 1.0756e+00,  1.3063e-03, -5.8510e-01,  ..., -5.8305e-03,
         -3.4652e-01, -1.1912e+00],
        [ 1.2265e+00, -9.0259e-02, -3.9968e-01,  ..., -1.5300e-01,
         -5.2789e-01, -1.2939e+00],
        [ 1.1770e+00,  1.6524e-01, -5.1461e-01,  ..., -6.5718e-02,
         -4.9589e-01, -1.3619e+00],
        ...,
        [ 1.2839e+00,  1.0722e-02, -3.6673e-01,  ..., -7.4722e-02,
         -4.0168e-01, -1.3429e+00],
        [ 1.1328e+00,  1.0426e-01, -4.5403e-01,  ..., -1.0995e-01,
         -6.0258e-01, -1.4791e+00],
        [ 1.2728e+00,  7.4930e-02, -3.3275e-01,  ...,  2.4351e-02,
         -4.1659e-01, -1.2888e+00]])

In [4]:
(x @ w + b)

tensor([[ -7.7497,  -4.6337, -20.3994,  ...,  11.0727,  14.2768,  16.8557],
        [ 11.5624, -16.3541,   3.3344,  ...,  -7.7647,  -8.9386,   3.7175],
        [  5.2206,  16.3496, -11.3767,  ...,   3.4070,  -4.8417,  -4.9850],
        ...,
        [ 18.9017,  -3.4285,   7.5522,  ...,   2.2546,   7.2165,  -2.5537],
        [ -0.4373,   8.5438,  -3.6233,  ...,  -2.2544, -18.4990, -19.9883],
        [ 17.4838,   4.7901,  11.9011,  ...,  14.9358,   5.3080,   4.3627]])

In [7]:
# check it's approximately similar
bool(((new_multiplication(x, w, b) - (x @ w + b)) < 0.0001).all())

True

In [86]:
%%timeit
(x @ w + b)

24.9 µs ± 1.07 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [91]:
%%timeit
new_multiplication(x, w, b, aggregate=torch.mean)

61.6 µs ± 2.17 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [129]:
import numpy as np

In [134]:
a = np.arange(10)

In [135]:
a

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [145]:
alpha=5

In [146]:
a.max() + (1/alpha)*np.log(np.sum(np.exp( alpha * (a - a.max()))))

9.001352149889899

In [174]:

def matmul_abs_max_with_sign(x, w, b):
    intermediate = x.unsqueeze(1) * w.t()
    indxs = intermediate.abs().argmax(dim=-1)
    return torch.gather(intermediate, -1, indxs.unsqueeze(-1)).squeeze()

def matmul_soft_abs_max_with_sign(x, w, b, alpha=10):
    intermediate = x.unsqueeze(1) * w.t()
    max_vals, max_idxs = intermediate.abs().max(dim=-1)
    sign = torch.gather(intermediate, -1, max_idxs.unsqueeze(-1)).squeeze().sign()
    return (max_vals + (1 / alpha) * torch.log(torch.exp(alpha * (intermediate.abs() - max_vals.unsqueeze(-1))).sum(-1))) * sign


In [170]:
intermediate = x.unsqueeze(1) * w.t()
x_star = intermediate.max(dim=-1).values
x_star.shape

torch.Size([8, 256])

In [171]:
matmul_abs_max_with_sign(x, w, b)

tensor([[ 3.9366,  4.8260, -3.1061,  ...,  4.2637, -6.3641, -4.9272],
        [ 2.8481, -2.8832, -2.9196,  ..., -3.7980,  5.1746,  3.6667],
        [-4.6657, -6.2825,  5.2085,  ...,  4.4811,  5.8092, -6.8542],
        ...,
        [-3.8787, -3.8792, -3.4387,  ..., -4.1604,  4.0529,  3.7508],
        [ 3.5586,  2.9782, -3.5428,  ...,  5.6816, -3.7871,  3.9682],
        [-3.4155,  4.1684, -4.6598,  ..., -3.4170,  6.9878, -4.3139]])

In [172]:
matmul_soft_abs_max_with_sign(x, w, b, alpha=10)

tensor([[ 3.9407,  4.8260, -3.1115,  ...,  4.2717, -6.3758, -4.9297],
        [ 2.8502, -2.8873, -2.9216,  ..., -3.8021,  5.1746,  3.6708],
        [-4.6657, -6.2825,  5.2085,  ...,  4.4819,  5.8092, -6.8542],
        ...,
        [-3.8787, -3.8792, -3.4390,  ..., -4.1606,  4.0576,  3.7509],
        [ 3.5617,  2.9807, -3.5460,  ...,  5.6816, -3.8231,  3.9682],
        [-3.4157,  4.1684, -4.6598,  ..., -3.4340,  6.9878, -4.3141]])

In [175]:
class NewLinearLayer(torch.nn.Module):
    
    def __init__(self, in_features, out_features):
        super().__init__()
        self.register_parameter('weights', torch.nn.Parameter(torch.randn(in_features, out_features)))
        self.register_parameter('bias', torch.nn.Parameter(torch.randn(out_features)))
    
    def forward(self, x):
        return matmul_soft_abs_max_with_sign(x, self.weights, self.bias)


In [176]:
import torchvision
import torch.utils.data

In [177]:
new_linear_model = torch.nn.Sequential(
    NewLinearLayer(784, 256),
    torch.nn.Hardtanh(),
    NewLinearLayer(256, 128),
    torch.nn.Hardtanh(),
    NewLinearLayer(128, 64),
    torch.nn.Hardtanh(),
    NewLinearLayer(64, 10),
)

In [178]:
linear_model = torch.nn.Sequential(
    torch.nn.Linear(784, 256),
    torch.nn.Hardtanh(),
    torch.nn.Linear(256, 128),
    torch.nn.Hardtanh(),
    torch.nn.Linear(128, 64),
    torch.nn.Hardtanh(),
    torch.nn.Linear(64, 10),
)

In [179]:
dataset = torchvision.datasets.MNIST("/Users/jan/datasets/mnist/", transform=torchvision.transforms.ToTensor())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16)
x, y = next(iter(dataloader))
x = x.reshape(x.size(0), -1)

In [180]:
criterion = torch.nn.CrossEntropyLoss()

In [None]:

for model in (linear_model, new_linear_model):
    optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.8)
    avg = 0
    for i, (x, y) in enumerate(tqdm.tqdm_notebook(dataloader)):
        x = x.reshape(x.size(0), -1)
        predictions = model(x)
        loss = criterion(predictions, y)
        loss.backward()
        optim.step()
        optim.zero_grad()
        avg += loss.item() / 100
        if i % 100 == 0 and i > 0:
            print(avg)
            avg = 0


HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))

2.3180199432373048
2.2612040305137633
2.222285380363464
2.1685437893867503
2.12642452955246
2.0391437649726867
1.899020750522613
1.7774287736415864
1.6100853729248048
1.4525646042823788
1.2753579592704778
1.1391883230209354
0.9826255416870119
0.9034445458650587
0.8632834327220914
0.7839462408423423
0.7535398200154305
0.702422222197056
0.7132131338119506
0.7136265051364897
0.6865319877862932
0.5714167705178261
0.5325418104231356
0.6202109879255291
0.5728331223130224
0.5429678289592267
0.5726944382488729
0.48500355511903764
0.5558524391055109
0.4953328366577627
0.49067893758416187
0.5232392819225788
0.4796754842996597
0.4696946971118451
0.42503870829939844
0.402633321136236
0.3190304088592529


HBox(children=(IntProgress(value=0, max=3750), HTML(value='')))

3.98548649072647
4.126735692024231
4.012626132965087
3.960223226547241
3.984845917224885
3.9781308436393745
3.8746247375011427
3.881452393531798
3.7416932892799366
3.8226396346092217
3.7050264263153077
3.772139885425567
3.7775453400611876
3.7736988162994383
3.6761367630958564
3.746480813026428
3.813709783554078
3.750118653774261
3.637923209667205
3.604153666496277
3.629136486053468


In [17]:
vals, indxs = w.max(dim=1)

In [92]:
intermediate = x.unsqueeze(1) * w.t()

In [95]:
vals, indxs = intermediate.abs().max(dim=-1)

In [100]:
indxs.shape

torch.Size([8, 256])

In [104]:
vals

tensor([[3.2035, 5.4254, 5.3714,  ..., 5.3970, 5.2756, 4.8329],
        [4.7636, 4.2375, 4.4597,  ..., 2.7539, 4.1946, 3.9239],
        [4.5746, 3.1270, 6.6946,  ..., 4.5977, 6.5752, 3.0569],
        ...,
        [3.7134, 5.4075, 6.0133,  ..., 2.7904, 5.9060, 4.0532],
        [5.4211, 5.0082, 5.0784,  ..., 4.6823, 6.4084, 3.8672],
        [4.9760, 6.0081, 4.3977,  ..., 2.9710, 3.4506, 3.8452]])

In [107]:
torch.gather(intermediate, -1, indxs.unsqueeze(-1)).squeeze()

tensor([[ 3.2035, -5.4254, -5.3714,  ...,  5.3970,  5.2756, -4.8329],
        [-4.7636, -4.2375, -4.4597,  ...,  2.7539, -4.1946, -3.9239],
        [ 4.5746,  3.1270, -6.6946,  ..., -4.5977,  6.5752, -3.0569],
        ...,
        [ 3.7134,  5.4075,  6.0133,  ..., -2.7904, -5.9060, -4.0532],
        [ 5.4211, -5.0082, -5.0784,  ..., -4.6823,  6.4084,  3.8672],
        [-4.9760, -6.0081, -4.3977,  ..., -2.9710, -3.4506,  3.8452]])