#### DEFINE DATA LOADER FUNCTIONS

In [27]:
#### DEFINE DATA LOADER FUNCTIONS ####

from torch.utils.data import Dataset

import numpy as np
from scipy.stats import ortho_group

class DataGeneratorPPCA(Dataset):

    def __init__(self, dims, hdims, min_sv=0.11, max_sv=5.0, sigma_sq=0.1, deterministic=True, total=10000):
        self.dims = dims
        self.hdims = hdims

        self.eigs = min_sv + (max_sv - min_sv) * np.linspace(0, 1, hdims)
        self.eigvectors = ortho_group.rvs(dims)[:, :hdims]
        self.w = np.matmul(self.eigvectors, np.diag(np.sqrt(self.eigs - sigma_sq)))

        self.sigma_sq = sigma_sq
        self.sigma = np.sqrt(sigma_sq)

        self.total = total
        self.deterministic = deterministic
        if self.deterministic:
            self.z_sample = np.random.normal(size=(total, self.hdims))
            self.x_sample = np.random.normal(np.matmul(self.z_sample, self.w.T), self.sigma).astype(np.float32)

    def __getitem__(self, i):
        if self.deterministic:
            return self.x_sample[i]
        else:
            z_sample = np.random.normal(size=self.hdims)
            return np.random.normal(self.w.dot(z_sample), self.sigma).astype(np.float32)

    def __len__(self):
        # Return a large number for an epoch
        return self.total


class DataGeneratorPCA(Dataset):
    def __init__(self, dims, hdims, min_sv=0.11, max_sv=5.0, total=10000, sv_list=None,
                 load_data=None):
        self.dims = dims
        self.hdims = hdims

        if load_data is None:
            if isinstance(sv_list, list):
                assert len(sv_list) == dims
                self.full_eigs = np.array(sorted(sv_list, reverse=True))
            else:
                self.full_eigs = min_sv + (max_sv - min_sv) * np.linspace(1, 0, dims)
            self.eigs = self.full_eigs[:hdims]

            self.full_svs = np.sqrt(self.full_eigs)

            self.full_eigvectors = ortho_group.rvs(dims)
            self.eigvectors = self.full_eigvectors[:, :hdims]

            self.total = total

            self.full_z_sample = np.random.normal(size=(total, self.dims))
            self.x_sample = (self.full_eigvectors @ np.diag(self.full_svs) @ self.full_z_sample.T).T.astype(np.float32)

        else:
            self.x_sample = load_data
            u, s, vh = np.linalg.svd(self.x_sample.T, full_matrices=False)
            self.eigs = s[:self.hdims]
            self.eigvectors = u[:, :self.hdims]
            self.total = len(self.x_sample)

    def __getitem__(self, i):
        return self.x_sample[i]

    def __len__(self):
        return self.total

    @property
    def shape(self):
        return self.x_sample.shape


#### DEFINE MODEL CLASSES

In [20]:
#### DEFINE MODEL CLASSES ####

import os
import torch
import torch.nn as nn
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class ModelConfig:
    def __init__(self, model_name, model_type, model_class, input_dim, hidden_dim, init_scale, optim_class, lr,
                 extra_model_args={}, extra_optim_args={}):
        self.model_name = model_name
        self.model_type = model_type
        self.model_class = model_class
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.init_scale = init_scale
        self.extra_model_args = extra_model_args

        self.optim_class = optim_class
        self.lr = lr
        self.extra_optim_args = extra_optim_args
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model = model_class(input_dim=input_dim, hidden_dim=hidden_dim, init_scale=init_scale, **extra_model_args).to(device)

        self.optimizer = optim_class(self.model.parameters(), lr=lr, **extra_optim_args)

    @property
    def name(self):
        return self.model_name

    @property
    def type(self):
        return self.model_type

    def get_model(self):
        return self.model

    def get_optimizer(self):
        return self.optimizer

class LinearAE(nn.Module):
    def __init__(self,
                 input_dim, hidden_dim, init_scale=0.001,
                 weight_reg_type=None, l2_reg_list=None):
        super(LinearAE, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.encoder = nn.Linear(input_dim, hidden_dim, bias=False)
        self.decoder = nn.Linear(hidden_dim, input_dim, bias=False)

        self.weight_reg_type = weight_reg_type
        self.l2_reg_scalar = None
        self.l2_reg_list = l2_reg_list

        self.encoder.weight.data.normal_(0.0, init_scale)
        self.decoder.weight.data.normal_(0.0, init_scale)

        # configure regularization parameters

        assert self.weight_reg_type is None or isinstance(self.l2_reg_list, list), \
            "l2_reg_list must be a list if weight_reg_type is not None"

        assert self.l2_reg_list is None or len(self.l2_reg_list) == hidden_dim, \
            "Length of l2_reg_list must match latent dimension"

        if weight_reg_type in ("uniform_product", "uniform_sum"):
            self.l2_reg_scalar = l2_reg_list[0] ** 2    # more efficient to use scalar than diag_weights

        elif weight_reg_type == "non_uniform_sum":
            self.reg_weights = torch.tensor(
                np.array(self.l2_reg_list).astype(np.float32)
            )
            self.diag_weights = nn.Parameter(torch.diag(self.reg_weights), requires_grad=False)

    def forward(self, x):
        return self.get_reconstruction_loss(x) + self._get_reg_loss()

    def compute_trace_norm(self):
        """
        Computes the trace norm of the autoencoder, as well as decoder and encoder individually
        :return: trace_norm(W2W1), trace_norm(W1), trace_norm(W2)
        """
        return torch.matmul(self.decoder.weight, self.encoder.weight).norm(p='nuc'), \
               self.encoder.weight.norm(p='nuc'), \
               self.decoder.weight.norm(p='nuc'),

    def get_reconstruction_loss(self, x):
        z = self.encoder(x)
        recon = self.decoder(z)

        recon_loss = torch.sum((x - recon) ** 2) / len(x)
        return recon_loss

    def get_reg_weights_np(self):
        if self.weight_reg_type is None:
            return np.zeros(self.hidden_dim)
        return np.array(self.l2_reg_list)

    def _get_reg_loss(self):
        # Standard L2 regularization, applied to W2W1 (product loss)
        if self.weight_reg_type == 'uniform_product':
            return self.l2_reg_scalar * (torch.norm(torch.matmul(self.decoder.weight, self.encoder.weight)) ** 2)

        # Standard L2 regularization for encoder and decoder separately (sum loss)
        elif self.weight_reg_type == 'uniform_sum':
            # regularize both encoder and decoder
            return self.l2_reg_scalar * (torch.norm(self.encoder.weight) ** 2 + torch.norm(self.decoder.weight) ** 2)

        # non-uniform sum
        elif self.weight_reg_type == 'non_uniform_sum':
            return torch.norm(self.diag_weights @ self.encoder.weight) ** 2 \
                   + torch.norm(self.decoder.weight @ self.diag_weights) ** 2

        # Do not apply regularization
        elif self.weight_reg_type is None:
            return 0.0

        else:
            raise ValueError("weight_reg_type should be one of (uniform_product, uniform_sum, non_uniform_sum, None)")
            


#### DEFINE MODEL TRAINING FUNCTION train_models

In [37]:
#### DEFINE MODEL TRAINING FUNCTION train_models ####

import os
import torch
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def train_models(data_loader, train_itr, metrics_dict, model_configs, eval_metrics_list=None):
    # Initialize model

    for train_i in range(train_itr):
        for x in data_loader:
            x_cuda = x.to(device)

            # ---- Optimize ----
            losses = {}
#             for model_config in model_configs:
            model = model_config.get_model()
            optimizer = model_config.get_optimizer()

            optimizer.zero_grad()

            loss = model(x_cuda)

            loss.backward()
            
            # ROTATION
            y = model.encoder.weight @ x_cuda.T
            yy_t_norm = y @ y.T / float(len(x))
            yy_t_upper = yy_t_norm - yy_t_norm.tril()
            gamma = 0.5 * (yy_t_upper - yy_t_upper.T)
            model.encoder.weight.grad -= gamma @ model.encoder.weight
            model.decoder.weight.grad -= model.decoder.weight @ gamma.T

            optimizer.step()

            losses[model_config.name] = loss.item()

        # ---- Log statistics ----
        if train_i == 0 or (train_i + 1) % 10 == 0:
            print("".join(["Iteration = {}, Losses: ".format(train_i + 1)]
                          + ["{} = {} ".format(key, val) for key, val in losses.items()]))
    
    return model


### TRAIN A MODEL

#### Get the data

In [30]:
##### GET DATA ####
import os
import torch
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

seed=1
# set random seed
np.random.seed(seed)
torch.manual_seed(seed)

input_dim = 1000
hidden_dim = 5

n_data = 5000
batch_size = n_data

max_sv = float(input_dim) * 0.1
min_sv = 1.0
sigma = 0.5

gt_data = DataGeneratorPCA(input_dim, hidden_dim, min_sv=min_sv, max_sv=max_sv, total=n_data)
data = DataGeneratorPCA(input_dim, hidden_dim, load_data=gt_data.x_sample)

loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=False)


#### Define the model

In [39]:
#### Define the model ####

import os
import torch
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#### DEFINE MODEL #####
model_dict = dict(
    model_name='rotation',
    model_type='rotation',
    model_class=LinearAE,
    extra_model_args = {"weight_reg_type": None},
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    init_scale=0.0001,
    optim_class=torch.optim.Adam,
    lr=0.0003,
    extra_optim_args={},
    train_itr=50000,
    seed=seed
)
#     extra_model_args=torch.optim.SGD
#     optim_class={'momentum': 0.9, 'nesterov': True},

# model config contains the model 
model_config = ModelConfig(
        model_name=model_dict['model_name'],
        model_type=model_dict['model_type'],
        model_class=model_dict['model_class'],
        input_dim=model_dict['input_dim'], 
        hidden_dim=model_dict['hidden_dim'],
        init_scale=model_dict['init_scale'],
        extra_model_args=model_dict['extra_model_args'],
        optim_class=model_dict['optim_class'],
        lr=model_dict['lr'],
        extra_optim_args=model_dict['extra_optim_args']
    )



In [None]:
trained_model = train_models(data_loader=loader, train_itr=model_dict['train_itr'], metrics_dict=None, model_configs=model_config)


Iteration = 1, Losses: rotation = 50466.04296875 
Iteration = 10, Losses: rotation = 50461.56640625 
Iteration = 20, Losses: rotation = 50439.94921875 
Iteration = 30, Losses: rotation = 50394.66796875 
Iteration = 40, Losses: rotation = 50324.515625 
Iteration = 50, Losses: rotation = 50235.48046875 
Iteration = 60, Losses: rotation = 50140.59375 
Iteration = 70, Losses: rotation = 50055.4609375 
Iteration = 80, Losses: rotation = 49990.68359375 
Iteration = 90, Losses: rotation = 49947.28125 
Iteration = 100, Losses: rotation = 49919.53125 
Iteration = 110, Losses: rotation = 49900.75390625 
Iteration = 120, Losses: rotation = 49886.57421875 
Iteration = 130, Losses: rotation = 49874.92578125 
Iteration = 140, Losses: rotation = 49864.92578125 
Iteration = 150, Losses: rotation = 49856.15234375 
Iteration = 160, Losses: rotation = 49848.33203125 
Iteration = 170, Losses: rotation = 49841.32421875 
Iteration = 180, Losses: rotation = 49835.0 
Iteration = 190, Losses: rotation = 49829.