<a href="https://colab.research.google.com/github/debamitro/6-s093-day1-finetune-flux/blob/main/6_S093%2C_Day_1_Finetune_Flux_Replicate.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# How to use

To run and modify the notebook, in the top left go to file -> make a copy in Drive.

Useful shortcuts:
- Shift + enter: runs a cell

Additional Resources:

More in depth fine tuning explanation [here](https://civitai.com/articles/4/make-your-own-loras-easy-and-free), [here](https://replicate.com/blog/fine-tune-flux) or [here](https://dreambooth.github.io/).

In [None]:
#@title 0.A Install replicate library
!pip install replicate



In [None]:
#@title 0.B Setup Replicate

import os
import replicate
from IPython.display import Image, display

# YOUR REPLICATE API KEY
REPLICATE_API_KEY = ""

os.environ["REPLICATE_API_TOKEN"] = REPLICATE_API_KEY

In [None]:
#@title 0.C Test the image generation model

output = replicate.run(
    "black-forest-labs/flux-dev",
    input={
        "prompt": "A photo of DJENNDOG in a space station",
        "num_inference_steps": 28, # typically need ~30 for "dev" model. Less steps == faster generation, but the quality is worse
        "guidance_scale": 7.5,     # how much attention the model pays to the prompt. Try different values between 1 and 50 to see
        "model": "schnell",        # after fine-tuning you can use "schnell" model to generate images faster. In that case put num_inference_steps=4
    }
)

generated_img_url = str(output[0])
print(f"Generated image URL: {generated_img_url}")
display(Image(url=generated_img_url))

Generated image URL: https://replicate.delivery/xezq/HZTvMorlSpLmCpgICjfkoZbcrPUWGweBjzycFOCHfDYcUjLoA/out-0.webp


# 1. Finetuning a text to image model

The first and most important thing to care about when training a custom image generation model, is the data. If you have a bad dataset that you are trianing on, it does not matter what model or how much compute you throw at the problem, your output model will still not perform the way that you want it to.

For image generation, we dont actually need a lot of data to add a new concept or style to the model, as little as 5 images will do, although more is always better, usually datasets are between 20-1000 miages. When selecting images here's what you need to keep in mind:

- Avoid low quality images, i.e. blurry or low (<256 px) resolution
- Avoid images with weird aspect ratios (anything more than 2:1, ie 1024x512px)
- Dont worry about getting 4k or super high resolution images, they will be downscaled to ~1024px per side when training

When training a model, you will typically either be training the model to understand a person, or new style. Because of this, you will usually include a trigger word that lets the model know you are trying to evoke that concept. That way the model will keep its previous understanding of concepts while also having a new one added to it. Because we dont want to overwrite existing concepts, the trigger word will be a specific person's name, or a "custom" word, i.e. "SUNDAI" or "tr1gg3r w0rd".

In [None]:
#@title 1.A Create the model repository

# Here we are setting up the repository in replicate where the model will go once we have trained it

import replicate
from replicate.exceptions import ReplicateError

# You can see your username on replicate in the top left corner.

# NOTE: we use sundai-club account name because you logged via Sundai org
replicate_username = "sundai-club"
# Name of your fintuned model
finetuned_mode_name = "flux-djenydog"

try:
  model = replicate.models.create(
      owner=replicate_username,
      name=finetuned_mode_name,
      visibility="public",  # or "private" if you prefer
      hardware="gpu-t4",  # Replicate will override this for fine-tuned models
      description="A fine-tuned FLUX.1 model",
  )
  print(f"Model created: {model.name}")
except ReplicateError as e:
  if "already exists" in e.detail:
    print("Model already exists, loading it.")
    model = replicate.models.get(f"{replicate_username}/{finetuned_mode_name}")
  else:
    raise e

print(f"Model URL: https://replicate.com/{model.owner}/{model.name}")

Model already exists, loading it.
Model URL: https://replicate.com/sundai-club/flux-djenydog


In [None]:
#@title 1.B Train the model

# The dataset needs to be a zip folder, with public access to it
dataset_url = "https://huggingface.co/datasets/AMead10/Victor-Perkins/resolve/main/vector.zip"
trigger_word = "DJENNDOG"
steps = 1000

training = replicate.trainings.create(
    version="ostris/flux-dev-lora-trainer:4ffd32160efd92e956d39c5338a9b8fbafca58e03f791f6d8011f3e20e8ea6fa",
    input={
        "input_images": open("djeny_train_data.zip", "rb"),
        "steps": 1000,
    },
    trigger_word=trigger_word,
    destination=f"{model.owner}/{model.name}"
)

print(f"Training started: {training.status}")
print(f"Training URL: https://replicate.com/p/{training.id}")

Training started: starting
Training URL: https://replicate.com/p/d2tmsxf1kdrm80cmdgv8v1r8fw


In [None]:
#@title 1.C Test your fine-tuned model

latest_version = model.versions.list()[0]
output = replicate.run(
    latest_version,
    input={
        "prompt": "DJENNDOG black dog is now wrapped in a small blanket, pretending to be asleep on the couch, still with visible cookie crumbs around its mouth. A confused-looking human mom is standing in the kitchen doorway, looking at the mess. The dog has one eye slightly open, peeking. Cute cartoonish style, warm colors.",
        "num_inference_steps": 28, # typically need ~30 for "dev" model. Less steps == faster generation, but the quality is worse
        "guidance_scale": 7.5,     # how much attention the model pays to the prompt. Try different values between 1 and 50 to see
        "model": "dev",            # after fine-tuning you can use "schnell" model to generate images faster. In that case put num_inference_steps=4
    }
)

generated_img_url = str(output[0])
print(f"Generated image URL: {generated_img_url}")
display(Image(url=generated_img_url))

Generated image URL: https://replicate.delivery/xezq/p23v7hOTUYrgIRN5ABwc3IbEbrA4L12cDYeRGArmlgaT14CKA/out-0.webp


In [None]:
#@title 1.D Faster Generation test

# TODO: play with the paramaeters of the model to get faster generation of the subject

latest_version = model.versions.list()[0]
output = replicate.run(
    latest_version,
    input={
        "prompt": "DJENNDOG black dog in a drwing vector style as a favicon",
        "num_inference_steps": 8, # typically need ~30 for "dev" model. Less steps == faster generation, but the quality is worse
        "guidance_scale": 7.5,     # how much attention the model pays to the prompt. Try different values between 1 and 50 to see
        "model": "schnell",        # after fine-tuning you can use "schnell" model to generate images faster. In that case put num_inference_steps=4
    }
)

generated_img_url = str(output[0])
print(f"Generated image URL: {generated_img_url}")
display(Image(url=generated_img_url))

Generated image URL: https://replicate.delivery/xezq/zafffNOXVfdSflw0PU5GUS6J1IzAaPXJ80r2qKuG44aBmkugC/out-0.webp


ValidationError: 2 validation errors for Image
prompt
  Field required [type=missing, input_value={'url': 'https://replicat...uG44aBmkugC/out-0.webp'}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.10/v/missing
captions
  Field required [type=missing, input_value={'url': 'https://replicat...uG44aBmkugC/out-0.webp'}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.10/v/missing

In [None]:
#@title 2. (Extra points) Add LLM calls

!pip install pydantic openai



In [None]:
GITHUB_TOKEN = ""

import json
from pydantic import BaseModel
from openai import OpenAI

client = OpenAI(
    base_url="https://models.inference.ai.azure.com",
    api_key=GITHUB_TOKEN,
)

In [None]:
# Note: responce_format is better way of enfocing structured output, however it
#       is still not supported in GitHub models. If you have an OpenAI key, you
#       can use these docs:
#         https://platform.openai.com/docs/guides/structured-outputs?lang=python

def generate_comics(user_prompt: str):
    """
    Generate comics story with image prompts and captions in JSON format

    Args:
        user_prompt (str): The story prompt from user

    Returns:
        dict: JSON formatted comics data
    """

    # System prompt for consistent output formatting
    system_prompt = """
    Create a 3-panel comic story about a dog's adventure. For each panel, provide:
    1. An image generation prompt that includes 'DJENNDOG black dog' and ends with 'cartoonish style, warm colors'
    2. A caption that refers to the dog as 'Djeny'

    Format the output as JSON with this structure:
    {
        "comics": [
            {
                "prompt": "Image generation prompt here",
                "caption": "Caption text here"
            }
        ]
    }
    """

    response = client.chat.completions.create(
        model="gpt-4o",
        response_format={ "type": "json_object" },
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]
    )

    # Parse response to JSON
    story_json = json.loads(response.choices[0].message.content)
    return story_json

comics = generate_comics("An adventrure in a concert")
print(json.dumps(comics, indent=2, ensure_ascii=False))

{
  "comics": [
    {
      "prompt": "DJENNDOG black dog wearing a red bandana, excitedly standing at the edge of a forest, colorful music notes floating in the air, trees swaying gently, old comics style, warm colors",
      "caption": "Djeny’s ears perked up as she heard music drifting through the forest—an adventure was calling!"
    },
    {
      "prompt": "DJENNDOG black dog sneaking through concert grounds, dodging between happy festival-goers and stacks of speakers, food stalls in the background, old comics style, warm colors",
      "caption": "Djeny weaved through the crowd, her nose leading her closer to the tantalizing beats and smells of the concert."
    },
    {
      "prompt": "DJENNDOG black dog dancing on stage under colorful lights, musicians laughing and playing instruments behind her, the crowd cheering joyfully, old comics style, warm colors",
      "caption": "To everyone’s delight, Djeny managed to leap on stage, becoming the star of the show with her wagging t

In [None]:
img_urls = []
for img_description in comics["comics"]:
  output = replicate.run(
    latest_version,
    input={
        "prompt": img_description["prompt"],
        "num_inference_steps": 8,  # typically need ~30 for "dev" model. Less steps == faster generation, but the quality is worse
        "guidance_scale": 7.5,     # how much attention the model pays to the prompt. Try different values between 1 and 50 to see
        "model": "schnell",        # after fine-tuning you can use "schnell" model to generate images faster. In that case put num_inference_steps=4
    }
  )

  generated_img_url = str(output[0])
  img_urls.append(generated_img_url)

In [None]:
from IPython.display import HTML

# HTML template for responsive image grid
html_template = '''
<div style="display: flex; flex-wrap: wrap; justify-content: space-around; gap: 20px;">
    {}
</div>
'''

# HTML template for each image and caption
image_template = '''
<div style="flex: 0 1 300px; text-align: center; margin-bottom: 20px;">
    <img src="{}" style="max-width: 100%; height: auto; border-radius: 8px; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
    <p style="margin-top: 10px; font-style: italic; color: #888;">{}</p>
</div>
'''

# Combine all images and captions
image_elements = []
for url, caption in zip(img_urls, [ b["caption"] for b in comics["comics"] ]):
    image_elements.append(image_template.format(url, caption))

# Display the final HTML
display(HTML(html_template.format(''.join(image_elements))))