In [52]:
import sys
import os
SCRIPT_DIR = os.path.dirname(os.path.abspath("."))
sys.path.append(SCRIPT_DIR)
import helper
from utils import data_utils
import matplotlib.pyplot as plt
from utils import training_utils
from utils import data_utils
import torch
from model import models
import json
import os
from model import lightning_models
import math
from torchvision import datasets
import analysis_utils
import numpy as np

In [53]:
n_views = 8
batch_size = 2

In [54]:
config = helper.Config("../simulations",default_config_file="../default_configs/default_config_cifar10.ini")
if "CIFAR" in config.DATA["dataset"] or "MNIST" in config.DATA["dataset"]:
    prune_backbone = True
else:
    prune_backbone = False
ssl_model = lightning_models.CLAP(backbone_name = config.SSL["backbone"],
                                  backbone_out_dim = config.SSL["backbone_out_dim"],
                                  prune = prune_backbone,
                                  use_projection_head=config.SSL["use_projection_head"],
                                  proj_dim = config.SSL["proj_dim"],
                                  proj_out_dim = config.SSL["proj_out_dim"],
                                  loss_name= config.SSL["loss_function"],
                                  optim_name = config.SSL["optimizer"],
                                  lr = 1.0,
                                  scheduler_name = config.SSL["lr_scheduler"],
                                  momentum = config.SSL["momentum"],
                                  weight_decay = config.SSL["weight_decay"],
                                  eta = config.SSL["lars_eta"],
                                  warmup_epochs = config.SSL["warmup_epochs"],
                                  n_epochs = config.SSL["n_epochs"],
                                  n_views = config.DATA["n_views"],
                                  batch_size = config.SSL["batch_size"],
                                  lw0 = config.SSL["lw0"],
                                  lw1 = config.SSL["lw1"],
                                  lw2 = config.SSL["lw2"],
                                  rs = config.SSL["rs"],
                                  pot_pow = config.SSL["pot_pow"])

Loading default settings...
[SemiSL]does not exist in the config file
[TL]does not exist in the config file
[SemiSL]does not exist in the config file
[TL]does not exist in the config file
[INFO]
num_nodes = 1
gpus_per_node = 1
cpus_per_gpu = 8
prefetch_factor = 2
precision = 16-mixed
fix_random_seed = True
strategy = auto
if_profile = True

[DATA]
dataset = CIFAR10
n_views = 8
augmentations = ['RandomResizedCrop', 'GaussianBlur', 'RandomGrayscale', 'ColorJitter', 'RandomHorizontalFlip']
augmentation_package = albumentations
crop_size = 32
crop_min_scale = 0.08
crop_max_scale = 1.0
hflip_prob = 0.5
blur_kernel_size = 1
blur_prob = 0.5
grayscale_prob = 0.2
jitter_brightness = 0.8
jitter_contrast = 0.8
jitter_saturation = 0.8
jitter_hue = 0.2
jitter_prob = 0.8

[SSL]
backbone = resnet18
backbone_out_dim = 2048
use_projection_head = True
proj_dim = 2048
proj_out_dim = 128
optimizer = LARS
lr = 0.8
lr_scale = linear
lr_scheduler = cosine-warmup
grad_accumulation_steps = 1
momentum = 0.0
wei

In [55]:
test_dataset = datasets.CIFAR10(root="../datasets/cifar10", train=False,download=False)
aug_ops = config.DATA["augmentations"] + ["ToTensor","Normalize"]
config.DATA["mean4norm"] = [0.491,0.482,0.446]
config.DATA["std4norm"] = [0.247,0.243,0.261]  
transform = data_utils.get_transform(aug_ops,aug_params=config.DATA,aug_pkg="torchvision")
test_dataset = data_utils.WrappedDataset(test_dataset,transform,n_views = n_views,aug_pkg="torchvision")
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size = batch_size,shuffle=True,drop_last=True,
                                              num_workers = 4,pin_memory=True)

In [56]:
imgs,labels = next(iter(test_loader))
'''
img_list, label_list = [],[]
for i_view in range(2):
    for j_img in range(2):
        img_list.append(imgs[i_view][j_img])
        #label_list.append(classes[labels[i_view][j_img]])
data_utils.show_images(img_list,2,2,label_list)
'''

'\nimg_list, label_list = [],[]\nfor i_view in range(2):\n    for j_img in range(2):\n        img_list.append(imgs[i_view][j_img])\n        #label_list.append(classes[labels[i_view][j_img]])\ndata_utils.show_images(img_list,2,2,label_list)\n'

In [57]:
device = torch.device("cuda:0")  # Use GPU 0
ssl_model.backbone.remove_projection_head()
backbone = ssl_model.backbone.to(device)

In [58]:
def sample_submanifolds(backbone,test_dataset,count=100):
    test_loader = torch.utils.data.DataLoader(test_dataset,batch_size = batch_size,shuffle=True,drop_last=True,
                                              num_workers = 4,pin_memory=True)
    center_vecs = []
    cov_mats = []
    all_labels = []
    i = 0
    for i,data in enumerate(test_loader):
        imgs,labels = data
        if i > count - 1:
            break
        all_labels.append(labels[0].detach().cpu())
        imgs = torch.cat(imgs,dim=0).to(device)
        with torch.no_grad():
            preds = backbone(imgs)
            preds = torch.reshape(preds,(n_views,batch_size,preds.shape[-1]))
            centers = torch.mean(preds,dim=0)
            preds -= centers
            cov = torch.matmul(torch.permute(preds,(1,2,0)), torch.permute(preds,(1,0,2)))/n_views # size B*O*O
            # save as CPU tensor to save GPU memory
            center_vecs.append(centers.detach().cpu())
            cov_mats.append(cov.detach().cpu())
    center_vecs = torch.cat(center_vecs,dim=0)
    cov_mats = torch.cat(cov_mats,axis=0)
    all_labels = torch.cat(all_labels,dim=-1)
    return center_vecs,cov_mats,all_labels

In [59]:
def sample_manifolds(backbone,test_dataset):
    test_loader = torch.utils.data.DataLoader(test_dataset,batch_size = batch_size,shuffle=True,drop_last=True,
                                              num_workers = 4,pin_memory=True)
    center_vecs = []
    cov_mats = []
    all_labels = []
    for i,data in enumerate(test_loader):
        imgs,labels = data
        if i > count - 1:
            break
        all_labels.append(labels[0].detach().cpu())
        imgs = torch.cat(imgs,dim=0).to(device)
        with torch.no_grad():
            preds = backbone(imgs)
            preds = torch.reshape(preds,(n_views,batch_size,preds.shape[-1]))
            centers = torch.mean(preds,dim=0)
            preds -= centers
            cov = torch.matmul(torch.permute(preds,(1,2,0)), torch.permute(preds,(1,0,2)))/n_views # size B*O*O
            # save as CPU tensor to save GPU memory
            center_vecs.append(centers.detach().cpu())
            cov_mats.append(cov.detach().cpu())
    center_vecs = torch.cat(center_vecs,dim=0)
    cov_mats = torch.cat(cov_mats,axis=0)
    all_labels = torch.cat(all_labels,dim=-1)
    return center_vecs,cov_mats,all_labels

In [60]:
center_vecs,cov_mats,all_labels = sample_submanifolds(backbone,test_dataset,count=1)

In [61]:
dist_matrix = analysis_utils.get_dist(center_vecs)


In [62]:
align_matrix = analysis_utils.get_cov_alignments(cov_mats)

In [63]:
print(align_matrix.shape)
print(dist_matrix.shape)

(2, 2)
(2, 2)


In [86]:
corr_norm = torch.linalg.matrix_norm(cov_mats,keepdim=True)
normalized_corr = cov_mats/corr_norm
corr_pow = torch.stack([torch.matrix_power(normalized_corr[i], 80) for i in range(cov_mats.shape[0])])
b0 = torch.rand(2048)
eigens = torch.matmul(corr_pow,b0) # size = B*O
eigens = eigens/(torch.norm(eigens,dim=1,keepdim=True) + 1e-9) 

In [85]:
print(eigens)

tensor([[ 4.1512e-04,  1.2667e-04, -8.1966e-05,  ...,  3.6015e-04,
         -3.3840e-05,  1.8224e-04],
        [ 5.5162e-14, -1.8867e-14,  1.2064e-14,  ...,  7.3666e-14,
         -1.4255e-14, -2.8650e-14]])


In [87]:
print(eigens)

tensor([[ 0.0303,  0.0092, -0.0060,  ...,  0.0263, -0.0025,  0.0133],
        [-0.0235,  0.0080, -0.0051,  ..., -0.0314,  0.0061,  0.0122]])


In [71]:
print(eigens)

tensor([[-2.3906e-02, -7.2947e-03,  4.7201e-03,  ..., -2.0740e-02,
          1.9487e-03, -1.0494e-02],
        [-8.9599e-05,  3.0645e-05, -1.9596e-05,  ..., -1.1965e-04,
          2.3154e-05,  4.6536e-05]])


In [51]:
print(eigens)

tensor([[ 2.3126e-03,  2.3965e-02, -9.7355e-03,  ...,  1.7474e-02,
         -1.6764e-02,  2.9104e-02],
        [-9.7606e-05,  2.0198e-04, -3.1222e-04,  ...,  4.0034e-04,
         -4.3767e-05,  4.0307e-04],
        [ 2.3270e-03, -4.6885e-03, -9.4833e-03,  ..., -4.6659e-03,
          4.7270e-02, -1.2521e-02],
        [ 2.6684e-09, -1.8261e-08,  2.2064e-08,  ..., -1.4754e-08,
          8.0365e-09, -3.2576e-08]])


In [92]:
def power_iteration(matrix, num_iterations=100, epsilon=1e-6):
    """
    Finds the principal eigenvector of a matrix using the power iteration method.

    Args:
        matrix (torch.Tensor): The matrix for which to find the principal eigenvector. 
                               Shape: (n, n)
        num_iterations (int): Number of iterations for convergence.
        epsilon (float): Convergence tolerance.

    Returns:
        torch.Tensor: The principal eigenvector.
    """
    # Ensure the matrix is square
    n, m = matrix.shape
    if n != m:
        raise ValueError("Matrix must be square for this method.")

    # Initialize a random vector
    vec = torch.rand(n, device=matrix.device)
    vec = vec / vec.norm()  # Normalize the initial vector

    for _ in range(num_iterations):
        # Multiply matrix by the vector
        next_vec = torch.matmul(matrix, vec)

        # Normalize the vector
        next_vec = next_vec / next_vec.norm()

        # Check for convergence
        if torch.norm(next_vec - vec) < epsilon:
            break

        vec = next_vec
    return vec

In [93]:
power_iteration(cov_mats[1])

10


tensor([-0.0235,  0.0080, -0.0051,  ..., -0.0314,  0.0061,  0.0122])