In [None]:

import os
import csv
import time
import copy
import json
import pickle
import random
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
import dnnlib
import numpy as np
import torch
from torch import autocast
from torch_utils import distributed as dist
from torch_utils import training_stats
from torch_utils import misc
from solver_utils import get_schedule
from models.ldm.util import instantiate_from_config
from torch_utils.download_util import check_file_by_key


def create_model(dataset_name=None, model_path=None, guidance_type=None, guidance_rate=None, device=None, is_second_stage=False, num_repeats=4):
    print("Function Parameters:")
    for key, value in locals().items():
        print(f"{key}: {value}")

    net_student = None
    if is_second_stage: # for second-stage distillation
        assert model_path is not None
        dist.print0(f'Loading the second-stage teacher model from "{model_path}"...')
        with dnnlib.util.open_url(model_path, verbose=(dist.get_rank() == 0)) as f:
            net = pickle.load(f)['model'].to(device)
        model_source = 'edm' if dataset_name in ['cifar10', 'ffhq', 'afhqv2', 'imagenet64'] else 'ldm'
        return net, model_source

    if model_path is None:
        model_path, _ = check_file_by_key(dataset_name)
    dist.print0(f'Loading the pre-trained diffusion model from "{model_path}"...')
    if dataset_name in ['cifar10', 'ffhq', 'afhqv2', 'imagenet64']:         # models from EDM
        with dnnlib.util.open_url(model_path, verbose=(dist.get_rank() == 0)) as f:
            net_temp = pickle.load(f)['ema'].to(device)
        network_kwargs = dnnlib.EasyDict()
        if dataset_name in ['cifar10']:
            network_kwargs.update(model_type='SongUNet', embedding_type='positional', encoder_type='standard', decoder_type='standard')
            network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels=128, channel_mult=[2,2,2])
            network_kwargs.update(dropout=0.13, use_fp16=False)
            network_kwargs.augment_dim = 9
            interface_kwargs = dict(img_resolution=32, img_channels=3, label_dim=0)
        elif dataset_name in ['ffhq', 'afhqv2']:
            network_kwargs.update(model_type='SongUNet', embedding_type='positional', encoder_type='standard', decoder_type='standard')
            network_kwargs.update(channel_mult_noise=1, resample_filter=[1,1], model_channels=128, channel_mult=[1,2,2,2])
            network_kwargs.update(dropout=0.05, use_fp16=False)
            network_kwargs.augment_dim = 9
            interface_kwargs = dict(img_resolution=64, img_channels=3, label_dim=0)
        else:
            network_kwargs.update(model_type='DhariwalUNet', model_channels=192, channel_mult=[1,2,3,4])
            interface_kwargs = dict(img_resolution=64, img_channels=3, label_dim=1000)
            
        network_kwargs.class_name = 'models.networks_edm.EDMPrecond'
        net = dnnlib.util.construct_class_by_name(**network_kwargs, **interface_kwargs) # subclass of torch.nn.Module
        net.to(device)
        net.load_state_dict(net_temp.state_dict(), strict=False)
        key_names = list(net.model.state_dict().keys())

        # Save to a text file
        with open("model_keys.txt", "w") as f:
            for key in key_names:
                f.write(key + "\n")
        network_kwargs.update(repeat=num_repeats)
        net_student = dnnlib.util.construct_class_by_name(**network_kwargs, **interface_kwargs)
        net_student.to(device)
        net_student.load_state_dict(net_temp.state_dict(), strict=False)

        del net_temp

        net.sigma_min = 0.006
        net.sigma_max = 80.0
        net_student.sigma_min = 0.006
        net_student.sigma_max = 80.0
        model_source = 'edm'
    elif dataset_name in ['lsun_bedroom_ldm', 'ffhq_ldm', 'ms_coco']:   # models from LDM
        from omegaconf import OmegaConf
        from models.networks_edm import CFGPrecond
        if dataset_name in ['lsun_bedroom_ldm']:
            assert guidance_type == 'uncond'
            config = OmegaConf.load('./models/ldm/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml')
            net = load_ldm_model(config, model_path)
            net = CFGPrecond(net, img_resolution=64, img_channels=3, guidance_rate=1., guidance_type='uncond', label_dim=0).to(device)
            net.sigma_min = 0.006
        elif dataset_name in ['ffhq_ldm']:
            assert guidance_type == 'uncond'
            config = OmegaConf.load('./models/ldm/configs/latent-diffusion/ffhq-ldm-vq-4.yaml')
            net = load_ldm_model(config, model_path)
            net = CFGPrecond(net, img_resolution=64, img_channels=3, guidance_rate=1., guidance_type='uncond', label_dim=0).to(device)
            net.sigma_min = 0.006
        elif dataset_name in ['ms_coco']:
            assert guidance_type == 'cfg'
            config = OmegaConf.load('./models/ldm/configs/stable-diffusion/v1-inference.yaml')
            net = load_ldm_model(config, model_path)
            net = CFGPrecond(net, img_resolution=64, img_channels=4, guidance_rate=guidance_rate, guidance_type='classifier-free', label_dim=True).to(device)
            net.sigma_min = 0.1
        model_source = 'ldm'
    else:
        raise ValueError(f"Unsupported dataset_name: {dataset_name}")
    
    return net, net_student, model_source



net_copy, net, model_source = create_model("cifar10", None, None, 0.0, "cpu", False)


Function Parameters:
dataset_name: cifar10
model_path: None
guidance_type: None
guidance_rate: 0.0
device: cpu
is_second_stage: False
num_repeats: 4
Model already exists: ../training/src/cifar10/edm-cifar10-32x32-uncond-vp.pkl
Loading the pre-trained diffusion model from "../training/src/cifar10/edm-cifar10-32x32-uncond-vp.pkl"...
model.dec.32x32_aux_conv LASTLASYER
calling tilingg
calling tilingg


In [19]:
net_copy.eval(),net.eval()

(EDMPrecond(
   (model): SongUNet(
     (map_noise): PositionalEmbedding()
     (map_augment): Linear()
     (map_layer0): Linear()
     (map_layer1): Linear()
     (map_step): PositionalEmbedding()
     (map_step_layer0): Linear()
     (map_step_layer1): Linear()
     (enc): ModuleDict(
       (32x32_conv): Conv2d()
       (32x32_block0): UNetBlock(
         (norm0): GroupNorm()
         (conv0): Conv2d()
         (affine): Linear()
         (norm1): GroupNorm()
         (conv1): Conv2d()
         (affine_step): Linear()
         (skip): Conv2d()
       )
       (32x32_block1): UNetBlock(
         (norm0): GroupNorm()
         (conv0): Conv2d()
         (affine): Linear()
         (norm1): GroupNorm()
         (conv1): Conv2d()
         (affine_step): Linear()
       )
       (32x32_block2): UNetBlock(
         (norm0): GroupNorm()
         (conv0): Conv2d()
         (affine): Linear()
         (norm1): GroupNorm()
         (conv1): Conv2d()
         (affine_step): Linear()
       )
 

In [20]:
import torch
def get_denoised(net, x, t, class_labels=None, condition=None, unconditional_condition=None, step_condition=None):
    if hasattr(net, 'guidance_type'):       # models from LDM and Stable Diffusion
        denoised = net(x, t, class_labels=class_labels, condition=condition, unconditional_condition=unconditional_condition, step_condition=step_condition)
    elif hasattr(net, 'module') and hasattr(net.module, 'guidance_type'):       # for training: models from LDM and Stable Diffusion
        denoised = net(x, t, class_labels=class_labels, condition=condition, unconditional_condition=unconditional_condition, step_condition=step_condition)
    else:
        print("calling")
        denoised = net(x, t, class_labels=class_labels, step_condition=step_condition)
    return denoised

# Model config
img_resolution = 32
img_channels = 3
label_dim = 10


# Dummy input
B = 2
x = torch.randn(B, img_channels, img_resolution, img_resolution)
sigma = torch.full((B,), 1.0)  # constant noise level
class_labels = torch.zeros(B, label_dim)  # zero dummy labels
# step_condition = torch.full((B,), 7.0)    # constant step for conditioning
step_condition = None


In [21]:
# Run dummy forward
out_stu_1 = get_denoised(net, x, sigma, class_labels=class_labels, step_condition=step_condition)

print("Output shape:", out_stu_1.shape)
# Run dummy forward
out_teacher_1 = get_denoised(net_copy, x, sigma, class_labels=class_labels, step_condition=step_condition)
out_teacher_2 = get_denoised(net_copy, x, sigma, class_labels=class_labels, step_condition=step_condition)
out_teacher_3 = get_denoised(net_copy, x, sigma, class_labels=class_labels, step_condition=step_condition)

print("Output shape:", out_teacher_1.shape)

calling


Output shape: torch.Size([2, 12, 32, 32])
calling
calling
calling
Output shape: torch.Size([2, 3, 32, 32])


In [22]:
(out_teacher_1-out_teacher_2).abs().mean()

tensor(0., grad_fn=<MeanBackward0>)

In [23]:
print((out_stu_1[:,0:3] - out_teacher_1).mean())
print((out_stu_1[:,3:6] - out_teacher_1).mean())
print((out_stu_1[:,6:9] - out_teacher_1).mean())
print((out_stu_1[:,9:12] - out_teacher_1).mean())

tensor(0., grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<MeanBackward0>)
tensor(0., grad_fn=<MeanBackward0>)


In [15]:
model = net.model
model_orig = net_copy.model


# Get last layer prefix (e.g., 'dec.32x32_aux_conv')
last_layer_prefix = model.last_layer.replace('model.', '')

# Collect weights and biases
for name, param in model.named_parameters():
    if name.startswith(last_layer_prefix):
        orig = dict(model_orig.named_parameters())[name]
        tiled = param
        print(f"\nAssigned:\n  orig -> {name} ({orig.shape})\n  tiled -> {name} ({tiled.shape})")
        # break



Assigned:
  orig -> dec.32x32_aux_conv.weight (torch.Size([3, 256, 3, 3]))
  tiled -> dec.32x32_aux_conv.weight (torch.Size([3, 256, 3, 3]))

Assigned:
  orig -> dec.32x32_aux_conv.bias (torch.Size([3]))
  tiled -> dec.32x32_aux_conv.bias (torch.Size([3]))


In [5]:
(tiled[0:3]-orig).mean()

tensor(0., grad_fn=<MeanBackward0>)

In [None]:
import torch
a = torch.rand(3,256,32,32)
b = a.repeat(4,1,1,1)
b.shape

In [None]:
(b[4:7]-a).mean()

In [None]:
print(b.shape)         # torch.Size([12, 256, 32, 32])
print((b[0] == a).all())  # Should be True




UnpicklingError: Weights only load failed. In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
Please file an issue with the following so that we can make `weights_only=True` compatible with your use case: WeightsUnpickler error: Unsupported operand 149

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.