# Exploring a U-Net

In [4]:
from diffusers import UNet2DModel
import torch

def explore_unet():
    # Load a U-Net model from Hugging Face Diffusers library
    unet = UNet2DModel.from_pretrained("google/ddpm-cifar10-32")

    # Print a summary of the model
    # print("\nModel Summary:\n")
    # print(unet)

    # Explore the model components
    print("\nModel Components:\n")
    for name, module in unet.named_children():
        print(f"{name}: {module.__class__.__name__}")

    # Inspect the encoder and decoder
    print("\nEncoder and Decoder Blocks:\n")
    for idx, block in enumerate(unet.down_blocks):
        print(f"Encoder Block {idx}: {block.__class__.__name__}")
    for idx, block in enumerate(unet.up_blocks):
        print(f"Decoder Block {idx}: {block.__class__.__name__}")

    # Check the attention blocks
    print("\nAttention Blocks:\n")
    if hasattr(unet, 'mid_block'):
        print(f"Mid Block Attention: {unet.mid_block.attentions if hasattr(unet.mid_block, 'attentions') else 'None'}")

    # Optionally, pass a dummy input through the model
    sample_input = torch.randn(1, unet.in_channels, 32, 32)  # Batch size 1, channels, height, width
    timestep = torch.tensor([10])  # Example timestep
    output = unet(sample_input, timestep).sample

    print("\nOutput Shape:", output.shape)

In [5]:
explore_unet()

Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.



Model Components:

conv_in: Conv2d
time_proj: Timesteps
time_embedding: TimestepEmbedding
down_blocks: ModuleList
up_blocks: ModuleList
mid_block: UNetMidBlock2D
conv_norm_out: GroupNorm
conv_act: SiLU
conv_out: Conv2d

Encoder and Decoder Blocks:

Encoder Block 0: DownBlock2D
Encoder Block 1: AttnDownBlock2D
Encoder Block 2: DownBlock2D
Encoder Block 3: DownBlock2D
Decoder Block 0: UpBlock2D
Decoder Block 1: UpBlock2D
Decoder Block 2: AttnUpBlock2D
Decoder Block 3: UpBlock2D

Attention Blocks:

Mid Block Attention: ModuleList(
  (0): Attention(
    (group_norm): GroupNorm(32, 256, eps=1e-06, affine=True)
    (to_q): Linear(in_features=256, out_features=256, bias=True)
    (to_k): Linear(in_features=256, out_features=256, bias=True)
    (to_v): Linear(in_features=256, out_features=256, bias=True)
    (to_out): ModuleList(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): Dropout(p=0.0, inplace=False)
    )
  )
)


  sample_input = torch.randn(1, unet.in_channels, 32, 32)  # Batch size 1, channels, height, width



Output Shape: torch.Size([1, 3, 32, 32])
