In [None]:
# Default env type is test
# Change this when running on colab or kaggle
ENV_TYPE = "TEST"

In [None]:
# Important deps
!apt update && apt install -y aria2 python3-pip git libgl1-mesa-glx

if ENV_TYPE != "TEST":
    # Clone the repo
    !git clone https://github.com/kk-digital/kcg-ml-sd1p4.git --recurse-submodules -b lora
    %cd kcg-ml-sd1p4
    # Download model weights
    !aria2c https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt -o ./input/model/sd-v1-4.ckpt -j 10

In [None]:
# Install reqs
!pip3 install -r lora-tuning/requirements.txt

In [None]:
# Let's now run the 'test suite'
!python3 lora-tuning/lora_train.py

In [None]:
# Generate images for each epoch:
import os
import subprocess
from PIL import Image, ImageDraw, ImageFont

# Set the directory containing the .safetensors files
lora_path = "./output/LoRas/Test/output/"
checkpoint_path = "input/model/sd-v1-4.ckpt"
scale = 512
prompt = "a sketch of a pixwaifu, cute, small, chibi, DeviantArt trending"

# Get a list of all .safetensors files in the directory
file_list = [filename for filename in os.listdir(lora_path) if filename.endswith(".safetensors")]

# Create a blank canvas to hold the generated images
canvas = Image.new("RGB", (1920, 1920))
draw = ImageDraw.Draw(canvas)
font = ImageFont.truetype("arial.ttf", 16)
y_offset = 10
epoch = 1

for filename in file_list:
    # Generate the image using txt2img.py for each .safetensors file
    output_filename = os.path.splitext(filename)[0] + ".png"
    command = f"python3 lora-tuning/inference/txt2img.py --checkpoint_path {checkpoint_path} --lora {lora_path} --output {output_filename} --scale {scale} --prompt \"{prompt}\""
    subprocess.run(command, shell=True, check=True)

    # Load the generated image
    image = Image.open(output_filename)

    # Add caption and epoch information to the canvas
    caption = f"Filename: {filename}\nPrompt: {prompt}\nEpoch: {epoch}"
    draw.text((10, y_offset), caption, font=font, fill="white")
    canvas.paste(image, (10, y_offset + 60))

    # Update the y_offset and epoch
    y_offset += image.size[1] + 100
    epoch += 1

# Save the final image
canvas.save("combined_image.png")
