In [1]:
# generate the dataset
from synderm.synderm.generation.generate import generate_synthetic_dataset
from synderm.synderm.fine_tune.text_to_image_diffusion import fine_tune_text_to_image

from torch.utils.data import Dataset
from PIL import Image
from pathlib import Path
import os

## Train with synthetic images
This notebook will walk you through the entire process of augmenting a classifier with synthetic images

### 1. Creating the dataset
The first step is the create a Pytorch dataset. Example datasets are listed in `sample_datasets.py`. For this example, we will use a simplified dataset that contains 10 classes, with 110 training and 32 validation samples per class. This dataset is included at `/sample_dataset`. 

For datasets to work with methods in this package, each entry must contain an `image` field returning a PIL Image, a `label` field with the label, and an `id` field containing a unique ID for each image.

In [4]:
class SampleDataset(Dataset):
    def __init__(self, dataset_dir, split="train"):
        self.dataset_dir = Path(dataset_dir)
        self.image_paths = []
        self.labels = []
        self.split = split

        # Walk through class folders
        data_dir = self.dataset_dir / self.split
        for class_name in os.listdir(data_dir):
            class_dir = data_dir / class_name
            if not class_dir.is_dir():
                continue
                
            # Get all png images in this class folder
            for img_name in os.listdir(class_dir):
                if img_name.lower().endswith('.png'):
                    self.image_paths.append(class_dir / img_name)
                    self.labels.append(class_name)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        
        # Load and convert image to RGB
        image = Image.open(image_path).convert('RGB')
        image_name = image_path.stem

        return {"id": image_name, "image": image, "label": label}



In [None]:
sample_dataset = SampleDataset(dataset_dir="sample_dataset", split="train")

### 2. Train the synthetic image generator
Now that we have a dataset, we will train a diffusion model using Dreambooth on our training set of images. This will result in generating images more similar to our training data.

In [None]:
fine_tune_text_to_image(
    train_dataset=sample_dataset,
    pretrained_model_name_or_path = "stabilityai/stable-diffusion-2-1-base",
    instance_prompt = "An image of {}, a skin disease",
    validation_prompt_format = "An image of {}, a skin disease",
    output_dir = "/n/scratch/users/t/thb286/dreambooth-outputs",
    #label_filter = "allergic-contact-dermatitis",
    resolution = 512,
    train_batch_size = 4,
    gradient_accumulation_steps = 1,
    learning_rate = 5e-6,
    lr_scheduler = "constant",
    lr_warmup_steps = 0,
    num_train_epochs = 4,
    report_to = "wandb"
)


In [None]:
# Train a dreambooth classifier using the train split

In [None]:
# Generate synthetic images using the classifier that we just trained

In [9]:
# Generate synthetic data using the fine-tuned model
generate_synthetic_dataset(
    output_dir_path = Path("test_outputs"),
    generation_type = "inpaint",
    model_path = "runwayml/stable-diffusion-inpainting",
    batch_size = 16,
    start_index = 0,
    num_generations_per_image = 1,
    seed = 42,
    guidance_scale = 3.0,
    num_inference_steps = 50,
    strength_inpaint = 0.970,
    strength_outpaint = 0.950,
    mask_fraction = 0.25,
    input_prompt = "An image of {}, a skin disease"
)

Loading model


Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

An error occurred while trying to fetch /home/thb286/.cache/huggingface/hub/models--runwayml--stable-diffusion-inpainting/snapshots/8a4288a76071f7280aedbdb3253bdb9e9d5d84bb/vae: Error no file named diffusion_pytorch_model.safetensors found in directory /home/thb286/.cache/huggingface/hub/models--runwayml--stable-diffusion-inpainting/snapshots/8a4288a76071f7280aedbdb3253bdb9e9d5d84bb/vae.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
An error occurred while trying to fetch /home/thb286/.cache/huggingface/hub/models--runwayml--stable-diffusion-inpainting/snapshots/8a4288a76071f7280aedbdb3253bdb9e9d5d84bb/unet: Error no file named diffusion_pytorch_model.safetensors found in directory /home/thb286/.cache/huggingface/hub/models--runwayml--stable-diffusion-inpainting/snapshots/8a4288a76071f7280aedbdb3253bdb9e9d5d84bb/unet.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.


Output()

Loaded pipeline with 859_535_364 unet parameters


  for batch_idx, batch in enumerate(tqdm(dataloader)):


KeyboardInterrupt: 

In [None]:
# Create a synthetic train/test split to train the final classifier

In [None]:
# Train the classifier

In [None]:
# Evaluate the performance of the classifier