In [17]:
import pandas as pd
import torch
from fastai.vision.all import *
from sklearn.model_selection import train_test_split
import wandb
from fastai.callback.wandb import *
from statsmodels.stats.proportion import proportion_confint
from itertools import product
import os

In [21]:
# Datasets
synthetic_derm = "lukemelas/synthetic-derm"  # 440 total images in each class using all_generations
fitz17 = "n/data1/hms/dbmi/manrai/derm/Fitzpatrick17k/" # Replace this with your own data directory


#image_dir = "/n/data1/hms/dbmi/manrai/derm/"

In [30]:
# Load in the labels we are using for training data
# TODO: these should be stored in the huggingface repo
metadata_train = pd.read_csv("/n/data1/hms/dbmi/manrai/derm/Fitzpatrick17k/fitzpatrick17k_10label_clean_training.csv")
top_n_labels = metadata_train["label"].value_counts().index[:9]
metadata_train = metadata_train[metadata_train["label"].isin(top_n_labels)].reset_index(drop=True)
metadata_train["synthetic"] = False

# These are the labels we are using for testing
metadata_test = pd.read_csv("/n/data1/hms/dbmi/manrai/derm/Fitzpatrick17k/fitzpatrick17k_10label_clean_held_out_set.csv")
metadata_test = metadata_test[metadata_test["label"].isin(top_n_labels)].reset_index(drop=True)
metadata_test["synthetic"] = False

ids_train = set(metadata_train["md5hash"])
ids_test = set(metadata_test["md5hash"])

if ids_train.isdisjoint(ids_test):
    print("train/test mutually exclusive.")
else:
    print("train/test not mutually exclusive.")

train/test mutually exclusive.


## Train a model with 32 real training images per disease condition

We will train a model with 32 real images per class, with and without synthetic images

In [32]:
# Set image directory and fastai path
image_dir = "/n/data1/hms/dbmi/manrai/derm/"
path = Path(image_dir)

# Set the random seed
random_state = 111108

# Set the generation folder
generation_folder = "all_generations/finetune-inpaint/"
generation_type = "inpaint"

In [33]:
n_synthetic_per_real = 10

# First, the dataset is duplicated n_synthetic_per_real times
df = pd.concat([metadata_train]*n_synthetic_per_real, ignore_index=True)

# create a variable that represents the nth copy of the image
df['n'] = df.groupby('md5hash').cumcount()
df['location'] = generation_folder + df['label'].str.replace(' ', '-')  + '/' + generation_type +'/0' + df['n'].astype(str) + '/' + df['md5hash'] + '.png'
df['synthetic'] = True
df['Qc'] = ''

# drop the 'n' column
df = df.drop(columns=['n'])

In [37]:
df.head()

Unnamed: 0,md5hash,fitzpatrick_scale,fitzpatrick_centaur,label,nine_partition_label,three_partition_label,qc,url,url_alphanum,synthetic,location,Qc
0,fa2911a9b13b6f8af79cb700937cc14f,1,1,photodermatoses,inflammatory,non-neoplastic,,https://www.dermaamin.com/site/images/clinical-pic/p/photosensitivity/photosensitivity18.jpg,httpwwwdermaamincomsiteimagesclinicalpicpphotosensitivityphotosensitivity18jpg.jpg,True,all_generations/finetune-inpaint/photodermatoses/inpaint/00/fa2911a9b13b6f8af79cb700937cc14f.png,
1,0a94359e7eaacd7178e06b2823777789,1,1,psoriasis,inflammatory,non-neoplastic,,https://www.dermaamin.com/site/images/clinical-pic/p/psoriasis/psoriasis38.jpg,httpwwwdermaamincomsiteimagesclinicalpicppsoriasispsoriasis38jpg.jpg,True,all_generations/finetune-inpaint/psoriasis/inpaint/00/0a94359e7eaacd7178e06b2823777789.png,
2,a39ec3b1f22c08a421fa20535e037bba,1,1,psoriasis,inflammatory,non-neoplastic,,https://www.dermaamin.com/site/images/clinical-pic/p/psoriasis-scalp/psoriasis-scalp20.jpg,httpwwwdermaamincomsiteimagesclinicalpicppsoriasisscalppsoriasisscalp20jpg.jpg,True,all_generations/finetune-inpaint/psoriasis/inpaint/00/a39ec3b1f22c08a421fa20535e037bba.png,
3,6c395be9325dbb10e55497304b398253,2,2,neutrophilic dermatoses,inflammatory,non-neoplastic,,https://www.dermaamin.com/site/images/clinical-pic/s/sweet-syndrome/sweet-syndrome98.jpg,httpwwwdermaamincomsiteimagesclinicalpicssweetsyndromesweetsyndrome98jpg.jpg,True,all_generations/finetune-inpaint/neutrophilic-dermatoses/inpaint/00/6c395be9325dbb10e55497304b398253.png,
4,09d46db9589ff45436cda87c4abc946b,3,2,allergic contact dermatitis,inflammatory,non-neoplastic,,https://www.dermaamin.com/site/images/clinical-pic/a/allergic_contact_dermatitis/allergic_contact_dermatitis114.jpg,httpwwwdermaamincomsiteimagesclinicalpicaallergiccontactdermatitisallergiccontactdermatitis114jpg.jpg,True,all_generations/finetune-inpaint/allergic-contact-dermatitis/inpaint/00/09d46db9589ff45436cda87c4abc946b.png,


In [43]:
train, val = synthetic_train_val_split(
    real_data=metadata_train, 
    synthetic_data = df, 
    per_class_size=40,
    n_real_per_class = 32,
    random_state = random_state,
    class_column = "label",
    id_column = "md5hash"
    )

In [44]:
train

Unnamed: 0,md5hash,fitzpatrick_scale,fitzpatrick_centaur,label,nine_partition_label,three_partition_label,qc,url,url_alphanum,synthetic,location,Qc,is_valid
0,161e5aa485b14a77a2dee45d20052a19,2,1,allergic contact dermatitis,inflammatory,non-neoplastic,,https://www.dermaamin.com/site/images/clinical-pic/r/rhus_dermatitis/rhus_dermatitis21.jpg,httpwwwdermaamincomsiteimagesclinicalpicrrhusdermatitisrhusdermatitis21jpg.jpg,False,,,False
1,667743ac9cf8862ac3ffcc143a6825af,2,1,allergic contact dermatitis,inflammatory,non-neoplastic,,https://www.dermaamin.com/site/images/clinical-pic/a/allergic_contact_dermatitis/allergic_contact_dermatitis46.jpg,httpwwwdermaamincomsiteimagesclinicalpicaallergiccontactdermatitisallergiccontactdermatitis46jpg.jpg,False,,,False
2,47df3741647f00f3713af546424df866,2,1,allergic contact dermatitis,inflammatory,non-neoplastic,,https://www.dermaamin.com/site/images/clinical-pic/c/contact_dermatitis/contact_dermatitis14.jpg,httpwwwdermaamincomsiteimagesclinicalpicccontactdermatitiscontactdermatitis14jpg.jpg,False,,,False
3,dc4852bdb88ee7e126d9d2078cdb1612,1,1,allergic contact dermatitis,inflammatory,non-neoplastic,5 Potentially,https://www.dermaamin.com/site/images/clinical-pic/I/irritant-contact-dermatitis/irritant-contact-dermatitis21.jpg,httpwwwdermaamincomsiteimagesclinicalpicIirritantcontactdermatitisirritantcontactdermatitis21jpg.jpg,False,,,False
4,fd2e0985058e54e4ced9910e56f9f6b5,2,2,allergic contact dermatitis,inflammatory,non-neoplastic,,https://www.dermaamin.com/site/images/clinical-pic/a/allergic_contact_dermatitis/allergic_contact_dermatitis55.jpg,httpwwwdermaamincomsiteimagesclinicalpicaallergiccontactdermatitisallergiccontactdermatitis55jpg.jpg,False,,,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...
33673,de727a04c900af1fd7f413c121a18199,4,3,lichen planus,inflammatory,non-neoplastic,,http://atlasdermatologico.com.br/img?imageId=3744,httpwwwatlasdermatologicocombrimgimageId3744.jpg,True,all_generations/finetune-inpaint/lichen-planus/inpaint/09/de727a04c900af1fd7f413c121a18199.png,,False
33674,cb5b385f3f6d18339194aeb3955f30ac,6,6,lupus erythematosus,inflammatory,non-neoplastic,,http://atlasdermatologico.com.br/img?imageId=4099,httpwwwatlasdermatologicocombrimgimageId4099.jpg,True,all_generations/finetune-inpaint/lupus-erythematosus/inpaint/09/cb5b385f3f6d18339194aeb3955f30ac.png,,False
33675,218ea232ca4b0255a395886ee86cbada,3,2,basal cell carcinoma,malignant epidermal,malignant,,http://atlasdermatologico.com.br/img?imageId=586,httpwwwatlasdermatologicocombrimgimageId586.jpg,True,all_generations/finetune-inpaint/basal-cell-carcinoma/inpaint/09/218ea232ca4b0255a395886ee86cbada.png,,False
33676,642718dcaa658d3fb796ca2f07f6e6d2,3,3,neutrophilic dermatoses,inflammatory,non-neoplastic,,http://atlasdermatologico.com.br/img?imageId=6921,httpwwwatlasdermatologicocombrimgimageId6921.jpg,True,all_generations/finetune-inpaint/neutrophilic-dermatoses/inpaint/09/642718dcaa658d3fb796ca2f07f6e6d2.png,,False


In [40]:
val

Unnamed: 0,md5hash,fitzpatrick_scale,fitzpatrick_centaur,label,nine_partition_label,three_partition_label,qc,url,url_alphanum,synthetic,is_valid
0,e4eaa82b962f827f69009449d0a8cbc2,-1,1,allergic contact dermatitis,inflammatory,non-neoplastic,,https://www.dermaamin.com/site/images/clinical-pic/a/acute_contact_dermatitis/acute_contact_dermatitis21.jpg,httpwwwdermaamincomsiteimagesclinicalpicaacutecontactdermatitisacutecontactdermatitis21jpg.jpg,False,True
1,2f3ec777d8f2af1a89f86ef3cef7a409,2,1,allergic contact dermatitis,inflammatory,non-neoplastic,,https://www.dermaamin.com/site/images/clinical-pic/c/contact_dermatitis/contact_dermatitis79.jpg,httpwwwdermaamincomsiteimagesclinicalpicccontactdermatitiscontactdermatitis79jpg.jpg,False,True
2,76027292304f381e5567eae73242932c,2,1,allergic contact dermatitis,inflammatory,non-neoplastic,,https://www.dermaamin.com/site/images/clinical-pic/c/contact_dermatitis/contact_dermatitis90.jpg,httpwwwdermaamincomsiteimagesclinicalpicccontactdermatitiscontactdermatitis90jpg.jpg,False,True
3,b920dcbc6e05b75aaf5b4159b4682f36,2,1,allergic contact dermatitis,inflammatory,non-neoplastic,,https://www.dermaamin.com/site/images/clinical-pic/a/allergic_contact_dermatitis/allergic_contact_dermatitis58.jpg,httpwwwdermaamincomsiteimagesclinicalpicaallergiccontactdermatitisallergiccontactdermatitis58jpg.jpg,False,True
4,523cdfb89a2bc15e2663b0f584b59a51,3,3,allergic contact dermatitis,inflammatory,non-neoplastic,,https://www.dermaamin.com/site/images/clinical-pic/c/contact_dermatitis/contact_dermatitis54.jpg,httpwwwdermaamincomsiteimagesclinicalpicccontactdermatitiscontactdermatitis54jpg.jpg,False,True
...,...,...,...,...,...,...,...,...,...,...,...
355,09a75e2a4e9bbc303e6e0cb3e16a32d8,3,3,squamous cell carcinoma,malignant epidermal,malignant,,http://atlasdermatologico.com.br/img?imageId=8874,httpwwwatlasdermatologicocombrimgimageId8874.jpg,False,True
356,b89758afa9975102aa1f58322cf3c8fa,6,5,squamous cell carcinoma,malignant epidermal,malignant,,https://www.dermaamin.com/site/images/clinical-pic/k/keratoacanthoma/keratoacanthoma84.jpg,httpwwwdermaamincomsiteimagesclinicalpickkeratoacanthomakeratoacanthoma84jpg.jpg,False,True
357,fd356daecf2740e664f15b74711d0c4c,2,1,squamous cell carcinoma,malignant epidermal,malignant,,https://www.dermaamin.com/site/images/clinical-pic/e/erythroplasia_of_queyrat/erythroplasia_of_queyrat13.jpg,httpwwwdermaamincomsiteimagesclinicalpiceerythroplasiaofqueyraterythroplasiaofqueyrat13jpg.jpg,False,True
358,9748eb88f98e00ea924badf9e3aa89b2,-1,-1,squamous cell carcinoma,malignant epidermal,malignant,,https://www.dermaamin.com/site/images/clinical-pic/s/squamous-cell-carcinoma/squamous-cell-carcinoma18.jpg,httpwwwdermaamincomsiteimagesclinicalpicssquamouscellcarcinomasquamouscellcarcinoma18jpg.jpg,False,True


## Train a model with 32 real training images per disease condition

In [None]:
# We need to produce a train/test split

In [None]:

# Filter to the top 10 most common labels
top_n_labels = metadata['label'].value_counts().index[:9]
metadata = metadata[metadata['label'].isin(top_n_labels)].reset_index(drop=True)
metadata['location'] = 'Fitzpatrick17k/finalfitz17k/' + metadata['md5hash'] + '.jpg'
metadata['synthetic'] = False

# Take out 40 images from each label for testing
# test_data = pd.DataFrame(metadata.groupby(['label']).apply(lambda x: x.sample(n=40, random_state=random_state)).reset_index(drop=True))
test_data = pd.read_csv("../Metadata/fitzpatrick17k_10label_clean_held_out_set.csv")
test_data = test_data[test_data['label'].isin(top_n_labels)].reset_index(drop=True)
test_data['location'] = 'Fitzpatrick17k/finalfitz17k/' + test_data['md5hash'] + '.jpg'
test_data['synthetic'] = False
test_data['is_valid'] = False

# Remove the test data from the train/val data
train_val_data = pd.DataFrame(metadata[~metadata['md5hash'].isin(test_data['md5hash'])])

synthetic_metadata = generate_synthetic_metadata(train_val_data, generation_type='inpaint', n_synthetic_per_real=10, folder = generation_folder)

# Concatenate the synthetic data with the train/val data
train_val_data = pd.concat([train_val_data, synthetic_metadata]).reset_index(drop=True)

# Split the train/val data into train and val, with 40 images from each label in the val set
real_train_val_data = train_val_data[train_val_data['synthetic'] == False]
val_data = pd.DataFrame(real_train_val_data.groupby(['label']).apply(lambda x: x.sample(n=40, random_state=random_state, replace=False)).reset_index(drop=True))
train_data = pd.DataFrame(train_val_data[~train_val_data['md5hash'].isin(val_data['md5hash'])])

# Add an 'is_valid' column to the train and val data
train_data['is_valid'] = False
val_data['is_valid'] = True
metadata.groupby('label').size()
# show the number of real and synthetic images in the train data, val data, and test data
print("Number of real images in train data: ", len(train_data[train_data['synthetic'] == False]))
print("Number of synthetic images in train data: ", len(train_data[train_data['synthetic'] == True]))
print("Number of real images in val data: ", len(val_data[val_data['synthetic'] == False]))
print("Number of synthetic images in val data: ", len(val_data[val_data['synthetic'] == True]))
print("Number of real images in test data: ", len(test_data[test_data['synthetic'] == False]))
print("Number of synthetic images in test data: ", len(test_data[test_data['synthetic'] == True]))



In [None]:
# Train a model with 32 real training images per disease condition, plus an additional 10 synthetic images

In [None]:
# plot some data to compare the performance of these two trials

In [None]:
# The entire experiment can be run using this script, although this will take a while to run
!python skin_classification_with_augmentation.py \    
    --dataset hugginface_repo \ 
    --n_real_per_label_list "[1, 8, 16, 32, 64, 128, 228]" \
    --max_batch_size 32 \
    --arg2 value2

In [None]:
# Now we can plot the data from the complete experiment
# We can provide the user code to reproduce this, but save the trial data so that user can still plot our results if they are unable to run the model