In [4]:
from matplotlib import pyplot as plt
%matplotlib inline

## Boilerplate needed for notebooks

In [5]:
run_dir = "../results/TransOpFISTA_09-22-2022_20-06-15/"
current_checkpoint = 5
device_idx = [0]

In [6]:
import sys
import os 
import math
sys.path.append(os.path.dirname(os.getcwd()) + "/src/")

import numpy as np
import torch
from omegaconf import OmegaConf
import omegaconf

from eval.utils import encode_features
from model.model import Model
from model.config import ModelConfig
from experiment import ExperimentConfig
from dataloader.contrastive_dataloader import get_dataloader
from dataloader.utils import get_unaugmented_dataloader

# Set the default device
default_device = torch.device("cuda:0")
# Load config
cfg = omegaconf.OmegaConf.load(run_dir + ".hydra/config.yaml")
cfg.model_cfg.backbone_cfg.load_backbone = None
# Load model
default_model_cfg = ModelConfig()
model = Model.initialize_model(cfg.model_cfg, device_idx)
state_dict = torch.load(run_dir + f"checkpoints/checkpoint_epoch{current_checkpoint}.pt")
model.load_state_dict(state_dict['model_state'])
# Manually override directory for dataloaders
cfg.train_dataloader_cfg.dataset_cfg.dataset_dir = "../datasets"
cfg.train_dataloader_cfg.batch_size = 32
cfg.eval_dataloader_cfg.dataset_cfg.dataset_dir = "../datasets"
# Load dataloaders
train_dataset, train_dataloader = get_dataloader(cfg.train_dataloader_cfg)
eval_dataset, eval_dataloader = get_dataloader(cfg.eval_dataloader_cfg)
unaugmented_train_dataloader = get_unaugmented_dataloader(train_dataloader)
# Get encoding of entire dataset
train_eval_input = encode_features(model, unaugmented_train_dataloader, default_device)
# Load transport operators
psi = model.contrastive_header.transop_header.transop.get_psi()

Using cache found in /home/kion/.cache/torch/hub/pytorch_vision_v0.10.0


In [29]:
unaugmented_train_dataloader.dataset.dataset.data.dtype

dtype('uint8')

In [34]:
from PIL import Image
import torchvision.transforms.functional as F

transform = lambda x: F.adjust_brightness(x, 0.2)

coeff_list = []
z0_list = []
z1_list = []

y_list = []

for idx, batch in enumerate(unaugmented_train_dataloader):
    x, y, idx = batch

    x_tilde = [transform(Image.fromarray(single_x.permute(1, 2, 0).detach().numpy().astype(np.uint8))) for single_x in x]
    x1 = torch.stack([F.to_tensor(single_x) for single_x in x_tilde])

    z0 = model.backbone.backbone_network(x.to(default_device))
    z1 = model.backbone.backbone_network(x1.to(default_device))
    z0_list.append(z0.detach().cpu())
    z1_list.append(z1.detach().cpu())
    y_list.append(y)

z0_list = torch.concat(z0_list)
z1_list = torch.concat(z1_list)
y_list = torch.concat(y_list)

torch.Size([50000, 512])


In [46]:
import torch.nn.functional as F
from model.manifold.transop import TransOp_expm
from model.manifold.l1_inference import infer_coefficients
import os, contextlib

class_to_use = 8
zeta = 0.2

class_idx = y_list == class_to_use
z0_use = z0_list[class_idx]
z1_use = z1_list[class_idx]
z0_train, z1_train = z0_use[:int(len(z0_use)*.8)], z1_use[:int(len(z1_use)*.8)]
z0_test, z1_test = z0_use[int(len(z0_use)*.8):], z1_use[int(len(z1_use)*.8):]

transop = TransOp_expm(10, 512).to(default_device)
opt = torch.optim.SGD(transop.parameters(), lr=1e-2, weight_decay=1e-3)
opt_scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.99)

for i in range(100):
    train_error = []
    for j in range(len(z0_train) // 30):
        z0, z1 = z0_train[j*30:(j+1)*30].to(default_device), z1_train[j*30:(j+1)*30].to(default_device)
        with open(os.devnull, 'w') as devnull:
            with contextlib.redirect_stdout(devnull):
                _, c = infer_coefficients(
                    z0,
                    z1,
                    transop.get_psi(),
                    zeta,
                    max_iter=500,
                    num_trials=1,
                    device=default_device,
                )
        with open(os.devnull, 'w') as devnull:
            with contextlib.redirect_stdout(devnull):
                z1_hat = transop(z0.unsqueeze(-1), c).squeeze(dim=-1)
        
        loss = F.mse_loss(z1_hat, z1)
        loss.backward()
        opt.step()
        opt_scheduler.step()
        train_error.append(loss.item())
        break
    print(np.mean(train_error))
    print(c.count_nonzero(dim=-1).float().mean())




  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  out = torch.matrix_exp(T) @ x


3.8133301734924316
tensor(5.9667, device='cuda:0')


  x1_hat = (torch.matrix_exp(T) @ x0.unsqueeze(-1)).squeeze(-1)
  x1_hat = (torch.matrix_exp(T) @ x0.unsqueeze(-1)).squeeze(-1)


3.41575288772583
tensor(8.7667, device='cuda:0')


KeyboardInterrupt: 

In [4]:
import torchvision.models as models
import torchvision.datasets
import torchvision.transforms.transforms as T
import torch

cifar10 = torchvision.datasets.CIFAR10(
    "../datasets",
    train=True,
    transform=T.Compose(
        [
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(
                mean=[0.50707516, 0.48654887, 0.44091784],
                std=[0.26733429, 0.25643846, 0.27615047],
            ),
        ]
    ),
)
dataloader = torch.utils.data.DataLoader(
    cifar10, batch_size=100, shuffle=False, num_workers=20
)

backbone = models.resnet18(pretrained=True).to('cuda:0')
backbone.fc = torch.nn.Identity()

z_list = []
for idx, batch in enumerate(dataloader):
    x, y = batch
    x = x.to('cuda:0')
    z = backbone(x)

    z_list.append(z.detach().cpu())
z = torch.concat(z_list)
print(z.shape)



torch.Size([50000, 512])


In [9]:
from sklearn.metrics import pairwise_distances
import numpy as np

simm = pairwise_distances(z, z)
sort_sim = np.argsort(simm, axis=-1)[:,1:6]
print(sort_sim.shape)

(50000, 5)
