In [None]:
import os
import shutil
import glob
import torch
from safetensors import safe_open
from safetensors.torch import save_file
import pandas as pd

### this notebook provides an example corresponding to Table.3 in the paper.

SEED = 0
SPLIT = "light_and_dark_flex_to_dark"

### Process dataset

In [None]:
df_all = pd.read_csv('../data_splits/Fitz_subset.csv')

### sample flexible subset
skin_type_list = [1,2,5,6]
df = df_all[
    (df_all['fitzpatrick_scale'].isin(skin_type_list))
].copy()

### for each condition sample 8 images FST = [5,6] respectively. 
df_flex_dark = (
    df[df['fitzpatrick_scale'].isin([5, 6])]
      .groupby(['label'], group_keys=False)
      .sample(n=8, random_state=SEED, replace=False)
)

df_flex_light = (
    df[df['fitzpatrick_scale'].isin([1, 2])]
      .groupby(['label'], group_keys=False)
      .sample(n=8, random_state=SEED, replace=False)
)

### remove flexible subsets 
df_non_flex = df[~df['md5hash'].isin(df_flex_dark['md5hash']) &
                 ~df['md5hash'].isin(df_flex_light['md5hash'])]

df_train = df_non_flex[df_non_flex['fitzpatrick_scale'].isin([1, 2])]
df_train = pd.concat([df_train, df_flex_dark], ignore_index=True)

df_test = df_non_flex[df_non_flex['fitzpatrick_scale'].isin([5, 6])] 

### save files
df_train.to_csv(f'../data_splits/train_{SPLIT}_seed={SEED}.csv')
df_test.to_csv(f'../data_splits/test_{SPLIT}_seed={SEED}.csv')


### quick check
print(df_train.shape, df_test.shape)
print(df_train['fitzpatrick_scale'].value_counts())
print(df_test['fitzpatrick_scale'].value_counts())

### Textual Inversion

In [None]:
condition_list = ["basal_cell_carcinoma",  
                  "folliculitis",
                  "nematode_infection",
                  "neutrophilic_dermatoses",
                  "prurigo_nodularis", 
                  "psoriasis", 
                  "squamous_cell_carcinoma"]

MODEL_NAME="stabilityai/stable-diffusion-2-1-base"
TRAIN_DIR="/data/derm_data/Fitzpatrick17k/finalfitz17k/" # image folder
TRAIN_SPLIT=f"../data_splits/train_{SPLIT}_seed={SEED}.csv"

In [None]:
### Textual Inversion 

for c in condition_list:
    OUTPUT_DIR=f"../models/textual_inversion/{SPLIT}_seed={SEED}/{c}" 
    TOKEN=f'{c[:3]}-class'

    !accelerate launch ../scripts/textual_inversion.py \
        --pretrained_model_name_or_path=$MODEL_NAME \
        --train_data_dir=$TRAIN_DIR \
        --fitz_split_csv=$TRAIN_SPLIT \
        --learnable_property="object" \
        --placeholder_token=$TOKEN \
        --initializer_token="skin" \
        --resolution=512 \
        --train_batch_size=4 \
        --gradient_accumulation_steps=4 \
        --mixed_precision="fp16" \
        --max_train_steps=500 \
        --learning_rate=5.0e-04 \
        --scale_lr \
        --lr_scheduler="constant" \
        --lr_warmup_steps=0 \
        --output_dir=$OUTPUT_DIR \
        --class_name=$c \
        --repeats=10

In [None]:
### Merge all the learned embeddings into a single file

TI_OUTPUT_DIR = f"../models/textual_inversion/{SPLIT}_seed={SEED}"
path = f"{TI_OUTPUT_DIR}/*/learned_embeds.safetensors"
merged_dict = dict()
for file in glob.glob(path):
    tensors = {}
    with safe_open(file, framework="pt", device="cpu") as f:
        for key in f.keys():
            tensors[key] = f.get_tensor(key)
        merged_dict.update(tensors)

TI_EMBED_PATH = f"{TI_OUTPUT_DIR}/aggregated_embeds_SEED={SEED}.pt"
os.makedirs(os.path.dirname(TI_EMBED_PATH), exist_ok=True)
torch.save(merged_dict, TI_EMBED_PATH)
print(merged_dict.keys())

### LoRA Training

In [None]:
import json
import os

### let's create a json file to manage the training data

token_mapper ={
    "basal cell carcinoma": "bas-class",
    "folliculitis": "fol-class",
    "nematode infection": "nem-class",
    "neutrophilic dermatoses": "neu-class",
    "prurigo nodularis": "pru-class",
    "psoriasis": "pso-class",
    "squamous cell carcinoma": "squ-class",
}

skin_type_mapper = {
    1: 'a very light-skinned',
    2: 'a light-skinned',
    5: 'a dark-skinned',
    6: 'a very dark-skinned',
}

output_to_json = []
for i, row in df_train.iterrows():
    image_path = row['md5hash'] + '.jpg'
    label = row['label']
    fst = row['fitzpatrick_scale']
    disease_token = token_mapper[label]
    skin_type = skin_type_mapper[fst]
    prompt = f"An image of {disease_token} on the skin of {skin_type} individual"
    output_to_json.append({
        'image_path': image_path,
        'label': label,
        'skin_type': fst,
        'prompt': prompt
    })

with open(f"../data_splits/train_lora_{SPLIT}_seed={SEED}.json", "w") as f:
        json.dump(output_to_json, f, indent=4)
    

In [None]:
### for lora training, we use a json file to manage the training data with images and prompts
### note that, here, we should put the learned lesion token names derived from textual inversion
### in the prompts, instead of the original disease names. Feel free to try different rank sizes.

LORA_OUTPUT_DIR=f"../models/lora_weights/{SPLIT}_seed={SEED}"
json_path = f"../data_splits/train_lora_{SPLIT}_seed={SEED}.json"

!accelerate launch --mixed_precision="fp16"  ../scripts/train_text_to_image_lora.py \
    --pretrained_model_name_or_path=$MODEL_NAME \
    --train_data_dir=$TRAIN_DIR \
    --dataloader_num_workers=8 \
    --image_column="image" \
    --caption_column="prompt" \
    --resolution=512 \
    --train_batch_size=4 \
    --gradient_accumulation_steps=1 \
    --max_train_steps=3000 \
    --learning_rate=5e-06 \
    --max_grad_norm=1 \
    --lr_scheduler="constant" \
    --lr_warmup_steps=0 \
    --output_dir=$LORA_OUTPUT_DIR \
    --checkpointing_steps=1000 \
    --rank=8 \
    --embed_path=$TI_EMBED_PATH \
    --json_path=$json_path