### Imports and helpers

---

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
sys.path.insert(0, os.path.abspath('..'))

import torch
from omegaconf import OmegaConf
import wandb
import lightning as L
from torchvision.transforms.functional import to_pil_image
import numpy as np

import ipywidgets as widgets
import matplotlib.pyplot as plt


from helpers.dataset import get_dataloaders
from helpers.diffusion import get_diffusion
from helpers.model import WeightDiffusionTransformer
from helpers.pl_module import WeightDenoiser
from helpers.texture_encoding import GramEncoder, CLIP, VisionTransformer
from helpers.texture_loss import TextureLoss

from helpers.generator import Generator as nca_weight_generator


In [3]:
L.seed_everything(42);

Seed set to 42


In [4]:
def download_ckpt(model_id) -> str:
    artifact = wandb.Api().artifact(f'ludekcizinsky/hypernca/{model_id}')
    artifact.download()

def get_pl_module(model_id, artifact_dir, username='cizinsky', model_type='newest'):

    # download ckpt from wandb based on model_id
    download_ckpt(model_id)

    # Get first path to checkpoint
    model_dir = f'{artifact_dir}/{model_id}'
    path_to_ckpt = os.path.join(model_dir, "model.ckpt")

    # Load config
    ckpt = torch.load(path_to_ckpt, map_location='cpu', weights_only=False)
    cfg = ckpt['hyper_parameters']
    if model_type == 'baseline':
        default_cfg = OmegaConf.load('../configs/train.yaml')
        cfg = OmegaConf.merge(default_cfg, cfg)

    # Adjust the config
    OmegaConf.set_struct(cfg, False)
    cfg.data.nca_weights_path = f'/scratch/izar/{username}/hypernca/pretrained_nca/Flickr+DTD_NCA'
    cfg.model.type = model_type
    if model_type == 'baseline':
        cfg.model.use_cross_attention = False

    # Load all the other components
    _, val_dataloader, normaliser = get_dataloaders(cfg)

    diffusion = get_diffusion(cfg)
    model = WeightDiffusionTransformer(cfg)

    if "Gram" in cfg.texture_encoder._target_:
        encoder = GramEncoder(hidden_size=cfg.texture_encoder.hidden_size, normalize=cfg.texture_encoder.normalize)
    elif "CLIP" in cfg.texture_encoder._target_:
        encoder = CLIP()
    elif "VisionTransformer" in cfg.texture_encoder._target_:
        encoder = VisionTransformer(
            pretrained=cfg.texture_encoder.pretrained,
            trainable=cfg.texture_encoder.trainable,
            num_hidden_layers=cfg.texture_encoder.num_hidden_layers,
            patch_size=cfg.texture_encoder.patch_size,
            hidden_dim=cfg.texture_encoder.hidden_dim,
            num_layers=cfg.texture_encoder.num_layers,
            num_heads=cfg.texture_encoder.num_heads,
            mlp_dim=cfg.texture_encoder.mlp_dim,
            image_size=cfg.texture_encoder.image_size,
        )

    pl_module = WeightDenoiser(cfg=cfg, model=model, diffusion=diffusion, normaliser=normaliser,encoder=encoder).to('cuda')
    pl_module.load_state_dict(ckpt['state_dict'], strict=True)

    return pl_module, encoder, val_dataloader

### Post training testing of the models

---

In [5]:
baseline_id = 'model-sdnv2k1z:v0'
compare_id = 'model-4kkhlc27:v3'
artifact_dir = '/home/cizinsky/x-to-nif/notebooks/artifacts'

In [6]:
baseline, base_encoder, _ = get_pl_module(baseline_id, artifact_dir, model_type='baseline')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Downloading large artifact model-sdnv2k1z:v0, 391.15MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:1.0


In [7]:
comparison_model, comp_encoder, val_dataloader = get_pl_module(compare_id, artifact_dir, model_type='newest')

[34m[1mwandb[0m: Downloading large artifact model-4kkhlc27:v3, 910.55MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:2.8


Number of parameters: 151.28M


In [8]:
weights2image_gen = nca_weight_generator()
texture_loss = TextureLoss(loss_type="OT", device="cuda")



In [9]:
# Load the validation data
batch = next(iter(val_dataloader))
cond_images = batch['image']
weights = batch['weights']
text_names = batch['texture']
base_enc_cond_images = base_encoder(cond_images.to('cuda'))
comp_enc_cond_images = comp_encoder(cond_images.to('cuda'))

# Sample the generated weights
base_model_nca_weights = baseline.sample(num_steps=50, cond=base_enc_cond_images, seed=42)
comp_model_nca_weights = comparison_model.sample(num_steps=50, cond=comp_enc_cond_images, seed=42)

In [10]:
# Generate images from the ground truth weights
weights2image_gen.generate(weights[:100])
gt_gen_images = weights2image_gen.generated_images
weights2image_gen.generated_images = []

# Generate images from the baseline model's predicted weights
weights2image_gen.generate(base_model_nca_weights[:100])
base_pred_gen_images = weights2image_gen.generated_images
weights2image_gen.generated_images = []

# Generate images from the comparison model's predicted weights
weights2image_gen.generate(comp_model_nca_weights[:100])
comp_pred_gen_images = weights2image_gen.generated_images
weights2image_gen.generated_images = []

In [11]:
# Compute the OT distance between the ground truth and the predicted images
gt_ot_losses = texture_loss(cond_images.to('cuda')[:100], torch.stack(gt_gen_images))
base_ot_losses = texture_loss(cond_images.to('cuda')[:100], torch.stack(base_pred_gen_images))
comp_ot_losses = texture_loss(cond_images.to('cuda')[:100], torch.stack(comp_pred_gen_images))

# Compute the median OT distance for each model
gt_ot_med = np.median([ot_loss.item() for ot_loss in gt_ot_losses])
base_ot_med = np.median([ot_loss.item() for ot_loss in base_ot_losses])
comp_ot_med = np.median([ot_loss.item() for ot_loss in comp_ot_losses])

print(f'GT OT median: {gt_ot_med:.2f}, Base OT median: {base_ot_med:.2f}, Comp OT median: {comp_ot_med:.2f}')


GT OT median: 4.92, Base OT median: 8.12, Comp OT median: 8.05


In [12]:
def show_image(idx: int):
    # Transform tensors to PIL images
    cond_image = to_pil_image(cond_images[idx])
    gt_gen_image = to_pil_image(gt_gen_images[idx])
    base_pred_gen_image = to_pil_image(base_pred_gen_images[idx])
    comp_pred_gen_image = to_pil_image(comp_pred_gen_images[idx])

    # Get the OT distance for each image
    gt_ot_loss = gt_ot_losses[idx].item()
    base_ot_loss = base_ot_losses[idx].item()
    comp_ot_loss = comp_ot_losses[idx].item()

    fig, axs = plt.subplots(1, 4, figsize=(16, 4))
    axs[0].imshow(cond_image)
    axs[0].set_title('Condition')
    axs[0].axis('off')

    axs[1].imshow(gt_gen_image)
    axs[1].set_title(f'NCA, OT distance: {gt_ot_loss:.2f}')
    axs[1].axis('off')

    axs[2].imshow(base_pred_gen_image)
    axs[2].set_title(f'Baseline, OT distance: {base_ot_loss:.2f}')
    axs[2].axis('off')

    axs[3].imshow(comp_pred_gen_image)
    axs[3].set_title(f'Comparison, OT distance: {comp_ot_loss:.2f}')
    axs[3].axis('off')

    # Set figure title
    fig.suptitle(f'Baseline = Gram Encoder, Comparison = Vision Transformer, XA')

    plt.show()

widgets.interact(show_image, idx=widgets.IntSlider(min=0, max=len(gt_ot_losses)-1, step=1, value=0));

interactive(children=(IntSlider(value=0, description='idx', max=99), Output()), _dom_classes=('widget-interact…