In [1]:
# Enable inline plotting for matplotlib
%matplotlib inline

import os
import torch
import torchvision
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from einops import rearrange
from sklearn.metrics import pairwise_distances
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
import argparse  # We won't use command-line parsing but keep for reference

# Import your custom modules; make sure your PYTHONPATH is set correctly so that these modules are found.
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from models import DiT_models, get_2d_sincos_pos_embed
from datasets import MET 


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def imshow_tensor(img_tensor, title=""):
    """
    Display a tensor as an image.
    Assumes the tensor is normalized with mean=0.5 and std=0.5.
    """
    # Unnormalize
    img_tensor = img_tensor * 0.5 + 0.5
    npimg = img_tensor.permute(1, 2, 0).cpu().numpy()
    plt.figure(figsize=(4,4))
    plt.imshow(np.clip(npimg, 0, 1))
    plt.title(title)
    plt.axis("off")
    plt.show()


def center_crop_arr(pil_image, image_size):
    """
    Center crop using a method inspired by ADM.
    """
    while min(*pil_image.size) >= 2 * image_size:
        pil_image = pil_image.resize(
            tuple(x // 2 for x in pil_image.size), resample=Image.BOX
        )

    scale = image_size / min(*pil_image.size)
    pil_image = pil_image.resize(
        tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
    )

    arr = np.array(pil_image)
    crop_y = (arr.shape[0] - image_size) // 2
    crop_x = (arr.shape[1] - image_size) // 2
    return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])


def find_permutation(distance_matrix):
    """
    Greedy algorithm to find the permutation order based on the distance matrix.
    """
    sort_list = []
    for m in range(distance_matrix.shape[1]):
        order = distance_matrix[:, 0].argmin()
        sort_list.append(order)
        distance_matrix = distance_matrix[:, 1:]
        distance_matrix[order, :] = 2024
    return sort_list


In [None]:
def main(args):
    # Set up PyTorch
    torch.manual_seed(args.seed)
    torch.set_grad_enabled(False)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Create a template for visualization (not used further in this example)
    template = np.zeros((6, 6))
    for i in range(6):
        for j in range(6):
            template[i, j] = 18 * i + j
    template = np.concatenate((template, template, template), axis=0)
    template = np.concatenate((template, template, template), axis=1)
    
    # Load model:
    model = DiT_models[args.model](
        input_size=args.image_size,
    ).to(device)
    print("Loading model from:", args.ckpt)
    ckpt_path = args.ckpt 
    model_dict = model.state_dict()
    state_dict = torch.load(ckpt_path, weights_only=False)
    model_state_dict = state_dict['model']
    pretrained_dict = {k: v for k, v in model_state_dict.items() if k in model_dict}
    model.load_state_dict(pretrained_dict, strict=False)

    print("Model keys (first 10):", list(model_dict.keys())[:10])
    print("Checkpoint keys:", list(model_state_dict.keys()))
    
    # Define the transformation for the dataset.
    transform = transforms.Compose([
       transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, 192)),
       transforms.ToTensor(),
       transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
    ])
    
    # Set the model to training mode so batchnorm behaves (needed for batch size 1).
    model.train() 
    
    # Create the diffusion process.
    diffusion = create_diffusion(str(args.num_sampling_steps))
    
    # Choose dataset based on argument.
    if args.dataset == "met":
        dataset = MET(args.data_path, 'test')
    elif args.dataset == "imagenet":
        dataset = ImageFolder(args.data_path, transform=transform)
    
    loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        pin_memory=False,
        drop_last=True
    )
    
    # Create time embeddings.
    time_emb = torch.tensor(get_2d_sincos_pos_embed(8, 3)).unsqueeze(0).float().to(device)
    time_emb_noise = torch.tensor(get_2d_sincos_pos_embed(8, 12)).unsqueeze(0).float().to(device)
    time_emb_noise = torch.randn_like(time_emb_noise).repeat(1, 1, 1)
    model_kwargs = None
    
    abs_results = []
    for x in loader:
        # For ImageFolder datasets, unpack the tuple.
        if args.dataset == 'imagenet':
            x, _ = x
        x = x.to(device)
        imshow_tensor(x[0], title="Original Image")
        
        # Optionally crop patches if requested.
        if args.dataset == 'imagenet' and args.crop:
            centercrop = transforms.CenterCrop((64, 64))
            patchs = rearrange(x, 'b c (p1 h1) (p2 w1) -> b c (p1 p2) h1 w1', 
                               p1=3, p2=3, h1=96, w1=96)
            patchs = centercrop(patchs)
            x = rearrange(patchs, 'b c (p1 p2) h1 w1 -> b c (p1 h1) (p2 w1)', 
                          p1=3, p2=3, h1=64, w1=64)
        
        # Shuffle the patches
        indices = np.random.permutation(9)
        print("Shuffle indices:", indices)
        # Rearrange image into patches
        x = rearrange(x, 'b c (p1 h1) (p2 w1) -> b c (p1 p2) h1 w1', 
                      p1=3, p2=3, h1=args.image_size//3, w1=args.image_size//3)
        
        # Display patches before permutation
        patches = [x[0, :, i, :, :] for i in range(9)]
        grid = torch.stack(patches)
        save_image(grid, "debug_patches_before.png", nrow=3, normalize=True)
        plt.figure(figsize=(4, 4))
        plt.imshow(torchvision.utils.make_grid(grid, nrow=3, normalize=True).permute(1, 2, 0).cpu().numpy())
        plt.title("Patches Before Permutation")
        plt.axis("off")
        plt.show()
        
        # Apply the permutation to the patches
        x = x[:, :, indices, :, :]
        patches = [x[0, :, i, :, :] for i in range(9)]
        grid = torch.stack(patches)
        save_image(grid, "debug_patches_after.png", nrow=3, normalize=True)
        plt.figure(figsize=(4, 4))
        plt.imshow(torchvision.utils.make_grid(grid, nrow=3, normalize=True).permute(1, 2, 0).cpu().numpy())
        plt.title("Patches After Permutation")
        plt.axis("off")
        plt.show()
        
        # Reconstruct the scrambled image.
        x = rearrange(x, 'b c (p1 p2) h1 w1 -> b c (p1 h1) (p2 w1)', 
                      p1=3, p2=3, h1=args.image_size//3, w1=args.image_size//3)
        imshow_tensor(x[0], title="Final Scrambled Image")
        print("Scrambled image shape:", x.shape)
        
        # Generate samples using the diffusion process.
        samples = diffusion.p_sample_loop(
            model.forward, x, time_emb_noise.shape, time_emb_noise, 
            clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device
        )
        print("Generated samples shape:", samples.shape)
        
        # Process and compare each sample.
        for sample, img in zip(samples, x):
            sample = rearrange(sample, '(p1 h1 p2 w1) d -> (p1 p2) (h1 w1) d', 
                                 p1=3, p2=3, h1=args.image_size//48, w1=args.image_size//48)
            sample = sample.mean(1)
            dist = pairwise_distances(sample.cpu().numpy(), time_emb[0].cpu().numpy(), metric='manhattan')
            order = find_permutation(dist)
            pred = np.asarray(order).argsort()
            print("Predicted order:", pred)
            abs_results.append(int((pred == indices).all()))
         
        # Report accuracy on this batch.
        acc = np.asarray(abs_results).sum() / len(abs_results) if abs_results else 0
        print("Test result on", len(abs_results), "samples:", acc)
        # For demonstration, process only one batch.
        break


In [5]:
# Instead of using argparse, we can set up our arguments directly.
from types import SimpleNamespace

args = SimpleNamespace(
    model="JPDVT",
    dataset="imagenet",
    data_path="val",  # This is the subfolder inside the base directory.
    crop=False,
    image_size=192,
    num_sampling_steps=250,
    seed=0,
    ckpt="/cluster/home/muhamhz/JPDVT/image_model/results/009-imagenet-JPDVT-crop/checkpoints/2850000.pt"
)

# Construct the full data path.
base_data_path = "/cluster/home/muhamhz/data/imagenet/"
args.data_path = os.path.join(base_data_path, args.data_path)
print("Using data path:", args.data_path)
print("Using checkpoint:", args.ckpt)

# Run the main function.
main(args)


Using data path: /cluster/home/muhamhz/data/imagenet/val
Using checkpoint: /cluster/home/muhamhz/JPDVT/image_model/results/009-imagenet-JPDVT-crop/checkpoints/2850000.pt
Loading model from: /cluster/home/muhamhz/JPDVT/image_model/results/009-imagenet-JPDVT-crop/checkpoints/2850000.pt


UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL argparse.Namespace was not an allowed global by default. Please use `torch.serialization.add_safe_globals([Namespace])` or the `torch.serialization.safe_globals([Namespace])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.