In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torch.utils.data import Dataset
from torchvision import transforms, models

# import skimagea``
# from skimage import io

import pandas as pd
import numpy as np
# from sklearn.model_selection import train_test_split
import random

import PIL
from PIL import Image
import torch
# from pipeline import StableDiffusionPipeline
from diffusers import UNet2DConditionModel, StableDiffusionImg2ImgPipeline, StableDiffusionPipeline

import os
import json

In [None]:
SEED = 42

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)  

random.seed(SEED)
np.random.seed(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [None]:
MODEL_NAME = "stabilityai/stable-diffusion-2-1-base" 

pipe = StableDiffusionImg2ImgPipeline.from_pretrained(MODEL_NAME, 
                                                      torch_dtype=torch.float32, 
                                                    #   revision="fp16"
                                                      ).to('cuda:0')
pipe.safety_checker = None
pipe.requires_safety_checker = False

In [None]:
from transformers import CLIPTokenizer, CLIPTextModel

def load_embeddings(embed_path: str, 
                    model_path: str = "stabilityai/stable-diffusion-2-1-base"
                    ):

    tokenizer = CLIPTokenizer.from_pretrained(
        model_path, use_auth_token=True,
        subfolder="tokenizer")

    text_encoder = CLIPTextModel.from_pretrained(
        model_path, use_auth_token=True,
        subfolder="text_encoder")
    
    print(len(tokenizer))

    for token, token_embedding in torch.load(
            embed_path, map_location="cpu").items():

        # add the token in tokenizer
        num_added_tokens = tokenizer.add_tokens(token)
        assert num_added_tokens > 0

        # resize the token embeddings
        text_encoder.resize_token_embeddings(len(tokenizer))
        added_token_id = tokenizer.convert_tokens_to_ids(token)

        # get the old word embeddings
        embeddings = text_encoder.get_input_embeddings()

        # get the id for the token and assign new embeds
        embeddings.weight.data[added_token_id] = \
            token_embedding.to(embeddings.weight.dtype)
    print(len(tokenizer))

    return tokenizer, text_encoder.to('cuda:0')

### add textual inversion tokens
embed_path = f'aggregated_embed_sd2_1_base.pt'
print(embed_path)
tokenizer, text_encoder = load_embeddings(
                embed_path, model_path=MODEL_NAME)
pipe.tokenizer = tokenizer
pipe.text_encoder = text_encoder

In [None]:
import os

prompt_mapping = {
                0: 'xxxacnxxx',
                1: 'xxxactxxx',
                2: 'xxxallxxx',
                3: 'xxxbasxxx',
                4: 'xxxeczxxx',
                5: 'xxxeryxxx',
                6: 'xxxfolxxx',
                7: 'xxxgraxxx',
                8: 'xxxkelxxx',
                9: 'xxxlicxxx',
                10: 'xxxlupxxx',
                11: 'xxxmelxxx',
                12: 'xxxmycxxx',
                13: 'xxxpitxxx',
                14: 'xxxpruxxx',
                15: 'xxxpsoxxx',
                16: 'xxxsarxxx',
                17: 'xxxscaxxx',
                18: 'xxxsquxxx',
                19: 'xxxvitxxx'
                }

mapping = {
            0: 'acne',
            1: 'actinic keratosis',
            2: 'allergic contact dermatitis',
            3: 'basal cell carcinoma',
            4: 'eczema',
            5: 'erythema multiforme',
            6: 'folliculitis',
            7: 'granuloma annulare',
            8: 'keloid',
            9: 'lichen planus',
            10: 'lupus erythematosus',
            11: 'melanoma',
            12: 'mycosis fungoides',
            13: 'pityriasis rosea',
            14: 'prurigo nodularis',
            15: 'psoriasis',
            16: 'sarcoidosis',
            17: 'scabies',
            18: 'squamous cell carcinoma',
            19: 'vitiligo'
            }

In [None]:
image_dir = '/data/derm_data/Fitzpatrick17k/finalfitz17k'

### load model after MAGIC-DPO
lora_path = f'/data/DermDPO/logs/fitz_20class_3100_w_bodyparts_new_cklist_2025.04.27_20.36.50/checkpoints/checkpoint_20'
output = f'output' # output path

os.makedirs(output, exist_ok=True)
pipe.load_lora_weights(lora_path)
pipe.to(dtype=torch.float32)

In [None]:
train_data_path = 'MAGIC/magic_pytorch/assets/all_prompts.json'

with open(train_data_path, 'r') as f:
    train_data = json.load(f)

res = []
for data in train_data:
    image_path = data['image_path']
    img_id = data['image_path'].split('.')[0].split('/')[-1]
    
    original_class = data['label']
    
    try:
        target_class = random.choice([
            x for x in range(20)
            if mapping[x] != original_class
            and any(part in body_part_dist[mapping[x]] for part in gross_body_part)
        ])
    except:
        target_class = random.choice([
            x for x in range(20)
            if mapping[x] != original_class
        ])
    target_label = mapping[target_class]
        
    d_type = prompt_mapping[target_class]
    
    img = Image.open(f'{image_dir}/{image_path}').convert("RGB")
    img = img.resize((512, 512), resample=PIL.Image.BILINEAR)
    
    images = pipe(
        prompt=f"an image of {d_type} on human skin",  # {d_type} around {body_part} of a person
        image=img, 
        strength=0.3, # default: 0.8
        num_inference_steps=100,
        guidance_scale=5, # default: 7.5
        num_images_per_prompt=1,
    ).images
    
    # idx = 0
    for image in images:
        name = f"{img_id}_{target_class}"
        resized_img = image.resize(size=(256, 256))
        resized_img.save(f'{output}/{name}.png')
        res.append([name, original_class, target_label])
        # idx += 1
            
        
synthetic_train = pd.DataFrame(res, columns=['md5hash', 'original_class', 'label'])
synthetic_train.to_csv(f'output.csv', index=False)
synthetic_train