In [1]:
import os
from pathlib import Path

import itertools
import torch
from models.ddpm import *
from models.unet import Unet
from models.ddpm_classifier_free import Unet as Unet_class
from utils.image_utils import save_image_to_dir, save_patches_to_dir
from utils.model_utils import (load_model, load_classifier_free_model, generate_whole_image, create_lcl_ctx_channels, 
                               create_inputs, generate_patches, stitch_patches, create_patch_channels)
from config import IS_COND, OVERLAP, MID_IMAGE_SIZE, FINAL_IMAGE_SIZE, total_timesteps, WH_PATH, LC_PATH, PH_PATH, SAVE_DIR
import einops
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from datasets.vindr_local_context import VINDR_Dataset
import matplotlib.pyplot as plt

In [3]:
model_name = 'vindr_local_context_256'
model_dir = '/lustre/mambo/results/vindr_local_context_256/25_02_21_09_24/models/model_'
save_dir = '/lustre/mambo/data/ddim_early_stopping/'
images_path = '/data/vindr-mammo-a-large-scale-benchmark-dataset-for-computer-aided-detection-and-diagnosis-in-full-field-digital-mammography-1.0.0/images'
data_csv_path = '/data/vindr-mammo-a-large-scale-benchmark-dataset-for-computer-aided-detection-and-diagnosis-in-full-field-digital-mammography-1.0.0/finding_annotations.csv'

In [None]:
for model_id in range(1999, 58000, 2000):

    model_path = model_dir + f'{str(model_id)}.pt'
    patch_model = load_model(model_path, channels=3)
    patch_model = patch_model.to('cuda')
    orig_path = os.path.join(save_dir, model_name, 'orig/')
    os.makedirs(orig_path, exist_ok=True)
    save_path = os.path.join(save_dir, model_name, str(model_id))
    print('model loaded ', model_path)
    if os.path.exists(save_path):
        print('exists ', save_path)
        continue
    dataset = VINDR_Dataset(csv_path=data_csv_path, images_path=images_path)
    for i in range(100):
        try:
            x = dataset[i]
    #             save_dir = f'/lustre/ddim-gen/vindr/patches/{i}/'
            os.makedirs(save_path, exist_ok=True)
            print(save_path)
            plt.imsave(orig_path + f'orig_{str(i)}.png', x[0], cmap='gray')

            patch = ddim_sample_patch(patch_model, einops.rearrange(x, 'c h w -> 1 c h w'), sampling_timesteps=150)
            save_image_to_dir(np.array(patch), str(Path(save_path) / f'patch_{str(i)}.png'))
        except:
            print('error with num ', i)
            pass


model loaded  /lustre/mambo/results/vindr_local_context_256/25_02_21_09_24/models/model_1999.pt
exists  /lustre/mambo/data/ddim_early_stopping/vindr_local_context_256/1999
model loaded  /lustre/mambo/results/vindr_local_context_256/25_02_21_09_24/models/model_3999.pt
exists  /lustre/mambo/data/ddim_early_stopping/vindr_local_context_256/3999
model loaded  /lustre/mambo/results/vindr_local_context_256/25_02_21_09_24/models/model_5999.pt
exists  /lustre/mambo/data/ddim_early_stopping/vindr_local_context_256/5999
model loaded  /lustre/mambo/results/vindr_local_context_256/25_02_21_09_24/models/model_7999.pt
exists  /lustre/mambo/data/ddim_early_stopping/vindr_local_context_256/7999
model loaded  /lustre/mambo/results/vindr_local_context_256/25_02_21_09_24/models/model_9999.pt
exists  /lustre/mambo/data/ddim_early_stopping/vindr_local_context_256/9999
model loaded  /lustre/mambo/results/vindr_local_context_256/25_02_21_09_24/models/model_11999.pt
exists  /lustre/mambo/data/ddim_early_stopp

sampling loop time step: 100%|██████████| 150/150 [00:05<00:00, 27.39it/s]


/lustre/mambo/data/ddim_early_stopping/vindr_local_context_256/25999


sampling loop time step: 100%|██████████| 150/150 [00:05<00:00, 29.14it/s]


/lustre/mambo/data/ddim_early_stopping/vindr_local_context_256/25999


sampling loop time step: 100%|██████████| 150/150 [00:05<00:00, 29.17it/s]


/lustre/mambo/data/ddim_early_stopping/vindr_local_context_256/25999


sampling loop time step: 100%|██████████| 150/150 [00:05<00:00, 28.89it/s]


/lustre/mambo/data/ddim_early_stopping/vindr_local_context_256/25999


sampling loop time step: 100%|██████████| 150/150 [00:05<00:00, 28.41it/s]


/lustre/mambo/data/ddim_early_stopping/vindr_local_context_256/25999


sampling loop time step: 100%|██████████| 150/150 [00:05<00:00, 28.52it/s]


/lustre/mambo/data/ddim_early_stopping/vindr_local_context_256/25999


sampling loop time step: 100%|██████████| 150/150 [00:05<00:00, 28.45it/s]


/lustre/mambo/data/ddim_early_stopping/vindr_local_context_256/25999


sampling loop time step: 100%|██████████| 150/150 [00:05<00:00, 28.99it/s]


/lustre/mambo/data/ddim_early_stopping/vindr_local_context_256/25999


sampling loop time step: 100%|██████████| 150/150 [00:05<00:00, 29.11it/s]


/lustre/mambo/data/ddim_early_stopping/vindr_local_context_256/25999


sampling loop time step:  24%|██▍       | 36/150 [00:01<00:03, 28.92it/s]


error with num  9
/lustre/mambo/data/ddim_early_stopping/vindr_local_context_256/25999


sampling loop time step:  92%|█████████▏| 138/150 [00:04<00:00, 29.15it/s]

In [12]:
patch.shape

torch.Size([256, 256])