In [1]:
import os
from pathlib import Path

import itertools
import torch
from models.ddpm import *
from models.unet import Unet
from models.ddpm_classifier_free import Unet as Unet_class
from utils.image_utils import save_image_to_dir, save_patches_to_dir
from utils.model_utils import (load_model, load_classifier_free_model, generate_whole_image, create_lcl_ctx_channels, 
                               create_inputs, generate_patches, stitch_patches, create_patch_channels)
from config import IS_COND, OVERLAP, MID_IMAGE_SIZE, FINAL_IMAGE_SIZE, total_timesteps, RSNA_WH_PATH, RSNA_LC_PATH, RSNA_PH_PATH, SAVE_DIR
import einops

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from datasets.rsna_patch import RSNA_Dataset
import matplotlib.pyplot as plt

In [3]:
patch_model = load_model(RSNA_PH_PATH, channels=3)
patch_model = patch_model.to('cuda')

In [4]:
for i, x in enumerate(RSNA_Dataset()):
    try:
        save_dir = f'/lustre/ddim-gen/rsna/patches/{i}/'
        os.makedirs(save_dir, exist_ok=True)
        plt.imsave(save_dir + 'orig.png', x[0], cmap='gray')

        for steps in [50, 100, 150, 200]:
                patch = ddim_sample_patch(patch_model, einops.rearrange(x, 'c h w -> 1 c h w'), sampling_timesteps=steps)
                plt.imsave(save_dir + f'{steps}.png', patch, cmap='gray')
    except:
        pass

10722
14898


sampling loop time step: 100%|██████████| 50/50 [00:02<00:00, 24.90it/s]
sampling loop time step: 100%|██████████| 100/100 [00:03<00:00, 29.31it/s]
sampling loop time step: 100%|██████████| 150/150 [00:05<00:00, 29.31it/s]
sampling loop time step: 100%|██████████| 200/200 [00:06<00:00, 29.31it/s]


10722
14898


sampling loop time step: 100%|██████████| 50/50 [00:01<00:00, 29.32it/s]
sampling loop time step: 100%|██████████| 100/100 [00:03<00:00, 29.30it/s]
sampling loop time step: 100%|██████████| 150/150 [00:05<00:00, 29.29it/s]
sampling loop time step: 100%|██████████| 200/200 [00:06<00:00, 29.27it/s]


10722
14898


sampling loop time step: 100%|██████████| 50/50 [00:01<00:00, 29.28it/s]
sampling loop time step: 100%|██████████| 100/100 [00:03<00:00, 29.27it/s]
sampling loop time step: 100%|██████████| 150/150 [00:05<00:00, 29.26it/s]
sampling loop time step: 100%|██████████| 200/200 [00:06<00:00, 29.27it/s]


10722
14898


sampling loop time step: 100%|██████████| 50/50 [00:01<00:00, 29.28it/s]
sampling loop time step: 100%|██████████| 100/100 [00:03<00:00, 29.27it/s]
sampling loop time step: 100%|██████████| 150/150 [00:05<00:00, 29.27it/s]
sampling loop time step: 100%|██████████| 200/200 [00:06<00:00, 29.25it/s]


10722
14898


sampling loop time step: 100%|██████████| 50/50 [00:01<00:00, 29.29it/s]
sampling loop time step: 100%|██████████| 100/100 [00:03<00:00, 29.27it/s]
sampling loop time step: 100%|██████████| 150/150 [00:05<00:00, 29.27it/s]
sampling loop time step: 100%|██████████| 200/200 [00:06<00:00, 29.24it/s]


4348
14898


sampling loop time step: 100%|██████████| 50/50 [00:01<00:00, 29.24it/s]
sampling loop time step: 100%|██████████| 100/100 [00:03<00:00, 29.23it/s]
sampling loop time step: 100%|██████████| 150/150 [00:05<00:00, 29.24it/s]
sampling loop time step: 100%|██████████| 200/200 [00:06<00:00, 29.25it/s]


4348
14898


sampling loop time step: 100%|██████████| 50/50 [00:01<00:00, 29.22it/s]
sampling loop time step: 100%|██████████| 100/100 [00:03<00:00, 29.23it/s]
sampling loop time step:  65%|██████▍   | 97/150 [00:03<00:01, 28.95it/s]

KeyboardInterrupt

