In [8]:
from itertools import islice

import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torchvision

In [9]:
########### Set Device ############
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
dtype = torch.float64
torch.set_default_dtype(dtype)
print("Using device: {}".format(device))

Using device: cuda:1


In [10]:
x = torch.randn(10, 3)

In [11]:
b = x.shape[-2]

In [12]:
b

10

In [27]:
M = torch.rand((4, 3))

In [29]:
M.shape

torch.Size([4, 3])

In [30]:
torch.matmul(x, M.T).shape

torch.Size([10, 4])

In [36]:
x.shape

torch.Size([10, 3])

In [38]:
M_ext = torch.tile(M, (b, 1, 1))
M_ext.shape

torch.Size([10, 4, 3])

In [47]:
y = torch.matmul(x, M_ext.transpose(2, 1))
torch.all(y[1, 1] == y[0, 1])

tensor(True)

In [60]:
torch.diagonal(y, dim1=-3, dim2=-2).T.shape

torch.Size([10, 4])

In [50]:
M_ext = torch.tile(M, (b, 1, 1))

In [51]:
M_ext.shape

torch.Size([10, 3, 4])

In [52]:
x.shape

torch.Size([10, 3])

In [54]:
torch.matmul(M_ext.transpose(2, 1), x.T).shape

torch.Size([10, 4, 10])

In [80]:
M_ext.shape

torch.Size([10, 4, 3])

In [41]:
torch.matmul(x, M_ext).shape

torch.Size([10, 10, 4])

In [51]:
torch.diagonal?

In [123]:
class AddNoise(nn.Module):
    def __init__(self, scale=0.05, device=device):
        super(AddNoise, self).__init__()
        self.scale = scale
        self.device = device
    
    def forward(self, x):
        return x + self.scale * torch.randn(x.shape).to(self.device)


class NoisyLinear(nn.Linear):
    def __init__(self, *args, scale, device, **kwargs):
        super(NoisyLinear, self).__init__(*args, **kwargs)
        self.scale = scale
        self.device = device
    
    def forward(self, x):
        b = x.shape[-2]
        weight_tiled = torch.tile(self.weight, (b, 1, 1))
        weight_noise = self.scale * torch.randn(weight_tiled.shape).to(self.device)
        weight_noisy = weight_tiled + weight_noise
        if self.bias:
            bias_tiled = torch.tile(self.bias, (b, 1))
            bias_noise = self.scale * torch.randn(bias_tiled.shape).to(self.device)
            bias_noisy = bias_tiled + bias_noise
        product = torch.matmul(x, weight_noisy.transpose(2, 1))
        product = torch.diagonal(product, dim1=-3, dim2=-2)
        if self.bias:
            return product.T + bias_noisy
        return product.T

In [129]:
torch.set_printoptions(sci_mode=False)

In [124]:
L = NoisyLinear(3, 4, scale=0.0, device=device, bias=False).to(device)

In [135]:
a = torch.randn((10, 3)).to(device)
b = torch.randn((10, 3)).to(device)

In [136]:
(L(a + b) - (L(a) + L(b)))

tensor([[     0.0000,     -0.0000,      0.0000,      0.0000],
        [     0.0000,      0.0000,      0.0000,      0.0000],
        [     0.0000,      0.0000,      0.0000,     -0.0000],
        [     0.0000,      0.0000,      0.0000,      0.0000],
        [     0.0000,      0.0000,      0.0000,      0.0000],
        [     0.0000,      0.0000,      0.0000,      0.0000],
        [    -0.0000,      0.0000,      0.0000,      0.0000],
        [    -0.0000,      0.0000,      0.0000,      0.0000],
        [    -0.0000,     -0.0000,      0.0000,      0.0000],
        [     0.0000,     -0.0000,     -0.0000,      0.0000]], device='cuda:1',
       grad_fn=<SubBackward0>)