In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from monai import transforms
from monai.apps import MedNISTDataset
from monai.config import print_config
from monai.data import DataLoader, Dataset
from monai.utils import first, set_determinism
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from PIL import Image

from generative.inferers import LatentDiffusionInferer
from generative.losses.adversarial_loss import PatchAdversarialLoss
from generative.losses.perceptual import PerceptualLoss
from generative.networks.nets import AutoencoderKL, DiffusionModelUNet, PatchDiscriminator
from generative.networks.schedulers import DDPMScheduler

print_config()

MONAI version: 1.3.1
Numpy version: 1.26.3
Pytorch version: 2.1.0+cu118
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 96bfda00c6bd290297f5e3514ea227c6be4d08b4
MONAI __file__: /home/<username>/.pyenv/versions/3.9.13/envs/Medical/lib/python3.9/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: NOT INSTALLED or UNKNOWN VERSION.
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
scipy version: 1.13.0
Pillow version: 10.2.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: 5.2.0
TorchVision version: 0.16.0+cu118
tqdm version: 4.66.2
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.8
pandas version: 2.2.2
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INST

## config

In [3]:
random_seed = 26
image_size = 128

gen_type = "covid"
root_dir = f"./model"
model_path = f"./{root_dir}"
device = torch.device("cuda")

gentype2scale_factor = {
    "covid": 0.9045405983924866,
    'normal': 0.9455165863037109,
    "pneumonia_vir": 0.9373854398727417,
    "pneumonia_bac": 0.9436339735984802,
}

In [4]:
import torch.nn as nn


autoencoderkl = AutoencoderKL(
    spatial_dims=2,
    in_channels=3,
    out_channels=3,
    num_channels=(128, 128, 256),
    latent_channels=3,
    num_res_blocks=2,
    attention_levels=(False, False, False),
    with_encoder_nonlocal_attn=False,
    with_decoder_nonlocal_attn=False,
)
# autoencoderkl = nn.DataParallel(autoencoderkl)
autoencoderkl = autoencoderkl.to(device)

In [5]:
discriminator = PatchDiscriminator(spatial_dims=2, num_layers_d=3, num_channels=64, in_channels=3, out_channels=3)
discriminator = discriminator.to(device)

adv_loss = PatchAdversarialLoss(criterion="least_squares")
adv_weight = 0.01

In [6]:
autoencoderkl.load_state_dict(torch.load(os.path.join(model_path, f"{gen_type}_best_autoencoderkl.pth"), map_location='cuda:0'))
autoencoderkl.eval()

AutoencoderKL(
  (encoder): Encoder(
    (blocks): ModuleList(
      (0): Convolution(
        (conv): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (1-2): 2 x ResBlock(
        (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
        (conv1): Convolution(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
        (conv2): Convolution(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (nin_shortcut): Identity()
      )
      (3): Downsample(
        (conv): Convolution(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
        )
      )
      (4-5): 2 x ResBlock(
        (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
        (conv1): Convolution(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (norm2): GroupNorm(3

In [7]:
unet = DiffusionModelUNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=3,
    num_res_blocks=2,
    num_channels=(128, 256, 512),
    attention_levels=(False, True, True),
    num_head_channels=(0, 256, 512),
)

In [8]:
unet.load_state_dict(torch.load(os.path.join(model_path, f"{gen_type}_best_unet_model.pth"), map_location='cuda:0'))
unet = unet.to(device)
unet.eval()

DiffusionModelUNet(
  (conv_in): Convolution(
    (conv): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (time_embed): Sequential(
    (0): Linear(in_features=128, out_features=512, bias=True)
    (1): SiLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
  )
  (down_blocks): ModuleList(
    (0): DownBlock(
      (resnets): ModuleList(
        (0-1): 2 x ResnetBlock(
          (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
          (nonlinearity): SiLU()
          (conv1): Convolution(
            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (time_emb_proj): Linear(in_features=512, out_features=128, bias=True)
          (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
          (conv2): Convolution(
            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (skip_connection): Identity()
        )
      )
      (downsampler): Downsample

In [9]:
scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="linear_beta", beta_start=0.0015, beta_end=0.0195)
inferer = LatentDiffusionInferer(scheduler, scale_factor=gentype2scale_factor[gen_type])

In [10]:
scheduler.set_timesteps(num_inference_steps=1000)


In [11]:
import torchvision.transforms.functional as F
from torchvision.utils import save_image

cnt = 0
save_dir=f"latent_diffusion2/{gen_type}"
os.makedirs(save_dir, exist_ok=True)

for i in range(0, 3000, 100):
    noise = torch.randn((100, 3, 32, 32))
    noise = noise.to(device)
    
    with torch.no_grad():
        image, intermediates = inferer.sample(
            input_noise=noise,
            diffusion_model=unet,
            scheduler=scheduler,
            save_intermediates=True,
            intermediate_steps=100,
            autoencoder_model=autoencoderkl,
        )
        
    for i in image:
        save_image(i.cpu()[0, :, :].permute(1, 0), os.path.join(save_dir, f"{cnt}.png"))
        cnt += 1
        if cnt==3000:
            break

100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:37<00:00,  4.60it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:37<00:00,  4.59it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:37<00:00,  4.59it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:37<00:00,  4.59it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:37<00:00,  4.59it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:37<00:00,  4.59it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [03:37<00:00,  4.59it/s]
100%|█