In [59]:
import sys

sys.path.append("/home/shinzato/GitHub/power-law-research/power_law_research")
import numpy as np
from data_modules import FashionMNISTDataModule

import pytorch_lightning as pl
import torch


class FashionMNISTDataModuleWhite(FashionMNISTDataModule):
    def __init__(self, batch_size=128):
        super().__init__(self, batch_size)

    def train_dataloader(self):
        self.zca_matrix = zca_whitening_matrix(self.train_datasets.data.view(-1, 784))
        train_white = np.dot(self.zca_matrix, self.train_datasets.data.view(784, -1).numpy())
        self.train_white = (
            torch.from_numpy(train_white.astype(np.float32)).clone().view(60000, 28, 28)
        )
        self.dummy_label = np.zeros(60000)
        return torch.utils.data.DataLoader(
            dataset=(self.train_white, self.train_white), batch_size=self.batch_size, shuffle=True, num_workers=4
        )


def zca_whitening_matrix(X):
    """
    Function to compute ZCA whitening matrix (aka Mahalanobis whitening).
    INPUT:  X: [N x M] matrix.
        Rows: Variables
        Columns: Observations
    OUTPUT: ZCAMatrix: [M x M] matrix
    """
    # Covariance matrix [column-wise variables]: Sigma = (X-mu)' * (X-mu) / N
    sigma = np.cov(X.T, rowvar=True)  # [M x M]
    # Singular Value Decomposition. X = U * np.diag(S) * V
    U, S, V = np.linalg.svd(sigma)
    # U: [M x M] eigenvectors of sigma.
    # S: [M x 1] eigenvalues of sigma.
    # V: [M x M] transpose of U
    # Whitening constant: prevents division by zero
    epsilon = 1e-8
    # ZCA Whitening matrix: U * Lambda * U'
    ZCAMatrix = np.dot(U, np.dot(np.diag(1.0 / np.sqrt(S + epsilon)), U.T))  # [M x M]
    return ZCAMatrix


In [3]:
data_module = FashionMNISTDataModule(batch_size=128)

In [10]:
data_module.train_dataloader.datasets


AttributeError: 'function' object has no attribute 'datasets'

In [61]:
from models import LitVanillaVAE

data_module_white = FashionMNISTDataModuleWhite(batch_size=128)
model_white = LitVanillaVAE(n_vis=784, n_hid=100, optimizer_name="sgd")
trainer = pl.Trainer(
    max_epochs=10,
    devices=1,
    accelerator="gpu",
    logger=pl.loggers.TensorBoardLogger(save_dir=f"../log/playground/white"),
)
trainer.fit(model_white, data_module_white)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name         | Type       | Params
--------------------------------------------
0 | encoder_mean | Sequential | 78.5 K
1 | encoder_var  | Sequential | 78.5 K
2 | decoder      | Sequential | 79.2 K
--------------------------------------------
236 K     Trainable params
0         Non-trainable params
236 K     Total params
0.945     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

In [65]:
from models import LitVanillaVAE

data_module = FashionMNISTDataModule(batch_size=128)
model = LitVanillaVAE(n_vis=784, n_hid=100, optimizer_name="sgd")
trainer = pl.Trainer(
    max_epochs=10,
    devices=1,
    accelerator="gpu",
    logger=pl.loggers.TensorBoardLogger(save_dir=f"../log/playground/white"),
)
trainer.fit(model, data_module)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: ../log/playground/white/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name         | Type       | Params
--------------------------------------------
0 | encoder_mean | Sequential | 78.5 K
1 | encoder_var  | Sequential | 78.5 K
2 | decoder      | Sequential | 79.2 K
--------------------------------------------
236 K     Trainable params
0         Non-trainable params
236 K     Total params
0.945     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

In [64]:
%load_ext tensorboard
%tensorboard --logdir ../log/playground

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6007 (pid 1259375), started 0:00:04 ago. (Use '!kill 1259375' to kill it.)