#### DEFINE DATA LOADER FUNCTIONS

In [1]:
#### DEFINE DATA LOADER FUNCTIONS ####

from torch.utils.data import Dataset

import numpy as np
from scipy.stats import ortho_group

class DataGeneratorPPCA(Dataset):

    def __init__(self, dims, hdims, min_sv=0.11, max_sv=5.0, sigma_sq=0.1, deterministic=True, total=10000):
        self.dims = dims
        self.hdims = hdims

        self.eigs = min_sv + (max_sv - min_sv) * np.linspace(0, 1, hdims)
        self.eigvectors = ortho_group.rvs(dims)[:, :hdims]
        self.w = np.matmul(self.eigvectors, np.diag(np.sqrt(self.eigs - sigma_sq)))

        self.sigma_sq = sigma_sq
        self.sigma = np.sqrt(sigma_sq)

        self.total = total
        self.deterministic = deterministic
        if self.deterministic:
            self.z_sample = np.random.normal(size=(total, self.hdims))
            self.x_sample = np.random.normal(np.matmul(self.z_sample, self.w.T), self.sigma).astype(np.float32)

    def __getitem__(self, i):
        if self.deterministic:
            return self.x_sample[i]
        else:
            z_sample = np.random.normal(size=self.hdims)
            return np.random.normal(self.w.dot(z_sample), self.sigma).astype(np.float32)

    def __len__(self):
        # Return a large number for an epoch
        return self.total


class DataGeneratorPCA(Dataset):
    def __init__(self, dims, hdims, min_sv=0.11, max_sv=5.0, total=10000, sv_list=None,
                 load_data=None):
        self.dims = dims
        self.hdims = hdims

        if load_data is None:
            if isinstance(sv_list, list):
                assert len(sv_list) == dims
                self.full_eigs = np.array(sorted(sv_list, reverse=True))
            else:
                self.full_eigs = min_sv + (max_sv - min_sv) * np.linspace(1, 0, dims)
            self.eigs = self.full_eigs[:hdims]

            self.full_svs = np.sqrt(self.full_eigs)

            self.full_eigvectors = ortho_group.rvs(dims)
            self.eigvectors = self.full_eigvectors[:, :hdims]

            self.total = total

            self.full_z_sample = np.random.normal(size=(total, self.dims))
            self.x_sample = (self.full_eigvectors @ np.diag(self.full_svs) @ self.full_z_sample.T).T.astype(np.float32)

        else:
            self.x_sample = load_data
            u, s, vh = np.linalg.svd(self.x_sample.T, full_matrices=False)
            self.eigs = s[:self.hdims]
            self.eigvectors = u[:, :self.hdims]
            self.total = len(self.x_sample)

    def __getitem__(self, i):
        return self.x_sample[i]

    def __len__(self):
        return self.total

    @property
    def shape(self):
        return self.x_sample.shape


#### DEFINE MODEL CLASSES

In [2]:
#### DEFINE MODEL CLASSES ####

import os
import torch
import torch.nn as nn
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class ModelConfig:
    def __init__(self, model_name, model_type, model_class, input_dim, hidden_dim, init_scale, optim_class, lr,
                 extra_model_args={}, extra_optim_args={}):
        self.model_name = model_name
        self.model_type = model_type
        self.model_class = model_class
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.init_scale = init_scale
        self.extra_model_args = extra_model_args

        self.optim_class = optim_class
        self.lr = lr
        self.extra_optim_args = extra_optim_args
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model = model_class(input_dim=input_dim, hidden_dim=hidden_dim, init_scale=init_scale, **extra_model_args).to(device)

        self.optimizer = optim_class(self.model.parameters(), lr=lr, **extra_optim_args)

    @property
    def name(self):
        return self.model_name

    @property
    def type(self):
        return self.model_type

    def get_model(self):
        return self.model

    def get_optimizer(self):
        return self.optimizer

class LinearAE(nn.Module):
    def __init__(self,
                 input_dim, hidden_dim, init_scale=0.001,
                 weight_reg_type=None, l2_reg_list=None):
        super(LinearAE, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.encoder = nn.Linear(input_dim, hidden_dim, bias=False)
        self.decoder = nn.Linear(hidden_dim, input_dim, bias=False)

        self.weight_reg_type = weight_reg_type
        self.l2_reg_scalar = None
        self.l2_reg_list = l2_reg_list

        self.encoder.weight.data.normal_(0.0, init_scale)
        self.decoder.weight.data.normal_(0.0, init_scale)

        # configure regularization parameters

        assert self.weight_reg_type is None or isinstance(self.l2_reg_list, list), \
            "l2_reg_list must be a list if weight_reg_type is not None"

        assert self.l2_reg_list is None or len(self.l2_reg_list) == hidden_dim, \
            "Length of l2_reg_list must match latent dimension"

        if weight_reg_type in ("uniform_product", "uniform_sum"):
            self.l2_reg_scalar = l2_reg_list[0] ** 2    # more efficient to use scalar than diag_weights

        elif weight_reg_type == "non_uniform_sum":
            self.reg_weights = torch.tensor(
                np.array(self.l2_reg_list).astype(np.float32)
            )
            self.diag_weights = nn.Parameter(torch.diag(self.reg_weights), requires_grad=False)

    def forward(self, x):
        return self.get_reconstruction_loss(x) + self._get_reg_loss()

    def compute_trace_norm(self):
        """
        Computes the trace norm of the autoencoder, as well as decoder and encoder individually
        :return: trace_norm(W2W1), trace_norm(W1), trace_norm(W2)
        """
        return torch.matmul(self.decoder.weight, self.encoder.weight).norm(p='nuc'), \
               self.encoder.weight.norm(p='nuc'), \
               self.decoder.weight.norm(p='nuc'),

    def get_reconstruction_loss(self, x):
        z = self.encoder(x)
        recon = self.decoder(z)

        recon_loss = torch.sum((x - recon) ** 2) / len(x)
        return recon_loss

    def get_reg_weights_np(self):
        if self.weight_reg_type is None:
            return np.zeros(self.hidden_dim)
        return np.array(self.l2_reg_list)

    def _get_reg_loss(self):
        # Standard L2 regularization, applied to W2W1 (product loss)
        if self.weight_reg_type == 'uniform_product':
            return self.l2_reg_scalar * (torch.norm(torch.matmul(self.decoder.weight, self.encoder.weight)) ** 2)

        # Standard L2 regularization for encoder and decoder separately (sum loss)
        elif self.weight_reg_type == 'uniform_sum':
            # regularize both encoder and decoder
            return self.l2_reg_scalar * (torch.norm(self.encoder.weight) ** 2 + torch.norm(self.decoder.weight) ** 2)

        # non-uniform sum
        elif self.weight_reg_type == 'non_uniform_sum':
            return torch.norm(self.diag_weights @ self.encoder.weight) ** 2 \
                   + torch.norm(self.decoder.weight @ self.diag_weights) ** 2

        # Do not apply regularization
        elif self.weight_reg_type is None:
            return 0.0

        else:
            raise ValueError("weight_reg_type should be one of (uniform_product, uniform_sum, non_uniform_sum, None)")
            

            
            
class LinearAENestedDropout(nn.Module):
    def __init__(self,
                 input_dim, hidden_dim, init_scale=0.001, prior_probs=None, use_expectation=False):
        super(LinearAENestedDropout, self).__init__()

        self.use_expectation = use_expectation

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.encoder = nn.Linear(input_dim, hidden_dim, bias=False)
        self.decoder = nn.Linear(hidden_dim, input_dim, bias=False)

        self.encoder.weight.data.normal_(0.0, init_scale)
        self.decoder.weight.data.normal_(0.0, init_scale)

        if prior_probs is None:
            # use geometric distribution
            # p(b) = rho^b (1 - rho) (b = 0 ... k - 2)
            # p(b = k-1) = 1 - sum(p(b), b < k-1)

            self.geom_p = 0.9
            prior_probs = [self.geom_p ** b * (1 - self.geom_p) for b in range(self.hidden_dim - 1)]
            prior_probs.append(1.0 - sum(prior_probs))

        self.prior_probs = torch.tensor(prior_probs)

        cum_probs = [1. - sum(prior_probs[:i]) for i in range(self.hidden_dim)]
        self.cum_probs = torch.tensor(cum_probs)
        self.diag_expected_mask = nn.Parameter(torch.diag(self.cum_probs), requires_grad=False)
        l_expected_mask = np.zeros((self.hidden_dim, self.hidden_dim))
        for i in range(self.hidden_dim):
            l_expected_mask[i, i] = cum_probs[i]
            l_expected_mask[:i, i] = cum_probs[i]
            l_expected_mask[i, :i] = cum_probs[i]
        self.l_expected_mask = nn.Parameter(torch.from_numpy(l_expected_mask).float(), requires_grad=False)

    def forward(self, x):
        if self.use_expectation:
            tr_xtx = torch.norm(x) ** 2
            w1_x = self.encoder(x).T        # (k, n)
            tr_xt_w2_y = torch.trace(w1_x @ x @ self.decoder.weight @ self.diag_expected_mask)
            w2t_w2_masked = (self.decoder.weight.T @ self.decoder.weight) * self.l_expected_mask
            tr_yt_w2t_w2_y = torch.trace(w1_x @ w1_x.T @ w2t_w2_masked)

            recon_loss = (tr_xtx - 2 * tr_xt_w2_y + tr_yt_w2t_w2_y) / len(x)
        else:
            hidden_units = self.encoder(x)
            hidden_units = self._nested_dropout(hidden_units)
            recon = self.decoder(hidden_units)

            recon_loss = torch.sum((x - recon) ** 2) / len(x)
        return recon_loss

    def _nested_dropout(self, hidden_units):
        prior_inds = torch.multinomial(self.prior_probs, len(hidden_units), replacement=True)
        mask = torch.ones_like(hidden_units)
        for hdim_i in range(1, self.hidden_dim):
            drop_row_inds = (prior_inds < hdim_i).float()     # 1 if row is dropped, 0 if kept
            mask[:, hdim_i] = 1 - drop_row_inds     # 1 if kept, 0 if dropped

        masked_hidden_units = hidden_units * mask
        return masked_hidden_units

    def get_reconstruction_loss(self, x):
        hidden_units = self.encoder(x)
        recon = self.decoder(hidden_units)

        recon_loss = torch.sum((x - recon) ** 2) / len(x)
        return recon_loss
    


#### DEFINE MODEL TRAINING FUNCTION train_models

In [3]:
#### DEFINE MODEL TRAINING FUNCTION train_models ####

import os
import torch
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def train_models(data_loader, train_itr, metrics_dict, model_configs, eval_metrics_list=None):
    # Initialize model
#     print('new')
#     model = model_config.get_model()
#     optimizer = model_config.get_optimizer()
    
    for train_i in range(train_itr):
        for x in data_loader:
            x_cuda = x.to(device)

            # ---- Optimize ----
            losses = {}

            model = model_config.get_model()
            optimizer = model_config.get_optimizer()

            optimizer.zero_grad()

            loss = model(x_cuda)

            loss.backward()
            
            if model_config.type == 'rotation':
#                 print('Before rotation: ', model.decoder.weight.grad)

                # Rotation Augmented Gradient (RAG) 
                y = model.encoder.weight @ x_cuda.T
                yy_t_norm = y @ y.T / float(len(x))
                yy_t_upper = yy_t_norm - yy_t_norm.tril()
                gamma = 0.5 * (yy_t_upper - yy_t_upper.T)
#                 print('Gamma: ', gamma)
                model.encoder.weight.grad -= gamma @ model.encoder.weight
                model.decoder.weight.grad -= model.decoder.weight @ gamma.T

#                 print('After rotation: ', model.decoder.weight.grad)

            optimizer.step()

            losses[model_config.name] = loss.item()

        # ---- Log statistics ----
        if train_i == 0 or (train_i + 1) % 10 == 0:
            print("".join(["Iteration = {}, Losses: ".format(train_i + 1)]
                          + ["{} = {} ".format(key, val) for key, val in losses.items()]))
            
#     model_config.model = model
    
    return model


In [4]:
# x=next(enumerate(loader))[1]
# x_cuda = x.to(device)
# print(model.encoder.weight.grad)
# y = model.encoder.weight @ x_cuda.T
# yy_t_norm = y @ y.T / float(len(x))
# yy_t_upper = yy_t_norm - yy_t_norm.tril()
# gamma = 0.5 * (yy_t_upper - yy_t_upper.T)
# model.encoder.weight.grad -= gamma @ model.encoder.weight
# print(model.encoder.weight.grad)


#### DEFINE EVALUATION METRICS

In [5]:
import os
import torch
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def get_weight_tensor_from_seq(weight_seq):
    if isinstance(weight_seq, nn.Linear):
        return weight_seq.weight.detach()
    elif isinstance(weight_seq, nn.Sequential):
        weight_tensor = None
        for layer in weight_seq:
            if isinstance(layer, nn.Linear):
                layer_weight = layer.weight.detach()
                if weight_tensor is None:
                    weight_tensor = layer_weight
                else:
                    weight_tensor = layer_weight @ weight_tensor
            elif isinstance(layer, nn.BatchNorm1d):
                bn_weight = layer.weight.detach()

                # ignore bias

                if weight_tensor is None:
                    weight_tensor = torch.diag(bn_weight)
                else:
                    weight_tensor = torch.diag(bn_weight) @ weight_tensor
            else:
                raise ValueError("Layer type {} not supported!".format(type(layer)))
        return weight_tensor


def metric_transpose_theorem(model):
    """
    Metric for how close encoder and decoder.T are
    :param model: LinearAE model
    :return: ||W1 - W2^T||_F^2 / hidden_dim
    """
    encoder_weight = get_weight_tensor_from_seq(model.encoder)
    decoder_weight = get_weight_tensor_from_seq(model.decoder)

    transpose_metric = torch.norm(encoder_weight - decoder_weight.T) ** 2
    return transpose_metric.item() / float(model.hidden_dim)


def metric_alignment(model, gt_eigvectors):
    """
    Metric for alignment of decoder columns to ground truth eigenvectors
    :param model: Linear AE model
    :param gt_eigvectors: ground truth eigenvectors (input_dims,hidden_dims)
    :return: sum_i (1 - max_j (cos(eigvector_i, normalized_decoder column_j)))
    """
    decoder_weight = get_weight_tensor_from_seq(model.decoder)
    decoder_np = decoder_weight.detach().cpu().numpy()

    # normalize columns of gt_eigvectors
    norm_gt_eigvectors = gt_eigvectors / np.linalg.norm(gt_eigvectors, axis=0)
    # normalize columns of decoder
    norm_decoder = decoder_np / (np.linalg.norm(decoder_np, axis=0) + 1e-8)

    total_angles = 0.0
    for eig_i in range(gt_eigvectors.shape[1]):
        eigvector = norm_gt_eigvectors[:, eig_i]
        total_angles += 1. - np.max(np.abs(norm_decoder.T @ eigvector)) ** 2

    return total_angles / float(model.hidden_dim)


def metric_subspace(model, gt_eigvectors, gt_eigs):
    decoder_weight = get_weight_tensor_from_seq(model.decoder)
    decoder_np = decoder_weight.detach().cpu().numpy()

    # k - tr(UU^T WW^T), where W is left singular vector matrix of decoder
    u, s, vh = np.linalg.svd(decoder_np, full_matrices=False)
    return 1 - np.trace(gt_eigvectors @ gt_eigvectors.T @ u @ u.T) / float(model.hidden_dim)


def metric_loss(model, data_loader):
    """
    Measures the full batch loss
    :param model: a linear (variational) AE model
    :param data_loader: full batch data loader. Should be different from the training data loader, if in minibatch mode
    """
    loss = None
    for x in data_loader:
        loss = model(x.to(device)).item()
    return loss


def metric_recon_loss(model, data_loader):
    recon_loss = None
    for x in data_loader:
        recon_loss = model.get_reconstruction_loss(x.to(device)).item()
    return recon_loss


### TRAIN A MODEL

#### Get the data

In [6]:
##### GET DATA ####
import os
import torch
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

seed=1234
# set random seed
np.random.seed(seed)
torch.manual_seed(seed)

input_dim = 1000
hidden_dim = 400

n_data = 5000
batch_size = n_data

max_sv = float(input_dim) * 0.1
min_sv = 1.0
sigma = 0.5

gt_data = DataGeneratorPCA(input_dim, hidden_dim, min_sv=min_sv, max_sv=max_sv, total=n_data)
data = DataGeneratorPCA(input_dim, hidden_dim, load_data=gt_data.x_sample)

loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=False)


#### Define the model

In [7]:
#### Define the model ####

import os
import torch
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#### DEFINE MODEL #####
model_dict = dict(
    model_name='rotation',
    model_type='rotation',
    model_class=LinearAE,
    extra_model_args = {"weight_reg_type": None},
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    init_scale=0.0001,
    optim_class=torch.optim.SGD,
    extra_optim_args={'momentum': 0.9, 'nesterov': True},
    lr=0.0001,
#     optim_class=torch.optim.Adam,
#     extra_optim_args={},
#     lr=0.0003,
    train_itr=1000,#50000,
    seed=seed
)

# model config contains the model 
model_config = ModelConfig(
        model_name=model_dict['model_name'],
        model_type=model_dict['model_type'],
        model_class=model_dict['model_class'],
        input_dim=model_dict['input_dim'], 
        hidden_dim=model_dict['hidden_dim'],
        init_scale=model_dict['init_scale'],
        extra_model_args=model_dict['extra_model_args'],
        optim_class=model_dict['optim_class'],
        lr=model_dict['lr'],
        extra_optim_args=model_dict['extra_optim_args']
    )

print(model_dict,'\n')
print(model_config.get_model(),'\n')
print(model_config.get_optimizer())

print('Transpose:', metric_transpose_theorem(model_config.get_model()),'\n') # how close encoder and decoder.T are 
print('Distance to axis-aligned solution:', metric_alignment(model_config.get_model(), data.eigvectors),'\n') # alignment of decoder columns to ground truth eigenvectors
print('Distance to optimal subspace):', metric_subspace(model_config.get_model(), data.eigvectors, data.eigs),'\n')



{'model_name': 'rotation', 'model_type': 'rotation', 'model_class': <class '__main__.LinearAE'>, 'extra_model_args': {'weight_reg_type': None}, 'input_dim': 1000, 'hidden_dim': 400, 'init_scale': 0.0001, 'optim_class': <class 'torch.optim.sgd.SGD'>, 'extra_optim_args': {'momentum': 0.9, 'nesterov': True}, 'lr': 0.0001, 'train_itr': 1000, 'seed': 1234} 

LinearAE(
  (encoder): Linear(in_features=1000, out_features=400, bias=False)
  (decoder): Linear(in_features=400, out_features=1000, bias=False)
) 

SGD (
Parameter Group 0
    dampening: 0
    foreach: None
    lr: 0.0001
    maximize: False
    momentum: 0.9
    nesterov: True
    weight_decay: 0
)
Transpose: 1.9994410686194896e-05 

Distance to axis-aligned solution: 0.9899812176786509 

Distance to optimal subspace): 0.6012058744201634 



In [8]:
modelIn = model_config.get_model()
print('Reconstrution Loss:', metric_recon_loss(modelIn, loader),'\n') # full batch loss
print('Loss:', metric_loss(modelIn, loader),'\n')
print('Transpose:', metric_transpose_theorem(modelIn),'\n') # how close encoder and decoder.T are 
print('Distance to axis-aligned solution:', metric_alignment(modelIn, data.eigvectors),'\n') # alignment of decoder columns to ground truth eigenvectors
print('Distance to optimal subspace):', metric_subspace(modelIn, data.eigvectors, data.eigs),'\n')

trained_model = train_models(data_loader=loader, train_itr=100, metrics_dict=None, model_configs=model_config)

modelIn = model_config.get_model()
print('Reconstrution Loss:', metric_recon_loss(modelIn, loader),'\n') # full batch loss
print('Loss:', metric_loss(modelIn, loader),'\n')
print('Transpose:', metric_transpose_theorem(modelIn),'\n') # how close encoder and decoder.T are 
print('Distance to axis-aligned solution:', metric_alignment(modelIn, data.eigvectors),'\n') # alignment of decoder columns to ground truth eigenvectors
print('Distance to optimal subspace):', metric_subspace(modelIn, data.eigvectors, data.eigs),'\n')



Reconstrution Loss: 50555.9375 

Loss: 50555.9375 

Transpose: 1.9994410686194896e-05 

Distance to axis-aligned solution: 0.9899812176786509 

Distance to optimal subspace): 0.6012058744201634 

Gamma:  tensor([[ 0.0000e+00,  7.7891e-06, -8.5934e-06,  ..., -6.3761e-06,
         -6.1098e-06, -1.5279e-05],
        [-7.7891e-06,  0.0000e+00,  4.4910e-06,  ...,  7.6507e-06,
          2.2580e-05,  4.7529e-06],
        [ 8.5934e-06, -4.4910e-06,  0.0000e+00,  ..., -8.2844e-06,
         -1.0614e-05, -7.0378e-06],
        ...,
        [ 6.3761e-06, -7.6507e-06,  8.2844e-06,  ...,  0.0000e+00,
         -2.1074e-06,  1.2068e-06],
        [ 6.1098e-06, -2.2580e-05,  1.0614e-05,  ...,  2.1074e-06,
          0.0000e+00, -1.4485e-05],
        [ 1.5279e-05, -4.7529e-06,  7.0378e-06,  ..., -1.2068e-06,
          1.4485e-05,  0.0000e+00]], grad_fn=<MulBackward0>)
Iteration = 1, Losses: rotation = 50555.9375 
Gamma:  tensor([[ 0.0000e+00,  7.9253e-06, -8.2184e-06,  ..., -7.0982e-06,
         -5.8463e-0

Gamma:  tensor([[ 0.0000e+00, -8.4320e-06, -2.5722e-05,  ..., -8.1189e-06,
          1.3046e-05, -4.5721e-05],
        [ 8.4320e-06,  0.0000e+00, -2.7197e-05,  ..., -9.3232e-05,
          3.5901e-05,  2.5303e-05],
        [ 2.5722e-05,  2.7197e-05,  0.0000e+00,  ..., -1.1172e-04,
          4.5427e-05, -7.3799e-05],
        ...,
        [ 8.1189e-06,  9.3232e-05,  1.1172e-04,  ...,  0.0000e+00,
         -5.6206e-06, -5.1651e-05],
        [-1.3046e-05, -3.5901e-05, -4.5427e-05,  ...,  5.6206e-06,
          0.0000e+00, -1.1296e-04],
        [ 4.5721e-05, -2.5303e-05,  7.3799e-05,  ...,  5.1651e-05,
          1.1296e-04,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00, -1.5545e-05, -3.2187e-05,  ..., -6.6926e-06,
          1.5123e-05, -5.5364e-05],
        [ 1.5545e-05,  0.0000e+00, -3.6151e-05,  ..., -1.2141e-04,
          3.9389e-05,  3.2808e-05],
        [ 3.2187e-05,  3.6151e-05,  0.0000e+00,  ..., -1.4372e-04,
          5.9510e-05, -9.0828e-05],
        ...,
      

Gamma:  tensor([[ 0.0000e+00, -1.9219e-03, -9.0752e-04,  ...,  2.5944e-04,
         -3.6319e-04, -1.6311e-03],
        [ 1.9219e-03,  0.0000e+00, -1.0049e-03,  ..., -3.6253e-03,
         -1.7898e-04,  1.8666e-03],
        [ 9.0752e-04,  1.0049e-03,  0.0000e+00,  ..., -5.3413e-03,
          1.5301e-03, -2.7186e-03],
        ...,
        [-2.5944e-04,  3.6253e-03,  5.3413e-03,  ...,  0.0000e+00,
          5.6667e-05, -3.0591e-03],
        [ 3.6319e-04,  1.7898e-04, -1.5301e-03,  ..., -5.6667e-05,
          0.0000e+00, -4.9320e-03],
        [ 1.6311e-03, -1.8666e-03,  2.7186e-03,  ...,  3.0591e-03,
          4.9320e-03,  0.0000e+00]], grad_fn=<MulBackward0>)
Iteration = 30, Losses: rotation = 50504.01953125 
Gamma:  tensor([[ 0.0000, -0.0026, -0.0012,  ...,  0.0004, -0.0005, -0.0021],
        [ 0.0026,  0.0000, -0.0013,  ..., -0.0046, -0.0003,  0.0025],
        [ 0.0012,  0.0013,  0.0000,  ..., -0.0069,  0.0019, -0.0035],
        ...,
        [-0.0004,  0.0046,  0.0069,  ...,  0.0000,  0.

Gamma:  tensor([[ 0.0000, -0.3452, -0.1128,  ...,  0.1100, -0.0444, -0.2136],
        [ 0.3452,  0.0000, -0.0445,  ..., -0.3283, -0.1489,  0.2678],
        [ 0.1128,  0.0445,  0.0000,  ..., -0.6658,  0.1089, -0.3268],
        ...,
        [-0.1100,  0.3283,  0.6658,  ...,  0.0000,  0.0657, -0.3414],
        [ 0.0444,  0.1489, -0.1089,  ..., -0.0657,  0.0000, -0.5493],
        [ 0.2136, -0.2678,  0.3268,  ...,  0.3414,  0.5493,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000, -0.4285, -0.1407,  ...,  0.1345, -0.0535, -0.2517],
        [ 0.4285,  0.0000, -0.0572,  ..., -0.3970, -0.1971,  0.3236],
        [ 0.1407,  0.0572,  0.0000,  ..., -0.8252,  0.1382, -0.4112],
        ...,
        [-0.1345,  0.3970,  0.8252,  ...,  0.0000,  0.0892, -0.4072],
        [ 0.0535,  0.1971, -0.1382,  ..., -0.0892,  0.0000, -0.6699],
        [ 0.2517, -0.3236,  0.4112,  ...,  0.4072,  0.6699,  0.0000]],
       grad_fn=<MulBackward0>)
Iteration = 50, Losses: rotation = 45504.73828125 
Gam

Gamma:  tensor([[ 0.0000, -1.6949, -0.7395,  ..., -0.4698,  0.7689,  1.3051],
        [ 1.6949,  0.0000, -1.0907,  ..., -0.5486, -0.7683, -0.1145],
        [ 0.7395,  1.0907,  0.0000,  ..., -1.4974,  1.8430, -0.1016],
        ...,
        [ 0.4698,  0.5486,  1.4974,  ...,  0.0000,  0.6682, -0.4160],
        [-0.7689,  0.7683, -1.8430,  ..., -0.6682,  0.0000, -1.1063],
        [-1.3051,  0.1145,  0.1016,  ...,  0.4160,  1.1063,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000, -1.6642, -0.7319,  ..., -0.3729,  0.9490,  1.2150],
        [ 1.6642,  0.0000, -0.9933,  ..., -0.5634, -0.5153, -0.1059],
        [ 0.7319,  0.9933,  0.0000,  ..., -1.3617,  1.7611, -0.0575],
        ...,
        [ 0.3729,  0.5634,  1.3617,  ...,  0.0000,  0.6377, -0.3312],
        [-0.9490,  0.5153, -1.7611,  ..., -0.6377,  0.0000, -1.0407],
        [-1.2150,  0.1059,  0.0575,  ...,  0.3312,  1.0407,  0.0000]],
       grad_fn=<MulBackward0>)
Iteration = 70, Losses: rotation = 28950.736328125 
Ga

Gamma:  tensor([[ 0.0000, -1.5686, -0.5154,  ..., -0.4386,  0.0191, -0.0072],
        [ 1.5686,  0.0000,  0.0982,  ..., -0.2656,  0.3665,  0.1365],
        [ 0.5154, -0.0982,  0.0000,  ..., -0.8401,  0.9262, -0.5503],
        ...,
        [ 0.4386,  0.2656,  0.8401,  ...,  0.0000,  0.4961,  0.1406],
        [-0.0191, -0.3665, -0.9262,  ..., -0.4961,  0.0000, -0.8702],
        [ 0.0072, -0.1365,  0.5503,  ..., -0.1406,  0.8702,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000, -1.5243, -0.4974,  ..., -0.5007, -0.0507,  0.0448],
        [ 1.5243,  0.0000,  0.1008,  ..., -0.2387,  0.2917,  0.1413],
        [ 0.4974, -0.1008,  0.0000,  ..., -0.8109,  0.8693, -0.5488],
        ...,
        [ 0.5007,  0.2387,  0.8109,  ...,  0.0000,  0.5975,  0.1439],
        [ 0.0507, -0.2917, -0.8693,  ..., -0.5975,  0.0000, -0.8747],
        [-0.0448, -0.1413,  0.5488,  ..., -0.1439,  0.8747,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000, -1.4776, -0.4766,  ..., -0.5

In [9]:
##### GET DATA ####
import os
import torch
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

seed=1234
# set random seed
np.random.seed(seed)
torch.manual_seed(seed)

input_dim = 1000
hidden_dim = 50

n_data = 5000
batch_size = n_data

max_sv = float(input_dim) * 0.1
min_sv = 1.0
sigma = 0.5

gt_data = DataGeneratorPCA(input_dim, hidden_dim, min_sv=min_sv, max_sv=max_sv, total=n_data)
data = DataGeneratorPCA(input_dim, hidden_dim, load_data=gt_data.x_sample)

loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=False)

#### Define the model ####

import os
import torch
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#### DEFINE MODEL #####
model_dict = dict(
    model_name='rotation',
    model_type='rotation',
    model_class=LinearAE,
    extra_model_args = {"weight_reg_type": None},
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    init_scale=0.0001,
    optim_class=torch.optim.SGD,
    extra_optim_args={'momentum': 0.9, 'nesterov': True},
    lr=0.0001,
#     optim_class=torch.optim.Adam,
#     extra_optim_args={},
#     lr=0.0003,
    train_itr=1000,#50000,
    seed=seed
)

# model config contains the model 
model_config = ModelConfig(
        model_name=model_dict['model_name'],
        model_type=model_dict['model_type'],
        model_class=model_dict['model_class'],
        input_dim=model_dict['input_dim'], 
        hidden_dim=model_dict['hidden_dim'],
        init_scale=model_dict['init_scale'],
        extra_model_args=model_dict['extra_model_args'],
        optim_class=model_dict['optim_class'],
        lr=model_dict['lr'],
        extra_optim_args=model_dict['extra_optim_args']
    )

print(model_dict,'\n')
print(model_config.get_model(),'\n')
print(model_config.get_optimizer())

print('Transpose:', metric_transpose_theorem(model_config.get_model()),'\n') # how close encoder and decoder.T are 
print('Distance to axis-aligned solution:', metric_alignment(model_config.get_model(), data.eigvectors),'\n') # alignment of decoder columns to ground truth eigenvectors
print('Distance to optimal subspace):', metric_subspace(model_config.get_model(), data.eigvectors, data.eigs),'\n')



{'model_name': 'rotation', 'model_type': 'rotation', 'model_class': <class '__main__.LinearAE'>, 'extra_model_args': {'weight_reg_type': None}, 'input_dim': 1000, 'hidden_dim': 50, 'init_scale': 0.0001, 'optim_class': <class 'torch.optim.sgd.SGD'>, 'extra_optim_args': {'momentum': 0.9, 'nesterov': True}, 'lr': 0.0001, 'train_itr': 1000, 'seed': 1234} 

LinearAE(
  (encoder): Linear(in_features=1000, out_features=50, bias=False)
  (decoder): Linear(in_features=50, out_features=1000, bias=False)
) 

SGD (
Parameter Group 0
    dampening: 0
    foreach: None
    lr: 0.0001
    maximize: False
    momentum: 0.9
    nesterov: True
    weight_decay: 0
)
Transpose: 1.9993053283542393e-05 

Distance to axis-aligned solution: 0.993799557974467 

Distance to optimal subspace): 0.9513405523538548 



In [10]:
modelIn = model_config.get_model()
print('Reconstrution Loss:', metric_recon_loss(modelIn, loader),'\n') # full batch loss
print('Loss:', metric_loss(modelIn, loader),'\n')
print('Transpose:', metric_transpose_theorem(modelIn),'\n') # how close encoder and decoder.T are 
print('Distance to axis-aligned solution:', metric_alignment(modelIn, data.eigvectors),'\n') # alignment of decoder columns to ground truth eigenvectors
print('Distance to optimal subspace):', metric_subspace(modelIn, data.eigvectors, data.eigs),'\n')

trained_model = train_models(data_loader=loader, train_itr=1000, metrics_dict=None, model_configs=model_config)

modelIn = model_config.get_model()
print('Reconstrution Loss:', metric_recon_loss(modelIn, loader),'\n') # full batch loss
print('Loss:', metric_loss(modelIn, loader),'\n')
print('Transpose:', metric_transpose_theorem(modelIn),'\n') # how close encoder and decoder.T are 
print('Distance to axis-aligned solution:', metric_alignment(modelIn, data.eigvectors),'\n') # alignment of decoder columns to ground truth eigenvectors
print('Distance to optimal subspace):', metric_subspace(modelIn, data.eigvectors, data.eigs),'\n')



Reconstrution Loss: 50555.9375 

Loss: 50555.9375 

Transpose: 1.9993053283542393e-05 

Distance to axis-aligned solution: 0.993799557974467 

Distance to optimal subspace): 0.9513405523538548 

Gamma:  tensor([[ 0.0000e+00,  2.4825e-08,  1.1136e-05,  ...,  1.7017e-05,
          6.6775e-06, -5.8729e-06],
        [-2.4825e-08,  0.0000e+00,  1.3265e-05,  ..., -5.7284e-06,
          2.0758e-05, -1.2807e-06],
        [-1.1136e-05, -1.3265e-05,  0.0000e+00,  ...,  2.8009e-06,
         -6.3987e-06, -6.9829e-06],
        ...,
        [-1.7017e-05,  5.7284e-06, -2.8009e-06,  ...,  0.0000e+00,
          1.6798e-06, -5.9501e-06],
        [-6.6775e-06, -2.0758e-05,  6.3987e-06,  ..., -1.6798e-06,
          0.0000e+00,  6.1862e-06],
        [ 5.8729e-06,  1.2807e-06,  6.9829e-06,  ...,  5.9501e-06,
         -6.1862e-06,  0.0000e+00]], grad_fn=<MulBackward0>)
Iteration = 1, Losses: rotation = 50555.9375 
Gamma:  tensor([[ 0.0000e+00,  7.2502e-07,  1.1781e-05,  ...,  1.7055e-05,
          6.4766e-06

Gamma:  tensor([[ 0.0000e+00,  3.4359e-05,  6.6931e-05,  ...,  5.4405e-06,
          1.2831e-05, -5.5701e-05],
        [-3.4359e-05,  0.0000e+00,  5.1028e-05,  ..., -4.6271e-05,
          1.7901e-05,  1.1047e-05],
        [-6.6931e-05, -5.1028e-05,  0.0000e+00,  ...,  5.1058e-05,
         -3.5043e-06,  7.0114e-06],
        ...,
        [-5.4405e-06,  4.6271e-05, -5.1058e-05,  ...,  0.0000e+00,
          4.3875e-05, -4.1429e-05],
        [-1.2831e-05, -1.7901e-05,  3.5043e-06,  ..., -4.3875e-05,
          0.0000e+00, -3.4105e-05],
        [ 5.5701e-05, -1.1047e-05, -7.0114e-06,  ...,  4.1429e-05,
          3.4105e-05,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  4.3127e-05,  8.2547e-05,  ...,  1.9163e-06,
          1.4856e-05, -6.6221e-05],
        [-4.3127e-05,  0.0000e+00,  6.2539e-05,  ..., -5.9569e-05,
          2.0328e-05,  1.4906e-05],
        [-8.2547e-05, -6.2539e-05,  0.0000e+00,  ...,  6.4217e-05,
         -4.9444e-06,  1.0509e-05],
        ...,
      

Gamma:  tensor([[ 0.0000,  0.0009,  0.0015,  ..., -0.0004, -0.0002, -0.0009],
        [-0.0009,  0.0000,  0.0012,  ..., -0.0016,  0.0005,  0.0007],
        [-0.0015, -0.0012,  0.0000,  ...,  0.0015, -0.0004,  0.0003],
        ...,
        [ 0.0004,  0.0016, -0.0015,  ...,  0.0000,  0.0012, -0.0020],
        [ 0.0002, -0.0005,  0.0004,  ..., -0.0012,  0.0000, -0.0016],
        [ 0.0009, -0.0007, -0.0003,  ...,  0.0020,  0.0016,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.0011,  0.0019,  ..., -0.0006, -0.0003, -0.0012],
        [-0.0011,  0.0000,  0.0014,  ..., -0.0021,  0.0007,  0.0009],
        [-0.0019, -0.0014,  0.0000,  ...,  0.0019, -0.0006,  0.0004],
        ...,
        [ 0.0006,  0.0021, -0.0019,  ...,  0.0000,  0.0015, -0.0026],
        [ 0.0003, -0.0007,  0.0006,  ..., -0.0015,  0.0000, -0.0021],
        [ 0.0012, -0.0009, -0.0004,  ...,  0.0026,  0.0021,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.0014,  0.0023,  ..., -0.0

Gamma:  tensor([[ 0.0000,  0.0744,  0.1020,  ..., -0.1044, -0.1140, -0.0671],
        [-0.0744,  0.0000,  0.0965,  ..., -0.1908,  0.1243,  0.1268],
        [-0.1020, -0.0965,  0.0000,  ...,  0.1808, -0.0757,  0.0561],
        ...,
        [ 0.1044,  0.1908, -0.1808,  ...,  0.0000,  0.1338, -0.3157],
        [ 0.1140, -0.1243,  0.0757,  ..., -0.1338,  0.0000, -0.2058],
        [ 0.0671, -0.1268, -0.0561,  ...,  0.3157,  0.2058,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.0941,  0.1268,  ..., -0.1371, -0.1520, -0.0839],
        [-0.0941,  0.0000,  0.1214,  ..., -0.2450,  0.1642,  0.1649],
        [-0.1268, -0.1214,  0.0000,  ...,  0.2328, -0.0980,  0.0724],
        ...,
        [ 0.1371,  0.2450, -0.2328,  ...,  0.0000,  0.1705, -0.4086],
        [ 0.1520, -0.1642,  0.0980,  ..., -0.1705,  0.0000, -0.2635],
        [ 0.0839, -0.1649, -0.0724,  ...,  0.4086,  0.2635,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.1187,  0.1575,  ..., -0.1

Gamma:  tensor([[ 0.0000,  0.9428,  3.2963,  ..., -3.0275, -1.6901, -0.4048],
        [-0.9428,  0.0000,  1.7736,  ..., -5.5357,  2.2181,  2.7341],
        [-3.2963, -1.7736,  0.0000,  ...,  4.4944, -2.6574, -0.0698],
        ...,
        [ 3.0275,  5.5357, -4.4944,  ...,  0.0000,  1.2860, -4.0295],
        [ 1.6901, -2.2181,  2.6574,  ..., -1.2860,  0.0000, -4.0672],
        [ 0.4048, -2.7341,  0.0698,  ...,  4.0295,  4.0672,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.7338,  3.4898,  ..., -3.1017, -1.3185, -0.2377],
        [-0.7338,  0.0000,  1.8108,  ..., -5.6872,  1.8522,  2.6397],
        [-3.4898, -1.8108,  0.0000,  ...,  4.3879, -2.6795, -0.0780],
        ...,
        [ 3.1017,  5.6872, -4.3879,  ...,  0.0000,  0.8471, -3.5859],
        [ 1.3185, -1.8522,  2.6795,  ..., -0.8471,  0.0000, -4.2368],
        [ 0.2377, -2.6397,  0.0780,  ...,  3.5859,  4.2368,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.4676,  3.5426,  ..., -3.1

Gamma:  tensor([[ 0.0000,  0.0851, -0.8961,  ..., -0.4899, -0.9162, -0.2901],
        [-0.0851,  0.0000, -0.1543,  ...,  0.2320,  1.0668,  0.6112],
        [ 0.8961,  0.1543,  0.0000,  ...,  0.3489, -0.3732,  0.2524],
        ...,
        [ 0.4899, -0.2320, -0.3489,  ...,  0.0000,  0.1532,  0.0892],
        [ 0.9162, -1.0668,  0.3732,  ..., -0.1532,  0.0000,  0.2377],
        [ 0.2901, -0.6112, -0.2524,  ..., -0.0892, -0.2377,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.1354, -0.7333,  ..., -0.6285, -0.9560, -0.3042],
        [-0.1354,  0.0000, -0.0559,  ..., -0.0326,  1.1311,  0.7152],
        [ 0.7333,  0.0559,  0.0000,  ...,  0.5537, -0.4977,  0.2518],
        ...,
        [ 0.6285,  0.0326, -0.5537,  ...,  0.0000,  0.1569, -0.0458],
        [ 0.9560, -1.1311,  0.4977,  ..., -0.1569,  0.0000, -0.0098],
        [ 0.3042, -0.7152, -0.2518,  ...,  0.0458,  0.0098,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.1715, -0.5758,  ..., -0.7

Gamma:  tensor([[ 0.0000, -0.0468, -0.3023,  ..., -0.9291, -0.4132, -0.2344],
        [ 0.0468,  0.0000,  0.3123,  ..., -0.6987,  0.6415,  0.4757],
        [ 0.3023, -0.3123,  0.0000,  ...,  0.7842, -0.5150,  0.2898],
        ...,
        [ 0.9291,  0.6987, -0.7842,  ...,  0.0000, -0.7952, -0.4256],
        [ 0.4132, -0.6415,  0.5150,  ...,  0.7952,  0.0000, -0.8253],
        [ 0.2344, -0.4757, -0.2898,  ...,  0.4256,  0.8253,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000, -0.0385, -0.3602,  ..., -0.8730, -0.4186, -0.2415],
        [ 0.0385,  0.0000,  0.2802,  ..., -0.6091,  0.6477,  0.4476],
        [ 0.3602, -0.2802,  0.0000,  ...,  0.7330, -0.4787,  0.2781],
        ...,
        [ 0.8730,  0.6091, -0.7330,  ...,  0.0000, -0.7633, -0.3893],
        [ 0.4186, -0.6477,  0.4787,  ...,  0.7633,  0.0000, -0.7344],
        [ 0.2415, -0.4476, -0.2781,  ...,  0.3893,  0.7344,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000, -0.0260, -0.4117,  ..., -0.8

Gamma:  tensor([[ 0.0000,  0.1209, -0.4951,  ..., -0.7101, -0.4169, -0.3432],
        [-0.1209,  0.0000,  0.2375,  ..., -0.4598,  0.7506,  0.4203],
        [ 0.4951, -0.2375,  0.0000,  ...,  0.7435, -0.5113,  0.1872],
        ...,
        [ 0.7101,  0.4598, -0.7435,  ...,  0.0000, -0.6362, -0.3199],
        [ 0.4169, -0.7506,  0.5113,  ...,  0.6362,  0.0000, -0.5964],
        [ 0.3432, -0.4203, -0.1872,  ...,  0.3199,  0.5964,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.1169, -0.5001,  ..., -0.7006, -0.4030, -0.3407],
        [-0.1169,  0.0000,  0.2345,  ..., -0.4530,  0.7393,  0.4095],
        [ 0.5001, -0.2345,  0.0000,  ...,  0.7347, -0.5023,  0.1841],
        ...,
        [ 0.7006,  0.4530, -0.7347,  ...,  0.0000, -0.6440, -0.3277],
        [ 0.4030, -0.7393,  0.5023,  ...,  0.6440,  0.0000, -0.5903],
        [ 0.3407, -0.4095, -0.1841,  ...,  0.3277,  0.5903,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.1132, -0.5071,  ..., -0.6

Gamma:  tensor([[ 0.0000,  0.1121, -0.6569,  ..., -0.4972, -0.2970, -0.3178],
        [-0.1121,  0.0000,  0.1290,  ..., -0.2627,  0.6673,  0.3096],
        [ 0.6569, -0.1290,  0.0000,  ...,  0.5851, -0.4186,  0.1131],
        ...,
        [ 0.4972,  0.2627, -0.5851,  ...,  0.0000, -0.6022, -0.3215],
        [ 0.2970, -0.6673,  0.4186,  ...,  0.6022,  0.0000, -0.4060],
        [ 0.3178, -0.3096, -0.1131,  ...,  0.3215,  0.4060,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.1117, -0.6603,  ..., -0.4923, -0.2916, -0.3168],
        [-0.1117,  0.0000,  0.1262,  ..., -0.2603,  0.6636,  0.3073],
        [ 0.6603, -0.1262,  0.0000,  ...,  0.5820, -0.4181,  0.1108],
        ...,
        [ 0.4923,  0.2603, -0.5820,  ...,  0.0000, -0.6024, -0.3238],
        [ 0.2916, -0.6636,  0.4181,  ...,  0.6024,  0.0000, -0.4041],
        [ 0.3168, -0.3073, -0.1108,  ...,  0.3238,  0.4041,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.1110, -0.6637,  ..., -0.4

Gamma:  tensor([[ 0.0000,  0.0908, -0.7465,  ..., -0.3728, -0.1938, -0.2788],
        [-0.0908,  0.0000,  0.0472,  ..., -0.1800,  0.5748,  0.2488],
        [ 0.7465, -0.0472,  0.0000,  ...,  0.4854, -0.3735,  0.0642],
        ...,
        [ 0.3728,  0.1800, -0.4854,  ...,  0.0000, -0.5829, -0.3667],
        [ 0.1938, -0.5748,  0.3735,  ...,  0.5829,  0.0000, -0.3263],
        [ 0.2788, -0.2488, -0.0642,  ...,  0.3667,  0.3263,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.0903, -0.7504,  ..., -0.3678, -0.1908, -0.2772],
        [-0.0903,  0.0000,  0.0435,  ..., -0.1764,  0.5713,  0.2468],
        [ 0.7504, -0.0435,  0.0000,  ...,  0.4812, -0.3719,  0.0623],
        ...,
        [ 0.3678,  0.1764, -0.4812,  ...,  0.0000, -0.5811, -0.3674],
        [ 0.1908, -0.5713,  0.3719,  ...,  0.5811,  0.0000, -0.3225],
        [ 0.2772, -0.2468, -0.0623,  ...,  0.3674,  0.3225,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.0899, -0.7541,  ..., -0.3

Gamma:  tensor([[ 0.0000,  0.0839, -0.8045,  ..., -0.2933, -0.1417, -0.2472],
        [-0.0839,  0.0000, -0.0157,  ..., -0.1267,  0.5059,  0.2133],
        [ 0.8045,  0.0157,  0.0000,  ...,  0.4126, -0.3406,  0.0353],
        ...,
        [ 0.2933,  0.1267, -0.4126,  ...,  0.0000, -0.5545, -0.3883],
        [ 0.1417, -0.5059,  0.3406,  ...,  0.5545,  0.0000, -0.2676],
        [ 0.2472, -0.2133, -0.0353,  ...,  0.3883,  0.2676,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.0838, -0.8071,  ..., -0.2896, -0.1395, -0.2456],
        [-0.0838,  0.0000, -0.0188,  ..., -0.1243,  0.5023,  0.2116],
        [ 0.8071,  0.0188,  0.0000,  ...,  0.4090, -0.3387,  0.0341],
        ...,
        [ 0.2896,  0.1243, -0.4090,  ...,  0.0000, -0.5530, -0.3893],
        [ 0.1395, -0.5023,  0.3387,  ...,  0.5530,  0.0000, -0.2646],
        [ 0.2456, -0.2116, -0.0341,  ...,  0.3893,  0.2646,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.0837, -0.8097,  ..., -0.2

Gamma:  tensor([[ 0.0000,  0.0886, -0.8472,  ..., -0.2284, -0.1075, -0.2158],
        [-0.0886,  0.0000, -0.0715,  ..., -0.0861,  0.4389,  0.1823],
        [ 0.8472,  0.0715,  0.0000,  ...,  0.3462, -0.3040,  0.0150],
        ...,
        [ 0.2284,  0.0861, -0.3462,  ...,  0.0000, -0.5266, -0.4037],
        [ 0.1075, -0.4389,  0.3040,  ...,  0.5266,  0.0000, -0.2122],
        [ 0.2158, -0.1823, -0.0150,  ...,  0.4037,  0.2122,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.0892, -0.8488,  ..., -0.2256, -0.1062, -0.2143],
        [-0.0892,  0.0000, -0.0740,  ..., -0.0845,  0.4357,  0.1809],
        [ 0.8488,  0.0740,  0.0000,  ...,  0.3432, -0.3021,  0.0142],
        ...,
        [ 0.2256,  0.0845, -0.3432,  ...,  0.0000, -0.5254, -0.4043],
        [ 0.1062, -0.4357,  0.3021,  ...,  0.5254,  0.0000, -0.2097],
        [ 0.2143, -0.1809, -0.0142,  ...,  0.4043,  0.2097,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.0898, -0.8504,  ..., -0.2

Gamma:  tensor([[ 0.0000,  0.1047, -0.8722,  ..., -0.1784, -0.0858, -0.1870],
        [-0.1047,  0.0000, -0.1160,  ..., -0.0595,  0.3778,  0.1559],
        [ 0.8722,  0.1160,  0.0000,  ...,  0.2902, -0.2671,  0.0026],
        ...,
        [ 0.1784,  0.0595, -0.2902,  ...,  0.0000, -0.5058, -0.4137],
        [ 0.0858, -0.3778,  0.2671,  ...,  0.5058,  0.0000, -0.1641],
        [ 0.1870, -0.1559, -0.0026,  ...,  0.4137,  0.1641,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.1058, -0.8730,  ..., -0.1762, -0.0849, -0.1856],
        [-0.1058,  0.0000, -0.1180,  ..., -0.0584,  0.3749,  0.1547],
        [ 0.8730,  0.1180,  0.0000,  ...,  0.2876, -0.2653,  0.0022],
        ...,
        [ 0.1762,  0.0584, -0.2876,  ...,  0.0000, -0.5050, -0.4140],
        [ 0.0849, -0.3749,  0.2653,  ...,  0.5050,  0.0000, -0.1619],
        [ 0.1856, -0.1547, -0.0022,  ...,  0.4140,  0.1619,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.1068, -0.8738,  ..., -0.1

Gamma:  tensor([[ 0.0000,  0.1237, -0.8813,  ..., -0.1460, -0.0738, -0.1659],
        [-0.1237,  0.0000, -0.1446,  ..., -0.0450,  0.3336,  0.1376],
        [ 0.8813,  0.1446,  0.0000,  ...,  0.2516, -0.2385, -0.0034],
        ...,
        [ 0.1460,  0.0450, -0.2516,  ...,  0.0000, -0.4940, -0.4172],
        [ 0.0738, -0.3336,  0.2385,  ...,  0.4940,  0.0000, -0.1305],
        [ 0.1659, -0.1376,  0.0034,  ...,  0.4172,  0.1305,  0.0000]],
       grad_fn=<MulBackward0>)
Iteration = 240, Losses: rotation = 44052.796875 
Gamma:  tensor([[ 0.0000,  0.1250, -0.8816,  ..., -0.1442, -0.0731, -0.1646],
        [-0.1250,  0.0000, -0.1462,  ..., -0.0443,  0.3310,  0.1365],
        [ 0.8816,  0.1462,  0.0000,  ...,  0.2494, -0.2367, -0.0037],
        ...,
        [ 0.1442,  0.0443, -0.2494,  ...,  0.0000, -0.4933, -0.4173],
        [ 0.0731, -0.3310,  0.2367,  ...,  0.4933,  0.0000, -0.1286],
        [ 0.1646, -0.1365,  0.0037,  ...,  0.4173,  0.1286,  0.0000]],
       grad_fn=<MulBackward0>)
Gamm

Gamma:  tensor([[ 0.0000,  0.1508, -0.8815,  ..., -0.1144, -0.0630, -0.1431],
        [-0.1508,  0.0000, -0.1715,  ..., -0.0334,  0.2866,  0.1186],
        [ 0.8815,  0.1715,  0.0000,  ...,  0.2123, -0.2063, -0.0076],
        ...,
        [ 0.1144,  0.0334, -0.2123,  ...,  0.0000, -0.4838, -0.4166],
        [ 0.0630, -0.2866,  0.2063,  ...,  0.4838,  0.0000, -0.0972],
        [ 0.1431, -0.1186,  0.0076,  ...,  0.4166,  0.0972,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.1524, -0.8812,  ..., -0.1129, -0.0625, -0.1419],
        [-0.1524,  0.0000, -0.1728,  ..., -0.0329,  0.2843,  0.1177],
        [ 0.8812,  0.1728,  0.0000,  ...,  0.2104, -0.2047, -0.0078],
        ...,
        [ 0.1129,  0.0329, -0.2104,  ...,  0.0000, -0.4833, -0.4164],
        [ 0.0625, -0.2843,  0.2047,  ...,  0.4833,  0.0000, -0.0957],
        [ 0.1419, -0.1177,  0.0078,  ...,  0.4164,  0.0957,  0.0000]],
       grad_fn=<MulBackward0>)
Iteration = 260, Losses: rotation = 44048.0234375 
Gam

Gamma:  tensor([[ 0.0000,  0.1807, -0.8718,  ..., -0.0890, -0.0544, -0.1226],
        [-0.1807,  0.0000, -0.1919,  ..., -0.0259,  0.2453,  0.1021],
        [ 0.8718,  0.1919,  0.0000,  ...,  0.1790, -0.1769, -0.0099],
        ...,
        [ 0.0890,  0.0259, -0.1790,  ...,  0.0000, -0.4757, -0.4108],
        [ 0.0544, -0.2453,  0.1769,  ...,  0.4757,  0.0000, -0.0717],
        [ 0.1226, -0.1021,  0.0099,  ...,  0.4108,  0.0717,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.1823, -0.8711,  ..., -0.0878, -0.0540, -0.1216],
        [-0.1823,  0.0000, -0.1928,  ..., -0.0256,  0.2433,  0.1013],
        [ 0.8711,  0.1928,  0.0000,  ...,  0.1774, -0.1754, -0.0100],
        ...,
        [ 0.0878,  0.0256, -0.1774,  ...,  0.0000, -0.4753, -0.4103],
        [ 0.0540, -0.2433,  0.1754,  ...,  0.4753,  0.0000, -0.0706],
        [ 0.1216, -0.1013,  0.0100,  ...,  0.4103,  0.0706,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.1839, -0.8703,  ..., -0.0

Gamma:  tensor([[ 0.0000,  0.2113, -0.8543,  ..., -0.0686, -0.0472, -0.1045],
        [-0.2113,  0.0000, -0.2065,  ..., -0.0210,  0.2093,  0.0879],
        [ 0.8543,  0.2065,  0.0000,  ...,  0.1511, -0.1507, -0.0110],
        ...,
        [ 0.0686,  0.0210, -0.1511,  ...,  0.0000, -0.4675, -0.4001],
        [ 0.0472, -0.2093,  0.1507,  ...,  0.4675,  0.0000, -0.0540],
        [ 0.1045, -0.0879,  0.0110,  ...,  0.4001,  0.0540,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.2129, -0.8532,  ..., -0.0676, -0.0468, -0.1036],
        [-0.2129,  0.0000, -0.2071,  ..., -0.0208,  0.2075,  0.0872],
        [ 0.8532,  0.2071,  0.0000,  ...,  0.1498, -0.1494, -0.0110],
        ...,
        [ 0.0676,  0.0208, -0.1498,  ...,  0.0000, -0.4671, -0.3995],
        [ 0.0468, -0.2075,  0.1494,  ...,  0.4671,  0.0000, -0.0532],
        [ 0.1036, -0.0872,  0.0110,  ...,  0.3995,  0.0532,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.2145, -0.8520,  ..., -0.0

Gamma:  tensor([[ 0.0000,  0.2413, -0.8306,  ..., -0.0524, -0.0409, -0.0886],
        [-0.2413,  0.0000, -0.2163,  ..., -0.0178,  0.1780,  0.0755],
        [ 0.8306,  0.2163,  0.0000,  ...,  0.1277, -0.1277, -0.0114],
        ...,
        [ 0.0524,  0.0178, -0.1277,  ...,  0.0000, -0.4576, -0.3852],
        [ 0.0409, -0.1780,  0.1277,  ...,  0.4576,  0.0000, -0.0433],
        [ 0.0886, -0.0755,  0.0114,  ...,  0.3852,  0.0433,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.2429, -0.8292,  ..., -0.0516, -0.0406, -0.0878],
        [-0.2429,  0.0000, -0.2167,  ..., -0.0177,  0.1765,  0.0749],
        [ 0.8292,  0.2167,  0.0000,  ...,  0.1266, -0.1265, -0.0114],
        ...,
        [ 0.0516,  0.0177, -0.1266,  ...,  0.0000, -0.4571, -0.3843],
        [ 0.0406, -0.1765,  0.1265,  ...,  0.4571,  0.0000, -0.0429],
        [ 0.0878, -0.0749,  0.0114,  ...,  0.3843,  0.0429,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.2444, -0.8278,  ..., -0.0

Gamma:  tensor([[ 0.0000,  0.2699, -0.8025,  ..., -0.0397, -0.0352, -0.0748],
        [-0.2699,  0.0000, -0.2220,  ..., -0.0156,  0.1510,  0.0647],
        [ 0.8025,  0.2220,  0.0000,  ...,  0.1081, -0.1077, -0.0114],
        ...,
        [ 0.0397,  0.0156, -0.1081,  ...,  0.0000, -0.4450, -0.3667],
        [ 0.0352, -0.1510,  0.1077,  ...,  0.4450,  0.0000, -0.0387],
        [ 0.0748, -0.0647,  0.0114,  ...,  0.3667,  0.0387,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.2713, -0.8009,  ..., -0.0391, -0.0349, -0.0742],
        [-0.2713,  0.0000, -0.2221,  ..., -0.0155,  0.1496,  0.0642],
        [ 0.8009,  0.2221,  0.0000,  ...,  0.1071, -0.1067, -0.0114],
        ...,
        [ 0.0391,  0.0155, -0.1071,  ...,  0.0000, -0.4443, -0.3657],
        [ 0.0349, -0.1496,  0.1067,  ...,  0.4443,  0.0000, -0.0386],
        [ 0.0742, -0.0642,  0.0114,  ...,  0.3657,  0.0386,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.2727, -0.7993,  ..., -0.0

Gamma:  tensor([[ 0.0000,  0.2976, -0.7697,  ..., -0.0294, -0.0299, -0.0624],
        [-0.2976,  0.0000, -0.2241,  ..., -0.0139,  0.1265,  0.0549],
        [ 0.7697,  0.2241,  0.0000,  ...,  0.0908, -0.0897, -0.0111],
        ...,
        [ 0.0294,  0.0139, -0.0908,  ...,  0.0000, -0.4283, -0.3442],
        [ 0.0299, -0.1265,  0.0897,  ...,  0.4283,  0.0000, -0.0392],
        [ 0.0624, -0.0549,  0.0111,  ...,  0.3442,  0.0392,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.2989, -0.7680,  ..., -0.0289, -0.0297, -0.0618],
        [-0.2989,  0.0000, -0.2241,  ..., -0.0138,  0.1253,  0.0544],
        [ 0.7680,  0.2241,  0.0000,  ...,  0.0900, -0.0889, -0.0111],
        ...,
        [ 0.0289,  0.0138, -0.0900,  ...,  0.0000, -0.4274, -0.3431],
        [ 0.0297, -0.1253,  0.0889,  ...,  0.4274,  0.0000, -0.0393],
        [ 0.0618, -0.0544,  0.0111,  ...,  0.3431,  0.0393,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.3002, -0.7663,  ..., -0.0

Gamma:  tensor([[ 0.0000,  0.3215, -0.7369,  ..., -0.0219, -0.0255, -0.0523],
        [-0.3215,  0.0000, -0.2232,  ..., -0.0126,  0.1065,  0.0468],
        [ 0.7369,  0.2232,  0.0000,  ...,  0.0771, -0.0752, -0.0107],
        ...,
        [ 0.0219,  0.0126, -0.0771,  ...,  0.0000, -0.4093, -0.3208],
        [ 0.0255, -0.1065,  0.0752,  ...,  0.4093,  0.0000, -0.0433],
        [ 0.0523, -0.0468,  0.0107,  ...,  0.3208,  0.0433,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.3227, -0.7352,  ..., -0.0215, -0.0253, -0.0518],
        [-0.3227,  0.0000, -0.2231,  ..., -0.0125,  0.1056,  0.0464],
        [ 0.7352,  0.2231,  0.0000,  ...,  0.0764, -0.0745, -0.0106],
        ...,
        [ 0.0215,  0.0125, -0.0764,  ...,  0.0000, -0.4082, -0.3196],
        [ 0.0253, -0.1056,  0.0745,  ...,  0.4082,  0.0000, -0.0436],
        [ 0.0518, -0.0464,  0.0106,  ...,  0.3196,  0.0436,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.3238, -0.7334,  ..., -0.0

Gamma:  tensor([[ 0.0000,  0.3429, -0.7035,  ..., -0.0162, -0.0216, -0.0437],
        [-0.3429,  0.0000, -0.2198,  ..., -0.0114,  0.0895,  0.0397],
        [ 0.7035,  0.2198,  0.0000,  ...,  0.0654, -0.0629, -0.0101],
        ...,
        [ 0.0162,  0.0114, -0.0654,  ...,  0.0000, -0.3875, -0.2961],
        [ 0.0216, -0.0895,  0.0629,  ...,  0.3875,  0.0000, -0.0500],
        [ 0.0437, -0.0397,  0.0101,  ...,  0.2961,  0.0500,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.3439, -0.7017,  ..., -0.0159, -0.0214, -0.0433],
        [-0.3439,  0.0000, -0.2196,  ..., -0.0113,  0.0886,  0.0394],
        [ 0.7017,  0.2196,  0.0000,  ...,  0.0649, -0.0623, -0.0100],
        ...,
        [ 0.0159,  0.0113, -0.0649,  ...,  0.0000, -0.3863, -0.2948],
        [ 0.0214, -0.0886,  0.0623,  ...,  0.3863,  0.0000, -0.0504],
        [ 0.0433, -0.0394,  0.0100,  ...,  0.2948,  0.0504,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.3450, -0.7000,  ..., -0.0

Gamma:  tensor([[ 0.0000,  0.3618, -0.6702,  ..., -0.0118, -0.0181, -0.0365],
        [-0.3618,  0.0000, -0.2144,  ..., -0.0103,  0.0749,  0.0336],
        [ 0.6702,  0.2144,  0.0000,  ...,  0.0555, -0.0525, -0.0094],
        ...,
        [ 0.0118,  0.0103, -0.0555,  ...,  0.0000, -0.3636, -0.2708],
        [ 0.0181, -0.0749,  0.0525,  ...,  0.3636,  0.0000, -0.0585],
        [ 0.0365, -0.0336,  0.0094,  ...,  0.2708,  0.0585,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.3627, -0.6685,  ..., -0.0116, -0.0180, -0.0362],
        [-0.3627,  0.0000, -0.2141,  ..., -0.0102,  0.0742,  0.0333],
        [ 0.6685,  0.2141,  0.0000,  ...,  0.0551, -0.0520, -0.0093],
        ...,
        [ 0.0116,  0.0102, -0.0551,  ...,  0.0000, -0.3622, -0.2694],
        [ 0.0180, -0.0742,  0.0520,  ...,  0.3622,  0.0000, -0.0589],
        [ 0.0362, -0.0333,  0.0093,  ...,  0.2694,  0.0589,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.3637, -0.6668,  ..., -0.0

Gamma:  tensor([[ 0.0000,  0.3776, -0.6394,  ..., -0.0087, -0.0154, -0.0307],
        [-0.3776,  0.0000, -0.2077,  ..., -0.0092,  0.0630,  0.0286],
        [ 0.6394,  0.2077,  0.0000,  ...,  0.0476, -0.0442, -0.0087],
        ...,
        [ 0.0087,  0.0092, -0.0476,  ...,  0.0000, -0.3394, -0.2466],
        [ 0.0154, -0.0630,  0.0442,  ...,  0.3394,  0.0000, -0.0674],
        [ 0.0307, -0.0286,  0.0087,  ...,  0.2466,  0.0674,  0.0000]],
       grad_fn=<MulBackward0>)
Iteration = 430, Losses: rotation = 44032.4609375 
Gamma:  tensor([[ 0.0000,  0.3785, -0.6378,  ..., -0.0086, -0.0152, -0.0304],
        [-0.3785,  0.0000, -0.2073,  ..., -0.0092,  0.0624,  0.0284],
        [ 0.6378,  0.2073,  0.0000,  ...,  0.0472, -0.0437, -0.0086],
        ...,
        [ 0.0086,  0.0092, -0.0472,  ...,  0.0000, -0.3380, -0.2453],
        [ 0.0152, -0.0624,  0.0437,  ...,  0.3380,  0.0000, -0.0680],
        [ 0.0304, -0.0284,  0.0086,  ...,  0.2453,  0.0680,  0.0000]],
       grad_fn=<MulBackward0>)
Gam

Gamma:  tensor([[ 0.0000,  0.3923, -0.6081,  ..., -0.0063, -0.0128, -0.0256],
        [-0.3923,  0.0000, -0.1994,  ..., -0.0082,  0.0524,  0.0241],
        [ 0.6081,  0.1994,  0.0000,  ...,  0.0403, -0.0367, -0.0079],
        ...,
        [ 0.0063,  0.0082, -0.0403,  ...,  0.0000, -0.3130, -0.2214],
        [ 0.0128, -0.0524,  0.0367,  ...,  0.3130,  0.0000, -0.0774],
        [ 0.0256, -0.0241,  0.0079,  ...,  0.2214,  0.0774,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.3930, -0.6065,  ..., -0.0061, -0.0127, -0.0253],
        [-0.3930,  0.0000, -0.1989,  ..., -0.0081,  0.0519,  0.0239],
        [ 0.6065,  0.1989,  0.0000,  ...,  0.0400, -0.0364, -0.0079],
        ...,
        [ 0.0061,  0.0081, -0.0400,  ...,  0.0000, -0.3116, -0.2201],
        [ 0.0127, -0.0519,  0.0364,  ...,  0.3116,  0.0000, -0.0780],
        [ 0.0253, -0.0239,  0.0079,  ...,  0.2201,  0.0780,  0.0000]],
       grad_fn=<MulBackward0>)
Iteration = 450, Losses: rotation = 44031.8046875 
Gam

Gamma:  tensor([[ 0.0000,  0.4045, -0.5797,  ..., -0.0045, -0.0108, -0.0215],
        [-0.4045,  0.0000, -0.1904,  ..., -0.0072,  0.0439,  0.0204],
        [ 0.5797,  0.1904,  0.0000,  ...,  0.0345, -0.0308, -0.0072],
        ...,
        [ 0.0045,  0.0072, -0.0345,  ...,  0.0000, -0.2877, -0.1980],
        [ 0.0108, -0.0439,  0.0308,  ...,  0.2877,  0.0000, -0.0871],
        [ 0.0215, -0.0204,  0.0072,  ...,  0.1980,  0.0871,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.4051, -0.5782,  ..., -0.0044, -0.0107, -0.0213],
        [-0.4051,  0.0000, -0.1899,  ..., -0.0071,  0.0435,  0.0202],
        [ 0.5782,  0.1899,  0.0000,  ...,  0.0342, -0.0305, -0.0072],
        ...,
        [ 0.0044,  0.0071, -0.0342,  ...,  0.0000, -0.2862, -0.1967],
        [ 0.0107, -0.0435,  0.0305,  ...,  0.2862,  0.0000, -0.0876],
        [ 0.0213, -0.0202,  0.0072,  ...,  0.1967,  0.0876,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.4057, -0.5767,  ..., -0.0

Gamma:  tensor([[ 0.0000,  0.4157, -0.5514,  ..., -0.0031, -0.0090, -0.0179],
        [-0.4157,  0.0000, -0.1801,  ..., -0.0062,  0.0363,  0.0171],
        [ 0.5514,  0.1801,  0.0000,  ...,  0.0292, -0.0255, -0.0065],
        ...,
        [ 0.0031,  0.0062, -0.0292,  ...,  0.0000, -0.2610, -0.1741],
        [ 0.0090, -0.0363,  0.0255,  ...,  0.2610,  0.0000, -0.0971],
        [ 0.0179, -0.0171,  0.0065,  ...,  0.1741,  0.0971,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.4163, -0.5500,  ..., -0.0031, -0.0089, -0.0177],
        [-0.4163,  0.0000, -0.1796,  ..., -0.0061,  0.0359,  0.0170],
        [ 0.5500,  0.1796,  0.0000,  ...,  0.0290, -0.0253, -0.0065],
        ...,
        [ 0.0031,  0.0061, -0.0290,  ...,  0.0000, -0.2597, -0.1729],
        [ 0.0089, -0.0359,  0.0253,  ...,  0.2597,  0.0000, -0.0976],
        [ 0.0177, -0.0170,  0.0065,  ...,  0.1729,  0.0976,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.4168, -0.5485,  ..., -0.0

Gamma:  tensor([[ 0.0000,  0.4250, -0.5262,  ..., -0.0021, -0.0075, -0.0150],
        [-0.4250,  0.0000, -0.1698,  ..., -0.0053,  0.0302,  0.0144],
        [ 0.5262,  0.1698,  0.0000,  ...,  0.0249, -0.0214, -0.0059],
        ...,
        [ 0.0021,  0.0053, -0.0249,  ...,  0.0000, -0.2364, -0.1524],
        [ 0.0075, -0.0302,  0.0214,  ...,  0.2364,  0.0000, -0.1064],
        [ 0.0150, -0.0144,  0.0059,  ...,  0.1524,  0.1064,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.4255, -0.5248,  ..., -0.0021, -0.0075, -0.0149],
        [-0.4255,  0.0000, -0.1692,  ..., -0.0053,  0.0299,  0.0143],
        [ 0.5248,  0.1692,  0.0000,  ...,  0.0247, -0.0212, -0.0058],
        ...,
        [ 0.0021,  0.0053, -0.0247,  ...,  0.0000, -0.2350, -0.1513],
        [ 0.0075, -0.0299,  0.0212,  ...,  0.2350,  0.0000, -0.1069],
        [ 0.0149, -0.0143,  0.0058,  ...,  0.1513,  0.1069,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.4260, -0.5235,  ..., -0.0

Gamma:  tensor([[ 0.0000,  0.4337, -0.5013,  ..., -0.0014, -0.0063, -0.0125],
        [-0.4337,  0.0000, -0.1584,  ..., -0.0045,  0.0249,  0.0121],
        [ 0.5013,  0.1584,  0.0000,  ...,  0.0211, -0.0177, -0.0052],
        ...,
        [ 0.0014,  0.0045, -0.0211,  ...,  0.0000, -0.2113, -0.1307],
        [ 0.0063, -0.0249,  0.0177,  ...,  0.2113,  0.0000, -0.1157],
        [ 0.0125, -0.0121,  0.0052,  ...,  0.1307,  0.1157,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.4341, -0.5000,  ..., -0.0013, -0.0062, -0.0124],
        [-0.4341,  0.0000, -0.1578,  ..., -0.0044,  0.0246,  0.0119],
        [ 0.5000,  0.1578,  0.0000,  ...,  0.0209, -0.0175, -0.0052],
        ...,
        [ 0.0013,  0.0044, -0.0209,  ...,  0.0000, -0.2100, -0.1295],
        [ 0.0062, -0.0246,  0.0175,  ...,  0.2100,  0.0000, -0.1162],
        [ 0.0124, -0.0119,  0.0052,  ...,  0.1295,  0.1162,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.4345, -0.4988,  ..., -0.0

Gamma:  tensor([[ 0.0000,  0.4420, -0.4757,  ..., -0.0007, -0.0051, -0.0102],
        [-0.4420,  0.0000, -0.1455,  ..., -0.0036,  0.0200,  0.0099],
        [ 0.4757,  0.1455,  0.0000,  ...,  0.0174, -0.0143, -0.0046],
        ...,
        [ 0.0007,  0.0036, -0.0174,  ...,  0.0000, -0.1850, -0.1080],
        [ 0.0051, -0.0200,  0.0143,  ...,  0.1850,  0.0000, -0.1254],
        [ 0.0102, -0.0099,  0.0046,  ...,  0.1080,  0.1254,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.4423, -0.4746,  ..., -0.0007, -0.0051, -0.0101],
        [-0.4423,  0.0000, -0.1449,  ..., -0.0036,  0.0198,  0.0098],
        [ 0.4746,  0.1449,  0.0000,  ...,  0.0173, -0.0142, -0.0046],
        ...,
        [ 0.0007,  0.0036, -0.0173,  ...,  0.0000, -0.1838, -0.1070],
        [ 0.0051, -0.0198,  0.0142,  ...,  0.1838,  0.0000, -0.1258],
        [ 0.0101, -0.0098,  0.0046,  ...,  0.1070,  0.1258,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.4427, -0.4734,  ..., -0.0

Gamma:  tensor([[ 0.0000e+00,  4.4826e-01, -4.5548e-01,  ..., -3.2585e-04,
         -4.3271e-03, -8.6209e-03],
        [-4.4826e-01,  0.0000e+00, -1.3424e-01,  ..., -2.9610e-03,
          1.6553e-02,  8.2839e-03],
        [ 4.5548e-01,  1.3424e-01,  0.0000e+00,  ...,  1.4835e-02,
         -1.1986e-02, -4.0791e-03],
        ...,
        [ 3.2585e-04,  2.9610e-03, -1.4835e-02,  ...,  0.0000e+00,
         -1.6387e-01, -8.9894e-02],
        [ 4.3271e-03, -1.6553e-02,  1.1986e-02,  ...,  1.6387e-01,
          0.0000e+00, -1.3314e-01],
        [ 8.6209e-03, -8.2839e-03,  4.0791e-03,  ...,  8.9894e-02,
          1.3314e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  4.4859e-01, -4.5439e-01,  ..., -3.0702e-04,
         -4.2876e-03, -8.5393e-03],
        [-4.4859e-01,  0.0000e+00, -1.3362e-01,  ..., -2.9276e-03,
          1.6380e-02,  8.2036e-03],
        [ 4.5439e-01,  1.3362e-01,  0.0000e+00,  ...,  1.4701e-02,
         -1.1867e-02, -4.0532e-03],
        ...,
      

Gamma:  tensor([[ 0.0000e+00,  4.5299e-01, -4.3968e-01,  ..., -6.6956e-05,
         -3.7603e-03, -7.4855e-03],
        [-4.5299e-01,  0.0000e+00, -1.2486e-01,  ..., -2.4736e-03,
          1.4130e-02,  7.1642e-03],
        [ 4.3968e-01,  1.2486e-01,  0.0000e+00,  ...,  1.2947e-02,
         -1.0318e-02, -3.7004e-03],
        ...,
        [ 6.6956e-05,  2.4736e-03, -1.2947e-02,  ...,  0.0000e+00,
         -1.4738e-01, -7.5671e-02],
        [ 3.7603e-03, -1.4130e-02,  1.0318e-02,  ...,  1.4738e-01,
          0.0000e+00, -1.3921e-01],
        [ 7.4855e-03, -7.1642e-03,  3.7004e-03,  ...,  7.5671e-02,
          1.3921e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  4.5329e-01, -4.3867e-01,  ..., -5.2417e-05,
         -3.7250e-03, -7.4162e-03],
        [-4.5329e-01,  0.0000e+00, -1.2423e-01,  ..., -2.4436e-03,
          1.3981e-02,  7.0949e-03],
        [ 4.3867e-01,  1.2423e-01,  0.0000e+00,  ...,  1.2829e-02,
         -1.0215e-02, -3.6753e-03],
        ...,
      

Gamma:  tensor([[ 0.0000e+00,  4.5733e-01, -4.2483e-01,  ...,  1.3895e-04,
         -3.2744e-03, -6.5085e-03],
        [-4.5733e-01,  0.0000e+00, -1.1550e-01,  ..., -2.0367e-03,
          1.2051e-02,  6.1942e-03],
        [ 4.2483e-01,  1.1550e-01,  0.0000e+00,  ...,  1.1291e-02,
         -8.8817e-03, -3.3525e-03],
        ...,
        [-1.3895e-04,  2.0367e-03, -1.1291e-02,  ...,  0.0000e+00,
         -1.3192e-01, -6.2259e-02],
        [ 3.2744e-03, -1.2051e-02,  8.8817e-03,  ...,  1.3192e-01,
          0.0000e+00, -1.4490e-01],
        [ 6.5085e-03, -6.1942e-03,  3.3525e-03,  ...,  6.2259e-02,
          1.4490e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  4.5761e-01, -4.2387e-01,  ...,  1.5142e-04,
         -3.2455e-03, -6.4490e-03],
        [-4.5761e-01,  0.0000e+00, -1.1488e-01,  ..., -2.0098e-03,
          1.1923e-02,  6.1356e-03],
        [ 4.2387e-01,  1.1488e-01,  0.0000e+00,  ...,  1.1189e-02,
         -8.7933e-03, -3.3308e-03],
        ...,
      

Gamma:  tensor([[ 0.0000e+00,  4.6082e-01, -4.1268e-01,  ...,  2.8439e-04,
         -2.9111e-03, -5.7716e-03],
        [-4.6082e-01,  0.0000e+00, -1.0747e-01,  ..., -1.6972e-03,
          1.0495e-02,  5.4604e-03],
        [ 4.1268e-01,  1.0747e-01,  0.0000e+00,  ...,  1.0024e-02,
         -7.8001e-03, -3.0764e-03],
        ...,
        [-2.8439e-04,  1.6972e-03, -1.0024e-02,  ...,  0.0000e+00,
         -1.1939e-01, -5.1282e-02],
        [ 2.9111e-03, -1.0495e-02,  7.8001e-03,  ...,  1.1939e-01,
          0.0000e+00, -1.4955e-01],
        [ 5.7716e-03, -5.4604e-03,  3.0764e-03,  ...,  5.1282e-02,
          1.4955e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  4.6108e-01, -4.1177e-01,  ...,  2.9497e-04,
         -2.8850e-03, -5.7192e-03],
        [-4.6108e-01,  0.0000e+00, -1.0685e-01,  ..., -1.6720e-03,
          1.0384e-02,  5.4084e-03],
        [ 4.1177e-01,  1.0685e-01,  0.0000e+00,  ...,  9.9337e-03,
         -7.7230e-03, -3.0558e-03],
        ...,
      

Gamma:  tensor([[ 0.0000e+00,  4.6409e-01, -4.0117e-01,  ...,  4.0558e-04,
         -2.5932e-03, -5.1268e-03],
        [-4.6409e-01,  0.0000e+00, -9.9522e-02,  ..., -1.3909e-03,
          9.1350e-03,  4.8148e-03],
        [ 4.0117e-01,  9.9522e-02,  0.0000e+00,  ...,  8.8965e-03,
         -6.8527e-03, -2.8212e-03],
        ...,
        [-4.0558e-04,  1.3909e-03, -8.8965e-03,  ...,  0.0000e+00,
         -1.0766e-01, -4.0895e-02],
        [ 2.5932e-03, -9.1350e-03,  6.8527e-03,  ...,  1.0766e-01,
          0.0000e+00, -1.5393e-01],
        [ 5.1268e-03, -4.8148e-03,  2.8212e-03,  ...,  4.0895e-02,
          1.5393e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  4.6433e-01, -4.0031e-01,  ...,  4.1403e-04,
         -2.5708e-03, -5.0788e-03],
        [-4.6433e-01,  0.0000e+00, -9.8916e-02,  ..., -1.3697e-03,
          9.0380e-03,  4.7694e-03],
        [ 4.0031e-01,  9.8916e-02,  0.0000e+00,  ...,  8.8163e-03,
         -6.7845e-03, -2.8024e-03],
        ...,
      

Gamma:  tensor([[ 0.0000,  0.4676, -0.3886,  ...,  0.0005, -0.0023, -0.0045],
        [-0.4676,  0.0000, -0.0905,  ..., -0.0011,  0.0078,  0.0042],
        [ 0.3886,  0.0905,  0.0000,  ...,  0.0078, -0.0059, -0.0026],
        ...,
        [-0.0005,  0.0011, -0.0078,  ...,  0.0000, -0.0951, -0.0296],
        [ 0.0023, -0.0078,  0.0059,  ...,  0.0951,  0.0000, -0.1587],
        [ 0.0045, -0.0042,  0.0026,  ...,  0.0296,  0.1587,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.4678, -0.3878,  ...,  0.0005, -0.0023, -0.0044],
        [-0.4678,  0.0000, -0.0899,  ..., -0.0011,  0.0077,  0.0041],
        [ 0.3878,  0.0899,  0.0000,  ...,  0.0077, -0.0058, -0.0025],
        ...,
        [-0.0005,  0.0011, -0.0077,  ...,  0.0000, -0.0943, -0.0289],
        [ 0.0023, -0.0077,  0.0058,  ...,  0.0943,  0.0000, -0.1590],
        [ 0.0044, -0.0041,  0.0025,  ...,  0.0289,  0.1590,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.4681, -0.3870,  ...,  0.0

Gamma:  tensor([[ 0.0000,  0.4722, -0.3723,  ...,  0.0006, -0.0019, -0.0037],
        [-0.4722,  0.0000, -0.0782,  ..., -0.0007,  0.0062,  0.0034],
        [ 0.3723,  0.0782,  0.0000,  ...,  0.0064, -0.0048, -0.0022],
        ...,
        [-0.0006,  0.0007, -0.0064,  ...,  0.0000, -0.0793, -0.0151],
        [ 0.0019, -0.0062,  0.0048,  ...,  0.0793,  0.0000, -0.1648],
        [ 0.0037, -0.0034,  0.0022,  ...,  0.0151,  0.1648,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.4724, -0.3716,  ...,  0.0006, -0.0019, -0.0037],
        [-0.4724,  0.0000, -0.0776,  ..., -0.0007,  0.0062,  0.0034],
        [ 0.3716,  0.0776,  0.0000,  ...,  0.0063, -0.0047, -0.0022],
        ...,
        [-0.0006,  0.0007, -0.0063,  ...,  0.0000, -0.0786, -0.0144],
        [ 0.0019, -0.0062,  0.0047,  ...,  0.0786,  0.0000, -0.1650],
        [ 0.0037, -0.0034,  0.0022,  ...,  0.0144,  0.1650,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000,  0.4726, -0.3708,  ...,  0.0

Gamma:  tensor([[ 0.0000e+00,  4.7563e-01, -3.6007e-01,  ...,  7.1874e-04,
         -1.6597e-03, -3.2087e-03],
        [-4.7563e-01,  0.0000e+00, -6.8545e-02,  ..., -4.5688e-04,
          5.1831e-03,  2.8943e-03],
        [ 3.6007e-01,  6.8545e-02,  0.0000e+00,  ...,  5.4568e-03,
         -4.0568e-03, -1.9831e-03],
        ...,
        [-7.1874e-04,  4.5688e-04, -5.4568e-03,  ...,  0.0000e+00,
         -6.7881e-02, -4.2994e-03],
        [ 1.6597e-03, -5.1831e-03,  4.0568e-03,  ...,  6.7881e-02,
          0.0000e+00, -1.6925e-01],
        [ 3.2087e-03, -2.8943e-03,  1.9831e-03,  ...,  4.2994e-03,
          1.6925e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  4.7582e-01, -3.5937e-01,  ...,  7.2299e-04,
         -1.6458e-03, -3.1811e-03],
        [-4.7582e-01,  0.0000e+00, -6.7987e-02,  ..., -4.4243e-04,
          5.1278e-03,  2.8674e-03],
        [ 3.5937e-01,  6.7987e-02,  0.0000e+00,  ...,  5.4063e-03,
         -4.0181e-03, -1.9709e-03],
        ...,
      

Gamma:  tensor([[ 0.0000e+00,  4.7811e-01, -3.5125e-01,  ...,  7.6515e-04,
         -1.4978e-03, -2.8717e-03],
        [-4.7811e-01,  0.0000e+00, -6.1381e-02,  ..., -2.9307e-04,
          4.5116e-03,  2.5597e-03],
        [ 3.5125e-01,  6.1381e-02,  0.0000e+00,  ...,  4.8400e-03,
         -3.5734e-03, -1.8209e-03],
        ...,
        [-7.6515e-04,  2.9307e-04, -4.8400e-03,  ...,  0.0000e+00,
         -5.9981e-02,  3.3682e-03],
        [ 1.4978e-03, -4.5116e-03,  3.5734e-03,  ...,  5.9981e-02,
          0.0000e+00, -1.7243e-01],
        [ 2.8717e-03, -2.5597e-03,  1.8209e-03,  ..., -3.3682e-03,
          1.7243e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  4.7830e-01, -3.5059e-01,  ...,  7.6782e-04,
         -1.4870e-03, -2.8477e-03],
        [-4.7830e-01,  0.0000e+00, -6.0839e-02,  ..., -2.8081e-04,
          4.4635e-03,  2.5361e-03],
        [ 3.5059e-01,  6.0839e-02,  0.0000e+00,  ...,  4.7947e-03,
         -3.5389e-03, -1.8093e-03],
        ...,
      

Gamma:  tensor([[ 0.0000e+00,  4.8086e-01, -3.4162e-01,  ...,  8.0841e-04,
         -1.3341e-03, -2.5336e-03],
        [-4.8086e-01,  0.0000e+00, -5.3360e-02,  ..., -1.2876e-04,
          3.8459e-03,  2.2241e-03],
        [ 3.4162e-01,  5.3360e-02,  0.0000e+00,  ...,  4.2148e-03,
         -3.0897e-03, -1.6512e-03],
        ...,
        [-8.0841e-04,  1.2876e-04, -4.2148e-03,  ...,  0.0000e+00,
         -5.1697e-02,  1.1635e-02],
        [ 1.3341e-03, -3.8459e-03,  3.0897e-03,  ...,  5.1697e-02,
          0.0000e+00, -1.7586e-01],
        [ 2.5336e-03, -2.2241e-03,  1.6512e-03,  ..., -1.1635e-02,
          1.7586e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  4.8104e-01, -3.4100e-01,  ...,  8.1096e-04,
         -1.3244e-03, -2.5124e-03],
        [-4.8104e-01,  0.0000e+00, -5.2834e-02,  ..., -1.1899e-04,
          3.8063e-03,  2.2037e-03],
        [ 3.4100e-01,  5.2834e-02,  0.0000e+00,  ...,  4.1762e-03,
         -3.0603e-03, -1.6400e-03],
        ...,
      

Gamma:  tensor([[ 0.0000e+00,  4.8313e-01, -3.3371e-01,  ...,  8.3821e-04,
         -1.2123e-03, -2.2780e-03],
        [-4.8313e-01,  0.0000e+00, -4.6630e-02,  ..., -7.5195e-06,
          3.3512e-03,  1.9735e-03],
        [ 3.3371e-01,  4.6630e-02,  0.0000e+00,  ...,  3.7386e-03,
         -2.7258e-03, -1.5180e-03],
        ...,
        [-8.3821e-04,  7.5195e-06, -3.7386e-03,  ...,  0.0000e+00,
         -4.5201e-02,  1.8322e-02],
        [ 1.2123e-03, -3.3512e-03,  2.7258e-03,  ...,  4.5201e-02,
          0.0000e+00, -1.7863e-01],
        [ 2.2780e-03, -1.9735e-03,  1.5180e-03,  ..., -1.8322e-02,
          1.7863e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  4.8331e-01, -3.3311e-01,  ...,  8.4039e-04,
         -1.2026e-03, -2.2594e-03],
        [-4.8331e-01,  0.0000e+00, -4.6120e-02,  ...,  4.8828e-08,
          3.3158e-03,  1.9553e-03],
        [ 3.3311e-01,  4.6120e-02,  0.0000e+00,  ...,  3.7051e-03,
         -2.7000e-03, -1.5083e-03],
        ...,
      

Gamma:  tensor([[ 0.0000e+00,  4.8566e-01, -3.2504e-01,  ...,  8.6574e-04,
         -1.0875e-03, -2.0197e-03],
        [-4.8566e-01,  0.0000e+00, -3.9127e-02,  ...,  1.1104e-04,
          2.8612e-03,  1.7206e-03],
        [ 3.2504e-01,  3.9127e-02,  0.0000e+00,  ...,  3.2587e-03,
         -2.3631e-03, -1.3795e-03],
        ...,
        [-8.6574e-04, -1.1104e-04, -3.2587e-03,  ...,  0.0000e+00,
         -3.8441e-02,  2.5518e-02],
        [ 1.0875e-03, -2.8612e-03,  2.3631e-03,  ...,  3.8441e-02,
          0.0000e+00, -1.8160e-01],
        [ 2.0197e-03, -1.7206e-03,  1.3795e-03,  ..., -2.5518e-02,
          1.8160e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  4.8582e-01, -3.2448e-01,  ...,  8.6674e-04,
         -1.0796e-03, -2.0039e-03],
        [-4.8582e-01,  0.0000e+00, -3.8637e-02,  ...,  1.1826e-04,
          2.8309e-03,  1.7044e-03],
        [ 3.2448e-01,  3.8637e-02,  0.0000e+00,  ...,  3.2298e-03,
         -2.3407e-03, -1.3709e-03],
        ...,
      

Gamma:  tensor([[ 0.0000e+00,  4.8792e-01, -3.1737e-01,  ...,  8.8497e-04,
         -9.8552e-04, -1.8099e-03],
        [-4.8792e-01,  0.0000e+00, -3.2384e-02,  ...,  2.0464e-04,
          2.4703e-03,  1.5183e-03],
        [ 3.1737e-01,  3.2384e-02,  0.0000e+00,  ...,  2.8688e-03,
         -2.0696e-03, -1.2621e-03],
        ...,
        [-8.8497e-04, -2.0464e-04, -2.8688e-03,  ...,  0.0000e+00,
         -3.2798e-02,  3.1763e-02],
        [ 9.8552e-04, -2.4703e-03,  2.0696e-03,  ...,  3.2798e-02,
          0.0000e+00, -1.8418e-01],
        [ 1.8099e-03, -1.5183e-03,  1.2621e-03,  ..., -3.1763e-02,
          1.8418e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  4.8807e-01, -3.1684e-01,  ...,  8.8468e-04,
         -9.7859e-04, -1.7952e-03],
        [-4.8807e-01,  0.0000e+00, -3.1912e-02,  ...,  2.1084e-04,
          2.4449e-03,  1.5041e-03],
        [ 3.1684e-01,  3.1912e-02,  0.0000e+00,  ...,  2.8425e-03,
         -2.0513e-03, -1.2541e-03],
        ...,
      

Gamma:  tensor([[ 0.0000e+00,  4.9009e-01, -3.1007e-01,  ...,  8.9822e-04,
         -8.9541e-04, -1.6259e-03],
        [-4.9009e-01,  0.0000e+00, -2.5895e-02,  ...,  2.8213e-04,
          2.1361e-03,  1.3421e-03],
        [ 3.1007e-01,  2.5895e-02,  0.0000e+00,  ...,  2.5278e-03,
         -1.8165e-03, -1.1572e-03],
        ...,
        [-8.9822e-04, -2.8213e-04, -2.5278e-03,  ...,  0.0000e+00,
         -2.7751e-02,  3.7583e-02],
        [ 8.9541e-04, -2.1361e-03,  1.8165e-03,  ...,  2.7751e-02,
          0.0000e+00, -1.8658e-01],
        [ 1.6259e-03, -1.3421e-03,  1.1572e-03,  ..., -3.7583e-02,
          1.8658e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  4.9025e-01, -3.0956e-01,  ...,  8.9971e-04,
         -8.8916e-04, -1.6134e-03],
        [-4.9025e-01,  0.0000e+00, -2.5441e-02,  ...,  2.8755e-04,
          2.1147e-03,  1.3302e-03],
        [ 3.0956e-01,  2.5441e-02,  0.0000e+00,  ...,  2.5047e-03,
         -1.7984e-03, -1.1491e-03],
        ...,
      

Gamma:  tensor([[ 0.0000e+00,  4.9234e-01, -3.0263e-01,  ...,  9.0826e-04,
         -8.1135e-04, -1.4538e-03],
        [-4.9234e-01,  0.0000e+00, -1.9227e-02,  ...,  3.5029e-04,
          1.8311e-03,  1.1798e-03],
        [ 3.0263e-01,  1.9227e-02,  0.0000e+00,  ...,  2.2086e-03,
         -1.5814e-03, -1.0557e-03],
        ...,
        [-9.0826e-04, -3.5029e-04, -2.2086e-03,  ...,  0.0000e+00,
         -2.2951e-02,  4.3379e-02],
        [ 8.1135e-04, -1.8311e-03,  1.5814e-03,  ...,  2.2951e-02,
          0.0000e+00, -1.8897e-01],
        [ 1.4538e-03, -1.1798e-03,  1.0557e-03,  ..., -4.3379e-02,
          1.8897e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Iteration = 770, Losses: rotation = 44027.4140625 
Gamma:  tensor([[ 0.0000e+00,  4.9249e-01, -3.0215e-01,  ...,  9.0927e-04,
         -8.0657e-04, -1.4423e-03],
        [-4.9249e-01,  0.0000e+00, -1.8794e-02,  ...,  3.5542e-04,
          1.8115e-03,  1.1690e-03],
        [ 3.0215e-01,  1.8794e-02,  0.0000e+00,  ...,  2.1897e-03,
    

Gamma:  tensor([[ 0.0000e+00,  4.9450e-01, -2.9556e-01,  ...,  9.1334e-04,
         -7.3691e-04, -1.3023e-03],
        [-4.9450e-01,  0.0000e+00, -1.2852e-02,  ...,  4.0591e-04,
          1.5703e-03,  1.0399e-03],
        [ 2.9556e-01,  1.2852e-02,  0.0000e+00,  ...,  1.9337e-03,
         -1.3794e-03, -9.6528e-04],
        ...,
        [-9.1334e-04, -4.0591e-04, -1.9337e-03,  ...,  0.0000e+00,
         -1.8732e-02,  4.8753e-02],
        [ 7.3691e-04, -1.5703e-03,  1.3794e-03,  ...,  1.8732e-02,
          0.0000e+00, -1.9118e-01],
        [ 1.3023e-03, -1.0399e-03,  9.6528e-04,  ..., -4.8753e-02,
          1.9118e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  4.9464e-01, -2.9510e-01,  ...,  9.1413e-04,
         -7.3191e-04, -1.2925e-03],
        [-4.9464e-01,  0.0000e+00, -1.2437e-02,  ...,  4.0928e-04,
          1.5547e-03,  1.0324e-03],
        [ 2.9510e-01,  1.2437e-02,  0.0000e+00,  ...,  1.9163e-03,
         -1.3677e-03, -9.5898e-04],
        ...,
      

Gamma:  tensor([[ 0.0000e+00,  4.9658e-01, -2.8882e-01,  ...,  9.1478e-04,
         -6.7063e-04, -1.1703e-03],
        [-4.9658e-01,  0.0000e+00, -6.7682e-03,  ...,  4.5049e-04,
          1.3512e-03,  9.2024e-04],
        [ 2.8882e-01,  6.7682e-03,  0.0000e+00,  ...,  1.6950e-03,
         -1.2068e-03, -8.8191e-04],
        ...,
        [-9.1478e-04, -4.5049e-04, -1.6950e-03,  ...,  0.0000e+00,
         -1.5034e-02,  5.3737e-02],
        [ 6.7063e-04, -1.3512e-03,  1.2068e-03,  ...,  1.5034e-02,
          0.0000e+00, -1.9322e-01],
        [ 1.1703e-03, -9.2024e-04,  8.8191e-04,  ..., -5.3737e-02,
          1.9322e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Iteration = 800, Losses: rotation = 44027.2265625 
Gamma:  tensor([[ 0.0000e+00,  4.9672e-01, -2.8838e-01,  ...,  9.1520e-04,
         -6.6736e-04, -1.1623e-03],
        [-4.9672e-01,  0.0000e+00, -6.3732e-03,  ...,  4.5293e-04,
          1.3375e-03,  9.1321e-04],
        [ 2.8838e-01,  6.3732e-03,  0.0000e+00,  ...,  1.6796e-03,
    

Gamma:  tensor([[ 0.0000e+00,  4.9858e-01, -2.8239e-01,  ...,  9.1212e-04,
         -6.1333e-04, -1.0547e-03],
        [-4.9858e-01,  0.0000e+00, -9.7084e-04,  ...,  4.8384e-04,
          1.1661e-03,  8.1882e-04],
        [ 2.8239e-01,  9.7084e-04,  0.0000e+00,  ...,  1.4888e-03,
         -1.0566e-03, -8.0955e-04],
        ...,
        [-9.1212e-04, -4.8384e-04, -1.4888e-03,  ...,  0.0000e+00,
         -1.1814e-02,  5.8360e-02],
        [ 6.1333e-04, -1.1661e-03,  1.0566e-03,  ...,  1.1814e-02,
          0.0000e+00, -1.9512e-01],
        [ 1.0547e-03, -8.1882e-04,  8.0955e-04,  ..., -5.8360e-02,
          1.9512e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  4.9871e-01, -2.8197e-01,  ...,  9.1219e-04,
         -6.0820e-04, -1.0474e-03],
        [-4.9871e-01,  0.0000e+00, -5.9528e-04,  ...,  4.8535e-04,
          1.1537e-03,  8.1223e-04],
        [ 2.8197e-01,  5.9528e-04,  0.0000e+00,  ...,  1.4763e-03,
         -1.0479e-03, -8.0413e-04],
        ...,
      

Gamma:  tensor([[ 0.0000,  0.5005, -0.2762,  ...,  0.0009, -0.0006, -0.0010],
        [-0.5005,  0.0000,  0.0045,  ...,  0.0005,  0.0010,  0.0007],
        [ 0.2762, -0.0045,  0.0000,  ...,  0.0013, -0.0009, -0.0007],
        ...,
        [-0.0009, -0.0005, -0.0013,  ...,  0.0000, -0.0090,  0.0627],
        [ 0.0006, -0.0010,  0.0009,  ...,  0.0090,  0.0000, -0.1969],
        [ 0.0010, -0.0007,  0.0007,  ..., -0.0627,  0.1969,  0.0000]],
       grad_fn=<MulBackward0>)
Iteration = 830, Losses: rotation = 44027.0625 
Gamma:  tensor([[ 0.0000,  0.5006, -0.2758,  ...,  0.0009, -0.0006, -0.0009],
        [-0.5006,  0.0000,  0.0049,  ...,  0.0005,  0.0010,  0.0007],
        [ 0.2758, -0.0049,  0.0000,  ...,  0.0013, -0.0009, -0.0007],
        ...,
        [-0.0009, -0.0005, -0.0013,  ...,  0.0000, -0.0088,  0.0629],
        [ 0.0006, -0.0010,  0.0009,  ...,  0.0088,  0.0000, -0.1970],
        [ 0.0009, -0.0007,  0.0007,  ..., -0.0629,  0.1970,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:

Gamma:  tensor([[ 0.0000,  0.5027, -0.2692,  ...,  0.0009, -0.0005, -0.0008],
        [-0.5027,  0.0000,  0.0108,  ...,  0.0005,  0.0008,  0.0006],
        [ 0.2692, -0.0108,  0.0000,  ...,  0.0011, -0.0008, -0.0007],
        ...,
        [-0.0009, -0.0005, -0.0011,  ...,  0.0000, -0.0062,  0.0674],
        [ 0.0005, -0.0008,  0.0008,  ...,  0.0062,  0.0000, -0.1988],
        [ 0.0008, -0.0006,  0.0007,  ..., -0.0674,  0.1988,  0.0000]],
       grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  5.0281e-01, -2.6882e-01,  ...,  8.9624e-04,
         -5.0266e-04, -8.4221e-04],
        [-5.0281e-01,  0.0000e+00,  1.1133e-02,  ...,  5.3008e-04,
          8.4229e-04,  6.3667e-04],
        [ 2.6882e-01, -1.1133e-02,  0.0000e+00,  ...,  1.1171e-03,
         -7.9022e-04, -6.6792e-04],
        ...,
        [-8.9624e-04, -5.3008e-04, -1.1171e-03,  ...,  0.0000e+00,
         -6.0295e-03,  6.7655e-02],
        [ 5.0266e-04, -8.4229e-04,  7.9022e-04,  ...,  6.0295e-03,
          0.0000e+00, -1.98

Gamma:  tensor([[ 0.0000e+00,  5.0421e-01, -2.6433e-01,  ...,  8.8774e-04,
         -4.6917e-04, -7.7988e-04],
        [-5.0421e-01,  0.0000e+00,  1.5071e-02,  ...,  5.3804e-04,
          7.5298e-04,  5.8728e-04],
        [ 2.6433e-01, -1.5071e-02,  0.0000e+00,  ...,  1.0138e-03,
         -7.1586e-04, -6.2515e-04],
        ...,
        [-8.8774e-04, -5.3804e-04, -1.0138e-03,  ...,  0.0000e+00,
         -4.4201e-03,  7.0586e-02],
        [ 4.6917e-04, -7.5298e-04,  7.1586e-04,  ...,  4.4201e-03,
          0.0000e+00, -2.0012e-01],
        [ 7.7988e-04, -5.8728e-04,  6.2515e-04,  ..., -7.0586e-02,
          2.0012e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  5.0432e-01, -2.6396e-01,  ...,  8.8695e-04,
         -4.6558e-04, -7.7606e-04],
        [-5.0432e-01,  0.0000e+00,  1.5391e-02,  ...,  5.3931e-04,
          7.4678e-04,  5.8206e-04],
        [ 2.6396e-01, -1.5391e-02,  0.0000e+00,  ...,  1.0042e-03,
         -7.1014e-04, -6.2244e-04],
        ...,
      

Gamma:  tensor([[ 0.0000e+00,  5.0577e-01, -2.5927e-01,  ...,  8.7468e-04,
         -4.3267e-04, -7.1553e-04],
        [-5.0577e-01,  0.0000e+00,  1.9455e-02,  ...,  5.4355e-04,
          6.6418e-04,  5.3445e-04],
        [ 2.5927e-01, -1.9455e-02,  0.0000e+00,  ...,  9.0469e-04,
         -6.3851e-04, -5.8071e-04],
        ...,
        [-8.7468e-04, -5.4355e-04, -9.0469e-04,  ...,  0.0000e+00,
         -2.7864e-03,  7.3799e-02],
        [ 4.3267e-04, -6.6418e-04,  6.3851e-04,  ...,  2.7864e-03,
          0.0000e+00, -2.0142e-01],
        [ 7.1553e-04, -5.3445e-04,  5.8071e-04,  ..., -7.3799e-02,
          2.0142e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  5.0588e-01, -2.5891e-01,  ...,  8.7333e-04,
         -4.3040e-04, -7.1114e-04],
        [-5.0588e-01,  0.0000e+00,  1.9760e-02,  ...,  5.4399e-04,
          6.5767e-04,  5.3167e-04],
        [ 2.5891e-01, -1.9760e-02,  0.0000e+00,  ...,  8.9661e-04,
         -6.3308e-04, -5.7695e-04],
        ...,
      

Gamma:  tensor([[ 0.0000e+00,  5.0736e-01, -2.5404e-01,  ...,  8.5942e-04,
         -3.9790e-04, -6.5391e-04],
        [-5.0736e-01,  0.0000e+00,  2.3905e-02,  ...,  5.4536e-04,
          5.8099e-04,  4.8652e-04],
        [ 2.5404e-01, -2.3905e-02,  0.0000e+00,  ...,  8.0405e-04,
         -5.6735e-04, -5.3660e-04],
        ...,
        [-8.5942e-04, -5.4536e-04, -8.0405e-04,  ...,  0.0000e+00,
         -1.2968e-03,  7.7015e-02],
        [ 3.9790e-04, -5.8099e-04,  5.6735e-04,  ...,  1.2968e-03,
          0.0000e+00, -2.0273e-01],
        [ 6.5391e-04, -4.8652e-04,  5.3660e-04,  ..., -7.7015e-02,
          2.0273e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Iteration = 890, Losses: rotation = 44026.77734375 
Gamma:  tensor([[ 0.0000e+00,  5.0747e-01, -2.5370e-01,  ...,  8.5788e-04,
         -3.9673e-04, -6.5011e-04],
        [-5.0747e-01,  0.0000e+00,  2.4193e-02,  ...,  5.4590e-04,
          5.7660e-04,  4.8423e-04],
        [ 2.5370e-01, -2.4193e-02,  0.0000e+00,  ...,  7.9727e-04,
   

Gamma:  tensor([[ 0.0000e+00,  5.0867e-01, -2.4966e-01,  ...,  8.4567e-04,
         -3.7134e-04, -6.0505e-04],
        [-5.0867e-01,  0.0000e+00,  2.7560e-02,  ...,  5.4370e-04,
          5.1977e-04,  4.5044e-04],
        [ 2.4966e-01, -2.7560e-02,  0.0000e+00,  ...,  7.2668e-04,
         -5.1266e-04, -5.0188e-04],
        ...,
        [-8.4567e-04, -5.4370e-04, -7.2668e-04,  ...,  0.0000e+00,
         -1.9962e-04,  7.9628e-02],
        [ 3.7134e-04, -5.1977e-04,  5.1266e-04,  ...,  1.9962e-04,
          0.0000e+00, -2.0378e-01],
        [ 6.0505e-04, -4.5044e-04,  5.0188e-04,  ..., -7.9628e-02,
          2.0378e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  5.0877e-01, -2.4933e-01,  ...,  8.4535e-04,
         -3.6919e-04, -6.0151e-04],
        [-5.0877e-01,  0.0000e+00,  2.7834e-02,  ...,  5.4316e-04,
          5.1661e-04,  4.4749e-04],
        [ 2.4933e-01, -2.7834e-02,  0.0000e+00,  ...,  7.2249e-04,
         -5.0919e-04, -4.9919e-04],
        ...,
      

Gamma:  tensor([[ 0.0000e+00,  5.0992e-01, -2.4542e-01,  ...,  8.3051e-04,
         -3.4585e-04, -5.6193e-04],
        [-5.0992e-01,  0.0000e+00,  3.1033e-02,  ...,  5.3867e-04,
          4.6841e-04,  4.1846e-04],
        [ 2.4542e-01, -3.1033e-02,  0.0000e+00,  ...,  6.5977e-04,
         -4.6389e-04, -4.7051e-04],
        ...,
        [-8.3051e-04, -5.3867e-04, -6.5977e-04,  ...,  0.0000e+00,
          7.3705e-04,  8.2091e-02],
        [ 3.4585e-04, -4.6841e-04,  4.6389e-04,  ..., -7.3705e-04,
          0.0000e+00, -2.0476e-01],
        [ 5.6193e-04, -4.1846e-04,  4.7051e-04,  ..., -8.2091e-02,
          2.0476e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  5.1001e-01, -2.4510e-01,  ...,  8.2865e-04,
         -3.4338e-04, -5.5937e-04],
        [-5.1001e-01,  0.0000e+00,  3.1294e-02,  ...,  5.3960e-04,
          4.6410e-04,  4.1643e-04],
        [ 2.4510e-01, -3.1294e-02,  0.0000e+00,  ...,  6.5474e-04,
         -4.6068e-04, -4.6848e-04],
        ...,
      

Gamma:  tensor([[ 0.0000e+00,  5.1119e-01, -2.4098e-01,  ...,  8.1302e-04,
         -3.2065e-04, -5.1959e-04],
        [-5.1119e-01,  0.0000e+00,  3.4578e-02,  ...,  5.3188e-04,
          4.1853e-04,  3.8938e-04],
        [ 2.4098e-01, -3.4578e-02,  0.0000e+00,  ...,  5.9470e-04,
         -4.1765e-04, -4.3882e-04],
        ...,
        [-8.1302e-04, -5.3188e-04, -5.9470e-04,  ...,  0.0000e+00,
          1.5855e-03,  8.4590e-02],
        [ 3.2065e-04, -4.1853e-04,  4.1765e-04,  ..., -1.5855e-03,
          0.0000e+00, -2.0575e-01],
        [ 5.1959e-04, -3.8938e-04,  4.3882e-04,  ..., -8.4590e-02,
          2.0575e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Iteration = 930, Losses: rotation = 44026.6171875 
Gamma:  tensor([[ 0.0000e+00,  5.1127e-01, -2.4067e-01,  ...,  8.1206e-04,
         -3.1899e-04, -5.1729e-04],
        [-5.1127e-01,  0.0000e+00,  3.4824e-02,  ...,  5.3213e-04,
          4.1541e-04,  3.8738e-04],
        [ 2.4067e-01, -3.4824e-02,  0.0000e+00,  ...,  5.8999e-04,
    

Gamma:  tensor([[ 0.0000e+00,  5.1229e-01, -2.3697e-01,  ...,  7.9485e-04,
         -2.9863e-04, -4.8352e-04],
        [-5.1229e-01,  0.0000e+00,  3.7693e-02,  ...,  5.2429e-04,
          3.7817e-04,  3.6431e-04],
        [ 2.3697e-01, -3.7693e-02,  0.0000e+00,  ...,  5.4229e-04,
         -3.8038e-04, -4.1096e-04],
        ...,
        [-7.9485e-04, -5.2429e-04, -5.4229e-04,  ...,  0.0000e+00,
          2.2419e-03,  8.6774e-02],
        [ 2.9863e-04, -3.7817e-04,  3.8038e-04,  ..., -2.2419e-03,
          0.0000e+00, -2.0661e-01],
        [ 4.8352e-04, -3.6431e-04,  4.1096e-04,  ..., -8.6774e-02,
          2.0661e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  5.1237e-01, -2.3667e-01,  ...,  7.9379e-04,
         -2.9746e-04, -4.8134e-04],
        [-5.1237e-01,  0.0000e+00,  3.7925e-02,  ...,  5.2463e-04,
          3.7630e-04,  3.6299e-04],
        [ 2.3667e-01, -3.7925e-02,  0.0000e+00,  ...,  5.3718e-04,
         -3.7717e-04, -4.0967e-04],
        ...,
      

Gamma:  tensor([[ 0.0000e+00,  5.1333e-01, -2.3308e-01,  ...,  7.7716e-04,
         -2.7861e-04, -4.5079e-04],
        [-5.1333e-01,  0.0000e+00,  4.0641e-02,  ...,  5.1431e-04,
          3.4429e-04,  3.4248e-04],
        [ 2.3308e-01, -4.0641e-02,  0.0000e+00,  ...,  4.9500e-04,
         -3.4645e-04, -3.8640e-04],
        ...,
        [-7.7716e-04, -5.1431e-04, -4.9500e-04,  ...,  0.0000e+00,
          2.7865e-03,  8.8840e-02],
        [ 2.7861e-04, -3.4429e-04,  3.4645e-04,  ..., -2.7865e-03,
          0.0000e+00, -2.0742e-01],
        [ 4.5079e-04, -3.4248e-04,  3.8640e-04,  ..., -8.8840e-02,
          2.0742e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  5.1340e-01, -2.3278e-01,  ...,  7.7617e-04,
         -2.7627e-04, -4.4940e-04],
        [-5.1340e-01,  0.0000e+00,  4.0862e-02,  ...,  5.1382e-04,
          3.4119e-04,  3.4050e-04],
        [ 2.3278e-01, -4.0862e-02,  0.0000e+00,  ...,  4.9102e-04,
         -3.4413e-04, -3.8508e-04],
        ...,
      

Gamma:  tensor([[ 0.0000e+00,  5.1436e-01, -2.2899e-01,  ...,  7.5754e-04,
         -2.5781e-04, -4.1898e-04],
        [-5.1436e-01,  0.0000e+00,  4.3636e-02,  ...,  5.0354e-04,
          3.1185e-04,  3.2292e-04],
        [ 2.2899e-01, -4.3636e-02,  0.0000e+00,  ...,  4.4844e-04,
         -3.1394e-04, -3.6165e-04],
        ...,
        [-7.5754e-04, -5.0354e-04, -4.4844e-04,  ...,  0.0000e+00,
          3.2611e-03,  9.0944e-02],
        [ 2.5781e-04, -3.1185e-04,  3.1394e-04,  ..., -3.2611e-03,
          0.0000e+00, -2.0823e-01],
        [ 4.1898e-04, -3.2292e-04,  3.6165e-04,  ..., -9.0944e-02,
          2.0823e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Iteration = 970, Losses: rotation = 44026.48046875 
Gamma:  tensor([[ 0.0000e+00,  5.1443e-01, -2.2870e-01,  ...,  7.5618e-04,
         -2.5620e-04, -4.1677e-04],
        [-5.1443e-01,  0.0000e+00,  4.3842e-02,  ...,  5.0142e-04,
          3.0935e-04,  3.2087e-04],
        [ 2.2870e-01, -4.3842e-02,  0.0000e+00,  ...,  4.4631e-04,
   

Gamma:  tensor([[ 0.0000e+00,  5.1532e-01, -2.2501e-01,  ...,  7.3672e-04,
         -2.3804e-04, -3.8895e-04],
        [-5.1532e-01,  0.0000e+00,  4.6451e-02,  ...,  4.9053e-04,
          2.8498e-04,  3.0396e-04],
        [ 2.2501e-01, -4.6451e-02,  0.0000e+00,  ...,  4.0994e-04,
         -2.8552e-04, -3.3804e-04],
        ...,
        [-7.3672e-04, -4.9053e-04, -4.0994e-04,  ...,  0.0000e+00,
          3.6340e-03,  9.2927e-02],
        [ 2.3804e-04, -2.8498e-04,  2.8552e-04,  ..., -3.6340e-03,
          0.0000e+00, -2.0898e-01],
        [ 3.8895e-04, -3.0396e-04,  3.3804e-04,  ..., -9.2927e-02,
          2.0898e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  5.1538e-01, -2.2472e-01,  ...,  7.3523e-04,
         -2.3716e-04, -3.8650e-04],
        [-5.1538e-01,  0.0000e+00,  4.6644e-02,  ...,  4.9009e-04,
          2.8157e-04,  3.0217e-04],
        [ 2.2472e-01, -4.6644e-02,  0.0000e+00,  ...,  4.0649e-04,
         -2.8342e-04, -3.3857e-04],
        ...,
      

Gamma:  tensor([[ 0.0000e+00,  5.1618e-01, -2.2112e-01,  ...,  7.1685e-04,
         -2.2109e-04, -3.6249e-04],
        [-5.1618e-01,  0.0000e+00,  4.9089e-02,  ...,  4.7698e-04,
          2.6014e-04,  2.8652e-04],
        [ 2.2112e-01, -4.9089e-02,  0.0000e+00,  ...,  3.7209e-04,
         -2.6082e-04, -3.1763e-04],
        ...,
        [-7.1685e-04, -4.7698e-04, -3.7209e-04,  ...,  0.0000e+00,
          3.9196e-03,  9.4801e-02],
        [ 2.2109e-04, -2.6014e-04,  2.6082e-04,  ..., -3.9196e-03,
          0.0000e+00, -2.0968e-01],
        [ 3.6249e-04, -2.8652e-04,  3.1763e-04,  ..., -9.4801e-02,
          2.0968e-01,  0.0000e+00]], grad_fn=<MulBackward0>)
Gamma:  tensor([[ 0.0000e+00,  5.1624e-01, -2.2085e-01,  ...,  7.1545e-04,
         -2.1963e-04, -3.6140e-04],
        [-5.1624e-01,  0.0000e+00,  4.9271e-02,  ...,  4.7676e-04,
          2.5790e-04,  2.8586e-04],
        [ 2.2085e-01, -4.9271e-02,  0.0000e+00,  ...,  3.7085e-04,
         -2.5837e-04, -3.1606e-04],
        ...,
      

In [11]:
##### GET DATA ####
import os
import torch
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

seed=1234
# set random seed
np.random.seed(seed)
torch.manual_seed(seed)

input_dim = 1000
hidden_dim = 20

n_data = 5000
batch_size = n_data

max_sv = float(input_dim) * 0.1
min_sv = 1.0
sigma = 0.5

gt_data = DataGeneratorPCA(input_dim, hidden_dim, min_sv=min_sv, max_sv=max_sv, total=n_data)
data = DataGeneratorPCA(input_dim, hidden_dim, load_data=gt_data.x_sample)

loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=False)

#### Define the model ####

import os
import torch
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


#### DEFINE MODEL #####
model_dict = dict(
    model_name='nd_expectation',
    model_type='nested_dropout',
    model_class=LinearAENestedDropout,
    extra_model_args = {'use_expectation': True},
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    init_scale=0.0001,
    optim_class=torch.optim.Adam,
    extra_optim_args={},
    lr=0.003,
    train_itr=3000,#50000,
    seed=seed
)

# model config contains the model 
model_config = ModelConfig(
        model_name=model_dict['model_name'],
        model_type=model_dict['model_type'],
        model_class=model_dict['model_class'],
        input_dim=model_dict['input_dim'], 
        hidden_dim=model_dict['hidden_dim'],
        init_scale=model_dict['init_scale'],
        extra_model_args=model_dict['extra_model_args'],
        optim_class=model_dict['optim_class'],
        lr=model_dict['lr'],
        extra_optim_args=model_dict['extra_optim_args']
    )

print(model_dict,'\n')
print(model_config.get_model(),'\n')
print(model_config.get_optimizer())

print('Transpose:', metric_transpose_theorem(model_config.get_model()),'\n') # how close encoder and decoder.T are 
print('Distance to axis-aligned solution:', metric_alignment(model_config.get_model(), data.eigvectors),'\n') # alignment of decoder columns to ground truth eigenvectors
print('Distance to optimal subspace):', metric_subspace(model_config.get_model(), data.eigvectors, data.eigs),'\n')





{'model_name': 'nd_expectation', 'model_type': 'nested_dropout', 'model_class': <class '__main__.LinearAENestedDropout'>, 'extra_model_args': {'use_expectation': True}, 'input_dim': 1000, 'hidden_dim': 20, 'init_scale': 0.0001, 'optim_class': <class 'torch.optim.adam.Adam'>, 'extra_optim_args': {}, 'lr': 0.003, 'train_itr': 3000, 'seed': 1234} 

LinearAENestedDropout(
  (encoder): Linear(in_features=1000, out_features=20, bias=False)
  (decoder): Linear(in_features=20, out_features=1000, bias=False)
) 

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    eps: 1e-08
    foreach: None
    lr: 0.003
    maximize: False
    weight_decay: 0
)
Transpose: 2.013526827795431e-05 

Distance to axis-aligned solution: 0.994424752923425 

Distance to optimal subspace): 0.9777534324698326 



In [12]:
modelIn = model_config.get_model()
print('Reconstrution Loss:', metric_recon_loss(modelIn, loader),'\n') # full batch loss
print('Loss:', metric_loss(modelIn, loader),'\n')
print('Transpose:', metric_transpose_theorem(modelIn),'\n') # how close encoder and decoder.T are 
print('Distance to axis-aligned solution:', metric_alignment(modelIn, data.eigvectors),'\n') # alignment of decoder columns to ground truth eigenvectors
print('Distance to optimal subspace):', metric_subspace(modelIn, data.eigvectors, data.eigs),'\n')

trained_model = train_models(data_loader=loader, train_itr=model_dict['train_itr'], metrics_dict=None, model_configs=model_config)

modelIn = model_config.get_model()
print('Reconstrution Loss:', metric_recon_loss(modelIn, loader),'\n') # full batch loss
print('Loss:', metric_loss(modelIn, loader),'\n')
print('Transpose:', metric_transpose_theorem(modelIn),'\n') # how close encoder and decoder.T are 
print('Distance to axis-aligned solution:', metric_alignment(modelIn, data.eigvectors),'\n') # alignment of decoder columns to ground truth eigenvectors
print('Distance to optimal subspace):', metric_subspace(modelIn, data.eigvectors, data.eigs),'\n')



Reconstrution Loss: 50555.9375 

Loss: 50544.8828125 

Transpose: 2.013526827795431e-05 

Distance to axis-aligned solution: 0.994424752923425 

Distance to optimal subspace): 0.9777534324698326 

Iteration = 1, Losses: nd_expectation = 50544.8828125 
Iteration = 10, Losses: nd_expectation = 50005.375 
Iteration = 20, Losses: nd_expectation = 49659.06640625 
Iteration = 30, Losses: nd_expectation = 49504.6328125 
Iteration = 40, Losses: nd_expectation = 49428.1875 
Iteration = 50, Losses: nd_expectation = 49391.37109375 
Iteration = 60, Losses: nd_expectation = 49371.125 
Iteration = 70, Losses: nd_expectation = 49358.13671875 
Iteration = 80, Losses: nd_expectation = 49350.11328125 
Iteration = 90, Losses: nd_expectation = 49344.70703125 
Iteration = 100, Losses: nd_expectation = 49340.89453125 
Iteration = 110, Losses: nd_expectation = 49338.078125 
Iteration = 120, Losses: nd_expectation = 49335.91796875 
Iteration = 130, Losses: nd_expectation = 49334.1875 
Iteration = 140, Losses:

Iteration = 1400, Losses: nd_expectation = 49319.25390625 
Iteration = 1410, Losses: nd_expectation = 49319.24609375 
Iteration = 1420, Losses: nd_expectation = 49319.24609375 
Iteration = 1430, Losses: nd_expectation = 49319.23828125 
Iteration = 1440, Losses: nd_expectation = 49319.234375 
Iteration = 1450, Losses: nd_expectation = 49319.23046875 
Iteration = 1460, Losses: nd_expectation = 49319.22265625 
Iteration = 1470, Losses: nd_expectation = 49319.21875 
Iteration = 1480, Losses: nd_expectation = 49319.21484375 
Iteration = 1490, Losses: nd_expectation = 49319.2109375 
Iteration = 1500, Losses: nd_expectation = 49319.2109375 
Iteration = 1510, Losses: nd_expectation = 49319.203125 
Iteration = 1520, Losses: nd_expectation = 49319.19921875 
Iteration = 1530, Losses: nd_expectation = 49319.1953125 
Iteration = 1540, Losses: nd_expectation = 49319.19140625 
Iteration = 1550, Losses: nd_expectation = 49319.1875 
Iteration = 1560, Losses: nd_expectation = 49319.1875 
Iteration = 157

Iteration = 2820, Losses: nd_expectation = 49319.02734375 
Iteration = 2830, Losses: nd_expectation = 49319.02734375 
Iteration = 2840, Losses: nd_expectation = 49319.02734375 
Iteration = 2850, Losses: nd_expectation = 49319.02734375 
Iteration = 2860, Losses: nd_expectation = 49319.0234375 
Iteration = 2870, Losses: nd_expectation = 49319.0234375 
Iteration = 2880, Losses: nd_expectation = 49319.02734375 
Iteration = 2890, Losses: nd_expectation = 49319.0234375 
Iteration = 2900, Losses: nd_expectation = 49319.02734375 
Iteration = 2910, Losses: nd_expectation = 49319.0234375 
Iteration = 2920, Losses: nd_expectation = 49319.0234375 
Iteration = 2930, Losses: nd_expectation = 49319.01953125 
Iteration = 2940, Losses: nd_expectation = 49319.0234375 
Iteration = 2950, Losses: nd_expectation = 49319.01953125 
Iteration = 2960, Losses: nd_expectation = 49319.0234375 
Iteration = 2970, Losses: nd_expectation = 49319.0234375 
Iteration = 2980, Losses: nd_expectation = 49319.0234375 
Iterat