In [1]:
from diffusers import UNet2DConditionModel

In [2]:
# 1. Pull the architecture settings (but not the weights) from SD v1.5:
# model_id = "runwayml/stable-diffusion-v1-5"
model_id = "timbrooks/instruct-pix2pix"
config = UNet2DConditionModel.load_config(
    model_id, subfolder="unet"
)
# 2. Build a fresh UNet model from that config:
original_unet = UNet2DConditionModel.from_config(config)

In [9]:
config

{'_class_name': 'UNet2DConditionModel',
 '_diffusers_version': '0.12.0.dev0',
 'act_fn': 'silu',
 'attention_head_dim': 8,
 'block_out_channels': [320, 640, 1280, 1280],
 'center_input_sample': False,
 'class_embed_type': None,
 'cross_attention_dim': 768,
 'down_block_types': ['CrossAttnDownBlock2D',
  'CrossAttnDownBlock2D',
  'CrossAttnDownBlock2D',
  'DownBlock2D'],
 'downsample_padding': 1,
 'dual_cross_attention': False,
 'flip_sin_to_cos': True,
 'freq_shift': 0,
 'in_channels': 8,
 'layers_per_block': 2,
 'mid_block_scale_factor': 1,
 'mid_block_type': 'UNetMidBlock2DCrossAttn',
 'norm_eps': 1e-05,
 'norm_num_groups': 32,
 'num_class_embeds': None,
 'only_cross_attention': False,
 'out_channels': 4,
 'resnet_time_scale_shift': 'default',
 'sample_size': 64,
 'up_block_types': ['UpBlock2D',
  'CrossAttnUpBlock2D',
  'CrossAttnUpBlock2D',
  'CrossAttnUpBlock2D'],
 'upcast_attention': False,
 'use_linear_projection': False}

In [5]:
# Show original model size
type(original_unet)

diffusers.models.unets.unet_2d_condition.UNet2DConditionModel

In [7]:
# Calculate the total number of parameters
total_params = sum(p.numel() for p in original_unet.parameters())

# Format it nicely with commas
print(f"Total number of parameters: {total_params:,}")

Total number of parameters: 859,532,484


In [13]:
custom_config = config.copy()
custom_config['cross_attention_dim'] = 256
custom_config['block_out_channels'] = (128, 256, 512, 512)

custom_unet = UNet2DConditionModel.from_config(custom_config)

# Show custom model size
custom_total_params = sum(p.numel() for p in custom_unet.parameters())
print(f"Custom model size: {custom_total_params:,}")

# Show the difference
print(f"Difference: {total_params - custom_total_params:,}")

Custom model size: 137,128,836
Difference: 722,403,648
