In [None]:
### use Fitzpatrick17k as an example

condition_list = ['acne',
                'actinic_keratosis',
                'allergic_contact_dermatitis',
                'basal_cell_carcinoma',
                'eczema',
                'erythema_multiforme',
                'folliculitis',
                'granuloma_annulare',
                'keloid',
                'lichen_planus',
                'lupus_erythematosus',
                'melanoma',
                'mycosis_fungoides',
                'pityriasis_rosea',
                'prurigo_nodularis',
                'psoriasis',
                'sarcoidosis',
                'scabies',
                'squamous_cell_carcinoma',
                'vitiligo']

MODEL_NAME="stabilityai/stable-diffusion-2-1-base"
TRAIN_DIR="/data/derm_data/Fitzpatrick17k/finalfitz17k/" # image folder
TRAIN_SPLIT='../splits/Fitzpatrick17k/train.csv'

In [None]:
### Textual Inversion 

for c in condition_list:
    OUTPUT_DIR=f"../models/textual_inversion_weights/fitzpatrick17k/{c}" 
    TOKEN=f'xxx{c[:3]}xxx'

    !accelerate launch ../scripts/textual_inversion.py \
        --pretrained_model_name_or_path=$MODEL_NAME \
        --train_data_dir=$TRAIN_DIR \
        --train_data_split=$TRAIN_SPLIT \
        --learnable_property="object" \
        --placeholder_token=$TOKEN \
        --initializer_token="skin" \
        --resolution=512 \
        --train_batch_size=4 \
        --gradient_accumulation_steps=4 \
        --max_train_steps=500 \   
        --learning_rate=5.0e-04 \
        --scale_lr \
        --lr_scheduler="constant" \
        --lr_warmup_steps=0 \
        --output_dir=$OUTPUT_DIR \
        --class_name=$c \
        --repeats=1 \
        #   --push_to_hub

In [None]:
### Merge all the learned embeddings into a single torch model

import os
import shutil
import glob
import torch
from safetensors import safe_open
from safetensors.torch import save_file

TI_OUTPUT_DIR = "../models/textual_inversion_weights/fitzpatrick17k"
path = f"{TI_OUTPUT_DIR}/*/learned_embeds.safetensors"
merged_dict = dict()
for file in glob.glob(path):
    tensors = {}
    with safe_open(file, framework="pt", device="cpu") as f:
        for key in f.keys():
            tensors[key] = f.get_tensor(key)
        merged_dict.update(tensors)

TI_EMBED_PATH = f"{TI_OUTPUT_DIR}/aggregated_embed_sd2_1_base.pt"
os.makedirs(os.path.dirname(TI_EMBED_PATH), exist_ok=True)
torch.save(merged_dict, TI_EMBED_PATH)

print(merged_dict.keys())

In [None]:
### for lora training, we use a json file to manage the training data with images and prompts
### note that, here, we should put the learned lesion token names derived from textual inversion
### in the prompts, instead of the original disease names. Feel free to try different rank sizes.

LORA_OUTPUT_DIR="/media/janet/DermDPO/models/lora_weights/fitzpatrick17k"
json_path = "../splits/Fitzpatrick17k/fitzpatrick17k.json"

!accelerate launch --mixed_precision="fp16"  ../scripts/train_text_to_image_lora.py \
    --pretrained_model_name_or_path=$MODEL_NAME \
    --train_data_dir=$TRAIN_DIR \
    --dataloader_num_workers=8 \
    --image_column="image" \
    --caption_column="prompt" \
    --resolution=512 \
    --center_crop \
    --random_flip \
    --train_batch_size=4 \
    --gradient_accumulation_steps=1 \
    --max_train_steps=3000 \
    --learning_rate=5e-06 \
    --max_grad_norm=1 \
    --lr_scheduler="cosine" \
    --lr_warmup_steps=0 \
    --output_dir=$LORA_OUTPUT_DIR \
    --report_to=wandb \
    --checkpointing_steps=1000 \
    --seed=1337 \
    --rank=32 \
    --embed_path=$TI_EMBED_PATH \
    --json_path=$json_path