#### DEFINE DATA LOADER FUNCTIONS

In [1]:
#### DEFINE DATA LOADER FUNCTIONS ####

from torch.utils.data import Dataset

import numpy as np
from scipy.stats import ortho_group

class DataGeneratorPPCA(Dataset):

    def __init__(self, dims, hdims, min_sv=0.11, max_sv=5.0, sigma_sq=0.1, deterministic=True, total=10000):
        self.dims = dims
        self.hdims = hdims

        self.eigs = min_sv + (max_sv - min_sv) * np.linspace(0, 1, hdims)
        self.eigvectors = ortho_group.rvs(dims)[:, :hdims]
        self.w = np.matmul(self.eigvectors, np.diag(np.sqrt(self.eigs - sigma_sq)))

        self.sigma_sq = sigma_sq
        self.sigma = np.sqrt(sigma_sq)

        self.total = total
        self.deterministic = deterministic
        if self.deterministic:
            self.z_sample = np.random.normal(size=(total, self.hdims))
            self.x_sample = np.random.normal(np.matmul(self.z_sample, self.w.T), self.sigma).astype(np.float32)

    def __getitem__(self, i):
        if self.deterministic:
            return self.x_sample[i]
        else:
            z_sample = np.random.normal(size=self.hdims)
            return np.random.normal(self.w.dot(z_sample), self.sigma).astype(np.float32)

    def __len__(self):
        # Return a large number for an epoch
        return self.total


class DataGeneratorPCA(Dataset):
    def __init__(self, dims, hdims, min_sv=0.11, max_sv=5.0, total=10000, sv_list=None,
                 load_data=None):
        self.dims = dims
        self.hdims = hdims

        if load_data is None:
            if isinstance(sv_list, list):
                assert len(sv_list) == dims
                self.full_eigs = np.array(sorted(sv_list, reverse=True))
            else:
                self.full_eigs = min_sv + (max_sv - min_sv) * np.linspace(1, 0, dims)
            self.eigs = self.full_eigs[:hdims]

            self.full_svs = np.sqrt(self.full_eigs)

            self.full_eigvectors = ortho_group.rvs(dims)
            self.eigvectors = self.full_eigvectors[:, :hdims]

            self.total = total

            self.full_z_sample = np.random.normal(size=(total, self.dims))
            self.x_sample = (self.full_eigvectors @ np.diag(self.full_svs) @ self.full_z_sample.T).T.astype(np.float32)

        else:
            self.x_sample = load_data
            u, s, vh = np.linalg.svd(self.x_sample.T, full_matrices=False)
            self.eigs = s[:self.hdims]
            self.eigvectors = u[:, :self.hdims]
            self.total = len(self.x_sample)

    def __getitem__(self, i):
        return self.x_sample[i]

    def __len__(self):
        return self.total

    @property
    def shape(self):
        return self.x_sample.shape


#### DEFINE MODEL CLASSES

In [2]:
#### DEFINE MODEL CLASSES ####

import os
import torch
import torch.nn as nn
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class ModelConfig:
    def __init__(self, model_name, model_type, model_class, input_dim, hidden_dim, init_scale, optim_class, lr,
                 extra_model_args={}, extra_optim_args={}):
        self.model_name = model_name
        self.model_type = model_type
        self.model_class = model_class
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.init_scale = init_scale
        self.extra_model_args = extra_model_args

        self.optim_class = optim_class
        self.lr = lr
        self.extra_optim_args = extra_optim_args
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model = model_class(input_dim=input_dim, hidden_dim=hidden_dim, init_scale=init_scale, **extra_model_args).to(device)

        self.optimizer = optim_class(self.model.parameters(), lr=lr, **extra_optim_args)

    @property
    def name(self):
        return self.model_name

    @property
    def type(self):
        return self.model_type

    def get_model(self):
        return self.model

    def get_optimizer(self):
        return self.optimizer

class LinearAE(nn.Module):
    def __init__(self,
                 input_dim, hidden_dim, init_scale=0.001,
                 weight_reg_type=None, l2_reg_list=None):
        super(LinearAE, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.encoder = nn.Linear(input_dim, hidden_dim, bias=False)
        self.decoder = nn.Linear(hidden_dim, input_dim, bias=False)

        self.weight_reg_type = weight_reg_type
        self.l2_reg_scalar = None
        self.l2_reg_list = l2_reg_list

        self.encoder.weight.data.normal_(0.0, init_scale)
        self.decoder.weight.data.normal_(0.0, init_scale)

        # configure regularization parameters

        assert self.weight_reg_type is None or isinstance(self.l2_reg_list, list), \
            "l2_reg_list must be a list if weight_reg_type is not None"

        assert self.l2_reg_list is None or len(self.l2_reg_list) == hidden_dim, \
            "Length of l2_reg_list must match latent dimension"

        if weight_reg_type in ("uniform_product", "uniform_sum"):
            self.l2_reg_scalar = l2_reg_list[0] ** 2    # more efficient to use scalar than diag_weights

        elif weight_reg_type == "non_uniform_sum":
            self.reg_weights = torch.tensor(
                np.array(self.l2_reg_list).astype(np.float32)
            )
            self.diag_weights = nn.Parameter(torch.diag(self.reg_weights), requires_grad=False)

    def forward(self, x):
        return self.get_reconstruction_loss(x) + self._get_reg_loss()

    def compute_trace_norm(self):
        """
        Computes the trace norm of the autoencoder, as well as decoder and encoder individually
        :return: trace_norm(W2W1), trace_norm(W1), trace_norm(W2)
        """
        return torch.matmul(self.decoder.weight, self.encoder.weight).norm(p='nuc'), \
               self.encoder.weight.norm(p='nuc'), \
               self.decoder.weight.norm(p='nuc'),

    def get_reconstruction_loss(self, x):
        z = self.encoder(x)
        recon = self.decoder(z)

        recon_loss = torch.sum((x - recon) ** 2) / len(x)
        return recon_loss

    def get_reg_weights_np(self):
        if self.weight_reg_type is None:
            return np.zeros(self.hidden_dim)
        return np.array(self.l2_reg_list)

    def _get_reg_loss(self):
        # Standard L2 regularization, applied to W2W1 (product loss)
        if self.weight_reg_type == 'uniform_product':
            return self.l2_reg_scalar * (torch.norm(torch.matmul(self.decoder.weight, self.encoder.weight)) ** 2)

        # Standard L2 regularization for encoder and decoder separately (sum loss)
        elif self.weight_reg_type == 'uniform_sum':
            # regularize both encoder and decoder
            return self.l2_reg_scalar * (torch.norm(self.encoder.weight) ** 2 + torch.norm(self.decoder.weight) ** 2)

        # non-uniform sum
        elif self.weight_reg_type == 'non_uniform_sum':
            return torch.norm(self.diag_weights @ self.encoder.weight) ** 2 \
                   + torch.norm(self.decoder.weight @ self.diag_weights) ** 2

        # Do not apply regularization
        elif self.weight_reg_type is None:
            return 0.0

        else:
            raise ValueError("weight_reg_type should be one of (uniform_product, uniform_sum, non_uniform_sum, None)")
            


#### DEFINE MODEL TRAINING FUNCTION train_models

In [11]:
#### DEFINE MODEL TRAINING FUNCTION train_models ####

import os
import torch
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def train_models(data_loader, train_itr, metrics_dict, model_configs, eval_metrics_list=None):
    # Initialize model
#     print('new')
    model = model_config.get_model()
    optimizer = model_config.get_optimizer()
    
    for train_i in range(train_itr):
        for x in data_loader:
            x_cuda = x.to(device)

            # ---- Optimize ----
            losses = {}

#             model = model_config.get_model()
#             optimizer = model_config.get_optimizer()

            optimizer.zero_grad()

            loss = model(x_cuda)

            loss.backward()
            
            print('Before rotation: ', model.decoder.weight.grad)

            # ROTATION
            y = model.encoder.weight @ x_cuda.T
            yy_t_norm = y @ y.T / float(len(x))
            yy_t_upper = yy_t_norm - yy_t_norm.tril()
            gamma = 0.5 * (yy_t_upper - yy_t_upper.T)
            print('Gamma: ', gamma)
            model.encoder.weight.grad -= gamma @ model.encoder.weight
            model.decoder.weight.grad -= model.decoder.weight @ gamma.T
            
            print('After rotation: ', model.decoder.weight.grad)

            optimizer.step()

            losses[model_config.name] = loss.item()

        # ---- Log statistics ----
        if train_i == 0 or (train_i + 1) % 10 == 0:
            print("".join(["Iteration = {}, Losses: ".format(train_i + 1)]
                          + ["{} = {} ".format(key, val) for key, val in losses.items()]))
            
#     model_config.model = model
    
    return model


In [4]:
# x=next(enumerate(loader))[1]
# x_cuda = x.to(device)
# print(model.encoder.weight.grad)
# y = model.encoder.weight @ x_cuda.T
# yy_t_norm = y @ y.T / float(len(x))
# yy_t_upper = yy_t_norm - yy_t_norm.tril()
# gamma = 0.5 * (yy_t_upper - yy_t_upper.T)
# model.encoder.weight.grad -= gamma @ model.encoder.weight
# print(model.encoder.weight.grad)


#### DEFINE EVALUATION METRICS

In [5]:
import os
import torch
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def get_weight_tensor_from_seq(weight_seq):
    if isinstance(weight_seq, nn.Linear):
        return weight_seq.weight.detach()
    elif isinstance(weight_seq, nn.Sequential):
        weight_tensor = None
        for layer in weight_seq:
            if isinstance(layer, nn.Linear):
                layer_weight = layer.weight.detach()
                if weight_tensor is None:
                    weight_tensor = layer_weight
                else:
                    weight_tensor = layer_weight @ weight_tensor
            elif isinstance(layer, nn.BatchNorm1d):
                bn_weight = layer.weight.detach()

                # ignore bias

                if weight_tensor is None:
                    weight_tensor = torch.diag(bn_weight)
                else:
                    weight_tensor = torch.diag(bn_weight) @ weight_tensor
            else:
                raise ValueError("Layer type {} not supported!".format(type(layer)))
        return weight_tensor


def metric_transpose_theorem(model):
    """
    Metric for how close encoder and decoder.T are
    :param model: LinearAE model
    :return: ||W1 - W2^T||_F^2 / hidden_dim
    """
    encoder_weight = get_weight_tensor_from_seq(model.encoder)
    decoder_weight = get_weight_tensor_from_seq(model.decoder)

    transpose_metric = torch.norm(encoder_weight - decoder_weight.T) ** 2
    return transpose_metric.item() / float(model.hidden_dim)


def metric_alignment(model, gt_eigvectors):
    """
    Metric for alignment of decoder columns to ground truth eigenvectors
    :param model: Linear AE model
    :param gt_eigvectors: ground truth eigenvectors (input_dims,hidden_dims)
    :return: sum_i (1 - max_j (cos(eigvector_i, normalized_decoder column_j)))
    """
    decoder_weight = get_weight_tensor_from_seq(model.decoder)
    decoder_np = decoder_weight.detach().cpu().numpy()

    # normalize columns of gt_eigvectors
    norm_gt_eigvectors = gt_eigvectors / np.linalg.norm(gt_eigvectors, axis=0)
    # normalize columns of decoder
    norm_decoder = decoder_np / (np.linalg.norm(decoder_np, axis=0) + 1e-8)

    total_angles = 0.0
    for eig_i in range(gt_eigvectors.shape[1]):
        eigvector = norm_gt_eigvectors[:, eig_i]
        total_angles += 1. - np.max(np.abs(norm_decoder.T @ eigvector)) ** 2

    return total_angles / float(model.hidden_dim)


def metric_subspace(model, gt_eigvectors, gt_eigs):
    decoder_weight = get_weight_tensor_from_seq(model.decoder)
    decoder_np = decoder_weight.detach().cpu().numpy()

    # k - tr(UU^T WW^T), where W is left singular vector matrix of decoder
    u, s, vh = np.linalg.svd(decoder_np, full_matrices=False)
    return 1 - np.trace(gt_eigvectors @ gt_eigvectors.T @ u @ u.T) / float(model.hidden_dim)


def metric_loss(model, data_loader):
    """
    Measures the full batch loss
    :param model: a linear (variational) AE model
    :param data_loader: full batch data loader. Should be different from the training data loader, if in minibatch mode
    """
    loss = None
    for x in data_loader:
        loss = model(x.to(device)).item()
    return loss


def metric_recon_loss(model, data_loader):
    recon_loss = None
    for x in data_loader:
        recon_loss = model.get_reconstruction_loss(x.to(device)).item()
    return recon_loss


### TRAIN A MODEL

#### Get the data

In [6]:
##### GET DATA ####
import os
import torch
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

seed=1234
# set random seed
np.random.seed(seed)
torch.manual_seed(seed)

input_dim = 1000
hidden_dim = 400

n_data = 5000
batch_size = n_data

max_sv = float(input_dim) * 0.1
min_sv = 1.0
sigma = 0.5

gt_data = DataGeneratorPCA(input_dim, hidden_dim, min_sv=min_sv, max_sv=max_sv, total=n_data)
data = DataGeneratorPCA(input_dim, hidden_dim, load_data=gt_data.x_sample)

loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=False)


#### Define the model

In [7]:
#### Define the model ####

import os
import torch
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#### DEFINE MODEL #####
model_dict = dict(
    model_name='rotation',
    model_type='rotation',
    model_class=LinearAE,
    extra_model_args = {"weight_reg_type": None},
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    init_scale=0.0001,
    optim_class=torch.optim.SGD,
    extra_optim_args={'momentum': 0.9, 'nesterov': True},
    lr=0.0001,
#     optim_class=torch.optim.Adam,
#     extra_optim_args={},
#     lr=0.0003,
    train_itr=3000,#50000,
    seed=seed
)

# model config contains the model 
model_config = ModelConfig(
        model_name=model_dict['model_name'],
        model_type=model_dict['model_type'],
        model_class=model_dict['model_class'],
        input_dim=model_dict['input_dim'], 
        hidden_dim=model_dict['hidden_dim'],
        init_scale=model_dict['init_scale'],
        extra_model_args=model_dict['extra_model_args'],
        optim_class=model_dict['optim_class'],
        lr=model_dict['lr'],
        extra_optim_args=model_dict['extra_optim_args']
    )

print(model_dict,'\n')
print(model_config.get_model(),'\n')
print(model_config.get_optimizer())

print('Transpose:', metric_transpose_theorem(model_config.get_model()),'\n') # how close encoder and decoder.T are 
print('Distance to axis-aligned solution:', metric_alignment(model_config.get_model(), gt_data.eigvectors),'\n') # alignment of decoder columns to ground truth eigenvectors
print('Distance to optimal subspace):', metric_subspace(model_config.get_model(), gt_data.eigvectors, gt_data.eigs),'\n')



{'model_name': 'rotation', 'model_type': 'rotation', 'model_class': <class '__main__.LinearAE'>, 'extra_model_args': {'weight_reg_type': None}, 'input_dim': 1000, 'hidden_dim': 400, 'init_scale': 0.0001, 'optim_class': <class 'torch.optim.sgd.SGD'>, 'extra_optim_args': {'momentum': 0.9, 'nesterov': True}, 'lr': 0.0001, 'train_itr': 3000, 'seed': 1234} 

LinearAE(
  (encoder): Linear(in_features=1000, out_features=400, bias=False)
  (decoder): Linear(in_features=400, out_features=1000, bias=False)
) 

SGD (
Parameter Group 0
    dampening: 0
    foreach: None
    lr: 0.0001
    maximize: False
    momentum: 0.9
    nesterov: True
    weight_decay: 0
)
Transpose: 1.9994410686194896e-05 

Distance to axis-aligned solution: 0.9899812176786509 

Distance to optimal subspace): 0.6012058744201634 



In [12]:
modelIn = model_config.get_model()
print('Reconstrution Loss:', metric_recon_loss(modelIn, loader),'\n') # full batch loss
print('Loss:', metric_loss(modelIn, loader),'\n')
print('Transpose:', metric_transpose_theorem(modelIn),'\n') # how close encoder and decoder.T are 
print('Distance to axis-aligned solution:', metric_alignment(modelIn, gt_data.eigvectors),'\n') # alignment of decoder columns to ground truth eigenvectors
print('Distance to optimal subspace):', metric_subspace(modelIn, gt_data.eigvectors, gt_data.eigs),'\n')

trained_model = train_models(data_loader=loader, train_itr=100, metrics_dict=None, model_configs=model_config)

modelIn = model_config.get_model()
print('Reconstrution Loss:', metric_recon_loss(modelIn, loader),'\n') # full batch loss
print('Loss:', metric_loss(modelIn, loader),'\n')
print('Transpose:', metric_transpose_theorem(modelIn),'\n') # how close encoder and decoder.T are 
print('Distance to axis-aligned solution:', metric_alignment(modelIn, gt_data.eigvectors),'\n') # alignment of decoder columns to ground truth eigenvectors
print('Distance to optimal subspace):', metric_subspace(modelIn, gt_data.eigvectors, gt_data.eigs),'\n')



Reconstrution Loss: 17339.2109375 

Loss: 17339.2109375 

Transpose: 0.0009765464812517167 

Distance to axis-aligned solution: 0.9782801040323659 

Distance to optimal subspace): 0.23624933546645333 

Before rotation:  tensor([[-0.4204, -0.2094, -0.0861,  ...,  0.3865, -0.4836,  0.0648],
        [ 0.2481, -0.0397,  0.4613,  ...,  1.1730,  0.1908,  1.6082],
        [-0.0101,  0.0526, -0.3319,  ..., -0.5726, -0.4435, -1.1281],
        ...,
        [-0.3594, -0.1234,  0.1040,  ...,  1.0016,  0.4299,  1.1393],
        [-0.1519, -0.0150,  0.0352,  ...,  0.8421,  1.2439, -0.8766],
        [-0.1648,  0.1782, -0.3119,  ..., -0.0118,  0.2491,  1.2206]])
Gamma:  tensor([[ 0.0000, -1.0953, -0.2407,  ..., -0.1197,  0.1726,  0.2847],
        [ 1.0953,  0.0000,  0.2409,  ..., -0.1833, -0.1547,  0.2392],
        [ 0.2407, -0.2409,  0.0000,  ..., -0.2412,  0.2063, -0.0773],
        ...,
        [ 0.1197,  0.1833,  0.2412,  ...,  0.0000,  2.0928, -1.0289],
        [-0.1726,  0.1547, -0.2063,  ..., -2.

Before rotation:  tensor([[ 0.0871, -0.2177,  0.0109,  ...,  1.0200, -0.4090, -0.0132],
        [-0.0264,  0.0553,  0.0684,  ...,  1.5591, -0.2567,  1.4992],
        [-0.1773, -0.0898, -0.0547,  ...,  0.3533, -0.0113, -1.6889],
        ...,
        [-0.2237, -0.0222, -0.0228,  ...,  0.1921, -0.1642,  0.8790],
        [-0.0742, -0.0246, -0.1391,  ...,  0.3208,  0.2448, -1.1624],
        [ 0.0590, -0.0020, -0.1557,  ...,  0.8860, -0.5686,  1.5896]])
Gamma:  tensor([[ 0.0000, -1.1080, -0.1975,  ..., -0.0036,  0.1958,  0.1213],
        [ 1.1080,  0.0000,  0.4038,  ..., -0.1301, -0.0430,  0.1525],
        [ 0.1975, -0.4038,  0.0000,  ..., -0.0658,  0.2113,  0.0197],
        ...,
        [ 0.0036,  0.1301,  0.0658,  ...,  0.0000,  1.7876, -1.5745],
        [-0.1958,  0.0430, -0.2113,  ..., -1.7876,  0.0000,  0.2452],
        [-0.1213, -0.1525, -0.0197,  ...,  1.5745, -0.2452,  0.0000]],
       grad_fn=<MulBackward0>)
After rotation:  tensor([[-0.1441, -0.1650,  0.1831,  ...,  1.1459, -0.1930

Before rotation:  tensor([[ 0.1101,  0.0060,  0.0278,  ...,  0.9195, -0.6073,  0.4821],
        [-0.0532,  0.1228, -0.0683,  ...,  1.1505, -0.1193,  0.9147],
        [-0.0471, -0.1141, -0.0453,  ...,  0.7935,  0.1234, -0.9759],
        ...,
        [-0.0371,  0.0728,  0.0028,  ..., -0.3329, -0.2094,  0.1553],
        [ 0.0195,  0.0747, -0.0868,  ..., -0.2560, -0.4807, -0.6456],
        [ 0.0591, -0.0403, -0.0019,  ...,  1.1587, -1.0952,  1.0896]])
Gamma:  tensor([[ 0.0000, -1.1129, -0.1591,  ..., -0.0414,  0.1304,  0.0628],
        [ 1.1129,  0.0000,  0.4957,  ..., -0.0680,  0.0209,  0.0469],
        [ 0.1591, -0.4957,  0.0000,  ..., -0.0403,  0.2314, -0.0133],
        ...,
        [ 0.0414,  0.0680,  0.0403,  ...,  0.0000,  1.1413, -1.2447],
        [-0.1304, -0.0209, -0.2314,  ..., -1.1413,  0.0000, -0.0593],
        [-0.0628, -0.0469,  0.0133,  ...,  1.2447,  0.0593,  0.0000]],
       grad_fn=<MulBackward0>)
After rotation:  tensor([[-1.3622e-01,  5.5424e-02,  1.9800e-01,  ...,  1.1

Before rotation:  tensor([[-0.0402,  0.1090,  0.0287,  ...,  0.4469, -0.5417,  0.4010],
        [-0.0029,  0.0183, -0.0636,  ...,  0.0698, -0.0891,  0.0591],
        [ 0.0408, -0.0852, -0.0415,  ...,  0.8654,  0.2050,  0.2449],
        ...,
        [ 0.0419,  0.0592,  0.0348,  ..., -0.2132, -0.0510, -0.4394],
        [ 0.0838,  0.0183,  0.0084,  ..., -0.5087, -0.6569,  0.2508],
        [-0.0480, -0.0033,  0.0279,  ...,  0.6285, -0.6877, -0.0151]])
Gamma:  tensor([[ 0.0000, -1.0690, -0.1095,  ..., -0.0779,  0.0685,  0.0641],
        [ 1.0690,  0.0000,  0.5850,  ..., -0.0299,  0.0271, -0.0111],
        [ 0.1095, -0.5850,  0.0000,  ..., -0.0434,  0.1903, -0.0602],
        ...,
        [ 0.0779,  0.0299,  0.0434,  ...,  0.0000,  0.4767, -0.5691],
        [-0.0685, -0.0271, -0.1903,  ..., -0.4767,  0.0000, -0.2741],
        [-0.0641,  0.0111,  0.0602,  ...,  0.5691,  0.2741,  0.0000]],
       grad_fn=<MulBackward0>)
After rotation:  tensor([[-0.2628,  0.1690,  0.1987,  ...,  0.6405, -0.6016

Before rotation:  tensor([[-0.0939,  0.0510,  0.0023,  ..., -0.0295, -0.1501, -0.1256],
        [-0.0223, -0.0305, -0.0258,  ..., -0.6390,  0.0120, -0.1824],
        [ 0.0201, -0.0468, -0.0192,  ...,  0.4461,  0.2105,  0.6477],
        ...,
        [ 0.0189,  0.0219,  0.0118,  ..., -0.0279,  0.0103, -0.4712],
        [ 0.0696, -0.0247,  0.0330,  ..., -0.3475, -0.3358,  0.4642],
        [-0.0879,  0.0117, -0.0050,  ..., -0.0677, -0.1409, -0.4617]])
Gamma:  tensor([[ 0.0000, -1.0179, -0.0731,  ..., -0.0696,  0.0539,  0.0592],
        [ 1.0179,  0.0000,  0.6663,  ..., -0.0309,  0.0135, -0.0119],
        [ 0.0731, -0.6663,  0.0000,  ..., -0.0450,  0.1362, -0.0667],
        ...,
        [ 0.0696,  0.0309,  0.0450,  ...,  0.0000,  0.2206, -0.1859],
        [-0.0539, -0.0135, -0.1362,  ..., -0.2206,  0.0000, -0.2690],
        [-0.0592,  0.0119,  0.0667,  ...,  0.1859,  0.2690,  0.0000]],
       grad_fn=<MulBackward0>)
After rotation:  tensor([[-0.3144,  0.1374,  0.1746,  ...,  0.1082, -0.1752

Before rotation:  tensor([[-6.9593e-02, -5.7224e-03, -1.4952e-02,  ..., -3.3617e-01,
          2.2186e-01, -2.7235e-01],
        [-4.7253e-02,  1.0103e-02, -3.2079e-03,  ..., -6.7844e-01,
          1.9498e-01,  6.5742e-02],
        [-1.1601e-02, -2.6732e-02, -3.7849e-02,  ..., -1.1042e-02,
          9.5302e-02,  3.9753e-01],
        ...,
        [-1.7137e-02,  1.0144e-02, -1.4534e-02,  ..., -1.5821e-02,
         -6.1762e-02, -1.6721e-01],
        [ 1.6885e-02, -6.2850e-04, -3.2735e-03,  ..., -9.4363e-02,
         -4.2333e-02,  1.0391e-01],
        [-4.8006e-02,  2.3472e-03, -1.2577e-02,  ..., -5.0976e-01,
          1.4256e-01, -2.6812e-01]])
Gamma:  tensor([[ 0.0000, -0.9646, -0.0370,  ..., -0.0370,  0.0571,  0.0282],
        [ 0.9646,  0.0000,  0.7629,  ..., -0.0422,  0.0024, -0.0023],
        [ 0.0370, -0.7629,  0.0000,  ..., -0.0389,  0.0845, -0.0522],
        ...,
        [ 0.0370,  0.0422,  0.0389,  ...,  0.0000,  0.3300, -0.1833],
        [-0.0571, -0.0024, -0.0845,  ..., -0.3300

Before rotation:  tensor([[-0.0506,  0.0047, -0.0037,  ..., -0.1757,  0.1739, -0.0748],
        [-0.0230,  0.0319, -0.0018,  ..., -0.2776,  0.2139,  0.3230],
        [-0.0072, -0.0341, -0.0547,  ..., -0.0174, -0.0034,  0.0736],
        ...,
        [-0.0124,  0.0207, -0.0109,  ..., -0.0064, -0.1104,  0.0028],
        [ 0.0114,  0.0276, -0.0196,  ..., -0.0286, -0.0693, -0.1132],
        [-0.0284, -0.0022,  0.0046,  ..., -0.3139,  0.0570,  0.0290]])
Gamma:  tensor([[ 0.0000e+00, -9.3290e-01, -1.6050e-02,  ..., -1.7907e-02,
          5.1358e-02,  1.0614e-02],
        [ 9.3290e-01,  0.0000e+00,  8.2641e-01,  ..., -3.9387e-02,
         -1.4046e-03, -6.2644e-04],
        [ 1.6050e-02, -8.2641e-01,  0.0000e+00,  ..., -2.2107e-02,
          5.7068e-02, -3.5098e-02],
        ...,
        [ 1.7907e-02,  3.9387e-02,  2.2107e-02,  ...,  0.0000e+00,
          5.0008e-01, -4.0065e-01],
        [-5.1358e-02,  1.4046e-03, -5.7068e-02,  ..., -5.0008e-01,
          0.0000e+00, -8.7882e-02],
        [-1.

Before rotation:  tensor([[-0.0472,  0.0139,  0.0018,  ..., -0.0047,  0.0376,  0.0457],
        [-0.0108,  0.0214, -0.0069,  ..., -0.0621,  0.1738,  0.3630],
        [-0.0012, -0.0358, -0.0458,  ...,  0.0919, -0.0280, -0.0646],
        ...,
        [-0.0024,  0.0258, -0.0027,  ...,  0.0065, -0.0899, -0.0149],
        [ 0.0182,  0.0268, -0.0141,  ..., -0.0547, -0.1534, -0.1552],
        [-0.0272,  0.0008,  0.0074,  ..., -0.0650, -0.1009,  0.1139]])
Gamma:  tensor([[ 0.0000e+00, -9.1078e-01, -4.4159e-03,  ..., -1.0736e-02,
          3.8076e-02,  6.2902e-03],
        [ 9.1078e-01,  0.0000e+00,  8.7003e-01,  ..., -2.9168e-02,
         -1.3794e-03,  2.1521e-05],
        [ 4.4159e-03, -8.7003e-01,  0.0000e+00,  ..., -8.5450e-03,
          4.0917e-02, -2.1561e-02],
        ...,
        [ 1.0736e-02,  2.9168e-02,  8.5450e-03,  ...,  0.0000e+00,
          5.6024e-01, -5.2940e-01],
        [-3.8076e-02,  1.3794e-03, -4.0917e-02,  ..., -5.6024e-01,
          0.0000e+00, -1.0574e-01],
        [-6.

After rotation:  tensor([[-2.9921e-01,  1.5609e-01,  1.5516e-01,  ...,  1.4044e-01,
         -1.4658e-02,  2.3136e-02],
        [ 1.2775e-01,  4.0893e-02,  1.4445e-01,  ...,  2.3088e-02,
          7.9400e-02,  5.0085e-01],
        [-2.0821e-01, -2.1341e-01,  1.0497e-02,  ...,  2.5867e-01,
         -1.5990e-01, -2.2209e-01],
        ...,
        [-5.4401e-02, -2.7400e-01, -2.3445e-02,  ...,  1.8715e-01,
         -1.3450e-01,  5.4762e-06],
        [-1.8232e-01, -2.3049e-01,  2.1768e-01,  ..., -4.1281e-02,
         -2.7992e-01, -2.1715e-01],
        [-3.4418e-01, -1.1227e-01,  1.4525e-01,  ...,  8.0244e-02,
         -1.4852e-01, -2.7789e-02]], grad_fn=<SubBackward0>)
Before rotation:  tensor([[-0.0441,  0.0122,  0.0010,  ...,  0.0591, -0.0549,  0.0752],
        [-0.0135,  0.0096, -0.0113,  ..., -0.0258,  0.1270,  0.2772],
        [-0.0003, -0.0297, -0.0315,  ...,  0.1795, -0.0147, -0.0712],
        ...,
        [ 0.0009,  0.0255,  0.0018,  ...,  0.0019, -0.0474, -0.0931],
        [ 0.0190

Before rotation:  tensor([[-4.0205e-02,  8.8777e-03, -1.1613e-04,  ...,  4.3021e-02,
         -7.4890e-02,  5.0366e-02],
        [-1.5895e-02,  7.1004e-03, -1.1723e-02,  ..., -7.0885e-02,
          9.6380e-02,  1.8228e-01],
        [-7.0250e-04, -2.3959e-02, -2.5865e-02,  ...,  2.0745e-01,
          2.9652e-03, -2.7104e-02],
        ...,
        [ 5.1256e-04,  2.3693e-02,  2.6532e-03,  ..., -6.8072e-03,
         -2.1292e-02, -1.4176e-01],
        [ 1.6406e-02,  9.8792e-03, -6.7393e-03,  ..., -1.2343e-01,
         -2.1738e-01, -9.8777e-02],
        [-2.1821e-02,  7.2012e-03, -2.3061e-04,  ...,  9.3033e-02,
         -2.4136e-01,  6.8972e-04]])
Gamma:  tensor([[ 0.0000e+00, -8.7614e-01,  6.8191e-03,  ..., -6.6799e-03,
          1.7081e-02,  1.6559e-03],
        [ 8.7614e-01,  0.0000e+00,  9.3012e-01,  ..., -1.3676e-02,
          1.6960e-03, -6.9348e-04],
        [-6.8191e-03, -9.3012e-01,  0.0000e+00,  ..., -8.2573e-04,
          2.2769e-02, -1.0205e-02],
        ...,
        [ 6.6799e-03

Before rotation:  tensor([[-3.5657e-02,  7.4138e-03,  2.1246e-04,  ..., -8.4869e-03,
         -4.7373e-02,  6.4309e-03],
        [-1.3155e-02,  8.7404e-03, -9.6077e-03,  ..., -1.4078e-01,
          7.7512e-02,  1.0112e-01],
        [ 1.0955e-04, -2.0274e-02, -2.5588e-02,  ...,  1.9500e-01,
          1.4006e-02,  2.6198e-02],
        ...,
        [ 5.9398e-04,  2.1950e-02,  2.1948e-03,  ..., -1.4039e-02,
         -1.2915e-02, -1.5556e-01],
        [ 1.3610e-02,  1.0781e-02, -6.7565e-03,  ..., -1.3019e-01,
         -1.9867e-01, -7.2949e-02],
        [-1.8745e-02,  8.1650e-03,  8.4870e-04,  ...,  4.8235e-02,
         -2.1558e-01, -6.6591e-02]])
Gamma:  tensor([[ 0.0000e+00, -8.6162e-01,  8.5644e-03,  ..., -5.5010e-03,
          1.1124e-02, -1.4720e-03],
        [ 8.6162e-01,  0.0000e+00,  9.5121e-01,  ..., -1.0430e-02,
          1.8391e-03, -2.0486e-03],
        [-8.5644e-03, -9.5121e-01,  0.0000e+00,  ..., -7.9599e-04,
          1.5850e-02, -8.5001e-03],
        ...,
        [ 5.5010e-03

Before rotation:  tensor([[-0.0301,  0.0079,  0.0007,  ..., -0.0597,  0.0136, -0.0172],
        [-0.0078,  0.0096, -0.0072,  ..., -0.1890,  0.0796,  0.0684],
        [ 0.0022, -0.0180, -0.0251,  ...,  0.1510,  0.0043,  0.0506],
        ...,
        [ 0.0029,  0.0217,  0.0016,  ..., -0.0166, -0.0269, -0.1278],
        [ 0.0111,  0.0156, -0.0062,  ..., -0.1186, -0.1683, -0.0705],
        [-0.0165,  0.0074,  0.0033,  ..., -0.0262, -0.1484, -0.1021]])
Gamma:  tensor([[ 0.0000e+00, -8.4758e-01,  6.7641e-03,  ..., -3.5353e-03,
          6.3714e-03, -3.5099e-03],
        [ 8.4758e-01,  0.0000e+00,  9.6801e-01,  ..., -8.3938e-03,
          6.7065e-04, -2.4113e-03],
        [-6.7641e-03, -9.6801e-01,  0.0000e+00,  ...,  1.8757e-04,
          8.7613e-03, -5.2303e-03],
        ...,
        [ 3.5353e-03,  8.3938e-03, -1.8757e-04,  ...,  0.0000e+00,
          4.4332e-01, -4.6980e-01],
        [-6.3714e-03, -6.7065e-04, -8.7613e-03,  ..., -4.4332e-01,
          0.0000e+00, -1.0823e-01],
        [ 3.

Before rotation:  tensor([[-0.0254,  0.0075, -0.0002,  ..., -0.0676,  0.0449, -0.0042],
        [-0.0067,  0.0080, -0.0065,  ..., -0.1841,  0.0894,  0.0735],
        [ 0.0033, -0.0149, -0.0220,  ...,  0.1266, -0.0149,  0.0360],
        ...,
        [ 0.0045,  0.0220,  0.0020,  ..., -0.0154, -0.0380, -0.1042],
        [ 0.0092,  0.0167, -0.0056,  ..., -0.1104, -0.1590, -0.0842],
        [-0.0137,  0.0063,  0.0033,  ..., -0.0481, -0.1130, -0.1018]])
Gamma:  tensor([[ 0.0000e+00, -8.3890e-01,  2.2082e-03,  ..., -1.9516e-03,
          3.7245e-03, -3.5029e-03],
        [ 8.3890e-01,  0.0000e+00,  9.7404e-01,  ..., -6.7327e-03,
         -3.5358e-05, -1.4416e-03],
        [-2.2082e-03, -9.7404e-01,  0.0000e+00,  ...,  1.3060e-03,
          4.8824e-03, -1.9507e-03],
        ...,
        [ 1.9516e-03,  6.7327e-03, -1.3060e-03,  ...,  0.0000e+00,
          4.3362e-01, -4.6397e-01],
        [-3.7245e-03,  3.5358e-05, -4.8824e-03,  ..., -4.3362e-01,
          0.0000e+00, -9.0730e-02],
        [ 3.

Before rotation:  tensor([[-0.0218,  0.0067, -0.0011,  ..., -0.0603,  0.0514,  0.0108],
        [-0.0064,  0.0067, -0.0060,  ..., -0.1665,  0.0929,  0.0747],
        [ 0.0038, -0.0119, -0.0196,  ...,  0.1204, -0.0285,  0.0164],
        ...,
        [ 0.0051,  0.0214,  0.0026,  ..., -0.0143, -0.0382, -0.0982],
        [ 0.0080,  0.0160, -0.0052,  ..., -0.1104, -0.1604, -0.0965],
        [-0.0108,  0.0058,  0.0029,  ..., -0.0391, -0.1086, -0.0977]])
Gamma:  tensor([[ 0.0000e+00, -8.3319e-01, -3.3807e-03,  ..., -1.1186e-03,
          2.2837e-03, -3.0135e-03],
        [ 8.3319e-01,  0.0000e+00,  9.7373e-01,  ..., -5.2289e-03,
          8.9478e-06, -6.3730e-04],
        [ 3.3807e-03, -9.7373e-01,  0.0000e+00,  ...,  1.8220e-03,
          2.9158e-03, -1.9872e-04],
        ...,
        [ 1.1186e-03,  5.2289e-03, -1.8220e-03,  ...,  0.0000e+00,
          4.2941e-01, -4.6569e-01],
        [-2.2837e-03, -8.9478e-06, -2.9158e-03,  ..., -4.2941e-01,
          0.0000e+00, -8.1901e-02],
        [ 3.

Before rotation:  tensor([[-0.0189,  0.0061, -0.0015,  ..., -0.0537,  0.0493,  0.0186],
        [-0.0053,  0.0059, -0.0054,  ..., -0.1522,  0.0900,  0.0655],
        [ 0.0043, -0.0094, -0.0178,  ...,  0.1194, -0.0363,  0.0017],
        ...,
        [ 0.0053,  0.0204,  0.0028,  ..., -0.0141, -0.0323, -0.1008],
        [ 0.0071,  0.0154, -0.0047,  ..., -0.1145, -0.1628, -0.1048],
        [-0.0087,  0.0057,  0.0029,  ..., -0.0237, -0.1163, -0.0977]])
Gamma:  tensor([[ 0.0000e+00, -8.2868e-01, -1.0280e-02,  ..., -6.9563e-04,
          1.4354e-03, -2.4766e-03],
        [ 8.2868e-01,  0.0000e+00,  9.6909e-01,  ..., -3.8997e-03,
          3.4292e-04, -2.4974e-04],
        [ 1.0280e-02, -9.6909e-01,  0.0000e+00,  ...,  1.9670e-03,
          1.6696e-03,  4.7370e-04],
        ...,
        [ 6.9563e-04,  3.8997e-03, -1.9670e-03,  ...,  0.0000e+00,
          4.2305e-01, -4.6431e-01],
        [-1.4354e-03, -3.4292e-04, -1.6696e-03,  ..., -4.2305e-01,
          0.0000e+00, -7.4566e-02],
        [ 2.

Before rotation:  tensor([[-0.0163,  0.0056, -0.0018,  ..., -0.0528,  0.0475,  0.0191],
        [-0.0041,  0.0053, -0.0047,  ..., -0.1455,  0.0837,  0.0498],
        [ 0.0049, -0.0075, -0.0163,  ...,  0.1167, -0.0396, -0.0059],
        ...,
        [ 0.0055,  0.0194,  0.0027,  ..., -0.0147, -0.0258, -0.1043],
        [ 0.0064,  0.0151, -0.0042,  ..., -0.1191, -0.1621, -0.1082],
        [-0.0073,  0.0055,  0.0032,  ..., -0.0144, -0.1236, -0.1021]])
Gamma:  tensor([[ 0.0000e+00, -8.2514e-01, -1.8297e-02,  ..., -4.6243e-04,
          9.8561e-04, -2.0123e-03],
        [ 8.2514e-01,  0.0000e+00,  9.6088e-01,  ..., -2.9027e-03,
          6.1436e-04, -1.7730e-04],
        [ 1.8297e-02, -9.6088e-01,  0.0000e+00,  ...,  1.8792e-03,
          8.3970e-04,  5.6338e-04],
        ...,
        [ 4.6243e-04,  2.9027e-03, -1.8792e-03,  ...,  0.0000e+00,
          4.1296e-01, -4.5685e-01],
        [-9.8561e-04, -6.1436e-04, -8.3970e-04,  ..., -4.1296e-01,
          0.0000e+00, -6.5527e-02],
        [ 2.