In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import Dataset
from torchvision import transforms, models

# import skimage
# from skimage import io
import pandas as pd
import numpy as np
# from sklearn.model_selection import train_test_split
import random

import os
from glob import glob

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
SEED = 0
SPLIT = 'light_dark_seed_to_dark'

# # random.seed(SEED)
# # np.random.seed(SEED)
# # torch.manual_seed(SEED)
# # torch.cuda.manual_seed(SEED)
# # torch.backends.cudnn.deterministic = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [13]:
### perform train-test split; the training data will be used to perform img2img generation and classifier training

df = pd.read_csv('./data_splits/Fitz_subset.csv') # 7-class subset
image_dir='./data/finalfitz17k'

test_files = [os.path.splitext(os.path.basename(path))[0] for path in glob(f"/data_split/data_seed={SEED}/seed_4/*/*/*.jpg")]
seed_mask = df.md5hash.isin(test_files)
seed = df[seed_mask]
seed_dark = seed[(seed.fitzpatrick_scale == 5) | (seed.fitzpatrick_scale == 6)]

real_train = df[~seed_mask]
real_test = real_train[(real_train.fitzpatrick_scale == 5)| (real_train.fitzpatrick_scale == 6)]

real_train_light = real_train[(real_train.fitzpatrick_scale == 1) | (real_train.fitzpatrick_scale == 2)]
real_train = pd.concat([real_train_light, seed_dark])

print(SEED)
print(real_train.shape[0], real_test.shape[0])
print(real_train.fitzpatrick_scale.unique(), real_test.fitzpatrick_scale.unique())

0
1284 291
[1 2 5 6] [5 6]


In [4]:
import PIL
from PIL import Image
import torch
# from pipeline import StableDiffusionPipeline
from diffusers_ import UNet2DConditionModel, StableDiffusionImg2ImgPipeline, StableDiffusionPipeline

# disease_mapper = [
#     "<basal_cell_carcinoma>",
#     "<folliculitis>",
#     "<nematode_infection>",
#     "<neutrophilic_dermatoses>",
#     "<prurigo_nodularis>",
#     "<psoriasis>",
#     "<squamous_cell_carcinoma>",
# ]

disease_mapper = [
    "<bas-class>",
    "<fol-class>",
    "<nem-class>",
    "<neu-class>",
    "<pru-class>",
    "<pso-class>",
    "<squ-class>",
]

diseases_name = {
    "basal cell carcinoma": 0,
    "folliculitis": 1,
    "nematode infection": 2,
    "neutrophilic dermatoses": 3,
    "prurigo nodularis": 4,
    "psoriasis": 5,
    "squamous cell carcinoma": 6,
}
tone_mapper = {
    1: 'a very light-skinned',
    2: 'a light-skinned',
    5: 'a dark-skinned',
    6: 'a very dark-skinned',
}
prompt_template = "An image of {} on the skin of {} individual"

In [5]:
MODEL_NAME = "stabilityai/stable-diffusion-2-1-base" 

pipe = StableDiffusionImg2ImgPipeline.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, revision="fp16").to(device)

 The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title 'stabilityai/stable-diffusion-2-1-base is missing fp16 files' so that the correct variant file can be added.
unet/diffusion_pytorch_model.safetensors not found
  deprecate("LoRACompatibleLinear", "1.0.0", deprecation_message)
  deprecate("LoRACompatibleConv", "1.0.0", deprecation_message)
  deprecate("LoRACompatibleConv", "1.0.0", deprecation_message)
  deprecate("LoRACompatibleLinear", "1.0.0", deprecation_message)
Loading pipeline components...: 100%|██████████| 5/5 [00:01<00:00,  3.04it/s]


In [6]:
from transformers import CLIPTokenizer, CLIPTextModel

def load_embeddings(embed_path: str, 
                    model_path: str = "CompVis/stable-diffusion-v1-4"
                    ):

    tokenizer = CLIPTokenizer.from_pretrained(
        model_path, use_auth_token=True,
        subfolder="tokenizer")

    text_encoder = CLIPTextModel.from_pretrained(
        model_path, use_auth_token=True,
        subfolder="text_encoder")

    for token, token_embedding in torch.load(
            embed_path, map_location="cpu").items():

        # add the token in tokenizer
        num_added_tokens = tokenizer.add_tokens(token)
        assert num_added_tokens > 0

        # resize the token embeddings
        text_encoder.resize_token_embeddings(len(tokenizer))
        added_token_id = tokenizer.convert_tokens_to_ids(token)

        # get the old word embeddings
        embeddings = text_encoder.get_input_embeddings()

        # get the id for the token and assign new embeds
        embeddings.weight.data[added_token_id] = \
            token_embedding.to(embeddings.weight.dtype)

    return tokenizer, text_encoder.to(device)

embed_path = f'./textual_inversion_weights/{SPLIT}/SEED={SEED}/aggregated_embeds_SEED={SEED}.pt'
print(embed_path)
tokenizer, text_encoder = load_embeddings(
                embed_path, model_path=MODEL_NAME)
pipe.tokenizer = tokenizer
pipe.text_encoder = text_encoder

/ssd/janet/lora_textual_inversion/textual_inversion_weights/light_only_to_dark/SEED=1234/aggregated_embeds_SEED=1234.pt




In [7]:
import os
grouped_data = real_train.groupby('label')
res = []
lora_path = f'./textual_inversion_weights/{SPLIT}/ti_lora_SEED={SEED}'
path = f'./inference_img/{SPLIT}_SEED={SEED}_steps=100_strength=0.5_guidance=2'

os.makedirs(path, exist_ok=True)

for label, group_df in grouped_data:
    lora_weights = f'{lora_path}/{label[:3]}/pytorch_lora_weights.safetensors'
    pipe.load_lora_weights(lora_weights)
    pipe.to(dtype=torch.float16)
    
    d_type = diseases_name[label]

    for index, row in group_df.iterrows():
        stype = row.fitzpatrick_scale
        img = Image.open(f'{image_dir}/{row.md5hash}.jpg').convert("RGB")
        img = img.resize((512, 512), resample=PIL.Image.BILINEAR)
        if stype not in [5, 6]:
            stype = random.choice([5, 6])
        color = tone_mapper[stype]
        images = pipe( 
            prompt=f"An image of {disease_mapper[d_type]} on the skin of {color} individual",
            image=img,
            strength=0.5, # default: 0.8, < 0.7 -> cannot change the skin color, > 0.75 -> change the skin color
            num_inference_steps=100,
            guidance_scale=2, # default: 7.5
            num_images_per_prompt=5,
        ).images
        idx = 0
        for image in images:
            name = f"{row.md5hash}_{idx}"
            resized_img = image.resize(size=(256, 256))
            resized_img.save(f'{path}/{name}.jpg')
            res.append([name, label, stype])
            idx += 1

synthetic_train = pd.DataFrame(res, columns=['md5hash', 'label', 'fitzpatrick_scale'])
synthetic_train.to_csv(f'{path}.csv', index=False) 
synthetic_train # the output csv will be used for training the classifier

  deprecate("set_lora_layer", "1.0.0", deprecation_message)
  deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
100%|██████████| 30/30 [00:05<00:00,  5.42it/s]
100%|██████████| 30/30 [00:05<00:00,  5.41it/s]
100%|██████████| 30/30 [00:05<00:00,  5.42it/s]
100%|██████████| 30/30 [00:05<00:00,  5.38it/s]
100%|██████████| 30/30 [00:05<00:00,  5.39it/s]
100%|██████████| 30/30 [00:05<00:00,  5.38it/s]
 90%|█████████ | 27/30 [00:05<00:00,  5.59it/s]

100%|██████████| 30/30 [00:05<00:00,  5.39it/s]
100%|██████████| 30/30 [00:05<00:00,  5.38it/s]
100%|██████████| 30/30 [00:05<00:00,  5.35it/s]
100%|██████████| 30/30 [00:05<00:00,  5.35it/s]
100%|██████████| 30/30 [00:05<00:00,  5.35it/s]
100%|██████████| 30/30 [00:05<00:00,  5.34it/s]
100%|██████████| 30/30 [00:05<00:00,  5.32it/s]
100%|██████████| 30/30 [00:05<00:00,  5.31it/s]
100%|██████████| 30/30 [00:05<00:00,  5.31it/s]
100%|██████████| 30/30 [00:05<00:00,  5.32it/s]
100%|██████████| 30/30 [00:05<00:00,  5.31it/s]
100%|██████████| 30/30 [00:05<00:00,  5.31it/s]
100%|██████████| 30/30 [00:05<00:00,  5.32it/s]
100%|██████████| 30/30 [00:05<00:00,  5.32it/s]
100%|██████████| 30/30 [00:05<00:00,  5.32it/s]
100%|██████████| 30/30 [00:05<00:00,  5.32it/s]
100%|██████████| 30/30 [00:05<00:00,  5.32it/s]
100%|██████████| 30/30 [00:05<00:00,  5.32it/s]
100%|██████████| 30/30 [00:05<00:00,  5.32it/s]
100%|██████████| 30/30 [00:05<00:00,  5.32it/s]
100%|██████████| 30/30 [00:05<00:00,  5.

Unnamed: 0,md5hash,label,fitzpatrick_scale
0,d2e96d346577a155b0125dda66b16395_0,basal cell carcinoma,5
1,d2e96d346577a155b0125dda66b16395_1,basal cell carcinoma,5
2,d2e96d346577a155b0125dda66b16395_2,basal cell carcinoma,5
3,d2e96d346577a155b0125dda66b16395_3,basal cell carcinoma,5
4,d2e96d346577a155b0125dda66b16395_4,basal cell carcinoma,5
...,...,...,...
6135,d7c15764c5a698f47809c97f8836d442_0,squamous cell carcinoma,6
6136,d7c15764c5a698f47809c97f8836d442_1,squamous cell carcinoma,6
6137,d7c15764c5a698f47809c97f8836d442_2,squamous cell carcinoma,6
6138,d7c15764c5a698f47809c97f8836d442_3,squamous cell carcinoma,6
