In [2]:
import os
gpu_idx= '3'
os.environ["CUDA_VISIBLE_DEVICES"]= gpu_idx


import numpy as np
import glob
from tqdm import tqdm

import torch
from torch.cuda.amp import autocast

import yaml
import nibabel as nib

from monai import transforms
from generative.networks.schedulers import DDIMScheduler
from generative.networks.nets import DiffusionModelUNet, AutoencoderKL

from model.vcm import VCM
from config.model_config import defaultCFG

In [4]:
train_transforms = transforms.Compose([
        transforms.LoadImaged(keys=["image"]),
        transforms.EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
        transforms.EnsureTyped(keys=["image"]),
        transforms.Orientationd(keys=["image"], axcodes="LAS"), # torch.Size([240, 240, 155])
        transforms.AsDiscreted(keys=['image'], to_onehot=9)
    ] 
)

newinput_list = glob.glob('./data/*/new_semantics.nii.gz')

for newinput_path in tqdm(newinput_list):
    PREFIX = '/'.join(newinput_path.split('/')[:-1])
    d= {'image':newinput_path}

    d = train_transforms(d)
    img = d['image']

    img = img[1:, ...]
    
    torch.save(img, rf'{PREFIX}/new_semantics.pt')

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [00:25<00:00,  1.93it/s]


In [8]:
class NiftiSaver:
    def __init__(self, output_dir: str) -> None:
        super().__init__()
        self.output_dir = output_dir
        self.affine = np.array(
            [
                [-1.0, 0.0, 0.0, 96.48149872],
                [0.0, 1.0, 0.0, -141.47715759],
                [0.0, 0.0, 1.0, -156.55375671],
                [0.0, 0.0, 0.0, 1.0],
            ]
        )
        
        self.set_output_dir(self.output_dir)
        
    def set_output_dir(self, output_dir):
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        

    def save(self, image_data: torch.Tensor, file_name: str, is_label=False) -> None:
        image_data = image_data.cpu().numpy()
        image_data = image_data[0, 0, ...]
        if is_label:
            image_data = image_data.astype(np.uint8)  

        empty_header = nib.Nifti1Header()
        sample_nii = nib.Nifti1Image(image_data, self.affine, empty_header)
        nib.save(sample_nii, f"{str(self.output_dir)}/{file_name}.nii.gz")
        
    def save_label(self, img_path, file_name: str) -> None:
        img_nib = nib.load(img_path)
        empty_header = nib.Nifti1Header()
        sample_nii = nib.Nifti1Image(np.array(img_nib.dataobj), self.affine, empty_header)
        nib.save(sample_nii, f"{str(self.output_dir)}/{file_name}.nii.gz")

In [11]:
cfg = defaultCFG()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('cuda setup')
print(f"\tUsing {device}: {gpu_idx}")

print('model construction, load the weights')
autoencoder = AutoencoderKL(**cfg.get_AE_CFG())
AE_weight_path = 'weights/autoencoder.pth'
autoencoder.load_state_dict(torch.load(AE_weight_path))
autoencoder.to(device)
print(f"\t AE done")

diffusion = DiffusionModelUNet(**cfg.get_DM_CFG())
Diff_weight_path = 'weights/diffusion_model.pth'
diffusion.load_state_dict(torch.load(Diff_weight_path))
diffusion.to(device)
print(f"\t Diffusion done")



cuda setup
	Using cuda: 3
model construction, load the weights
	 AE done
	 Diffusion done


In [12]:
val_scheduler = DDIMScheduler(num_train_timesteps=1000, schedule="scaled_linear_beta", beta_start=0.0015, beta_end=0.0205, clip_sample=False)
val_scheduler.set_timesteps(num_inference_steps=200)

src_path = 'data/*'
val_data_list = glob.glob(src_path)

VCM_enc_CFG, enc_CFG = cfg.get_VCM_enc_CFG()
vcm = VCM(out_dim=3, diff_CFG=VCM_enc_CFG, enc_CFG=enc_CFG)
VCM_weight_path = f'weights/vcm_wegith_complexSemantics.pt'
acceler_dict = torch.load(VCM_weight_path)
new_d={}
for old_key in acceler_dict.keys():
    if 'module.' in old_key:
        new_key = old_key.replace('module.', '')
    new_d[new_key] = acceler_dict[old_key]
    
del acceler_dict
vcm.load_state_dict(new_d, strict=False)
vcm.to(device)
print(f"\t VCM done")

scale_factor = 0.8962649106979370
print(scale_factor)

saver = NiftiSaver(f'generated/VCM_complexSemantics')


for path in tqdm(val_data_list):

    sid = path.split('/')[-1]
    
    d = {
            'seg':torch.load(f'{path}/new_semantics.pt')
            
            }

    cond = torch.load(f'{path}/sex-age-ventV-brainV.pt')
    val_cond = cond.unsqueeze(0)
    
    val_label = d['seg'] 
    val_label = val_label.unsqueeze(0)
    
    val_BZ = val_label.shape[0]

    noise = torch.randn((val_BZ, 3, 20, 28, 20))
    noise = noise.to(device)

    image = noise
    image4LDM = noise.detach().clone()
    
    val_cond = val_cond.view(val_BZ, 1, 4).to(device)
    val_cond_concat = val_cond.view(val_BZ, 4, 1, 1, 1).to(device)
    val_cond_concat = val_cond_concat.expand(list(val_cond_concat.shape[0:2]) + list(image.shape[2:]))

    with torch.no_grad():
        with autocast():
            progress_bar = val_scheduler.timesteps
            for t in progress_bar:
                
                timesteps = torch.Tensor((t,)).to(device).long()
                
                # contorlled by VCM
                epsilon = diffusion(torch.cat((image, val_cond_concat), dim=1),
                                        timesteps=timesteps,
                                        context=val_cond,)
                scale, shift = vcm(x=torch.cat([image, epsilon], dim=1),
                                    y=val_label.to(device),
                                    timesteps=timesteps)
                vcm_out = epsilon * (1+scale) + shift
                image, _ = val_scheduler.step(vcm_out, t, image)
                
                # uncontrolled (BrainLDM)
                ldm_out = diffusion(torch.cat((image4LDM, val_cond_concat), dim=1),
                                        timesteps=timesteps,
                                        context=val_cond,)
                image4LDM, _ = val_scheduler.step(ldm_out, t, image4LDM)
                
            # decode to MRI
            vcm_sample = autoencoder.decode_stage_2_outputs(image.to(device)/scale_factor)
            ldm_sample = autoencoder.decode_stage_2_outputs(image4LDM.to(device)/scale_factor)
            saver.save(vcm_sample, f'{sid}__2-VCM')
            saver.save(ldm_sample, f'{sid}__1-LDM')
            
    saver.save_label(f'{path}/T1.nii.gz', f'{sid}__0-ori_T1')
    saver.save_label(f'{path}/new_semantics.nii.gz', f'{sid}__3-input_seg')
    

3500
/root/vcm/VCM/out/data500/newSemantics/VCM/2024-09-02/log/3500
	 VCM done
0.896264910697937


  return F.conv3d(
  ret = func(*args, **kwargs)
100%|██████████| 50/50 [38:46<00:00, 46.54s/it]
