In [1]:
import os 
# CUDA VISIBLE DEVICE 
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from monai import transforms
from monai.apps import DecathlonDataset
from monai.config import print_config
from monai.data import DataLoader
from monai.utils import first, set_determinism
from torch.cuda.amp import GradScaler, autocast
from torch.nn import L1Loss
from tqdm import tqdm

from generative.inferers import LatentDiffusionInferer
from generative.losses import PatchAdversarialLoss, PerceptualLoss
from generative.networks.nets import AutoencoderKL, DiffusionModelUNet, PatchDiscriminator
from generative.networks.schedulers import DDPMScheduler

print_config()

  from .autonotebook import tqdm as notebook_tqdm


MONAI version: 1.3.0
Numpy version: 1.26.4
Pytorch version: 2.2.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 865972f7a791bf7b42efbcd87c8402bd865b329e
MONAI __file__: /home/anaconda3/envs/m3d/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.2.1
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
scipy version: 1.13.0
Pillow version: 10.3.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.17.1+cu121
tqdm version: 4.66.2
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 7.0.0
pandas version: 2.2.2
einops version: 0.8.0
transformers version: 4.39.1
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For 

In [3]:
import logging
import os
import re
import time
from typing import Dict, List, Literal, Optional, Tuple

import torch


class StateDictAdapter:
    """
    StateDictAdapter for adapting the state dict of a model to a checkpoint state dict.

    This class will iterate over all keys in the checkpoint state dict and filter them by a list of regex keys.
    For each matching key, the class will adapt the checkpoint state dict to the model state dict.
    Depending on the target size, the class will add missing blocks or cut the block.
    When adding missing blocks, the class will use a strategy to fill the missing blocks: either adding zeros or normal random values.

    Example:

    ```
    adapter = StateDictAdapter()
    new_state_dict = adapter(
        model_state_dict=model.state_dict(),
        checkpoint_state_dict=state_dict,
        regex_keys=[
            r"class_embedding.linear_1.weight",
            r"conv_in.weight",
            r"(down_blocks|up_blocks)\.\d+\.attentions\.\d+\.transformer_blocks\.\d+\.attn\d+\.(to_k|to_v)\.weight",
            r"mid_block\.attentions\.\d+\.transformer_blocks\.\d+\.attn\d+\.(to_k|to_v)\.weight"
        ]
    )
    ```

    Args:
        model_state_dict (Dict[str, torch.Tensor]): The model state dict.
        checkpoint_state_dict (Dict[str, torch.Tensor]): The checkpoint state dict.
        regex_keys (Optional[List[str]]): A list of regex keys to adapt the checkpoint state dict. Defaults to None.
            Passing a list of regex will drastically reduce the latency.
            If None, all keys in the checkpoint state dict will be adapted.
        strategy (Literal["zeros", "normal"], optional): The strategy to fill the missing blocks. Defaults to "normal".

    """

    def _create_block(
        self,
        shape: List[int],
        strategy: Literal["zeros", "normal"],
        input: torch.Tensor = None,
    ):
        if strategy == "zeros":
            return torch.zeros(shape)
        elif strategy == "normal":
            if input is not None:
                mean = input.mean().item()
                std = input.std().item()
                return torch.randn(shape) * std + mean
            else:
                return torch.randn(shape)
        else:
            raise ValueError(f"Unknown strategy {strategy}")

    def __call__(
        self,
        model_state_dict: Dict[str, torch.Tensor],
        checkpoint_state_dict: Dict[str, torch.Tensor],
        regex_keys: Optional[List[str]] = None,
        strategy: Literal["zeros", "normal"] = "normal",
    ):
        start = time.perf_counter()
        # if no regex keys are provided, we use all keys in the model state dict
        if regex_keys is None:
            regex_keys = list(model_state_dict.keys())

        # iterate over all keys in the checkpoint state dict
        for checkpoint_key in list(checkpoint_state_dict.keys()):
            # iterate over all regex keys
            for regex_key in regex_keys:
                if re.match(regex_key, checkpoint_key):
                    dst_shape = model_state_dict[checkpoint_key].shape
                    src_shape = checkpoint_state_dict[checkpoint_key].shape

                    ## Sizes adapter
                    # if length of shapes are different, we need to unsqueeze or squeeze the tensor
                    if len(dst_shape) != len(src_shape):
                        # in the case [a] vs [a, b] -> unsqueeze [a, 1]
                        if len(src_shape) == 1:
                            checkpoint_state_dict[checkpoint_key] = (
                                checkpoint_state_dict[checkpoint_key].unsqueeze(1)
                            )
                            logging.info(
                                f"Unsqueeze {checkpoint_key}: {src_shape} -> {checkpoint_state_dict[checkpoint_key].shape}"
                            )
                        # in the case [a, b] vs [a] -> squeeze [a]
                        elif len(dst_shape) == 1:
                            checkpoint_state_dict[checkpoint_key] = (
                                checkpoint_state_dict[checkpoint_key][:, 0]
                            )
                            logging.info(
                                f"Squeeze {checkpoint_key}: {src_shape} -> {checkpoint_state_dict[checkpoint_key].shape}"
                            )
                        # in the other cases, raise an error
                        else:
                            raise ValueError(
                                f"Shapes of {checkpoint_key} are different: {dst_shape} != {src_shape}"
                            )

                        # update the shapes
                        dst_shape = model_state_dict[checkpoint_key].shape
                        src_shape = checkpoint_state_dict[checkpoint_key].shape
                        assert len(dst_shape) == len(
                            src_shape
                        ), f"Shapes of {checkpoint_key} are different: {dst_shape} != {src_shape}"

                    ## Shapes adapter
                    # modify the checkpoint state dict only if the shapes are different
                    if dst_shape != src_shape:
                        # create a copy of the tensor
                        tmp = torch.clone(checkpoint_state_dict[checkpoint_key])

                        # iterate over all dimensions
                        for i in range(len(dst_shape)):
                            if dst_shape[i] != src_shape[i]:
                                diff = dst_shape[i] - src_shape[i]

                                # if the difference is greater than 0, we need to add missing blocks
                                if diff > 0:
                                    missing_shape = list(tmp.shape)
                                    missing_shape[i] = diff
                                    missing = self._create_block(
                                        shape=missing_shape,
                                        strategy=strategy,
                                        input=tmp,
                                    )
                                    tmp = torch.cat((tmp, missing), dim=i)
                                    logging.info(
                                        f"Adapting {checkpoint_key} with strategy:{strategy} from shape {src_shape} to {dst_shape}"
                                    )
                                # if the difference is less than 0, we need to cut the block
                                else:
                                    tmp = tmp.narrow(i, 0, dst_shape[i])
                                    logging.info(
                                        f"Adapting {checkpoint_key} by narrowing from shape {src_shape} to {dst_shape}"
                                    )

                        checkpoint_state_dict[checkpoint_key] = tmp
        end = time.perf_counter()
        logging.info(f"StateDictAdapter took {end-start:.2f} seconds")
        return checkpoint_state_dict


class StateDictRenamer:
    """
    StateDictRenamer for renaming keys in a checkpoint state dict.
    This class will iterate over all keys in the checkpoint state dict and rename them according to a rename dict.

    Example:

        ```
        renamer = StateDictRenamer()
        new_state_dict = renamer(
            checkpoint_state_dict=state_dict,
            rename_dict={
                "add_embedding.linear_1.weight": "class_embedding.linear_1.weight",
                "add_embedding.linear_1.bias": "class_embedding.linear_1.bias",
                "add_embedding.linear_2.weight": "class_embedding.linear_2.weight",
                "add_embedding.linear_2.bias": "class_embedding.linear_2.bias",
            }
        )
        ```

    Args:

        checkpoint_state_dict (Dict[str, torch.Tensor]): The checkpoint state dict.
        rename_dict (Dict[str, str]): The dictionary mapping the old keys to new keys
    """

    def __call__(
        self,
        checkpoint_state_dict: Dict[str, torch.Tensor],
        rename_dict: Dict[str, str],
    ) -> Dict[str, torch.Tensor]:
        for old_key, new_key in rename_dict.items():
            if old_key not in checkpoint_state_dict:
                logging.warning(f"Key {old_key} not found in checkpoint state dict")
                continue
            else:
                assert (
                    new_key not in checkpoint_state_dict
                ), f"Key {new_key} already exists in checkpoint state dict"
                checkpoint_state_dict[new_key] = checkpoint_state_dict.pop(old_key)
                logging.info(f"Renaming {old_key} to {new_key}")
        return checkpoint_state_dict


In [4]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

In [7]:
batch_size = 2 
def get_autoencoder_model_from_ckpt(ckpt_path):
    # load state dict from ckpt
    state_dict = torch.load(ckpt_path)
    autoencoder_state_dict = state_dict["autoencoder_state"]
    autoencoder = AutoencoderKL(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        num_channels=(64, 128, 128, 128),
        latent_channels=3,
        num_res_blocks=2,
        norm_num_groups=32,
        norm_eps=1e-6,
        attention_levels=(False, False, False, False),
        with_encoder_nonlocal_attn=False,
        with_decoder_nonlocal_attn=False,
    )

    autoencoder.load_state_dict(autoencoder_state_dict, strict=True)
    return autoencoder

autoencoder = get_autoencoder_model_from_ckpt('/home/huutien/sources/GenerativeModels/3D_training/results/3D_training_KL_20251229_101810/best_checkpoint.pth')


In [8]:
inputs = torch.randn(2, 1, 160, 160, 96).to(device)

inputs = inputs.to(device)

# forward pass
# with torch.no_grad():
#     latent = autoencoder.encode_stage_2_inputs(inputs)

# latent.shape



In [7]:

discriminator = PatchDiscriminator(spatial_dims=3, num_layers_d=3, num_channels=32, in_channels=1, out_channels=1)
discriminator.to(device)
adv_loss = PatchAdversarialLoss(criterion="least_squares")
loss_perceptual = PerceptualLoss(spatial_dims=3, network_type="squeeze", is_fake_3d=True, fake_3d_ratio=0.2)
loss_perceptual.to(device)


def KL_loss(z_mu, z_sigma):
    kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3, 4])
    return torch.sum(kl_loss) / kl_loss.shape[0]



In [9]:
reconstruction, z_mu, z_sigma = autoencoder(inputs)
# kl_loss = KL_loss(z_mu, z_sigma)

# logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
# loss_d_fake = adv_loss(logits_fake, target_is_real=False, for_discriminator=True)
# logits_real = discriminator(inputs.contiguous().detach())[-1]
# loss_d_real = adv_loss(logits_real, target_is_real=True, for_discriminator=True)
# discriminator_loss = (loss_d_fake + loss_d_real) * 0.5

In [23]:
latent = autoencoder.encode_stage_2_inputs(inputs)
latent.shape

torch.Size([2, 3, 20, 20, 12])

In [10]:
logits_fake.shape, reconstruction.shape, z_mu.shape, z_sigma.shape, discriminator_loss

(torch.Size([2, 1, 18, 18, 10]),
 torch.Size([2, 1, 160, 160, 96]),
 torch.Size([2, 3, 20, 20, 12]),
 torch.Size([2, 3, 20, 20, 12]),
 tensor(0.6692, grad_fn=<MulBackward0>))

In [10]:
diffusion_model = DiffusionModelUNet(
    spatial_dims=3,
    in_channels=3,
    out_channels=3,
    num_channels=(256, 512, 768),
    num_res_blocks=2,
    attention_levels=(False, True, True),
    norm_num_groups=32,
    norm_eps=1e-6,
    resblock_updown=True,
    num_head_channels=[0, 512, 768],
    with_conditioning=True,
    transformer_num_layers=1,
    cross_attention_dim=4,
    upcast_attention=True,
    use_flash_attention=False,
)



In [12]:
ckpt = torch.load(os.path.join('/home/huutien/sources/GenerativeModels/large_files', "diffusion_model.pth"), map_location=device)
ckpt.keys()

odict_keys(['conv_in.conv.weight', 'conv_in.conv.bias', 'time_embed.0.weight', 'time_embed.0.bias', 'time_embed.2.weight', 'time_embed.2.bias', 'down_blocks.0.resnets.0.norm1.weight', 'down_blocks.0.resnets.0.norm1.bias', 'down_blocks.0.resnets.0.conv1.conv.weight', 'down_blocks.0.resnets.0.conv1.conv.bias', 'down_blocks.0.resnets.0.time_emb_proj.weight', 'down_blocks.0.resnets.0.time_emb_proj.bias', 'down_blocks.0.resnets.0.norm2.weight', 'down_blocks.0.resnets.0.norm2.bias', 'down_blocks.0.resnets.0.conv2.conv.weight', 'down_blocks.0.resnets.0.conv2.conv.bias', 'down_blocks.0.resnets.1.norm1.weight', 'down_blocks.0.resnets.1.norm1.bias', 'down_blocks.0.resnets.1.conv1.conv.weight', 'down_blocks.0.resnets.1.conv1.conv.bias', 'down_blocks.0.resnets.1.time_emb_proj.weight', 'down_blocks.0.resnets.1.time_emb_proj.bias', 'down_blocks.0.resnets.1.norm2.weight', 'down_blocks.0.resnets.1.norm2.bias', 'down_blocks.0.resnets.1.conv2.conv.weight', 'down_blocks.0.resnets.1.conv2.conv.bias', 'dow

In [13]:
ckpt['conv_in.conv.weight'].shape, ckpt['conv_in.conv.bias'].shape

(torch.Size([256, 7, 3, 3, 3]), torch.Size([256]))

In [15]:
diffusion_model.state_dict()['conv_in.conv.weight'].shape , diffusion_model.state_dict()['conv_in.conv.bias'].shape

(torch.Size([256, 3, 3, 3, 3]), torch.Size([256]))

In [16]:
state_dict_adapter = StateDictAdapter()
state_dict = state_dict_adapter(
    model_state_dict=diffusion_model.state_dict(),
    checkpoint_state_dict=ckpt,
    regex_keys=[
        r"conv_in.conv.weight",
    ],
    strategy="zeros",
)

In [17]:
state_dict['conv_in.conv.weight'].shape

torch.Size([256, 3, 3, 3, 3])

In [18]:
ckpt['conv_in.conv.weight'][:,:3,:,:,:].shape

torch.Size([256, 3, 3, 3, 3])

In [19]:
(state_dict['conv_in.conv.weight'] == ckpt['conv_in.conv.weight'][:,:3,:,:,:]).sum()

tensor(20736)

In [16]:
device

device(type='cuda')

In [20]:
diffusion_model.load_state_dict(state_dict, strict=True)
diffusion_model.to(device)

diffusion_model.eval()

DiffusionModelUNet(
  (conv_in): Convolution(
    (conv): Conv3d(3, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  )
  (time_embed): Sequential(
    (0): Linear(in_features=256, out_features=1024, bias=True)
    (1): SiLU()
    (2): Linear(in_features=1024, out_features=1024, bias=True)
  )
  (down_blocks): ModuleList(
    (0): DownBlock(
      (resnets): ModuleList(
        (0-1): 2 x ResnetBlock(
          (norm1): GroupNorm(32, 256, eps=1e-06, affine=True)
          (nonlinearity): SiLU()
          (conv1): Convolution(
            (conv): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          )
          (time_emb_proj): Linear(in_features=1024, out_features=256, bias=True)
          (norm2): GroupNorm(32, 256, eps=1e-06, affine=True)
          (conv2): Convolution(
            (conv): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          )
          (skip_connection): Identity()
        )
      )


In [21]:
# conditioning = torch.tensor([[0., 0, 0, 0]]).to( device).unsqueeze(1)
conditioning = torch.zeros(2,1, 4).to(device)
# use yaml config 


In [22]:
conditioning.shape

torch.Size([2, 1, 4])

In [None]:
latent = latent.to(device)
timesteps = torch.zeros(2, device=device)
prediction = diffusion_model(latent, timesteps, context=conditioning)
prediction.shape


torch.Size([2, 3, 20, 20, 12])

: 

In [27]:
import math

def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor:
    """
    Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic
    Models" https://arxiv.org/abs/2006.11239.

    Args:
        timesteps: a 1-D Tensor of N indices, one per batch element.
        embedding_dim: the dimension of the output.
        max_period: controls the minimum frequency of the embeddings.
    """
    if timesteps.ndim != 1:
        raise ValueError("Timesteps should be a 1d-array")

    half_dim = embedding_dim // 2
    exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
    freqs = torch.exp(exponent / half_dim)

    args = timesteps[:, None].float() * freqs[None, :]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)

    # zero pad
    if embedding_dim % 2 == 1:
        embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0))

    return embedding


In [33]:
t_emb = get_timestep_embedding(timesteps, diffusion_model.block_out_channels[0])

# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=latent.dtype)
print

emb = diffusion_model.time_embed(t_emb)



RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

In [31]:
t_emb.device

device(type='cuda', index=0)

In [38]:
diffusion_model.time_embed[0].weight.device

device(type='cpu')

In [39]:
diffusion_model.conv_in.conv.weight.device

device(type='cpu')