In [None]:
from tqdm import tqdm
import os
from easynmt import EasyNMT
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, BitsAndBytesConfig
import torch
from PIL import Image
import re
import shutil
import pandas as pd
import random


dataset_path = "../storage/dataset"
dataset_csv = "../storage/dataset.csv"
translated = "../storage/translated-description"
generated = "../storage/generated-description"
translated_random_split = "../storage/translated-description_random-split"
generated_random_split = "../storage/generated-description_random-split"
translated_category_split = "../storage/translated-description_category-split"
generated_category_split = "../storage/generated-description_category-split"

In [None]:
model = EasyNMT("opus-mt")

"""SDXL and its fine-tuning works better with english texts. It makes sense to translate the german descriptions into english."""
def translate_image_description():
    total_txt_files = sum(
        1 for _, _, files in os.walk(dataset_path) for file in files if file.lower().endswith(".txt")
    )
    
    progress_bar = tqdm(total=total_txt_files, desc="description translated", unit="description")

    
    for root, dirs, files in os.walk(dataset_path):
        for file in files:
            if file.lower().endswith(".txt"):
                input_file_path = os.path.join(root, file)
                relative_path = os.path.relpath(root, dataset_path)
                output_file_dir = os.path.join(translated, relative_path)
                output_file_path = os.path.join(output_file_dir, file)

                os.makedirs(output_file_dir, exist_ok=True)

                with open(input_file_path, "r", encoding="utf-8") as f:
                    content = f.read()

                translated_content = model.translate(content, target_lang="en")

                with open(output_file_path, "w", encoding="utf-8") as f:
                    f.write(translated_content)
                    
                progress_bar.update()
    progress_bar.close()
                

translate_image_description()

In [None]:
def clean_caption(text):
    # LLaVA adds start and end flags for descriptions
    text = re.sub(r"\[INST\].*?\[/INST\]", "", text, flags=re.DOTALL)
    
    # max new tokens is set to 70 - sometimes the model stops with unfinished sentences
    text = re.sub(r"[^\.]*$", "", text)
    text = text.strip()
    return text

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", quantization_config=quantization_config, device_map="auto")

image_captioning_model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True).to("cuda")

translation_model = EasyNMT("opus-mt")

"""LLaVA can be used to create additional image description and compare those with translated original description."""
def generate_llava_description():
    for root, dirs, files in os.walk(dataset_path):
        total_elements = len(dirs)
        progress_bar = tqdm(total=total_elements, desc="description generated", unit="description")
        for dir in dirs:
            input_file_path = os.path.join(root, f"{dir}/{dir}.png")
            output_file_dir = os.path.join(generated, dir)
            output_file_path = os.path.join(output_file_dir, f"{dir}.txt")
            
            # translate the title to inject additional context in LLaVA
            title = translation_model.translate(dir, source_lang="de", target_lang="en") 

            os.makedirs(output_file_dir, exist_ok=True)
            
            conversation = [
                {
                    "role": "user",
                    "content": [
                        {"type": "image"},
                        {"type": "text", "text": f"What is shown in this image? The title is {title}."},
                    ],
                },
            ]
            prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
            inputs = processor(prompt, Image.open(input_file_path), return_tensors="pt").to("cuda:0")
            
            output = image_captioning_model.generate(**inputs, max_new_tokens=70)
            
            image_caption = clean_caption(processor.decode(output[0], skip_special_tokens=True))

            with open(output_file_path, "w", encoding="utf-8") as f:
                f.write(image_caption)
                progress_bar.update()
                
        progress_bar.close()

generate_llava_description()

In [None]:
"""Images with at least 512 pixels can produce better fine-tuning results"""
def upscale_images(output_dir):
    for root, dirs, files in os.walk(dataset_path):
        for file in files:
            if file.lower().endswith(".png"):
                input_file_path = os.path.join(root, file)
                relative_path = os.path.relpath(root, dataset_path)
                output_file_dir = os.path.join(output_dir, relative_path)
                output_file_path = os.path.join(output_file_dir, file)

                os.makedirs(output_file_dir, exist_ok=True)

                low_res_img = Image.open(input_file_path).convert("RGB")
                
                original_width, original_height = low_res_img.size
                
                # keep aspect ratio
                if original_width > original_height:
                    new_width = 512
                    new_height = int((512 / original_width) * original_height)
                else:
                    new_height = 512
                    new_width = int((512 / original_height) * original_width)
                
                high_res_img = low_res_img.resize((new_width, new_height))

                high_res_img.save(output_file_path)


for dir in [translated, generated]:
    upscale_images(dir)

In [None]:
"""The traditional way to split training and test datasets is to create random partitions"""
def random_dataset_split(folder_path, new_folder_name, test_size=0.2, seed=None):
    if seed is not None:
        random.seed(seed)

    os.makedirs(new_folder_name, exist_ok=True)

    train_dir = os.path.join(new_folder_name, "train")
    test_dir = os.path.join(new_folder_name, "test")
    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(test_dir, exist_ok=True)

    all_folders = [d for d in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, d))]
    random.shuffle(all_folders)
    num_test_folders = int(len(all_folders) * test_size)

    test_folders = all_folders[:num_test_folders]
    train_folders = all_folders[num_test_folders:]

    for folder in test_folders:
        src_path = os.path.join(folder_path, folder)
        dest_path = os.path.join(test_dir, folder)
        shutil.copytree(src_path, dest_path)

    for folder in train_folders:
        src_path = os.path.join(folder_path, folder)
        dest_path = os.path.join(train_dir, folder)
        shutil.copytree(src_path, dest_path)


for dir in [translated, generated]:
    random_dataset_split(dir, f"{dir}_random-split", test_size=0.2, seed=42)

In [None]:
"""Since the dataset from LAG Selbsthilfe is highly unbalanced and biased you can also create a test dataset with samples from the category "Begriffe only" to reduce the bias a little bit. """
def split_dataset_by_category(csv_path, category, source_folder_path, target_folder_path):
    train_dir = os.path.join(target_folder_path, "train")
    test_dir = os.path.join(target_folder_path, "test")
    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(test_dir, exist_ok=True)

    df = pd.read_csv(csv_path)

    category_df = df[df["category"] == category]
    category_titles = category_df["title"].tolist()

    all_folders = [os.path.join(source_folder_path, d) for d in os.listdir(source_folder_path) if os.path.isdir(os.path.join(source_folder_path, d))]

    for folder in all_folders:
        folder_name = os.path.basename(folder)
        if folder_name in category_titles:
            dest_path = os.path.join(test_dir, folder_name)
        else:
            dest_path = os.path.join(train_dir, folder_name)

        if not os.path.exists(dest_path):
            shutil.copytree(folder, dest_path)

for dir in [translated, generated]:
    split_dataset_by_category(
    csv_path=dataset_csv,
    category="Begriffe",
    source_folder_path=dir,
    target_folder_path=f"{dir}_category-split"
)

In [None]:
"""Fine-Tuning with sd-scripts expects a folder structure with text- and image files in one shared directory. sd-scripts gets some hyperparameters e.g. the number of repeats per image from the training image folder name"""
def format_training_folder(
        dataset_dir,
        suffix=" leichte sprache style",
        number_repeats=20,
        instance_prompt="leichte sprache style",
        class_prompt="style"
):
    image_extensions = {".png"}
    text_extension = ".txt"
    # sd gets the number of repeats from dir prefix
    output_dir = f"{dataset_dir}/fine-tuning/{int(number_repeats)}_{instance_prompt} {class_prompt}"
    lora_dir = f"{dataset_dir}/loras"
    log_dir = f"{dataset_dir}/logs"
    
    # create dir where lora-weights can be stored
    os.makedirs(output_dir, exist_ok=True)
    
    # create dir where lora-weights can be stored
    os.makedirs(lora_dir, exist_ok=True)
    
    # create dir where training logs can be stored
    os.makedirs(log_dir, exist_ok=True)
    
    for root, _, files in os.walk(f"{dataset_dir}/train"):
        for file in files:
            if os.path.splitext(file)[1].lower() in image_extensions:
                image_path = os.path.join(root, file)
                description_file = os.path.splitext(file)[0] + text_extension
                description_path = os.path.join(root, description_file)

                if os.path.isfile(description_path):
                    shutil.copy2(image_path, output_dir)
                    target_text_path = os.path.join(output_dir, os.path.basename(description_path))
                    
                    with open(description_path, "r", encoding="utf-8") as src_file:
                        description = src_file.read().strip()
                        
                        # suffix to link prompt to fine-tuning context
                        if suffix:
                            description += suffix
                        with open(target_text_path, "w", encoding="utf-8") as dest_file:
                            dest_file.write(description)


for dir in [translated_category_split, translated_random_split, generated_category_split, generated_random_split]:
    format_training_folder(dir)

In [None]:
def resize_image(image_path, size):
    with Image.open(image_path) as img:
        img_resized = img.resize((size, size))
        img_resized.save(image_path)

"""For Evaluation with FID reference and generated images must have the same resolution"""
def format_test_folder(dir):
    test_path = f"{dir}/test"
    new_test_images_path = os.path.join(dir, "test-images-only")
    
    if not os.path.exists(new_test_images_path):
        os.makedirs(new_test_images_path)
        
    for root, dirs, files in os.walk(test_path):
        for file in files:
            if file.endswith(".png"):
                file_path = os.path.join(root, file)
                dest_path = os.path.join(new_test_images_path, file)

                shutil.copy(file_path, dest_path)
                
                resize_image(dest_path, 512)
                
for dir in [translated_category_split, translated_random_split, generated_category_split, generated_random_split]:
    format_test_folder(dir)

In [None]:
# optional: you can use the four dataset dirs for fine-tuning. Some dirs are not necessary anymore.
for dir in [dataset_path, generated, translated]:
    shutil.rmtree(dir, ignore_errors=True)