In [1]:
import sys
import os
SCRIPT_DIR = os.path.dirname(os.path.abspath("."))
sys.path.append(SCRIPT_DIR)
import helper
from utils import data_utils
import matplotlib.pyplot as plt
from utils import training_utils
from utils import data_utils
import torch
from model import models
import json
import os
from model import lightning_models
import math
from torchvision import datasets
import analysis_utils
import numpy as np

  check_for_updates()


In [2]:
n_views = 100
batch_size = 8
model_dir = "../simulations"

In [3]:
config = helper.Config(model_dir,default_config_file="../default_configs/default_config_cifar10.ini")
if "CIFAR" in config.DATA["dataset"] or "MNIST" in config.DATA["dataset"]:
    prune_backbone = True
else:
    prune_backbone = False
ssl_model = lightning_models.CLAP(backbone_name = config.SSL["backbone"],
                                  prune = prune_backbone,
                                  use_projection_head=config.SSL["use_projection_head"],
                                  proj_dim = config.SSL["proj_dim"],
                                  proj_out_dim = config.SSL["proj_out_dim"],
                                  loss_name= config.SSL["loss_function"],
                                  optim_name = config.SSL["optimizer"],
                                  lr = 1.0,
                                  scheduler_name = config.SSL["lr_scheduler"],
                                  momentum = config.SSL["momentum"],
                                  weight_decay = config.SSL["weight_decay"],
                                  eta = config.SSL["lars_eta"],
                                  warmup_epochs = config.SSL["warmup_epochs"],
                                  n_epochs = config.SSL["n_epochs"],
                                  n_views = config.DATA["n_views"],
                                  batch_size = config.SSL["batch_size"],
                                  lw0 = config.SSL["lw0"],
                                  lw1 = config.SSL["lw1"],
                                  lw2 = config.SSL["lw2"],
                                  rs = config.SSL["rs"],
                                  pot_pow = config.SSL["pot_pow"])

Loading default settings...
[SemiSL]does not exist in the config file
[TL]does not exist in the config file
[SemiSL]does not exist in the config file
[TL]does not exist in the config file
[INFO]
num_nodes = 1
gpus_per_node = 2
cpus_per_gpu = 12
prefetch_factor = 2
precision = 16-mixed
fix_random_seed = True
strategy = ddp
if_profile = False

[DATA]
dataset = CIFAR10
n_views = 16
augmentations = ['RandomResizedCrop', 'GaussianBlur', 'RandomGrayscale', 'ColorJitter', 'RandomHorizontalFlip']
augmentation_package = albumentations
crop_size = 32
crop_min_scale = 0.08
crop_max_scale = 1.0
hflip_prob = 0.5
blur_kernel_size = 3
blur_prob = 0.5
grayscale_prob = 0.2
jitter_brightness = 0.8
jitter_contrast = 0.8
jitter_saturation = 0.8
jitter_hue = 0.2
jitter_prob = 0.8

[SSL]
backbone = resnet18
use_projection_head = True
proj_dim = 2048
proj_out_dim = 256
optimizer = LARS
lr = 7.2
lr_scale = linear
lr_scheduler = cosine-warmup
grad_accumulation_steps = 1
momentum = 0.0
weight_decay = 0.0001
lar

In [4]:
ssl_train_loader,ssl_test_loader,ssl_val_loader = data_utils.get_dataloader(config.DATA,batch_size = 8,
                                                                            num_workers = config.INFO["cpus_per_gpu"],
                                                                            standardized_to_imagenet=False,
                                                                            augment_val_set = True,
                                                                            prefetch_factor=config.INFO["prefetch_factor"],
                                                                            aug_pkg = "torchvision")

In [5]:
imgs,labels = next(iter(test_loader))
'''
img_list, label_list = [],[]
for i_view in range(2):
    for j_img in range(2):
        img_list.append(imgs[i_view][j_img])
        #label_list.append(classes[labels[i_view][j_img]])
data_utils.show_images(img_list,2,2,label_list)
'''

'\nimg_list, label_list = [],[]\nfor i_view in range(2):\n    for j_img in range(2):\n        img_list.append(imgs[i_view][j_img])\n        #label_list.append(classes[labels[i_view][j_img]])\ndata_utils.show_images(img_list,2,2,label_list)\n'

In [6]:
device = torch.device("cuda:0")  # Use GPU 0
trained_filename = os.path.join(model_dir,"ssl",'best_val.ckpt')
if os.path.isfile(trained_filename):
    print(f'Found pretrained model at {trained_filename}, loading...')
    ssl_model = lightning_models.CLAP.load_from_checkpoint(trained_filename)
backbone = ssl_model.backbone.to(device)

Found pretrained model at ../simulations/ssl/best_val.ckpt, loading...


In [144]:
def sample_embedding(net,data_loader,count=1000):
    embedding_vecs = []
    i = 0
    for i,data in enumerate(data_loader):
        imgs,labels = data
        if i > count - 1:
            break
        all_labels.append(labels[0].detach().cpu())
        imgs = torch.cat(imgs,dim=0).to(device)
        with torch.no_grad():
            preds = net(imgs)
            preds = torch.reshape(preds,(n_views,batch_size,preds.shape[-1]))
            # save as CPU tensor to save GPU memory
            embedding_vecs.append(preds.detach().cpu())
    embedding_vecs = torch.cat(center_vecs,dim=1)
    # preds is [V,B*count,O] dimesional matrix
    com = torch.mean(preds,dim=(0,1))
    # make the center of mass of pres locate at the origin
    preds -= com
    # normalize
    preds = torch.nn.functional.normalize(preds,dim=-1)
    # centers.shape = [B*count,O] for B*ws ellipsoids
    centers = torch.mean(preds,dim=0)
    return centers

In [145]:
center_vecs,eigen_vecs,traces,all_labels = sample_submanifolds(backbone,test_dataset,count=100)

In [146]:
avg = torch.mean(centers,dim=0)**2
cov = torch.matmul(torch.permute(centers,(1,2,0)), torch.permute(centers,(1,0,2)))/centers.shape[1]-1.0) # size B*O*O


In [None]:
ckpt_format = os.path.join(model_dir,"ssl",'ssl-epoch=XXX.ckpt')
epochs = []
q1 = []
q2 = []
for i in range(1,90,10):
    epoch = i - 1
    ckpt_file = ckpt_format.replace("XXX",str(epoch))
    device = torch.device("cuda:0")  # Use GPU 0
    if os.path.isfile(ckpt_file):
        print(f'Found pretrained model at {trained_filename}, loading...')
        ssl_model = lightning_models.CLAP.load_from_checkpoint(trained_filename)
    backbone = ssl_model.backbone.to(device)
    train_loader,test_loader,val_loader = data_utils.get_dataloader(config.DATA,batch_size = 8,
                                                                    num_workers = config.INFO["cpus_per_gpu"],
                                                                    standardized_to_imagenet=False,
                                                                    augment_val_set = True,
                                                                    prefetch_factor=config.INFO["prefetch_factor"],
                                                                    aug_pkg = "torchvision")
    centers = sample_embedding(net,val_loader,count=1000)
    avg = torch.mean(centers,dim=0)
    cov = torch.matmul(torch.permute(centers,(1,2,0)), torch.permute(centers,(1,0,2)))/centers.shape[1]-1.0) # size B*O*O
    _q1 = torch.sqrt(avg*avg + 1e-12)
    _q2 = torch.sqrt((cov - torch.eye(proj_out_dim))**2 / proj_out_dim  + 1e-12)
    q1.append(_q1)
    q2.append(_q2)