In [4]:
import torch
import torchvista
import numpy as np
import torch.nn as nn
import denoising_diffusion_pytorch as ddp
from functools import partial
!pip install torchinfo
from torchinfo import summary

from edm2.training.networks_edm2 import Precond
from edm2.training.networks_edm2 import UNet as EDM2_UNet

import end_to_end_phantom_QPAT.utils.networks as e2eQPAT_networks
import utility_functions as uf
from epoch_steps import *
from nn_modules.time_conditioned_residual_unet import TimeConditionedResUNet
from nn_modules.DiT import DiT
from nn_modules.swin_unet import SwinTransformerSys



In [19]:
# arguments (from uf.get_config() in run_model.py)
model_name = 'UNet_diffusion_ablation'  # options: 'UNet_e2eQPAT', 'EDM2', 'UNet_wl_pos_emb', 'UNet_diffusion_ablation', 'Swin_UNet', 'DiT'
#model_name = 'UNet_e2eQPAT'  # options: 'UNet_e2eQPAT', 'EDM2', 'UNet_wl_pos_emb', 'UNet_diffusion_ablation', 'Swin_UNet', 'DiT'
image_size = 288
channels = 1
predict_fluence = True
attention = False
use_torchsummary = True
use_torchvista = False # torchvista currently does not support some layers
col_names = ["input_size", "output_size", "num_params", "kernel_size"]

In [None]:
logging.basicConfig(level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S')
device = torch.device('cpu')
logging.info(f'using device: {device}')

# ==================== Data ====================

match model_name:
    case 'UNet_e2eQPAT' | 'UNet_wl_pos_emb' | 'UNet_diffusion_ablation' | 'Swin_UNet':
        example_input = torch.ones((8, 1, image_size, image_size))
    case 'DDIM' | 'DiT':
        raise NotImplementedError
    case 'EDM2':
        example_input = torch.ones((8, 2, image_size, image_size))
        x_cond = torch.ones((8, 1, image_size, image_size))
        t = torch.ones((8, 1))
        wavelengths_one_hot = torch.ones((8, 1000))

# ==================== Model ====================
channels = 1
out_channels = channels * 2 if predict_fluence else channels
match model_name:
    case 'UNet_e2eQPAT':
        model = e2eQPAT_networks.RegressionUNet(
            in_channels=channels, 
            out_channels=out_channels,
            initial_filter_size=64, 
            kernel_size=3
        )
        if use_torchsummary:
            summary(model, input_size=(8, channels, image_size, image_size), device='cpu', verbose=1, col_names=col_names)
        if use_torchvista:
            torchvista.trace_model(model, example_input)
    case 'UNet_wl_pos_emb' | 'UNet_diffusion_ablation':
        model = EDM2_UNet(
            img_resolution=image_size,
            img_channels_in=channels,
            img_channels_out=out_channels,
            label_dim=0,
            model_channels=64,
            attn_resolutions=[16, 8] if attention else [],
            noise_emb=False,
            num_blocks=1,
            channel_mult=[1,2,4,8,16],
        )
        if use_torchsummary:
            summary(model, input_size=(32, channels, image_size, image_size), device='cpu', verbose=1, col_names=col_names)
        if use_torchvista:
            torchvista.trace_model(model, example_input)
    case 'Swin_UNet':
        model = SwinTransformerSys(
            img_size=image_size[0], patch_size=4, in_chans=channels, num_classes=out_channels,
            embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24],
            window_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
            drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
            norm_layer=nn.LayerNorm, ape=False, patch_norm=False,
            final_upsample="expand_first"
        )
        uf.remove_softmax(model)
    case 'DDIM':
        model = ddp.Unet(
            dim=32, channels=out_channels, out_dim=out_channels,
            self_condition=True, image_condition=True, 
            image_condition_channels=channels, use_attn=attention,
            full_attn=False, flash_attn=False
        )
        #model = TimeConditionedResUNet(
        #    dim_in=out_channels, dim_out=out_channels, dim_first_layer=64,
        #    kernel_size=3, theta_pos_emb=10000, self_condition=args.self_condition,
        #    image_condition=True, dim_image_condition=channels
        #)
        diffusion = ddp.GaussianDiffusion(
            # objecive='pred_v' predicts the velocity field, objective='pred_noise' predicts the noise
            model, image_size=image_size, timesteps=1000,
            sampling_timesteps=100, objective='pred_v', auto_normalize=False,
        )
    case 'DiT':
        # parameters depth=12, hidden_size=384, and num_heads=6 are the same as DiT-S/8.
        # with an image size of 256 and patch size of 16, we have the 
        # same number of patches as ViT from an image is worth 16x16 words
        #if image_size[0] % 16 != 0:
        #    raise ValueError('image size must be divisible by 16 for DiT model')
        #patch_size = image_size[0] // 16
        patch_size = 4
        model = DiT(
            dim_in=out_channels, dim_out=out_channels, input_size=image_size, 
            depth=12, hidden_size=384, patch_size=patch_size, num_heads=6,
            self_condition=True, image_condition=True
        )
        diffusion = ddp.GaussianDiffusion(
            # objecive='pred_v' predicts the velocity field, objective='pred_noise' predicts the noise
            model, image_size=image_size, timesteps=1000,
            sampling_timesteps=100, objective='pred_v', auto_normalize=False,
        )
    case 'EDM2':
        attn_resolutions = [16, 8] if attention else []
        in_channels = out_channels+1 # plus 1 for conditional information
        loss_fn = EDM2Loss(P_mean=-0.8, P_std=1.6, sigma_data=0.5)
        model = Precond(
            img_resolution=image_size, img_channels_in=in_channels, img_channels_out=out_channels, #img_channels_in=in_channels, img_channels_out=out_channels,
            label_dim=1000, model_channels=64, attn_resolutions=attn_resolutions, 
            use_fp16=False, sigma_data=0.5
        )
        model.unet.forward = partial(model.unet.forward, noise_labels=t.flatten(), class_labels=wavelengths_one_hot)
        if use_torchsummary:
            summary(model.unet, input_size=((8, in_channels, image_size, image_size)), device='cpu', verbose=1, col_names=col_names)
        if use_torchvista:
            torchvista.trace_model(model.unet, torch.cat([example_input, x_cond], dim=1), collapse_modules_after_depth=0)

INFO:root:using device: cpu


Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Kernel Shape
UNet                                     [32, 1, 288, 288]         [32, 2, 288, 288]         1                         --
├─ModuleDict: 1-1                        --                        --                        --                        --
│    └─MPConv: 2-1                       [32, 2, 288, 288]         [32, 64, 288, 288]        1,152                     --
│    └─Block: 2-2                        [32, 64, 288, 288]        [32, 64, 288, 288]        1                         --
│    │    └─MPConv: 3-1                  [32, 64, 288, 288]        [32, 64, 288, 288]        36,864                    --
│    │    └─MPConv: 3-2                  [32, 64, 288, 288]        [32, 64, 288, 288]        36,864                    --
│    └─Block: 2-3                        [32, 64, 288, 288]        [32, 64, 288, 288]        1                         --
│    │    └─MP