In [2]:
import pandas as pd
import torch
from fastai.vision.all import *
from sklearn.model_selection import train_test_split
import wandb
from fastai.callback.wandb import *
from statsmodels.stats.proportion import proportion_confint
from itertools import product
import os

from datasets import load_dataset, Image
from synderm.splits.train_test_splitter import synthetic_train_val_split



In [3]:
# Datasets
synthetic_derm = "lukemelas/synthetic-derm"  # 440 total images in each class using all_generations

In [17]:
# Load in the labels we are using for training data
metadata_train = pd.read_csv("/n/data1/hms/dbmi/manrai/derm/Fitzpatrick17k/fitzpatrick17k_10label_clean_training.csv")
top_n_labels = metadata_train["label"].value_counts().index[:9]
metadata_train = metadata_train[metadata_train["label"].isin(top_n_labels)].reset_index(drop=True)
metadata_train['location'] = 'Fitzpatrick17k/finalfitz17k/' + metadata_train['md5hash'] + '.jpg'
metadata_train["synthetic"] = False

# These are the labels we are using for testing
metadata_test = pd.read_csv("/n/data1/hms/dbmi/manrai/derm/Fitzpatrick17k/fitzpatrick17k_10label_clean_held_out_set.csv")
metadata_test = metadata_test[metadata_test["label"].isin(top_n_labels)].reset_index(drop=True)
metadata_test["synthetic"] = False

ids_train = set(metadata_train["md5hash"])
ids_test = set(metadata_test["md5hash"])

if ids_train.isdisjoint(ids_test):
    print("train/test mutually exclusive.")
else:
    print("train/test not mutually exclusive.")

train/test mutually exclusive.


## Train a model with 32 real training images per disease condition

We will train a model with 32 real images per class, with and without synthetic images

In [18]:
# Set image directory and fastai path
image_dir = "/n/data1/hms/dbmi/manrai/derm/"
path = Path(image_dir)

# Set the random seed
random_state = 111108

# Set the generation folder
generation_folder = "all_generations/finetune-inpaint/"
generation_type = "inpaint"

In [19]:
n_synthetic_per_real = 10

# First, the dataset is duplicated n_synthetic_per_real times
df = pd.concat([metadata_train]*n_synthetic_per_real, ignore_index=True)

# create a variable that represents the nth copy of the image
df['n'] = df.groupby('md5hash').cumcount()
df['location'] = generation_folder + df['label'].str.replace(' ', '-')  + '/' + generation_type +'/0' + df['n'].astype(str) + '/' + df['md5hash'] + '.png'
df['synthetic'] = True
df['Qc'] = ''

# drop the 'n' column
df = df.drop(columns=['n'])

In [20]:
train, val = synthetic_train_val_split(
    real_data=metadata_train, 
    synthetic_data = df, 
    per_class_size=40,
    n_real_per_class = 32,
    random_state = random_state,
    class_column = "label",
    mapping_real_to_synthetic = "md5hash"
    )

In [21]:
# Add 'is_valid' column
train['is_valid'] = False
val['is_valid'] = True

df = pd.concat([train, val]).reset_index(drop=True)

In [13]:
# adjust batch size based on number of images
if (len(df[df.is_valid == False])/10 >= 100):
    batch_size = 64
elif (len(df[df.is_valid == False])/10 >= 10):
    batch_size = 32
else:
    batch_size = 8

In [None]:
# Create a fastai dataloader
dls = ImageDataLoaders.from_df(df, 
                        path,
                        fn_col='location',
                        label_col='label',
                        valid_col='is_valid', 
                        bs=batch_size,
                        item_tfms=Resize(224),
                        batch_tfms=augmentation_dict[augmentation])            

In [None]:
# Set config parameters for wandb
config = dict (
    architecture = "EfficientNet-V2-M",
    gen_folder = generation_folder,
    random_state = random_state,
    augmentation = augmentation, 
    n_training_per_label = n_per_label,
    include_synthetic = include_synthetic,
    n_synthetic_per_real = n_synthetic_per_real,
    generation_type = generation_type
)

# set tags for wandb and
sample_tag = "n_real_per_label_" + str(n_per_label)
seed_tag = "seed_" + str(random_state)
include_synthetic_tag = "include_synthetic_" + str(include_synthetic)
generation_type_tag = str(generation_type)

wandb.init(
    project=wandb_project,
    tags=[sample_tag, seed_tag, include_synthetic_tag, generation_type_tag],
    config=config,
)



In [None]:
# fit with wandb callback
learn.fit(50, cbs=[WandbCallback(), EarlyStoppingCallback (monitor='valid_loss', min_delta=0.0, patience=3)])

In [None]:
# predict on test data
test_dl = dls.test_dl(test_data)
# get predictions and probabilities for test set
preds, _ = learn.get_preds(dl=test_dl)
# get predicted labels for top-1 and top-3
top1_pred = torch.argmax(preds, dim=1)
top3_pred = torch.topk(preds, k=3, dim=1).indices

md5hashes = test_data['md5hash']
# get predicted labels and probabilities for top-1 and top-3
top1_prob, top1_label = torch.topk(preds, k=1, dim=1)
top3_prob, top3_label = torch.topk(preds, k=3, dim=1)

# convert tensor labels to class labels
top1_label = [learn.dls.vocab[i] for i in top1_label.squeeze()]
top3_label = [[learn.dls.vocab[j] for j in i] for i in top3_label]

# get true labels for test set
true_labels = test_data['label']

# calculate accuracy scores
top1_acc = (top1_label == true_labels).mean()
top3_acc = torch.zeros(len(true_labels))
for i in range(len(true_labels)):
    top3_acc[i] = true_labels[i] in top3_label[i]
top3_acc = top3_acc.mean()

# calculate upper and lower bounds for 95% confidence interval
top1_ci_lower, top1_ci_upper = proportion_confint(top1_acc*len(true_labels), len(true_labels), alpha=0.05, method='normal')
top3_ci_lower, top3_ci_upper = proportion_confint(top3_acc*len(true_labels), len(true_labels), alpha=0.05, method='normal')

# log accuracy scores to wandb
wandb.log({'top1_acc': top1_acc,
            'top1_ci_lower': top1_ci_lower,
            'top1_ci_upper': top1_ci_upper,
            'top3_acc': top3_acc,
            'top3_ci_lower': top3_ci_lower,
            'top3_ci_upper': top3_ci_upper})

# split up the top3 probabilities
top1_prob, top2_prob, top3_prob = torch.split(top3_prob, 1, dim=1)

# Convert the tensors to NumPy arrays
top1_prob_arr = top1_prob.numpy().flatten()
top2_prob_arr = top2_prob.numpy().flatten()
top3_prob_arr = top3_prob.numpy().flatten()

# split up the top3 labels to match 
top1_label = [sublist[0] for sublist in top3_label]
top2_label = [sublist[1] for sublist in top3_label]
top3_label = [sublist[2] for sublist in top3_label]

# create dataframe of predictions
df_pred = pd.DataFrame({
    'architecture' : "EfficientNet-V2-M",
    'random_state' : random_state,
    'augmentation' : augmentation,
    'gen_folder' : generation_folder,
    'generation_type' : generation_type,
    'n_training_per_label' : n_per_label,
    'n_synthetic_per_real' : n_synthetic_per_real,
    'include_synthetic' : include_synthetic,
    'md5hash': md5hashes,
    'true_label': true_labels,
    'top1_label': top1_label,
    'top1_prob': top1_prob_arr,
    'top2_label': top2_label,
    'top2_prob': top2_prob_arr,
    'top3_label': top3_label
})

# log the test predictions
wandb.log({"test_predictions": wandb.Table(dataframe=df_pred)})

# Finish the run
wandb.finish()



## Train a model with 32 real training images per disease condition

In [None]:
# We need to produce a train/test split

In [None]:
# Train a model with 32 real training images per disease condition, plus an additional 10 synthetic images

In [None]:
# plot some data to compare the performance of these two trials

In [None]:
# The entire experiment can be run using this script, although this will take a while to run
!python skin_classification_with_augmentation.py \    
    --dataset hugginface_repo \ 
    --n_real_per_label_list "[1, 8, 16, 32, 64, 128, 228]" \
    --max_batch_size 32 \
    --arg2 value2

In [None]:
# Now we can plot the data from the complete experiment
# We can provide the user code to reproduce this, but save the trial data so that user can still plot our results if they are unable to run the model