In [25]:
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torch

batch_size = 1

torch.manual_seed(0)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

mnist = MNIST('./data', train=True, download=True, transform=transform)
data_loader = DataLoader(mnist, batch_size=batch_size, shuffle=True)

In [36]:
from diff_unet import Unet

img = next(iter(data_loader))[0]
t = torch.randint(0, 100, (batch_size,))

batch_size, channels, height, width = img.shape

unet = Unet(dim=height,
            channels=channels,
            dim_mults=(1,),
            resnet_block_groups=7,
            use_convnext=False)

from torchinfo import summary

summary(unet, [(batch_size, channels, height, width), (batch_size,)], depth=2)

Layer (type:depth-idx)                             Output Shape              Param #
Unet                                               [1, 1, 28, 28]            --
├─Conv2d: 1-1                                      [1, 18, 28, 28]           900
├─Sequential: 1-2                                  [1, 112]                  --
│    └─SinusoidalPositionEmbeddings: 2-1           [1, 28]                   --
│    └─Linear: 2-2                                 [1, 112]                  3,248
│    └─GELU: 2-3                                   [1, 112]                  --
│    └─Linear: 2-4                                 [1, 112]                  12,656
├─ModuleList: 1-3                                  --                        --
│    └─ModuleList: 2-5                             --                        47,376
├─ResBlock: 1-4                                    [1, 28, 28, 28]           --
│    └─Block: 2-6                                  [1, 28, 28, 28]           7,140
│    └─Sequential: 2

In [None]:
import matplotlib.pyplot as plt

out = unet(img, t)
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(img[0, 0, :, :].detach().numpy());
ax[1].imshow(out[0, 0, :, :].detach().numpy());

In [None]:
# Print unet attributes
dir(unet)

In [35]:
for param in unet.parameters():
    print(param.shape)

torch.Size([18, 1, 7, 7])
torch.Size([18])
torch.Size([112, 28])
torch.Size([112])
torch.Size([112, 112])
torch.Size([112])
torch.Size([28, 112])
torch.Size([28])
torch.Size([28, 18, 3, 3])
torch.Size([28])
torch.Size([28])
torch.Size([28])
torch.Size([28, 28, 3, 3])
torch.Size([28])
torch.Size([28])
torch.Size([28])
torch.Size([28, 18, 1, 1])
torch.Size([28])
torch.Size([28, 112])
torch.Size([28])
torch.Size([28, 28, 3, 3])
torch.Size([28])
torch.Size([28])
torch.Size([28])
torch.Size([28, 28, 3, 3])
torch.Size([28])
torch.Size([28])
torch.Size([28])
torch.Size([384, 28, 1, 1])
torch.Size([28, 128, 1, 1])
torch.Size([28])
torch.Size([28])
torch.Size([28])
torch.Size([28])
torch.Size([28])
torch.Size([28, 112])
torch.Size([28])
torch.Size([28, 28, 3, 3])
torch.Size([28])
torch.Size([28])
torch.Size([28])
torch.Size([28, 28, 3, 3])
torch.Size([28])
torch.Size([28])
torch.Size([28])
torch.Size([384, 28, 1, 1])
torch.Size([28, 128, 1, 1])
torch.Size([28])
torch.Size([28])
torch.Size([28])