In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
import numpy as np
from numpy.random import default_rng
from core import train_unsup
rng = default_rng()
device = torch.device("cuda:0" if True else "cpu")
DATA_PATH = '../data/gumbel_dataset.npy'

In [2]:
class TrainSet(torch.utils.data.Dataset):
    def __init__(self, path, transform=None):
        super().__init__()
        self.path = path
        self.data = np.load(path)
        self.rows = self.data.shape[0]
        self.cols = self.data.shape[1]
        self.transform = transform

    def __len__(self):
        return self.rows

    def __getitem__(self, idx):
        sample = torch.tensor(self.data[idx], dtype=torch.float)
        if self.transform:
            sample = self.transform(sample)

        return sample

trainset = TrainSet(DATA_PATH)

In [3]:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(trainset.cols, trainset.cols, bias=False)
        # self.trans = torch.diag(torch.tensor([-1 if k < 28*28/2 else 1 for k in range(28*28)], dtype=torch.float))
        # self.trans = torch.diag(torch.tensor([-1 if i < 28 * 28 / 2 else 1 for i in range(28 * 28)],
        #                                      dtype=torch.float, requires_grad=False))
        self.trans = torch.block_diag(*[torch.tensor([[0, 1], [1, 0]], dtype=torch.float) for _ in range(trainset.cols // 2)])
        # torch.nn.init.uniform_(self.fc1.weight, -10 ** -4, 10 ** -4)
        # self.fc1.weight += torch.eye(28, 28).view(-1)


    def forward(self, x):
        # x = x.view(-1, 28 * 28)
        x = self.fc1(x)
        # y = x

        x = F.linear(x, self.trans)

        # y = F.linear(y, self.fc1.weight.t())
        x = F.linear(x, self.fc1.weight.t())

        # y = y.view(-1, 1, 28, 28)
        # x = x.view(-1, 1, 28, 28)
        return x #, y

def init_weights(m):
    if type(m) == nn.Linear:
        # torch.nn.init.uniform_(m.weight, -10 ** -4, 10 ** -4)
        torch.nn.init.orthogonal_(m.weight)
        # with torch.no_grad():
        #     m.weight += torch.eye(28 * 28)


In [4]:
def split(X):
    return X[:X.shape[0]//2], X[X.shape[0]//2:]

def se_kernel(X, Y, sig2=1):
    X_norms =  torch.mean(X ** 2, dim=1)
    Y_norms = torch.mean(Y ** 2, dim=1)
    # print(torch.exp(X_norms.unsqueeze(1) / (2 * sig2)).shape, torch.exp(Y_norms.unsqueeze(0) / (2 * sig2)).shape)
    # print(X.shape, Y.shape)
    return torch.exp(X @ Y.t() / (2 * sig2)) / (torch.exp(X_norms.unsqueeze(1) / (2 * sig2)) @
                                                torch.exp(Y_norms.unsqueeze(0) / (2 * sig2))) * sig2

def poly_kernel(X, Y, r=1, m=2, gamma=0.01):
    return (r + gamma * X @ Y.t()) ** m


class MMDLoss(nn.Module):
    def __init__(self, kernel = se_kernel, **kwargs):
        super().__init__()
        self.kernel = kernel
        self.kwargs = kwargs

    def forward(self, X, Y):
        kernel_dists = self.kernel(X, X, **self.kwargs) + self.kernel(Y, Y, **self.kwargs) - 2 * self.kernel(X, Y, **self.kwargs)
        loss = torch.mean(kernel_dists)
        return loss

class SplitMMDLoss(nn.Module):
    def __init__(self, kernel = se_kernel, **kwargs):
        super().__init__()
        self.kernel = kernel
        self.kwargs = kwargs

    def forward(self, X, Y):
        X1, X2 = split(X)
        Y1, Y2 = split(Y)
        kernel_dists = self.kernel(X1, X2, **self.kwargs) + self.kernel(Y1, Y2, **self.kwargs) - \
                       self.kernel(X1, Y2, **self.kwargs) - self.kernel(X2, Y1, **self.kwargs)
        loss = torch.mean(kernel_dists)
        return loss

class DebiasedMMDLoss(nn.Module):
    def __init__(self, kernel = se_kernel, **kwargs):
        super().__init__()
        self.kernel = kernel
        self.kwargs = kwargs

    def forward(self, X, Y):
        kernel_dists = self.kernel(X, X, **self.kwargs) + self.kernel(Y, Y, **self.kwargs) - 2 * self.kernel(X, Y, **self.kwargs)
        mask = torch.eye(*kernel_dists.shape, device=kernel_dists.device).byte()
        kernel_dists.masked_fill_(mask, 0)
        loss = torch.mean(kernel_dists)
        return loss


In [5]:
BATCH_SIZE = 512
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
                                          shuffle=True, pin_memory=True)
# criterion = SplitMMDLoss(sig2=100)
# criterion = DebiasedMMDLoss(kernel = se_kernel, sig2=100)
criterion = DebiasedMMDLoss(kernel = poly_kernel, r=1, m=2, gamma=0.3)
weight_criterion = nn.MSELoss()

stride = len(trainset) // (BATCH_SIZE * 3)

In [40]:

train_unsup(trainloader, Net, device, optim.SGD, criterion, weight_criterion, init=init_weights,
            use_saved=True, error_display_stride=5, inter_error_stride=stride, epochs=1_000,
            optimizer_params={'lr': 0.0001, 'momentum': 0.5, 'weight_decay': 0},
            weight_penalty_adj=100)


Preventing Windows from going to sleep
[41, 130] loss: 0.2654, ortho_loss: 0.0011, ground_truth_loss: 0.4386, base change det: -0.99
[41, 260] loss: 0.3351, ortho_loss: 0.0010, ground_truth_loss: 0.4384, base change det: -0.99
[41, 390] loss: 0.3829, ortho_loss: 0.0010, ground_truth_loss: 0.4386, base change det: -0.99
total error = 0.3199
Finished epoch, cumulative time: 105.2368733882904s
[46, 130] loss: 0.2042, ortho_loss: 0.0010, ground_truth_loss: 0.4409, base change det: -0.99
[46, 260] loss: 0.3821, ortho_loss: 0.0010, ground_truth_loss: 0.4405, base change det: -0.99
[46, 390] loss: 0.2941, ortho_loss: 0.0009, ground_truth_loss: 0.4407, base change det: -0.99
total error = 0.2947
Finished epoch, cumulative time: 118.05787444114685s
[51, 130] loss: 0.4463, ortho_loss: 0.0010, ground_truth_loss: 0.4409, base change det: -0.99
[51, 260] loss: 0.4096, ortho_loss: 0.0011, ground_truth_loss: 0.4415, base change det: -0.99
[51, 390] loss: 0.1976, ortho_loss: 0.0009, ground_truth_loss:



KeyboardInterrupt: 

In [None]:
torch.std(trainset[:], dim=0)

In [None]:
np.std(trainset[:12_000].numpy(), axis=0)

In [None]:
torch.mean(trainset[:50_000], dim=0)

In [47]:
net = Net()
net.load_state_dict(torch.load('../data/state_dict.pt'))
net

Net(
  (fc1): Linear(in_features=6, out_features=6, bias=False)
)

In [48]:
mat = net.fc1.weight.t() @ net.trans @ net.fc1.weight

In [49]:
mat

tensor([[-0.3173, -0.4071, -0.1587, -0.3930,  0.6884, -0.2732],
        [-0.4071, -0.3225, -0.3990, -0.1313, -0.2818,  0.6845],
        [-0.1587, -0.3990,  0.5268, -0.4013, -0.5373, -0.2816],
        [-0.3930, -0.1313, -0.4013,  0.5300, -0.2767, -0.5488],
        [ 0.6884, -0.2818, -0.5373, -0.2767, -0.2085, -0.1834],
        [-0.2732,  0.6845, -0.2816, -0.5488, -0.1834, -0.1906]],
       grad_fn=<MmBackward>)

In [50]:
net.fc1.weight.detach()[:4]

tensor([[ 0.2769, -0.5538,  0.6043, -0.1832,  0.0666,  0.4596],
        [ 0.0316,  0.1434,  0.4021, -0.6491, -0.0276, -0.6263],
        [-0.4638, -0.7104, -0.0347,  0.2325,  0.1280, -0.4514],
        [ 0.5654,  0.0592,  0.3076,  0.6295, -0.1060, -0.4122]])

In [8]:
cov = np.cov(trainset[::10], rowvar=False)
cov

array([[ 7.69070564, -0.77565903,  0.8414195 , -2.26890737, -1.40195854,
         0.49187868],
       [-0.77565903,  7.69070564, -2.26890737,  0.8414195 ,  0.49187868,
        -1.40195854],
       [ 0.8414195 , -2.26890737, 11.35803414, -4.1179805 , -0.29665543,
        -1.94111775],
       [-2.26890737,  0.8414195 , -4.1179805 , 11.35803414, -1.94111775,
        -0.29665543],
       [-1.40195854,  0.49187868, -0.29665543, -1.94111775,  6.56501936,
        -0.8006946 ],
       [ 0.49187868, -1.40195854, -1.94111775, -0.29665543, -0.8006946 ,
         6.56501936]])

In [None]:
mat @ torch.tensor(cov, dtype=torch.float) @ mat.t()


In [None]:
mat @ torch.mean(trainset[:], dim=0)

In [56]:
torch.mean(trainset[:], dim=0)

tensor([ 0.4696,  0.4696, -0.5229, -0.5229,  0.6039,  0.6039])

In [57]:
mat @ torch.tensor(cov, dtype=torch.float) @ mat.t()


tensor([[ 7.7665, -0.6745,  0.8568, -2.3333, -1.1600,  0.7857],
        [-0.6745,  7.8840, -2.4450,  1.0088,  0.7521, -1.1641],
        [ 0.8568, -2.4450, 10.5508, -4.7146, -0.5094, -1.9335],
        [-2.3333,  1.0088, -4.7146, 10.6666, -1.9786, -0.5153],
        [-1.1600,  0.7521, -0.5094, -1.9786,  7.1002, -0.3529],
        [ 0.7857, -1.1641, -1.9335, -0.5153, -0.3529,  7.0733]],
       grad_fn=<MmBackward>)

In [58]:
torch.tensor(cov, dtype=torch.float)

tensor([[ 7.6907, -0.7757,  0.8414, -2.2689, -1.4020,  0.4919],
        [-0.7757,  7.6907, -2.2689,  0.8414,  0.4919, -1.4020],
        [ 0.8414, -2.2689, 11.3580, -4.1180, -0.2967, -1.9411],
        [-2.2689,  0.8414, -4.1180, 11.3580, -1.9411, -0.2967],
        [-1.4020,  0.4919, -0.2967, -1.9411,  6.5650, -0.8007],
        [ 0.4919, -1.4020, -1.9411, -0.2967, -0.8007,  6.5650]])

In [None]:
torch.mean(trainset[:], dim=0)

In [None]:
mat

In [None]:
torch.tensor(cov, dtype=torch.float)

In [None]:
mat @ torch.tensor(cov, dtype=torch.float) @ mat.t()

In [54]:
np.linalg.eig(cov)

(array([16.74224347,  3.42187876,  4.93927248,  7.40701679,  9.62657731,
         9.09052946]),
 array([[-0.23363209,  0.29516458, -0.45455862, -0.59393632, -0.48866401,
          0.24518875],
        [ 0.23363209,  0.29516458,  0.45455862, -0.59393632,  0.48866401,
          0.24518875],
        [-0.66379473,  0.40345443,  0.21197205, -0.03869398,  0.12018489,
         -0.5794198 ],
        [ 0.66379473,  0.40345443, -0.21197205, -0.03869398, -0.12018489,
         -0.5794198 ],
        [-0.06922863,  0.50010239, -0.49844189,  0.38176227,  0.49675254,
          0.32273081],
        [ 0.06922863,  0.50010239,  0.49844189,  0.38176227, -0.49675254,
          0.32273081]]))

In [55]:
np.linalg.eig(mat.detach())

(array([-0.9830776 , -0.9993397 , -0.99879336,  1.0014286 ,  0.9985549 ,
         0.9991174 ], dtype=float32),
 array([[ 0.4058    , -0.63258165, -0.30924073,  0.5789646 , -0.04740828,
          0.04556678],
        [ 0.3774907 , -0.3808367 ,  0.61255133, -0.3836642 ,  0.047991  ,
          0.4333082 ],
        [ 0.4608624 ,  0.11206639,  0.11394423, -0.04079129,  0.6573224 ,
         -0.5730069 ],
        [ 0.46207172,  0.05292811, -0.14203218, -0.36606646, -0.6830215 ,
         -0.40371877],
        [ 0.35122034,  0.58286643,  0.3776067 ,  0.5601702 , -0.2092571 ,
          0.19165519],
        [ 0.37890235,  0.3157694 , -0.5944811 , -0.26104134,  0.23036858,
          0.5311404 ]], dtype=float32))

In [None]:
np.linalg.eig(cov)[1][0] @ mat.detach().numpy()

In [None]:
np.linalg.eig(cov)[1][0]

In [68]:
cov @ np.linalg.eig(cov)[1][:, 0]

array([ -3.91152527,   3.91152527, -11.113413  ,  11.113413  ,
        -1.15904257,   1.15904257])

In [69]:
np.linalg.eig(cov)[0][0] * np.linalg.eig(cov)[1][:, 0]


array([ -3.91152527,   3.91152527, -11.113413  ,  11.113413  ,
        -1.15904257,   1.15904257])

In [73]:
mat.detach().numpy() @ np.linalg.eig(cov)[1][:, 0]

array([-0.24313127,  0.26436015, -0.65455919,  0.66056184, -0.05195789,
        0.04587332])

In [76]:
mat.detach().numpy() @ np.linalg.eig(cov)[1][:, 1]

array([-0.22878697, -0.22790091, -0.52350771, -0.515707  , -0.40444743,
       -0.40069406])

In [11]:
eigenvectors = np.linalg.eig(cov)[1]
mu = torch.mean(trainset[:], dim=0)

In [15]:
sol = np.linalg.solve(eigenvectors, mu)

In [16]:
eigenvectors

array([[-0.23363209,  0.29516458, -0.45455862, -0.59393632, -0.48866401,
         0.24518875],
       [ 0.23363209,  0.29516458,  0.45455862, -0.59393632,  0.48866401,
         0.24518875],
       [-0.66379473,  0.40345443,  0.21197205, -0.03869398,  0.12018489,
        -0.5794198 ],
       [ 0.66379473,  0.40345443, -0.21197205, -0.03869398, -0.12018489,
        -0.5794198 ],
       [-0.06922863,  0.50010239, -0.49844189,  0.38176227,  0.49675254,
         0.32273081],
       [ 0.06922863,  0.50010239,  0.49844189,  0.38176227, -0.49675254,
         0.32273081]])

In [22]:
rel_vectors = (np.abs(sol) > 0.0001) * 2. - 1
rel_vectors

array([-1.,  1., -1.,  1., -1.,  1.])

In [31]:
new_mat = eigenvectors @ np.diag(rel_vectors) @ eigenvectors.T
new_mat * (np.abs(new_mat) > 10 ** (-5) )

array([[-0.,  1.,  0.,  0.,  0., -0.],
       [ 1.,  0., -0., -0.,  0., -0.],
       [ 0., -0., -0.,  1., -0.,  0.],
       [ 0., -0.,  1.,  0., -0.,  0.],
       [ 0.,  0., -0., -0.,  0.,  1.],
       [-0., -0.,  0.,  0.,  1., -0.]])