In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import Dataset
from torchvision import transforms, models

# import skimage
# from skimage import io
import pandas as pd
import numpy as np
# from sklearn.model_selection import train_test_split
import random

import os
from glob import glob

import PIL
from PIL import Image
import torch
# from pipeline import StableDiffusionPipeline
from diffusers import UNet2DConditionModel, StableDiffusionImg2ImgPipeline, StableDiffusionPipeline
import json

In [None]:
SEED = 0
SPLIT = 'light_and_dark_flex_to_dark'

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
MODEL_NAME = "stabilityai/stable-diffusion-2-1-base" 

pipe = StableDiffusionImg2ImgPipeline.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, revision="fp16").to(device)

In [None]:
from transformers import CLIPTokenizer, CLIPTextModel

def load_embeddings(embed_path: str, 
                    model_path: str = "stabilityai/stable-diffusion-2-1-base"
                    ):

    tokenizer = CLIPTokenizer.from_pretrained(
        model_path, use_auth_token=True,
        subfolder="tokenizer")

    text_encoder = CLIPTextModel.from_pretrained(
        model_path, use_auth_token=True,
        subfolder="text_encoder")
    
    print(len(tokenizer))

    for token, token_embedding in torch.load(
            embed_path, map_location="cpu").items():

        # add the token in tokenizer
        num_added_tokens = tokenizer.add_tokens(token)

        # resize the token embeddings
        text_encoder.resize_token_embeddings(len(tokenizer))
        added_token_id = tokenizer.convert_tokens_to_ids(token)

        # get the old word embeddings
        embeddings = text_encoder.get_input_embeddings()

        # get the id for the token and assign new embeds
        embeddings.weight.data[added_token_id] = \
            token_embedding.to(embeddings.weight.dtype)
    print(len(tokenizer))

    return tokenizer, text_encoder.to(device)

### add textual inversion tokens
embed_path = f"../models/textual_inversion/{SPLIT}_seed={SEED}/aggregated_embeds_SEED={SEED}.pt"
print(embed_path)
tokenizer, text_encoder = load_embeddings(
                embed_path, model_path=MODEL_NAME)
pipe.tokenizer = tokenizer
pipe.text_encoder = text_encoder

In [None]:
real_train = f'../data_splits/train_lora_{SPLIT}_seed={SEED}.json'
lora_path = f'../models/lora_weights/{SPLIT}_seed={SEED}'
output_path = f'../inference_img/{SPLIT}_seed={SEED}_steps=100_strength=0.5_guidance=2'
image_dir = '/data/derm_data/Fitzpatrick17k/finalfitz17k/'
os.makedirs(output_path, exist_ok=True)

pipe.load_lora_weights(lora_path)
pipe.to(dtype=torch.float16)

tone_mapper = {
    1: 'a very light-skinned',
    2: 'a light-skinned',
    5: 'a dark-skinned',
    6: 'a very dark-skinned',
}

token_mapper ={
    "basal cell carcinoma": "bas-class",
    "folliculitis": "fol-class",
    "nematode infection": "nem-class",
    "neutrophilic dermatoses": "neu-class",
    "prurigo nodularis": "pru-class",
    "psoriasis": "pso-class",
    "squamous cell carcinoma": "squ-class",
}

with open(real_train, "r") as f:
        real_train = json.load(f)
        
print(len(real_train))

In [None]:
res = []
for data in real_train[:10]:
    dtype = data['label']
    stype = data['skin_type']
    img_id = data['image_path'].split('.')[0]
    img = Image.open(f'{image_dir}/{data['image_path']}').convert("RGB")
    img = img.resize((512, 512), resample=PIL.Image.BILINEAR)
    if stype not in [5, 6]:
        stype = random.choice([5, 6])
    images = pipe( 
        prompt=f"An image of {token_mapper[dtype]} on the skin of {tone_mapper[stype]} individual",
        image=img,
        strength=0.5, # default: 0.8,  0.7 - cannot change the skin color,  0.75 - change the skin color
        num_inference_steps=100,
        guidance_scale=2, # default: 7.5
        num_images_per_prompt=5,
    ).images
    idx = 0
    for image in images:
        name = f"{img_id}_{idx}"
        resized_img = image.resize(size=(256, 256))
        resized_img.save(f'{output_path}/{name}.jpg')
        res.append([name, dtype, stype])
        idx += 1

synthetic_train = pd.DataFrame(res, columns=['md5hash', 'label', 'fitzpatrick_scale'])
synthetic_train.to_csv(f'{output_path}.csv', index=False) 
synthetic_train