In [None]:
%load_ext autoreload
%autoreload 2

import sys    
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "7"

import pickle
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
from IPython.display import clear_output
from sklearn.utils import shuffle
from torch import Tensor, LongTensor
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

warnings.filterwarnings("ignore", category=DeprecationWarning) 
    
from constants import ARTIFACTS_DIR, PROJECT_DIR

from ibmd.core import IBMD
from pfgmpp.core import PFGMPP
from pfgmpp.training.losses import EDMLoss
from pfgmpp.utils.data import InfiniteDataLoader

sys.path.append(f"{PROJECT_DIR}/src/pfgmpp_original")

ckpt_name = "cifar10_ncsnpp_D_2048_conditional.pkl"
EDM_CHEKPOINT_PATH = os.path.join(ARTIFACTS_DIR, "pfgmpp_original", "checkpoints", ckpt_name)

DS_NAME = "cifar"
RUN_DIR = os.path.join(ARTIFACTS_DIR, "ibmd_on_images", DS_NAME)
IBMD_DIR = os.path.join(RUN_DIR, ckpt_name.split(".")[0])

os.makedirs(IBMD_DIR, exist_ok=True)

# Utils

In [None]:
# Poblem Params (CIFAR-10)
IMG_CHANNELS =  3
IMG_RESOLUTION = 32
DATA_DIM = IMG_CHANNELS * IMG_RESOLUTION**2
N_CLASSES = 10

SIGMA_MIN = 0.002
SIGMA_MAX = 80.0
POWER = 11

# Loss Params
SIGMA_PRIOR_MODE = "log_normal"
SIGMA_DATA = 0.5 

# Sampling params
N_GENS = 4
LABELS = torch.tensor([0] * N_GENS).long()

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def generate(
    *,
    pfgmpp: PFGMPP,
    net: nn.Module,
    sample_size: int,
    num_steps: int=32,
    label: LongTensor=None,
    seed: int=0,
):
    def drift(*, x, t, label):
        return (x - net(x=x, t=t, label=label)) / t
    return pfgmpp.sample(
        drift=drift,
        sample_size=sample_size,
        num_steps=num_steps,
        label=label,
        device=DEVICE,
        seed=seed,
    ).cpu().numpy().reshape(sample_size, IMG_CHANNELS, IMG_RESOLUTION, IMG_RESOLUTION)

def visualize_ibmd(
    *,
    ibmd: IBMD,
    sample_size: int,
    label: LongTensor=None,
    seed: int=0,
):
    gens = ibmd.sample(
        sample_size=sample_size,
        label=label,
        seed=seed,
    ).reshape(-1, IMG_CHANNELS, IMG_RESOLUTION, IMG_RESOLUTION).cpu().numpy()

    ncols = 4
    nrows = (sample_size + ncols - 1) // ncols
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, squeeze=False)
    fig.set_figheight(nrows * 2)
    fig.set_figwidth(ncols * 2)
    
    for ax_idx in range(nrows * ncols):
        i, j = ax_idx // ncols, ax_idx % ncols
        ax = axes[i, j]
        if ax_idx < len(gens):
            ax.axis("off")
            ax.imshow(gens[ax_idx].transpose(1, 2, 0));
        else:
            fig.delaxes(ax)
    plt.show()

In [None]:
class NetWrapper(nn.Module):
    def __init__(self, net: nn.Module):
        super().__init__()
        self.net = net

    def forward(self, x: Tensor, t: Tensor, label: LongTensor=None):
        x = x.reshape(-1, IMG_CHANNELS, IMG_RESOLUTION, IMG_RESOLUTION)
        class_labels=F.one_hot(label, num_classes=N_CLASSES)
        
        out = self.net(x=x, sigma=t, class_labels=class_labels)

        return out.reshape(-1, DATA_DIM)

# PFGM teacher init

In [None]:
pfgmpp = PFGMPP(
    data_dim=DATA_DIM,
    sigma_min=SIGMA_MIN,
    sigma_max=SIGMA_MAX,
    D=2**POWER,
)
with open(EDM_CHEKPOINT_PATH, "rb") as f:
    net = pickle.load(f)['ema'].to(DEVICE);
pfgmpp_cond = NetWrapper(net)
for param in pfgmpp_cond.parameters():
    param.requires_grad = True

In [None]:
gens = generate(pfgmpp=pfgmpp, net=pfgmpp_cond, sample_size=N_GENS, label=LABELS.to(DEVICE), seed=0)
plt.figure(figsize=(2, 2))
plt.axis("off")
plt.imshow(gens[0].transpose(1, 2, 0));

In [None]:
# IBMD init

In [None]:
BATCH_SIZE = 32
INNER_PROBLEM_ITERS = 5

In [None]:
ibmd = IBMD(
    teacher_dynamic=pfgmpp,
    teacher_net=pfgmpp_cond,
    teacher_loss_fn=EDMLoss(pfgmpp=pfgmpp),
    student_net_optimizer_config={"lr": 5e-5},
    student_data_estimator_net_config={"lr": 5e-5},
    n_classes=N_CLASSES,
    ema_decay=0.99,
)

In [None]:
visualize_ibmd(ibmd=ibmd, sample_size=N_GENS, label=LABELS.to(DEVICE), seed=0)

In [None]:
n_epochs = 500
sample_every = 50

ibmd.load(os.path.join(IBMD_DIR, f"{int(POWER)}.pt"))

for _ in range(n_epochs):
    clear_output()
    visualize_ibmd(ibmd=ibmd, sample_size=N_GENS, label=LABELS.to(DEVICE), seed=0)

    ibmd.train(
        batch_size=BATCH_SIZE,
        n_iters=sample_every,
        inner_problem_iters=INNER_PROBLEM_ITERS,
        log_every=10,
        # save_path=os.path.join(IBMD_DIR, f"{int(POWER)}.pt"),
    )