In [1]:
from matplotlib import pyplot as plt
%matplotlib inline

## Boilerplate needed for notebooks

In [2]:
run_dir = "../results/Transop_sep-loss/"
current_checkpoint = 400
device_idx = [0]

In [3]:
import sys
import os 
import math
sys.path.append(os.path.dirname(os.getcwd()) + "/src/")

import numpy as np
import torch
from omegaconf import OmegaConf
import omegaconf

from eval.utils import encode_features
from model.model import Model
from model.config import ModelConfig
from experiment import ExperimentConfig
from dataloader.contrastive_dataloader import get_dataloader
from dataloader.utils import get_unaugmented_dataloader

# Set the default device
default_device = torch.device("cuda:0")
# Load config
cfg = omegaconf.OmegaConf.load(run_dir + ".hydra/config.yaml")
cfg.model_cfg.backbone_cfg.load_backbone = None

# Load model
default_model_cfg = ModelConfig()
model = Model.initialize_model(cfg.model_cfg, cfg.train_dataloader_cfg.dataset_cfg.dataset_name, device_idx)
state_dict = torch.load(run_dir + f"checkpoints/checkpoint_epoch{current_checkpoint}.pt", map_location=default_device)
model.load_state_dict(state_dict['model_state'])
# Manually override directory for dataloaders
cfg.train_dataloader_cfg.dataset_cfg.dataset_dir = "../datasets"
cfg.train_dataloader_cfg.batch_size = 128
# Load dataloaders
train_dataset, train_dataloader = get_dataloader(cfg.train_dataloader_cfg)

# Load transport operators
psi = model.contrastive_header.transop_header.transop.get_psi()
backbone = model.backbone
nn_bank = model.contrastive_header.transop_header.nn_memory_bank

Using cache found in /storage/home/hcoda1/0/kfallah3/.cache/torch/hub/pytorch_vision_v0.10.0


In [4]:
import torch.nn as nn
from dataclasses import dataclass

@dataclass
class ExperimentConfig:
    num_epochs: int = 300

    kl_weight = 1.0e-3
    enable_c_l2: bool = False
    c_l2_weight: float = 1.0e-3

@dataclass
class CoeffEncoderConfig:
    variational: bool = True

    feature_dim: float = 512
    hidden_dim: float = 2048
    scale_prior: float = 0.02
    threshold: float = 0.032
    
    enable_c_l2: bool = False
    c_l2_weight: float = 1.0e-3
    
    lr: float = 0.03
    weight_decay: float = 1.0e-5

class VIEncoder(nn.Module):
    def __init__(self, cfg: CoeffEncoderConfig):
        super(VIEncoder, self).__init__()
        self.feat_extract = nn.Sequential(
            nn.Linear(128, cfg.hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(cfg.hidden_dim, cfg.hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(cfg.hidden_dim, cfg.feature_dim),
        )

        self.cfg = cfg
        if cfg.variational:
            self.scale = nn.Linear(cfg.feature_dim, 100)
            self.shift = nn.Linear(cfg.feature_dim, 100)
        else:
            self.pred = nn.Linear(cfg.feature_dim, 100)

    def soft_threshold(self, z: torch.Tensor, lambda_: torch.Tensor) -> torch.Tensor:
        return torch.nn.functional.relu(torch.abs(z) - lambda_) * torch.sign(z)

    def kl_loss(self, log_scale, shift):
        prior_scale = torch.ones_like(log_scale) * self.cfg.scale_prior
        prior_log_scale = torch.log(prior_scale)
        scale = torch.exp(log_scale)
        laplace_kl = ((shift).abs() / prior_scale) + prior_log_scale - log_scale - 1
        laplace_kl += (scale / prior_scale) * (-((shift).abs() / scale)).exp()
        return laplace_kl.sum(dim=-1).mean()

    def reparameterize(self, log_scale, shift, psi):
        # Reparameterize
        noise = torch.rand_like(log_scale) - 0.5
        scale = torch.exp(log_scale)
        eps = -scale * torch.sign(noise) * torch.log((1.0 - 2.0 * torch.abs(noise)).clamp(min=1e-6, max=1e6))
        c = shift + eps

        # Threshold
        c_thresh = self.soft_threshold(eps.detach(), self.cfg.threshold)
        non_zero = torch.nonzero(c_thresh, as_tuple=True)
        c_thresh[non_zero] = shift[non_zero].detach() + c_thresh[non_zero]
        c = c + c_thresh - c.detach()
        return c

    def forward(self, x0, x1, psi):
        z = self.feat_extract(torch.cat((x0, x1), dim=-1))
        if not self.cfg.variational:
            return self.pred(z), torch.tensor(0.)

        log_scale, shift = self.scale(z), self.shift(z)
        log_scale += torch.log(torch.ones_like(log_scale) * self.cfg.scale_prior)

        # Reparameterization
        c = self.reparameterize(log_scale, shift, psi)

        return c, self.kl_loss(log_scale, shift)

In [5]:
from model.public.linear_warmup_cos_anneal import LinearWarmupCosineAnnealingLR
import warnings
warnings.filterwarnings("ignore")

train_dataset, train_dataloader = get_dataloader(cfg.train_dataloader_cfg)

vi_cfg = CoeffEncoderConfig()
encoder = VIEncoder(vi_cfg).cuda()
opt = torch.optim.SGD(encoder.parameters(), lr=vi_cfg.lr, nesterov=True, momentum=0.9, weight_decay=vi_cfg.weight_decay)
iters_per_epoch = len(train_dataloader)
scheduler = LinearWarmupCosineAnnealingLR(opt, warmup_epochs=10 * iters_per_epoch, max_epochs=iters_per_epoch*300, eta_min=1e-4)

In [6]:
transop_loss_save = []
kl_loss_save = []
dw_loss_save = []
c_save = []

for i in range(300):
    for idx, batch in enumerate(train_dataloader):
        curr_iter = i*len(train_dataloader) + idx
        x0, x1 = batch[0][0], batch[0][1]
        x0, x1 = x0.cuda(), x1.cuda()

        with torch.no_grad():
            z0, z1 = backbone(x0), backbone(x1)
        z1 = nn_bank(z1.detach(), update=True).detach()

        z0 = torch.stack(torch.split(z0, 64, dim=-1)).transpose(0, 1).reshape(-1, 64)
        z1 = torch.stack(torch.split(z1, 64, dim=-1)).transpose(0, 1).reshape(-1, 64)

        c, kl_loss = encoder(z0, z1, psi)
        T = torch.matrix_exp(torch.einsum("bm,mpk->bpk", c, psi))
        z1_hat = (T @ z0.unsqueeze(-1)).squeeze(-1)
        transop_loss = torch.nn.functional.mse_loss(z1_hat, z1, reduction="none")

        loss = transop_loss.mean() + 8.0e-3*kl_loss
        if vi_cfg.enable_c_l2:
            l2_reg = (c**2).sum(dim=-1).mean()
            loss += vi_cfg.c_l2_weight * l2_reg
        opt.zero_grad()
        loss.backward()
        opt.step()
        scheduler.step()

        transop_loss_save.append(transop_loss.mean().item())
        kl_loss_save.append(kl_loss.item())
        dw_bw_points = torch.nn.functional.mse_loss(z0, z1, reduction="none").mean(dim=-1)
        dw_loss_save.append((transop_loss.mean(dim=-1) / dw_bw_points).mean().item())
        c_save.append(c.detach().cpu())

        if curr_iter % 500 == 0:
            last_c = torch.cat(c_save[-500:])
            c_nz = torch.count_nonzero(last_c, dim=-1).float().mean()
            c_mag = last_c[last_c.abs() > 0].abs().mean()
            print(f"Iter {curr_iter} -- TO loss: {np.mean(transop_loss_save[-500:]):.3E}," +
                  f" KL loss: {np.mean(kl_loss_save[-500:]):.3E}," +
                  f" dist improve: {np.mean(dw_loss_save[-500:]):.3E}," +
                  f" c nonzero: {c_nz:.3f}," +
                  f" c mag: {c_mag:.3f}")

Iter 0 -- TO loss: 1.447E-01, KL loss: 5.691E+01, dist improve: 1.084E+00, c nonzero: 20.200, c mag: 0.031
Iter 500 -- TO loss: 1.064E-01, KL loss: 1.041E+01, dist improve: 1.083E+00, c nonzero: 20.159, c mag: 0.022
Iter 1000 -- TO loss: 7.846E-02, KL loss: 1.367E+00, dist improve: 1.114E+00, c nonzero: 20.100, c mag: 0.020
Iter 1500 -- TO loss: 7.699E-02, KL loss: 4.222E-01, dist improve: 1.113E+00, c nonzero: 19.892, c mag: 0.020
Iter 2000 -- TO loss: 7.622E-02, KL loss: 3.197E-01, dist improve: 1.109E+00, c nonzero: 19.685, c mag: 0.020
Iter 2500 -- TO loss: 7.562E-02, KL loss: 3.430E-01, dist improve: 1.107E+00, c nonzero: 19.567, c mag: 0.020


KeyboardInterrupt: 

In [31]:
from model.manifold.l1_inference import infer_coefficients

(loss, _, k), c = infer_coefficients(
    z0[:128], z1[:128], psi, 0.05, max_iter=1200, tol=1e-5, num_trials=50, device="cuda:0", lr=1e-2, decay=0.99, c_init=None
)

In [32]:
T = torch.matrix_exp(torch.einsum("bm,mpk->bpk", c, psi))
z1_hat = (T @ z0[:128].unsqueeze(-1)).squeeze(-1)
transop_loss = torch.nn.functional.mse_loss(z1_hat, z1[:128], reduction="none").mean(dim=-1)
dw_bw_points = torch.nn.functional.mse_loss(z0[:128], z1[:128], reduction="none").mean(dim=-1)

print(k)
print(torch.count_nonzero(c, dim=-1).float().mean())
print((transop_loss / dw_bw_points).mean())


810
tensor(61.1953, device='cuda:0')
tensor(0.0561, device='cuda:0')
