In [6]:
import inspect
import math
import os
import pdb
import datetime
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Iterable, List, Optional, Union, Literal
import pandas as pd
import diffusers
import numpy as np
import torch
import torch.nn.functional as F
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline, DPMSolverMultistepScheduler
from IPython.display import display
from PIL import Image
from torch import Tensor, autocast
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms import functional as TVF
from torchvision.utils import make_grid
from tqdm.rich import tqdm, trange
from transformers import AutoTokenizer

In [7]:
# Device and autograd
ctx = torch.inference_mode()
ctx.__enter__()
device = 'cuda'
dtype = torch.float16

# Set up the experiment
prompt = 'An image of {}, a skin disease'
resolution = 512
batch_size = 16
model_type = "text-to-image"
#pretrained_model_name_or_path = "runwayml/stable-diffusion-inpainting"
pretrained_model_name_or_path = "stabilityai/stable-diffusion-2-1-base"
start_index = 0
num_generations_per_image = 1
seed = 42
guidance_scale = 3.0
num_inference_steps = 50

In [8]:
# Model
print('Loading model')
if model_type == 'inpaint':
    pipeline = StableDiffusionInpaintPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=dtype,
        safety_checker=None, feature_extractor=None, requires_safety_checker=False)
elif model_type == 'text-to-image':
    pipeline = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=dtype,
        safety_checker=None, feature_extractor=None, requires_safety_checker=False)
else:
    raise ValueError(model_type)
pipeline.set_progress_bar_config(disable=True)
pipeline.to(device)

print(f'Loaded pipeline with {sum(p.numel() for p in pipeline.unet.parameters()):_} unet parameters')


Loading model


Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

Loaded pipeline with 865_910_724 unet parameters


In [9]:
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
import os
from pathlib import Path


class CustomDataset(Dataset):
    def __init__(self, dataset_dir, instance_prompt, transform=None):
        self.dataset_dir = Path(dataset_dir)
        self.transform = transform
        self.instance_prompt = instance_prompt
        
        # Build list of image paths and their labels
        self.image_paths = []
        self.labels = []
        
        # Walk through class folders
        for class_name in os.listdir(self.dataset_dir):
            class_dir = self.dataset_dir / class_name
            if not class_dir.is_dir():
                continue
                
            # Get all jpg images in this class folder
            for img_name in os.listdir(class_dir):
                if img_name.lower().endswith('.png'):
                    self.image_paths.append(class_dir / img_name)
                    self.labels.append(class_name)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        
        # Load and convert image
        image = Image.open(image_path).convert('RGB')
        image_name = image_path.stem
        
        prompt = self.instance_prompt.format(label)

        if self.transform:
            image = self.transform(image)

        return {"prompt": prompt, "image_name": image_name, "pixel_values": image, "label": label}


In [10]:
transform = transforms.Compose([
    transforms.Resize((resolution, resolution)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),  # Normalize images to [-0.5, 0.5]
])

In [11]:
from torch.utils.data import DataLoader

dataset_directory = "sample_dataset"

custom_dataset = CustomDataset(dataset_directory, instance_prompt = prompt, transform=transform)
dataloader = DataLoader(custom_dataset, batch_size=batch_size, shuffle=True)


In [12]:
# Randomness
generator = torch.Generator(device=device)
generator.manual_seed(seed + start_index)

<torch._C.Generator at 0x7f1a2d445050>

In [13]:
# # Parse args
# output_dir_path = Path("/n/scratch/users/t/thb286/generation_test")

# def get_output_paths(batch: dict, stage: str, idx: int) -> list[Path]:
#     return [
#         output_dir_path / stage / f'{idx:02d}' / f'{image_name}.png'
#         for image_name in batch['image_name']
#     ]

def save(image, path):
    path = Path(path) if isinstance(path, str) else path
    path.parent.mkdir(exist_ok=True, parents=True)
    image.save(path)



In [28]:
# Generate the text-to-image 
output_dir_path = Path("/n/scratch/users/t/thb286/generation_test")

# text-to-image generation
for idx in range(start_index, start_index + num_generations_per_image):
    for batch_idx, batch in enumerate(tqdm(dataloader)):

        # Shared arguments
        gen_kwargs = dict(
            prompt=batch["prompt"],
            guidance_scale=guidance_scale,
            generator=generator,
            num_inference_steps=num_inference_steps,
            height=resolution,
            width=resolution,
        )

        # Text-to-image
        if model_type == 'text-to-image':
            #output_paths = get_output_paths(batch, 'text-to-image', idx)

            output_paths = [
                output_dir_path / "text-to-image" / f"{idx:02d}" / label / f"{name}.png"
                for label, name in zip(batch["label"], batch["image_name"])
            ]

            if all(output_path.is_file() for output_path in output_paths):
                continue  # Images have already been generated, skip this batch

            # Generate images
            images = pipeline(**gen_kwargs).images
            assert len(images) == len(output_paths)
            for image, path in zip(images, output_paths):
                save(image, path)

            # Image grid
            if batch_idx < 10:
                grid_images = [transforms.ToTensor()(img) for img in images]
                original_images = [img * 0.5 + 0.5 for img in batch["pixel_values"]]
                grid = make_grid(grid_images + original_images, nrow=batch_size, padding=4, pad_value=1.0)
                grid_path = output_dir_path / "grid" / f'{idx:02d}-batch-{batch_idx:02d}.png'
                save(transforms.ToPILImage()(grid), grid_path)

                if batch_idx % 1000 == 0:
                    print(f'[Repeat {idx}, batch {batch_idx}] Saved image grid to {grid_path}')


Output()

  for batch_idx, batch in enumerate(tqdm(dataloader)):


In [29]:
from PIL import Image
from IPython.display import display

image_path = "/n/scratch/users/t/thb286/generation_test/00/allergic-contact-dermatitis/0005.png"
#display(Image.open(image_path))

In [30]:
from PIL import Image
from IPython.display import display

image_path = "sample_dataset/allergic-contact-dermatitis/0001.png"
#display(Image.open(image_path))

In [None]:
# Generate a single batch of synthetic images, and display in this notebook using the Grid format that Luke designed

In [None]:
# Run the following script to generate lots of images

In [None]:
# In this notebook, generate about 100 synthetic images from a dummy dataset (with labels) to demonstrate how it works. There should be streamlined functionality to do this (just choose method, backbone, etc.) 

# generate_synthetic_dataset(
# real_images, # needs to be a metadata dataframe -- containing where to find an image, the label, etc.
# map_real_to_synthetic_label,  (this would be the hash for fitz, or some kind of unique ID which we need to store)
# method = "text-to-image", "inpaint", "outpaint",
# text_label, (for text-to-image, this is the image description "label")
# text_prompt,   # the prompt that will be applied to the label -- this is optional
# num_synthetic_per_real,  (defaults to 10)
# num_total, # option for specifying the total dataset size we want -- we will handle dataset balance
# num_total_type, # options are balanced, same --  works with num_total, do we want to make a balanced dataset or keep the same proportions
# output_dir,
# model_path = "xxx" can set this if we want to use a custom model
# )

# Then, we can show how to run the script for generating larger amounts of images

In [None]:


# generate_synthetic_dataset(
# real_images, 
# map_real_to_synthetic_label,  (this would be the hash for fitz, or some kind of unique ID which we need to store)
# method = "text-to-image", "inpaint", "outpaint",
# text_label, (for text-to-image, this is the image description "label")
# text_prompt,   # the prompt that will be applied to the label -- this is optional
# num_synthetic_per_real,  (defaults to 10)
# num_total, # option for specifying the total dataset size we want -- we will handle dataset balance
# num_total_type, # options are balanced, same --  works with num_total, do we want to make a balanced dataset or keep the same proportions
# output_dir,
# model_path = "xxx" can set this if we want to use a custom model
# )



In [None]:
! python generate.py --output_root generations-pretrained --instance_data_dir=${FITZPATRICK17K_DATASET_DIR} --model_type "text-to-image" --pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1-base" --instance_prompt="An image of {}, a skin disease" --disease_class=allergic-contact-dermatitis

In [None]:
# def generate_synthetic_dataset(
#     real_images,
#     map_real_to_synthetic_label,
#     method="text-to-image",
#     text_label=None,
#     text_prompt=None,
#     num_synthetic_per_real=10,
#     num_total=None,
#     num_total_type='balanced',
#     output_dir=None,
#     model_path="xxx"
# ):
#     for image, label in real_images:
#         # Use image and label to generate synthetic images
#         pass


# real_images parameter should ideally be a PyTorch Dataset object that returns tuples of (image, label). This means that function is flexible and compatible with PyTorch's data handling utilities, making it easier to integrate into training pipelines. we should also use this format for the train/test split functionality we are including


# In the function, we should be able to generate an augmented dataset that contains the mask, etc

# The longer version of this will run in a .py script with command line args

# TODO: make sure the train/test split function uses the same types of inputs, we need to have consistency across functions