In [41]:
import torch
import dill
import os.path as path
import time
import logging

UTIL_DATA = 'util_data.pkl'
import numpy as np
from numpy.random import default_rng
from sklearn.base import BaseEstimator, RegressorMixin
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import resample
from sklearn.model_selection import train_test_split
from scipy.stats import rankdata, ortho_group, skewtest, kurtosis
from metrics import DebiasedMMDLoss, se_kernel, poly_kernel
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
import torch.nn.functional as F
import torch.optim as optim
from itertools import islice, cycle

logger = logging.getLogger(__name__)
if not logger.handlers:
    file_handler = logging.FileHandler('label_exploitation.log')
    logger.addHandler(file_handler)
logger.setLevel(logging.INFO)


In [42]:
logger.handlers

[<FileHandler C:\Users\Work\PycharmProjects\Conda_test\group_action_learning\label_exploitation.log (NOTSET)>]

In [43]:
def roundrobin(*iterables):
    "roundrobin('ABC', 'D', 'EF') --> A D E B F C"
    # Recipe credited to George Sakkis
    num_active = len(iterables)
    nexts = cycle(iter(it).__next__ for it in iterables)
    while num_active:
        try:
            for next in nexts:
                yield next()
        except StopIteration:
            # Remove the iterator we just exhausted from the cycle.
            num_active -= 1
            nexts = cycle(islice(nexts, num_active))

In [44]:
dim = (10, 10)

transform = transforms.Compose(
    [
    transforms.RandomHorizontalFlip(),
    transforms.Resize(dim),
    transforms.ToTensor(),
     lambda x: x.view(-1)])


classes = list(range(10))
trainsets = [torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform) for _ in classes]
for i, trainset in zip(classes, trainsets):
    idx = trainset.train_labels == i
    trainset.targets = trainset.targets[idx]
    trainset.data = trainset.data[idx]

class Model(nn.Module):
    def __init__(self, dim=(10, 10)):
        super().__init__()
        self.dim = dim
        self.fc1 = nn.Linear(dim[0]*dim[1], dim[0]*dim[1], bias=False)

    def forward(self, x):
        x = self.fc1(x)
        return x

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.orthogonal_(m.weight)

model = Model(dim)
model.apply(init_weights)
true_epoch = 0



In [45]:
device='cuda:0'
trainloaders = [torch.utils.data.DataLoader(trainset, batch_size=1_024,
                                          shuffle=True, pin_memory=True) for trainset in trainsets]
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.5)
model = model.to(device)

In [46]:
from metrics import poly_kernel

epochs = 100
inter_error_stride = 20
error_display_stride = 5
weight_penalty_adj = 10
criterion = DebiasedMMDLoss(kernel = poly_kernel, r=0, m=2, gamma=1)
weight_criterion = nn.MSELoss()


logger.info(f"Model={model}, device={device}, optimizer={optimizer}, "
            f"criterion={criterion}, epochs={epochs}")


total_time = 0.
id_mat = torch.eye(model.fc1.weight.shape[1], requires_grad=False, device=device)
epoch_loss = 0.0
true_train_total = 0.0
for epoch in range(epochs):
    # loop over the dataset multiple times
    start_time = time.time()
    true_epoch += 1
    running_loss = 0.0
    running_ortho_loss = 0.0
    running_ground_truth_loss = 0.0
    total_train = 0.0
    for i, data in enumerate(roundrobin(*trainloaders), 0):
        with torch.no_grad():
            model.fc1.weight /= (torch.abs(model.fc1.weight.det())) ** (1/(dim[0] * dim[1]))
        # get the inputs
        inputs, _ = data
        inputs = inputs.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(inputs, outputs)
        running_loss += loss.detach()
        epoch_loss += loss.detach()

        # orth_loss = weight_criterion(model.fc1.weight.t() @ model.fc1.weight, id_mat) * weight_penalty_adj
        # orth_loss = (model.fc1.weight.det() - 1) ** 2 * weight_penalty_adj
        # running_ortho_loss += orth_loss.detach()

        # loss += orth_loss

        # Prevent learning the identity map
        id_loss = (torch.trace(model.fc1.weight)) ** 2
        loss += id_loss

        total_train += 1
        true_train_total += 1

        loss.backward()
        optimizer.step()

        optimizer.zero_grad()
        if i % inter_error_stride == (inter_error_stride - 1) and (true_epoch % error_display_stride) == 1:
            # print every n mini-batches
            partial_err_msg = f'[{true_epoch}, {i + 1}] ' \
                              f'loss: {running_loss / total_train:.4f}, ' \
                              f'ortho_loss: {running_ortho_loss / total_train:.4f}'
            print(partial_err_msg)
            logger.info(partial_err_msg)

            running_loss = 0.0
            running_ortho_loss = 0.0
            total_train = 0.0
    total_time += time.time() - start_time
    if true_epoch % error_display_stride == 1:
        tot_err_msg = f'total error = {epoch_loss / true_train_total:.4f}'
        print(tot_err_msg)
        logger.info(tot_err_msg)
        time_msg = f'Finished epoch, cumulative time: {total_time}s'
        print(time_msg)
        logger.info(time_msg)
    epoch_loss = 0.0
    true_train_total = 0.0


finish_msg = "Finished training!"
print(finish_msg)
logger.info(finish_msg)
print(epoch_loss / true_train_total if true_train_total > 0 else float('inf'))

tensor([0, 0, 0,  ..., 0, 0, 0])
tensor([1, 1, 1,  ..., 1, 1, 1])
tensor([2, 2, 2,  ..., 2, 2, 2])
tensor([3, 3, 3,  ..., 3, 3, 3])
tensor([4, 4, 4,  ..., 4, 4, 4])
tensor([5, 5, 5,  ..., 5, 5, 5])
tensor([6, 6, 6,  ..., 6, 6, 6])
tensor([7, 7, 7,  ..., 7, 7, 7])
tensor([8, 8, 8,  ..., 8, 8, 8])
tensor([9, 9, 9,  ..., 9, 9, 9])
tensor([0, 0, 0,  ..., 0, 0, 0])
tensor([1, 1, 1,  ..., 1, 1, 1])
tensor([2, 2, 2,  ..., 2, 2, 2])
tensor([3, 3, 3,  ..., 3, 3, 3])
tensor([4, 4, 4,  ..., 4, 4, 4])
tensor([5, 5, 5,  ..., 5, 5, 5])
tensor([6, 6, 6,  ..., 6, 6, 6])
tensor([7, 7, 7,  ..., 7, 7, 7])
tensor([8, 8, 8,  ..., 8, 8, 8])
tensor([9, 9, 9,  ..., 9, 9, 9])
[1, 20] loss: 30.9183, ortho_loss: 0.0000
tensor([0, 0, 0,  ..., 0, 0, 0])
tensor([1, 1, 1,  ..., 1, 1, 1])
tensor([2, 2, 2,  ..., 2, 2, 2])
tensor([3, 3, 3,  ..., 3, 3, 3])
tensor([4, 4, 4,  ..., 4, 4, 4])
tensor([5, 5, 5,  ..., 5, 5, 5])
tensor([6, 6, 6,  ..., 6, 6, 6])
tensor([7, 7, 7,  ..., 7, 7, 7])
tensor([8, 8, 8,  ..., 8, 8, 8])
t

KeyboardInterrupt: 

In [None]:
# model.fc1.weight
# weights = model.fc1.weight.detach().numpy()[:, 0]
# plt.imshow(weights.reshape(dim[0], dim[1]))

In [None]:
with torch.no_grad():
    model = model.to("cpu")
    orig_img = trainsets[1][0][0]
    fig, ax = plt.subplots(1,3)
    img1 = orig_img.reshape(dim) * 255
    ax[0].imshow(img1.numpy().reshape(dim) * 255, cmap='gray')
    img2 = (np.clip(model(orig_img).numpy(), 0, 1)).reshape(dim) * 255
    ax[1].imshow(img2, cmap='gray')
    img3 = model(orig_img).numpy().reshape(dim) * 255
    ax[2].imshow(img3, cmap='gray')

    plt.show()

In [None]:
weights = model.fc1.weight.detach().numpy()
eigens = np.linalg.eig(weights)

In [None]:
plt.hist2d(np.real(eigens[0]), np.imag(eigens[0]), bins=50)