In [1]:
# Verify environment setup
import sys
print(f"Python executable: {sys.executable}")
print(f"Virtual environment: {'.venv' in sys.executable}")

# Test synderm import
try:
    import synderm
    print("✓ synderm package is available")
    print(f"  Package location: {synderm.__file__}")
except ImportError as e:
    print(f"✗ synderm import failed: {e}")


Python executable: /workspace/synthetic-derm/.venv/bin/python
Virtual environment: True


  from .autonotebook import tqdm as notebook_tqdm


✓ synderm package is available
  Package location: /workspace/synthetic-derm/synderm/__init__.py


# Vignette: Augmenting Your Classifier with Synthetic Images

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from synderm.generation.generate import generate_synthetic_dataset
from synderm.utils.utils import synthetic_train_val_split
from webdataset import WebDataset, RandomMix
from huggingface_hub import get_token
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from huggingface_hub import HfApi
import matplotlib.pyplot as plt
import webdataset as wds
from pathlib import Path
from PIL import Image
import pandas as pd
import random
import os
import json
import io
import re

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set path to root directory of package
%cd ../../../

/workspace/synthetic-derm


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


## Introduction

This notebook will demonstrate how to augment a small dermatology dataset with our large collection of synthetic images. We will start by loading in a sample dermatology dataset. These are also synthetic images, but we are pretending they are real images for the purposes of this vignette. You should replace this dataset with your own dataset (adjusting the labels/format as necessary).

After we load in these images, we will select the desired labels from the synthetic-derm training dataset hosted on [HuggingFace](https://huggingface.co/datasets/tbuckley/synthetic-derm-1M-train). We will then mix these in with our real images, and use a subset of images for validation.

## 1. Load your dataset

First, create a Torch dataset based on the structure of your data. We provide a sample dataset called "sample_derm_dataset," with a folder for "train" and "val." Each folder is organized into subfolders for each label (similar to ImageNet). For use with this package, it is standard to return dictionary entries containing a "label" and "image" (PIL) field.

In [4]:
class SampleDataset(Dataset):
    def __init__(self, dataset_dir, split="train"):
        self.dataset_dir = Path(dataset_dir)
        self.image_paths = []
        self.labels = []
        self.split = split

        # Walk through class folders
        data_dir = self.dataset_dir / self.split
        for class_name in os.listdir(data_dir):
            class_dir = data_dir / class_name
            if not class_dir.is_dir():
                continue
                
            # Get all png images in this class folder
            for img_name in os.listdir(class_dir):
                if img_name.lower().endswith('.png'):
                    self.image_paths.append(class_dir / img_name)
                    self.labels.append(class_name)
                    
        # Shuffle the dataset
        indices = list(range(len(self.image_paths)))
        random.shuffle(indices)
        self.image_paths = [self.image_paths[i] for i in indices]
        self.labels = [self.labels[i] for i in indices]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        
        # Load and convert image to RGB
        image = Image.open(image_path).convert('RGB')
        image_name = image_path.stem

        return {"id": image_name, "image": image, "label": label}


In [5]:
train_data = SampleDataset("sample_derm_dataset")
test_data = SampleDataset("sample_derm_dataset", split="val")

In [6]:
# Print a sample entry
for item in train_data:
    print(item)
    break

{'id': '0064', 'image': <PIL.Image.Image image mode=RGB size=512x512 at 0x7558001073A0>, 'label': 'squamous-cell-carcinoma'}


## 2. Loading the synthetic images from HuggingFace

We will now load synthetic images using the train version of the dataset hosted on HuggingFace. This dataset contains 1 million images, seperated into four generation methods: finetune-inpaint, finetune-text-to-image, pretrained-inpaint, and pretrained-text-to-image. These have already been shuffled, then broken into shards. This is ideal for training as no reshuffling needs to be done, and shards only need to be loaded one at a time (saving lots of memory).

Based on the results in our paper, images produced from finetune-text-to-image perform the best, and this is also the largest split in the dataset. So, we will select this split of the dataset, and all shards (133 is the last numbered shard, shards can be viewed at [this link](https://huggingface.co/datasets/tbuckley/synthetic-derm-1M-train/tree/main/data))

This is selected using the following URL:

In [7]:
# Select finetune-text-to-image shards
url = "https://huggingface.co/datasets/tbuckley/synthetic-derm-1M-train/resolve/main/data/shard-finetune-text-to-image-{00000..00133}.tar"

Now, we will list the labels we would like to include, and create a WebDataset pipeline to filter and format each entry as the dataset is iterated.

In [8]:
LABELS = [
    "allergic-contact-dermatitis",
    "basal-cell-carcinoma",
    "folliculitis",
    "lichen-planus",
    "lupus-erythematosus",
    "neutrophilic-dermatoses",
    "photodermatoses",
    "psoriasis",
    "sarcoidosis",
    "squamous-cell-carcinoma"
]

def to_dict(sample):
    return {
        "id": sample["json"]["md5hash"], 
        "image": sample["png"],
        "label": sample["json"]["label"]
        }

def select_label(sample):
    if sample["label"] in LABELS:
        return sample
    else:
        return None

# Create a WebDataset
synthetic_data = (
    wds.WebDataset(url, shardshuffle=True)
    .shuffle(40000)
    .decode("pil")
    .map(to_dict)
    .map(select_label)
)



## 3. Mixing the real and synthetic training images

Now, we need to somehow combine our real and synthetic images for model training. We can use the convenient `RandomMix` function from the WebDataset package. This function allows us to combine two Pytorch datasets and specify the sampling probabilitiy for each one. 

We are going to create a dataset with a 1.5 probability of sampling real data, and 1.0 for synthetic data. This will give us on average more real images than synthetic. **We encourage you to try different mixing ratios for the best performance with your data.**

In [9]:
mixed_dataset = RandomMix([train_data, synthetic_data], [1.5, 1.0]) 

## 4. Model training and validation

Finally, now that we have our dataset of real and synthetic images, we will train a Pytorch EfficientNet_V2_M model to classify these images. We will validate our model on the held-out set of real images.

In [10]:
unique_labels = sorted(set(LABELS))
label_to_idx = {label: i for i, label in enumerate(unique_labels)}
idx_to_label = {i: label for label, i in label_to_idx.items()}

def collate_fn(batch):
    tfms = transforms.Compose([
        transforms.Resize((224, 224)),  # Ensure all images have same size
        transforms.ToTensor(),
    ])

    images, labels = [], []

    for sample in batch:
        img = tfms(sample['image'])
        lbl = label_to_idx[sample['label']]
        images.append(img)
        labels.append(lbl)

    images = torch.stack(images, dim=0)
    labels = torch.tensor(labels)

    return images, labels

train_loader = torch.utils.data.DataLoader(mixed_dataset, batch_size=32, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(test_data, batch_size=32, collate_fn=collate_fn)

In [11]:
model = models.efficientnet_v2_m(weights=models.EfficientNet_V2_M_Weights.IMAGENET1K_V1)

num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 10)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

Downloading: "https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_v2_m-dc08266a.pth
100%|██████████| 208M/208M [00:47<00:00, 4.61MB/s] 


In [12]:
num_epochs = 10
for epoch in range(num_epochs):
    # Training
    model.train()
    train_loss = 0.0
    for i, data in enumerate(train_loader):
        imgs, lbls = data
        imgs, lbls = imgs.to(device), lbls.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, lbls)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        if i % 100 == 0:
            print(f'Epoch: {epoch+1}/{num_epochs}, Batch: {i+1}, Train Loss: {loss.item()}')

Epoch: 1/10, Batch: 1, Train Loss: 2.3465139865875244
Epoch: 2/10, Batch: 1, Train Loss: 1.8198271989822388
Epoch: 3/10, Batch: 1, Train Loss: 1.6083544492721558
Epoch: 4/10, Batch: 1, Train Loss: 1.3128741979599
Epoch: 5/10, Batch: 1, Train Loss: 0.9874697923660278
Epoch: 6/10, Batch: 1, Train Loss: 0.8346286416053772
Epoch: 7/10, Batch: 1, Train Loss: 0.7049002051353455
Epoch: 8/10, Batch: 1, Train Loss: 0.33139896392822266
Epoch: 9/10, Batch: 1, Train Loss: 0.6521413326263428
Epoch: 10/10, Batch: 1, Train Loss: 0.4039866030216217


In [13]:
# Evaluation
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np

def evaluate_model(model, data_loader, device, criterion):
    model.eval()
    all_predictions = []
    all_labels = []
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data in data_loader:
            imgs, lbls = data
            imgs, lbls = imgs.to(device), lbls.to(device)
            
            outputs = model(imgs)
            loss = criterion(outputs, lbls)
            running_loss += loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            total += lbls.size(0)
            correct += (predicted == lbls).sum().item()
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(lbls.cpu().numpy())
    
    # Map numeric labels back to string labels
    all_predictions_str = [idx_to_label[pred] for pred in all_predictions]
    all_labels_str = [idx_to_label[lbl] for lbl in all_labels]
    
    # Calculate metrics
    loss = running_loss / len(data_loader)
    accuracy = 100 * correct / total
    
    # Generate detailed classification report
    report = classification_report(all_labels_str, all_predictions_str, digits=4)
    conf_matrix = confusion_matrix(all_labels_str, all_predictions_str)
    
    return loss, accuracy, report, conf_matrix

print("\nValidation Set Evaluation:")
val_loss, val_accuracy, val_report, val_conf_matrix = evaluate_model(model, val_loader, device, criterion)
print(f"Validation Loss: {val_loss:.4f}")
print(f"Validation Accuracy: {val_accuracy:.2f}%")
print("\nDetailed Validation Metrics:")
print(val_report)
print("\nValidation Confusion Matrix:")
print(val_conf_matrix)



Validation Set Evaluation:
Validation Loss: 2.9727
Validation Accuracy: 37.19%

Detailed Validation Metrics:
                             precision    recall  f1-score   support

allergic-contact-dermatitis     0.3600    0.2812    0.3158        32
       basal-cell-carcinoma     0.8750    0.2188    0.3500        32
               folliculitis     0.5517    0.5000    0.5246        32
              lichen-planus     0.2474    0.7500    0.3721        32
        lupus-erythematosus     0.3095    0.4062    0.3514        32
    neutrophilic-dermatoses     0.5294    0.2812    0.3673        32
            photodermatoses     0.3469    0.5312    0.4198        32
                  psoriasis     0.3333    0.0938    0.1463        32
                sarcoidosis     0.3103    0.2812    0.2951        32
    squamous-cell-carcinoma     0.8000    0.3750    0.5106        32

                   accuracy                         0.3719       320
                  macro avg     0.4664    0.3719    0.3653  

Given that this is a sample dataset, this model appears to perform OK. We encourage you to use your own data, augmented with our large collection of synthetic images. For next steps, you can try training your model with and without data augmentation, trying different mixing ratios, and different models. Best of luck!