In [1]:
import pandas as pd
import torch
from fastai.vision.all import *
from sklearn.model_selection import train_test_split
import wandb
from fastai.callback.wandb import *
from statsmodels.stats.proportion import proportion_confint
from itertools import product
import os

from datasets import load_dataset, Image
from synderm.splits.train_test_splitter import synthetic_train_val_split

In [22]:
# TODO: get rid of wandb logging
# TODO: create the huggingface dataset
# TODO: complete this demo, using the huggingface dataset and with some kind of visualization at the end

In [2]:
# Datasets
synthetic_derm = "lukemelas/synthetic-derm"  # 440 total images in each class using all_generations

In [3]:
# Load in the labels we are using for training data
metadata_train = pd.read_csv("/n/data1/hms/dbmi/manrai/derm/Fitzpatrick17k/fitzpatrick17k_10label_clean_training.csv")
top_n_labels = metadata_train["label"].value_counts().index[:9]
metadata_train = metadata_train[metadata_train["label"].isin(top_n_labels)].reset_index(drop=True)
metadata_train['location'] = 'Fitzpatrick17k/finalfitz17k/' + metadata_train['md5hash'] + '.jpg'
metadata_train["synthetic"] = False

# These are the labels we are using for testing
test_data = pd.read_csv("/n/data1/hms/dbmi/manrai/derm/Fitzpatrick17k/fitzpatrick17k_10label_clean_held_out_set.csv")
test_data = test_data[test_data["label"].isin(top_n_labels)].reset_index(drop=True)
test_data['location'] = 'Fitzpatrick17k/finalfitz17k/' + test_data['md5hash'] + '.jpg'
test_data['synthetic'] = False
test_data['is_valid'] = False

ids_train = set(metadata_train["md5hash"])
ids_test = set(test_data["md5hash"])

if ids_train.isdisjoint(ids_test):
    print("train/test mutually exclusive.")
else:
    print("train/test not mutually exclusive.")

train/test mutually exclusive.


## Train a model with 32 real training images per disease condition, including synthetic data

We will train a model with 32 real images per class, with and without synthetic images

In [4]:
# Set image directory and fastai path
image_dir = "/n/data1/hms/dbmi/manrai/derm/"
path = Path(image_dir)

# Set the generation folder -- using the finetuned model
generation_folder = "all_generations/finetune-inpaint/"
generation_type = "inpaint"


In [5]:
# Experiment parameters 
per_class_test_size = 40
n_real_per_class = 32
n_synthetic_per_real = 10

random_state = 111108
wandb_project = "n_real_per_disease_x"

In [6]:
# First, the dataset is duplicated n_synthetic_per_real times
df_synthetic = pd.concat([metadata_train]*n_synthetic_per_real, ignore_index=True)

# create a variable that represents the nth copy of the image
df_synthetic['n'] = df_synthetic.groupby('md5hash').cumcount()
df_synthetic['location'] = generation_folder + df_synthetic['label'].str.replace(' ', '-')  + '/' + generation_type +'/0' + df_synthetic['n'].astype(str) + '/' + df_synthetic['md5hash'] + '.png'
df_synthetic['synthetic'] = True
df_synthetic['Qc'] = ''

# drop the 'n' column
df_synthetic = df_synthetic.drop(columns=['n'])

In [7]:
train, val = synthetic_train_val_split(
    real_data = metadata_train, 
    synthetic_data = df_synthetic, 
    per_class_test_size = per_class_test_size,
    n_real_per_class = n_real_per_class,
    random_state = random_state,
    class_column = "label",
    mapping_real_to_synthetic = "md5hash"
    )

In [11]:
# Add 'is_valid' column
train['is_valid'] = False
val['is_valid'] = True

df = pd.concat([train, val]).reset_index(drop=True)

In [12]:
batch_size = 64

In [13]:
# Create a fastai dataloader
dls = ImageDataLoaders.from_df(df, 
                        path,
                        fn_col='location',
                        label_col='label',
                        valid_col='is_valid', 
                        bs=batch_size,
                        item_tfms=Resize(224),
                        batch_tfms=[])            

In [14]:
# Set config parameters for wandb
config = dict (
    architecture = "EfficientNet-V2-M",
    gen_folder = generation_folder,
    random_state = random_state,
    augmentation = "None", 
    n_training_per_label = n_real_per_class,
    include_synthetic = True,
    n_synthetic_per_real = n_synthetic_per_real,
    generation_type = generation_type
)

# set tags for wandb and
sample_tag = "n_real_per_label_" + str(n_real_per_class)
seed_tag = "seed_" + str(random_state)
include_synthetic_tag = "include_synthetic_" + str(True)
generation_type_tag = str(generation_type)

wandb.init(
    project=wandb_project,
    tags=[sample_tag, seed_tag, include_synthetic_tag, generation_type_tag],
    config=config,
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mtbu[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [15]:
learn = vision_learner(dls, 
                    arch=efficientnet_v2_m,
                    metrics=[error_rate, accuracy])



In [16]:
# fit with wandb callback
learn.fit(10, cbs=[WandbCallback(), EarlyStoppingCallback (monitor='valid_loss', min_delta=0.0, patience=3)])

epoch,train_loss,valid_loss,error_rate,accuracy,time
0,1.915981,2.912237,0.725,0.275,00:56
1,1.180606,2.913154,0.730556,0.269444,00:58
2,0.756596,3.069446,0.738889,0.261111,01:42
3,0.530889,3.091721,0.725,0.275,01:50


No improvement since epoch 0: early stopping


In [17]:
# predict on test data
test_dl = dls.test_dl(test_data)

# get predictions and probabilities for test set
preds, _ = learn.get_preds(dl=test_dl)

# get predicted labels for top-1 and top-3
top1_pred = torch.argmax(preds, dim=1)
top3_pred = torch.topk(preds, k=3, dim=1).indices

md5hashes = test_data['md5hash']

# get predicted labels and probabilities for top-1 and top-3
top1_prob, top1_label = torch.topk(preds, k=1, dim=1)
top3_prob, top3_label = torch.topk(preds, k=3, dim=1)

# convert tensor labels to class labels
top1_label = [learn.dls.vocab[i] for i in top1_label.squeeze()]
top3_label = [[learn.dls.vocab[j] for j in i] for i in top3_label]

# get true labels for test set
true_labels = test_data['label']

# calculate accuracy scores
top1_acc = (top1_label == true_labels).mean()
top3_acc = torch.zeros(len(true_labels))
for i in range(len(true_labels)):
    top3_acc[i] = true_labels[i] in top3_label[i]
top3_acc = top3_acc.mean()

# calculate upper and lower bounds for 95% confidence interval
top1_ci_lower, top1_ci_upper = proportion_confint(top1_acc*len(true_labels), len(true_labels), alpha=0.05, method='normal')
top3_ci_lower, top3_ci_upper = proportion_confint(top3_acc*len(true_labels), len(true_labels), alpha=0.05, method='normal')

# log accuracy scores to wandb
wandb.log({'top1_acc': top1_acc,
            'top1_ci_lower': top1_ci_lower,
            'top1_ci_upper': top1_ci_upper,
            'top3_acc': top3_acc,
            'top3_ci_lower': top3_ci_lower,
            'top3_ci_upper': top3_ci_upper})

# split up the top3 probabilities
top1_prob, top2_prob, top3_prob = torch.split(top3_prob, 1, dim=1)

# Convert the tensors to NumPy arrays
top1_prob_arr = top1_prob.numpy().flatten()
top2_prob_arr = top2_prob.numpy().flatten()
top3_prob_arr = top3_prob.numpy().flatten()

# split up the top3 labels to match 
top1_label = [sublist[0] for sublist in top3_label]
top2_label = [sublist[1] for sublist in top3_label]
top3_label = [sublist[2] for sublist in top3_label]

# create dataframe of predictions
df_pred = pd.DataFrame({
    'architecture' : "EfficientNet-V2-M",
    'random_state' : random_state,
    'augmentation' : "None",
    'gen_folder' : generation_folder,
    'generation_type' : generation_type,
    'n_training_per_label' : n_real_per_class,
    'n_synthetic_per_real' : n_synthetic_per_real,
    'include_synthetic' : True,
    'md5hash': md5hashes,
    'true_label': true_labels,
    'top1_label': top1_label,
    'top1_prob': top1_prob_arr,
    'top2_label': top2_label,
    'top2_prob': top2_prob_arr,
    'top3_label': top3_label
})

# log the test predictions
wandb.log({"test_predictions": wandb.Table(dataframe=df_pred)})

# Finish the run
wandb.finish()



0,1
accuracy,█▅▁█
epoch,▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇████
eps_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eps_1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
error_rate,▁▄█▁
lr_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr_1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mom_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mom_1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
raw_loss,█▆▅▄▆▃▅▄▅▃▄▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▂▂▁▁▁▁▁▂▁▂▂▁▁▁

0,1
accuracy,0.275
epoch,4.0
eps_0,1e-05
eps_1,1e-05
error_rate,0.725
lr_0,0.001
lr_1,0.001
mom_0,0.9
mom_1,0.9
raw_loss,0.30807


## Train a model with 32 real training images per disease condition, no synthetic images included

In [18]:
train_ns, val_ns = synthetic_train_val_split(
    real_data = metadata_train, 
    synthetic_data = None, 
    per_class_test_size = per_class_test_size,
    n_real_per_class = n_real_per_class,
    random_state = random_state,
    class_column = "label",
    mapping_real_to_synthetic = "md5hash"
    )

In [19]:
# Add 'is_valid' column
train_ns['is_valid'] = False
val_ns['is_valid'] = True

df_ns = pd.concat([train_ns, val_ns]).reset_index(drop=True)

# Create a fastai dataloader
dls = ImageDataLoaders.from_df(df_ns, 
                        path,
                        fn_col='location',
                        label_col='label',
                        valid_col='is_valid', 
                        bs=batch_size,
                        item_tfms=Resize(224),
                        batch_tfms=[])            

# Set config parameters for wandb
config = dict (
    architecture = "EfficientNet-V2-M",
    gen_folder = generation_folder,
    random_state = random_state,
    augmentation = "None", 
    n_training_per_label = n_real_per_class,
    include_synthetic = False,
    n_synthetic_per_real = n_synthetic_per_real,
    generation_type = generation_type
)

# set tags for wandb and
sample_tag = "n_real_per_label_" + str(n_real_per_class)
seed_tag = "seed_" + str(random_state)
include_synthetic_tag = "include_synthetic_" + str(True)
generation_type_tag = str(generation_type)

wandb.init(
    project=wandb_project,
    tags=[sample_tag, seed_tag, include_synthetic_tag, generation_type_tag],
    config=config,
)

learn = vision_learner(dls, 
                    arch=efficientnet_v2_m,
                    metrics=[error_rate, accuracy])

# fit with wandb callback
learn.fit(10, cbs=[WandbCallback(), EarlyStoppingCallback (monitor='valid_loss', min_delta=0.0, patience=3)])



epoch,train_loss,valid_loss,error_rate,accuracy,time
0,3.32379,2.822096,0.872222,0.127778,00:08
1,3.003263,2.617331,0.836111,0.163889,00:07
2,2.770729,2.589204,0.822222,0.177778,00:07
3,2.57337,2.703261,0.825,0.175,00:07
4,2.402319,2.874825,0.825,0.175,00:07
5,2.238272,2.92658,0.777778,0.222222,00:07


No improvement since epoch 2: early stopping


In [21]:
# predict on test data
test_dl = dls.test_dl(test_data)

# get predictions and probabilities for test set
preds, _ = learn.get_preds(dl=test_dl)

# get predicted labels for top-1 and top-3
top1_pred = torch.argmax(preds, dim=1)
top3_pred = torch.topk(preds, k=3, dim=1).indices

md5hashes = test_data['md5hash']

# get predicted labels and probabilities for top-1 and top-3
top1_prob, top1_label = torch.topk(preds, k=1, dim=1)
top3_prob, top3_label = torch.topk(preds, k=3, dim=1)

# convert tensor labels to class labels
top1_label = [learn.dls.vocab[i] for i in top1_label.squeeze()]
top3_label = [[learn.dls.vocab[j] for j in i] for i in top3_label]

# get true labels for test set
true_labels = test_data['label']

# calculate accuracy scores
top1_acc = (top1_label == true_labels).mean()
top3_acc = torch.zeros(len(true_labels))
for i in range(len(true_labels)):
    top3_acc[i] = true_labels[i] in top3_label[i]
top3_acc = top3_acc.mean()

# calculate upper and lower bounds for 95% confidence interval
top1_ci_lower, top1_ci_upper = proportion_confint(top1_acc*len(true_labels), len(true_labels), alpha=0.05, method='normal')
top3_ci_lower, top3_ci_upper = proportion_confint(top3_acc*len(true_labels), len(true_labels), alpha=0.05, method='normal')

# log accuracy scores to wandb
wandb.log({'top1_acc': top1_acc,
            'top1_ci_lower': top1_ci_lower,
            'top1_ci_upper': top1_ci_upper,
            'top3_acc': top3_acc,
            'top3_ci_lower': top3_ci_lower,
            'top3_ci_upper': top3_ci_upper})

# split up the top3 probabilities
top1_prob, top2_prob, top3_prob = torch.split(top3_prob, 1, dim=1)

# Convert the tensors to NumPy arrays
top1_prob_arr = top1_prob.numpy().flatten()
top2_prob_arr = top2_prob.numpy().flatten()
top3_prob_arr = top3_prob.numpy().flatten()

# split up the top3 labels to match 
top1_label = [sublist[0] for sublist in top3_label]
top2_label = [sublist[1] for sublist in top3_label]
top3_label = [sublist[2] for sublist in top3_label]

# create dataframe of predictions
df_pred = pd.DataFrame({
    'architecture' : "EfficientNet-V2-M",
    'random_state' : random_state,
    'augmentation' : "None",
    'gen_folder' : generation_folder,
    'generation_type' : generation_type,
    'n_training_per_label' : n_real_per_class,
    'n_synthetic_per_real' : n_synthetic_per_real,
    'include_synthetic' : True,
    'md5hash': md5hashes,
    'true_label': true_labels,
    'top1_label': top1_label,
    'top1_prob': top1_prob_arr,
    'top2_label': top2_label,
    'top2_prob': top2_prob_arr,
    'top3_label': top3_label
})

# log the test predictions
wandb.log({"test_predictions": wandb.Table(dataframe=df_pred)})

# Finish the run
wandb.finish()



Error: You must call wandb.init() before wandb.log()

## Run the complete experiment

In [None]:
# The entire experiment can be run using this script, although this will take a while to run
!python skin_classification_with_augmentation.py \    
    --dataset hugginface_repo \ 
    --n_real_per_label_list "[1, 8, 16, 32, 64, 128, 228]" \
    --max_batch_size 32 \
    --arg2 value2

In [None]:
# Now we can plot the data from the complete experiment
# We can provide the user code to reproduce this, but save the trial data so that user can still plot our results if they are unable to run the model