In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
import time
import numpy as np
import matplotlib.pyplot as plt
from itertools import islice, cycle
from torch.utils.data.dataset import random_split
from sklearn.neighbors import NearestNeighbors
from windows_inhibitor import WindowsInhibitor
from sklearn.base import BaseEstimator, TransformerMixin
from numpy.random import default_rng
import os.path as path
rng = default_rng()
device = torch.device("cuda:0" if True else "cpu")
DATA_PATH = 'strict_dataset.npy'

In [2]:
class TrainSet(torch.utils.data.Dataset):
    def __init__(self, path, transform=None):
        super().__init__()
        self.path = path
        self.data = np.load(path)
        self.rows = self.data.shape[0]
        self.cols = self.data.shape[1]
        self.transform = transform

    def __len__(self):
        return self.rows

    def __getitem__(self, idx):
        sample = torch.tensor(self.data[idx], dtype=torch.float)
        if self.transform:
            sample = self.transform(sample)

        return sample


In [3]:
class LinearTransformEM(BaseEstimator, TransformerMixin):
    def __init__(self):
        pass

    def fit(self, X, y=0):
        pass

    def transform(self, X):
        pass


In [4]:
# ae.eval()
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5), (0.5))])

classes = list(range(10))
trainset = TrainSet(DATA_PATH)


In [5]:
neighbor_data = np.array([list(x.view(-1).numpy()) for x in trainset])
neighbors = NearestNeighbors(n_neighbors=5, n_jobs=6)

neighbors.fit(neighbor_data)

def get_neighbors(data, neighbors, data_points, k=0):
    with torch.no_grad():
        nb_indices = neighbors.kneighbors(data_points.cpu())[1]
        nb_indices = nb_indices[:, k]
        out_tensor = torch.tensor(data[nb_indices], dtype=torch.float)
        out_tensor.requires_grad = False
    return out_tensor

def get_neighbor_indices(data, neighbors, data_points, k=0):
    with torch.no_grad():
        nb_indices = neighbors.kneighbors(data_points.cpu())[1]
        nb_indices = nb_indices[:, k]
        # out_tensor = torch.tensor(data[nb_indices], dtype=torch.float)
        # out_tensor.requires_grad = False
    return nb_indices

In [6]:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(trainset.cols, trainset.cols, bias=False)
        # self.trans = torch.diag(torch.tensor([-1 if k < 28*28/2 else 1 for k in range(28*28)], dtype=torch.float))
        # self.trans = torch.diag(torch.tensor([-1 if i < 28 * 28 / 2 else 1 for i in range(28 * 28)],
        #                                      dtype=torch.float, requires_grad=False))
        self.trans = torch.block_diag(*[torch.tensor([[0, 1], [1, 0]], dtype=torch.float) for _ in range(trainset.cols // 2)])
        # torch.nn.init.uniform_(self.fc1.weight, -10 ** -4, 10 ** -4)
        # self.fc1.weight += torch.eye(28, 28).view(-1)


    def forward(self, x):
        # x = x.view(-1, 28 * 28)
        x = self.fc1(x)
        # y = x

        x = F.linear(x, self.trans)

        # y = F.linear(y, self.fc1.weight.t())
        x = F.linear(x, self.fc1.weight.t())

        # y = y.view(-1, 1, 28, 28)
        # x = x.view(-1, 1, 28, 28)
        return x #, y

def init_weights(m):
    if type(m) == nn.Linear:
        # torch.nn.init.uniform_(m.weight, -10 ** -4, 10 ** -4)
        torch.nn.init.orthogonal_(m.weight)
        # with torch.no_grad():
        #     m.weight += torch.eye(28 * 28)


net = Net()
net.apply(init_weights)
total_time = 0
true_epoch = 0
train_error_list = []
test_error_list = []


In [7]:
# state_dict = torch.load('symmetry_net.pkl')
# net.load_state_dict(state_dict)

In [8]:

trainloader = torch.utils.data.DataLoader(trainset, batch_size=20,
                                          shuffle=True, pin_memory=True)
criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.003, momentum=0.5, weight_decay=0)

In [9]:
with WindowsInhibitor():
    net.to(device)
    net.trans = net.trans.to(device)
    id_mat = torch.eye(trainset.cols, requires_grad=False, device=device)
    for epoch in range(100):  # loop over the dataset multiple times
        start_time = time.time()
        true_epoch += 1
        running_loss = 0.0
        epoch_loss = 0.0
        running_ortho_loss = 0.0
        running_discont_loss = 0.0
        running_ground_truth_loss = 0.0
        true_train_total = 0.0
        total_train = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs
            inputs = data
            inputs = inputs.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward + backward + optimize
            outputs = net(inputs)
            k = rng.choice(range(5))
            nbs = get_neighbors(neighbor_data, neighbors, outputs.detach(), k=k).to(device)
            # print(type(outputs), type(nbs))
            loss = criterion(outputs, nbs)
            running_loss += loss.detach()
            epoch_loss += loss.detach()

            orth_loss = criterion(net.fc1.weight.t() @ net.fc1.weight, id_mat) * 3_000
            running_ortho_loss += orth_loss.detach()

            loss += orth_loss

            with torch.no_grad():
                ground_truth_loss = criterion(net.fc1.weight.t() @ net.trans @ net.fc1.weight, net.trans)
                running_ground_truth_loss += ground_truth_loss.detach()
                mat = net.fc1.weight.t() @ net.trans @ net.fc1.weight

            # print(loss.detach(), orth_loss.detach())
            total_train += 1
            true_train_total += 1

            loss.backward()
            optimizer.step()

            # with torch.no_grad():
            #     print(net.fc1.weight.det())
                # net.fc1.weight /= torch.abs(net.fc1.weight.det()) ** (1 / (28 * 28))
            optimizer.zero_grad()
            # print(f'loss: {running_loss / total_train:.4f}')
            # print statistics
            if i % 100 == 99:    # print every 25 mini-batches
                print(f'[{true_epoch}, {i + 1}] '
                      f'loss: {running_loss / total_train:.4f}, '
                      f'ortho_loss: {running_ortho_loss / total_train:.4f}, '
                      f'ground_truth_loss: {running_ground_truth_loss / total_train:.4f}, ')
                running_loss = 0.0
                running_ortho_loss = 0.0
                running_discont_loss = 0.0
                running_ground_truth_loss = 0.0
                total_train = 0.0
        print(f'total error = {epoch_loss / true_train_total:.4f}')
        with torch.no_grad():
            print(get_neighbor_indices(neighbor_data, neighbors, net(trainset[:10].to(device)), k=0))
        # test_error = 0.0
        # total = 0
        # with torch.no_grad():
        #     for data in testloader:
        #         images, labels = data
        #         images, labels = images.to(device), labels.to(device)
        #         outputs = net(images)
        #         test_loss = criterion(outputs, images)
        #         total += labels.size(0)
        #         test_error += test_loss.detach()
        #     test_error_list.append(test_error / total)

        total_time += time.time() - start_time
        # print(f'Accuracy of the network on the 10000 test images: {100 * test_error / total}')
        print(f'Finished epoch, cumulative time: {total_time}s')
    print("Finished training!")



Preventing Windows from going to sleep
[1, 100] loss: 1.6717, ortho_loss: 0.0924, ground_truth_loss: 0.0634, 
Allowing Windows to go to sleep


KeyboardInterrupt: 

In [None]:
mat = net.fc1.weight.t() @ net.trans @ net.fc1.weight
(mat * (torch.abs((mat)) > 0.001).float())[4:10, 4:10]

In [None]:
(net.fc1.weight.t() @ net.fc1.weight).det()

In [None]:
trainset[:3]


In [None]:
trainset[500:503]

In [None]:
net.to(device)
net.trans = net.trans.to(device)
with torch.no_grad():
    print(get_neighbor_indices(neighbor_data, neighbors, net(trainset[600:625].to(device)), k=0))


In [None]:
with torch.no_grad():
    print(get_neighbor_indices(neighbor_data, neighbors, net(trainset[400:425].to(device)), k=1))