# Training a Dermatology Image Classifier with Synthetic Data

This notebook demonstrates how to train an image classifier for dermatological conditions using both real and synthetic images. We'll explore how synthetic data can improve model performance, especially in scenarios with limited real training data.

## Step 1: Setup and Imports
First, let's import all necessary libraries:

Use `pip install -e .` to install the synderm package 

In [1]:
from synderm.splits.train_test_splitter import synthetic_train_val_split
from statsmodels.stats.proportion import proportion_confint
from datasets import load_dataset, Image
import pandas as pd
import torch
from fastai.vision.all import *
from sklearn.model_selection import train_test_split
import wandb
from fastai.callback.wandb import *
from statsmodels.stats.proportion import proportion_confint
from itertools import product
import os

## Step 2: Data Loading and Preparation

We'll use the Fitzpatrick17k dataset, focusing on the top 9 most common skin conditions:

We will:
1. Load training data from Fitzpatrick17k dataset
2. Select the top 9 most frequent skin conditions
3. Set up proper file paths
4. Verify train/test set separation

In [2]:
# Set image directory and fastai path
image_dir = "/n/data1/hms/dbmi/manrai/derm/"
path = Path(image_dir)

# Set the generation folder, this is where images are stored
generation_folder = "all_generations/finetune-inpaint/"
generation_type = "inpaint"

# Load in the training data
metadata_train = pd.read_csv("/n/data1/hms/dbmi/manrai/derm/Fitzpatrick17k/fitzpatrick17k_10label_clean_training.csv")
top_n_labels = metadata_train["label"].value_counts().index[:9]
metadata_train = metadata_train[metadata_train["label"].isin(top_n_labels)].reset_index(drop=True)
metadata_train['location'] = 'Fitzpatrick17k/finalfitz17k/' + metadata_train['md5hash'] + '.jpg'
metadata_train["synthetic"] = False

# Load in testing data
test_data = pd.read_csv("/n/data1/hms/dbmi/manrai/derm/Fitzpatrick17k/fitzpatrick17k_10label_clean_held_out_set.csv")
test_data = test_data[test_data["label"].isin(top_n_labels)].reset_index(drop=True)
test_data['location'] = 'Fitzpatrick17k/finalfitz17k/' + test_data['md5hash'] + '.jpg'
test_data['synthetic'] = False
test_data['is_valid'] = False

ids_train = set(metadata_train["md5hash"])
ids_test = set(test_data["md5hash"])

if ids_train.isdisjoint(ids_test):
    print("train/test mutually exclusive.")
else:
    print("train/test not mutually exclusive.")

train/test mutually exclusive.


## Step 3: Experiment Parameters

Define key parameters for our experiment:
- `per_class_test_size`: Number of test images per class (40)
- `n_real_per_class`: Number of real training images per class (32)
- `n_synthetic_per_real`: Number of synthetic images generated per real image (10)

In [3]:
# Experiment parameters 
per_class_test_size = 40
n_real_per_class = 32
n_synthetic_per_real = 10
random_state = 111108

## Step 4: Synthetic Data Generation

Create our synthetic dataset by:
1. Duplicating the real dataset
2. Assigning unique identifiers
3. Defining the paths to the synthetic images

In [4]:
# First, the dataset is duplicated n_synthetic_per_real times
df_synthetic = pd.concat([metadata_train]*n_synthetic_per_real, ignore_index=True)

# create a variable that represents the nth copy of the image
df_synthetic['n'] = df_synthetic.groupby('md5hash').cumcount()
df_synthetic['location'] = generation_folder + df_synthetic['label'].str.replace(' ', '-')  + '/' + generation_type +'/0' + df_synthetic['n'].astype(str) + '/' + df_synthetic['md5hash'] + '.png'
df_synthetic['synthetic'] = True
df_synthetic['Qc'] = ''

# drop the 'n' column
df_synthetic = df_synthetic.drop(columns=['n'])

## Step 5: Model Training with Synthetic Data

Now we'll train our first model using both real and synthetic images:

In [5]:
train, val = synthetic_train_val_split(
    real_data = metadata_train, 
    synthetic_data = df_synthetic, 
    per_class_test_size = per_class_test_size,
    n_real_per_class = n_real_per_class,
    random_state = random_state,
    class_column = "label",
    mapping_real_to_synthetic = "md5hash"
    )

# Add 'is_valid' column
train['is_valid'] = False
val['is_valid'] = True

df = pd.concat([train, val]).reset_index(drop=True)

# adjust batch size based on number of images
if (len(df[df.is_valid == False])/10 >= 100):
    batch_size = 64
elif (len(df[df.is_valid == False])/10 >= 10):
    batch_size = 32
else:
    batch_size = 8

The model uses:
- EfficientNetV2-M architecture
- Early stopping to prevent overfitting
- Dynamic batch sizing based on dataset size

In [6]:
# Create a fastai dataloader
dls = ImageDataLoaders.from_df(df, 
                        path,
                        fn_col='location',
                        label_col='label',
                        valid_col='is_valid', 
                        bs=64,
                        item_tfms=Resize(224),
                        batch_tfms=[])            

# Create the learner
learn = vision_learner(
    dls,
    arch=efficientnet_v2_m,
    metrics=[error_rate, accuracy]
)

# Fit without wandb callback
learn.fit(10, cbs=[EarlyStoppingCallback(monitor='valid_loss', min_delta=0.0, patience=3)])



epoch,train_loss,valid_loss,error_rate,accuracy,time
0,1.959998,2.758843,0.716667,0.283333,00:45
1,1.18279,2.90002,0.705556,0.294444,00:57
2,0.739314,3.136784,0.725,0.275,00:44
3,0.518641,3.127845,0.708333,0.291667,00:54


No improvement since epoch 0: early stopping


## Step 6: Model Evaluation (with Synthetic Data)

In [7]:
# Predict on test data
test_dl = dls.test_dl(test_data)

# Get predictions and probabilities for test set
preds, _ = learn.get_preds(dl=test_dl)

# Get top-3 probabilities and labels
top3_prob, top3_label = torch.topk(preds, k=3, dim=1)

# Convert top3_label indices to class labels
top3_label = [[learn.dls.vocab[idx] for idx in indices] for indices in top3_label]

# Get true labels for test set
true_labels = test_data['label'].reset_index(drop=True)

# Calculate top-1 accuracy
top1_label = [labels[0] for labels in top3_label]
top1_acc = np.mean(np.array(top1_label) == np.array(true_labels))

# Calculate top-3 accuracy
top3_acc = np.mean([
    true_labels.iloc[i] in top3_label[i]
    for i in range(len(true_labels))
])

top1_ci_lower, top1_ci_upper = proportion_confint(
    count=top1_acc * len(true_labels),
    nobs=len(true_labels),
    alpha=0.05,
    method='normal'
)
top3_ci_lower, top3_ci_upper = proportion_confint(
    count=top3_acc * len(true_labels),
    nobs=len(true_labels),
    alpha=0.05,
    method='normal'
)

# Print accuracy scores
print("Accuracy of the model including synthetic images: ")
print(f'Top-1 Accuracy: {top1_acc}')
print(f'Top-1 95% CI: {top1_ci_lower} - {top1_ci_upper}')
print(f'Top-3 Accuracy: {top3_acc}')
print(f'Top-3 95% CI: {top3_ci_lower} - {top3_ci_upper}')

# Extract top-1, top-2, and top-3 probabilities
top1_prob_arr = top3_prob[:, 0].numpy()
top2_prob_arr = top3_prob[:, 1].numpy()
top3_prob_arr = top3_prob[:, 2].numpy()

# Extract top-1, top-2, and top-3 labels
top1_label = [labels[0] for labels in top3_label]
top2_label = [labels[1] for labels in top3_label]
top3_label = [labels[2] for labels in top3_label]

# Get md5hashes
md5hashes = test_data['md5hash'].reset_index(drop=True)

# Create dataframe of predictions
df_pred = pd.DataFrame({
    'architecture': "EfficientNet-V2-M",
    'random_state': random_state,
    'augmentation': "None",
    'gen_folder': generation_folder,
    'generation_type': generation_type,
    'n_training_per_label': n_real_per_class,
    'n_synthetic_per_real': n_synthetic_per_real,
    'include_synthetic': True,
    'md5hash': md5hashes,
    'true_label': true_labels,
    'top1_label': top1_label,
    'top1_prob': top1_prob_arr,
    'top2_label': top2_label,
    'top2_prob': top2_prob_arr,
    'top3_label': top3_label,
    'top3_prob': top3_prob_arr
})

Accuracy of the model including synthetic images: 
Top-1 Accuracy: 0.24166666666666667
Top-1 95% CI: 0.19744498136010485 - 0.2858883519732285
Top-3 Accuracy: 0.5277777777777778
Top-3 95% CI: 0.47620795950042477 - 0.5793475960551309


## Step 7: Baseline Model (No Synthetic Data)

For comparison, we'll train and evaluate a model using only the real images:

In [8]:
train_ns, val_ns = synthetic_train_val_split(
    real_data = metadata_train, 
    synthetic_data = None, 
    per_class_test_size = per_class_test_size,
    n_real_per_class = n_real_per_class,
    random_state = random_state,
    class_column = "label",
    mapping_real_to_synthetic = "md5hash"
    )

# Add 'is_valid' column
train_ns['is_valid'] = False
val_ns['is_valid'] = True

df_ns = pd.concat([train_ns, val_ns]).reset_index(drop=True)

# adjust batch size based on number of images
if (len(df_ns[df_ns.is_valid == False])/10 >= 100):
    batch_size = 64
elif (len(df_ns[df_ns.is_valid == False])/10 >= 10):
    batch_size = 32
else:
    batch_size = 8

In [9]:
# Create a fastai dataloader
dls_ns = ImageDataLoaders.from_df(df_ns, 
                        path,
                        fn_col='location',
                        label_col='label',
                        valid_col='is_valid', 
                        bs=64,
                        item_tfms=Resize(224),
                        batch_tfms=[])            

# Create the learner
learn = vision_learner(
    dls_ns,
    arch=efficientnet_v2_m,
    metrics=[error_rate, accuracy]
)

# Fit without wandb callback
learn.fit(10, cbs=[EarlyStoppingCallback(monitor='valid_loss', min_delta=0.0, patience=3)])



epoch,train_loss,valid_loss,error_rate,accuracy,time
0,3.826281,2.661138,0.827778,0.172222,00:05
1,3.263444,2.559533,0.816667,0.183333,00:05
2,2.932132,2.562016,0.825,0.175,00:05
3,2.645862,2.635431,0.797222,0.202778,00:05
4,2.453256,2.700307,0.777778,0.222222,00:04


No improvement since epoch 1: early stopping


In [11]:
# Predict on test data
test_dl = dls.test_dl(test_data)

# Get predictions and probabilities for test set
preds, _ = learn.get_preds(dl=test_dl)

# Get top-3 probabilities and labels
top3_prob, top3_label = torch.topk(preds, k=3, dim=1)

# Convert top3_label indices to class labels
top3_label = [[learn.dls.vocab[idx] for idx in indices] for indices in top3_label]

# Get true labels for test set
true_labels = test_data['label'].reset_index(drop=True)

# Calculate top-1 accuracy
top1_label = [labels[0] for labels in top3_label]
top1_acc = np.mean(np.array(top1_label) == np.array(true_labels))

# Calculate top-3 accuracy
top3_acc = np.mean([
    true_labels.iloc[i] in top3_label[i]
    for i in range(len(true_labels))
])

top1_ci_lower, top1_ci_upper = proportion_confint(
    count=top1_acc * len(true_labels),
    nobs=len(true_labels),
    alpha=0.05,
    method='normal'
)
top3_ci_lower, top3_ci_upper = proportion_confint(
    count=top3_acc * len(true_labels),
    nobs=len(true_labels),
    alpha=0.05,
    method='normal'
)

# Print accuracy scores
print("Accuracy of the model (no synthetic images): ")
print(f'Top-1 Accuracy: {top1_acc}')
print(f'Top-1 95% CI: {top1_ci_lower} - {top1_ci_upper}')
print(f'Top-3 Accuracy: {top3_acc}')
print(f'Top-3 95% CI: {top3_ci_lower} - {top3_ci_upper}')

# Extract top-1, top-2, and top-3 probabilities
top1_prob_arr = top3_prob[:, 0].numpy()
top2_prob_arr = top3_prob[:, 1].numpy()
top3_prob_arr = top3_prob[:, 2].numpy()

# Extract top-1, top-2, and top-3 labels
top1_label = [labels[0] for labels in top3_label]
top2_label = [labels[1] for labels in top3_label]
top3_label = [labels[2] for labels in top3_label]

# Get md5hashes
md5hashes = test_data['md5hash'].reset_index(drop=True)

# Create dataframe of predictions
df_pred = pd.DataFrame({
    'architecture': "EfficientNet-V2-M",
    'random_state': random_state,
    'augmentation': "None",
    'gen_folder': generation_folder,
    'generation_type': generation_type,
    'n_training_per_label': n_real_per_class,
    'n_synthetic_per_real': n_synthetic_per_real,
    'include_synthetic': True,
    'md5hash': md5hashes,
    'true_label': true_labels,
    'top1_label': top1_label,
    'top1_prob': top1_prob_arr,
    'top2_label': top2_label,
    'top2_prob': top2_prob_arr,
    'top3_label': top3_label,
    'top3_prob': top3_prob_arr
})

Accuracy of the model (no synthetic images): 
Top-1 Accuracy: 0.18888888888888888
Top-1 95% CI: 0.14845549263775715 - 0.22932228514002062
Top-3 Accuracy: 0.49166666666666664
Top-3 95% CI: 0.4400242546931338 - 0.5433090786401995


## Run the complete experiment

In [None]:
# The entire experiment can be run using this script, although this will take a while to run
!python skin_classification_with_augmentation.py \    
    --dataset hugginface_repo \ 
    --n_real_per_label_list "[1, 8, 16, 32, 64, 128, 228]" \
    --max_batch_size 32 \
    --arg2 value2