# How to use

To run and modify the notebook, in the top left go to file -> make a copy in Drive.

Useful shortcuts:
- Shift + enter: runs a cell

Additional Resources:

More in depth fine tuning explanation [here](https://civitai.com/articles/4/make-your-own-loras-easy-and-free)

In [1]:
!pip install replicate

Collecting replicate
  Downloading replicate-1.0.4-py3-none-any.whl.metadata (29 kB)
Downloading replicate-1.0.4-py3-none-any.whl (48 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/48.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48.0/48.0 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: replicate
Successfully installed replicate-1.0.4


Test Run

# Finetuning a text to image model

The first and most important thing to care about when training a custom image generation model, is the data. If you have a bad dataset that you are trianing on, it does not matter what model or how much compute you throw at the problem, your output model will still not perform the way that you want it to.

For image generation, we dont actually need a lot of data to add a new concept or style to the model, as little as 5 images will do, although more is always better, usually datasets are between 20-1000 miages. When selecting images here's what you need to keep in mind:

- Avoid low quality images, i.e. blurry or low (<256 px) resolution
- Avoid images with weird aspect ratios (anything more than 2:1, ie 1024x512px)
- Dont worry about getting 4k or super high resolution images, they will be downscaled to ~1024px per side when training

When training a model, you will typically either be training the model to understand a person, or new style. Because of this, you will usually include a trigger word that lets the model know you are trying to evoke that concept. That way the model will keep its previous understanding of concepts while also having a new one added to it. Because we dont want to overwrite existing concepts, the trigger word will be a specific person's name, or a "custom" word, i.e. "Andrew Mead" or "tr1gg3r w0rd".

In [5]:
#@title Setup Replicate

#@markdown To get your Replicate API key, go to [Replicate](https://replicate.com/signin?next=/docs) and register. You then find your api key on the [API tokens page](https://replicate.com/account/api-tokens), which you can then paste here.

import os
import replicate
from IPython.display import Image

# YOUR REPLICATE API KEY
replicate_api_key = "" #@param {type: 'string'}

os.environ["REPLICATE_API_TOKEN"] = replicate_api_key

Test Run

In [13]:
output = replicate.run(
    "stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc",
    input={
        "width": 768,
        "height": 768,
        "prompt": "Asian Indian Basketball player 13 year old playing basketball with multiple actions",
        "refine": "expert_ensemble_refiner",
        "scheduler": "K_EULER",
        "lora_scale": 0.6,
        "num_outputs": 1,
        "guidance_scale": 7.5,
        "apply_watermark": False,
        "high_noise_frac": 0.8,
        "negative_prompt": "",
        "prompt_strength": 0.8,
        "num_inference_steps": 25
    }
)

print(output)

[<replicate.helpers.FileOutput object at 0x7f703024cb20>]


In [14]:
image_url = output[0].url

display(Image(url=image_url))

In [16]:
#@title Create the model repository

#@markdown Here we are setting up the repository in replicate where the model will go once we have trained it

import replicate
from replicate.exceptions import ReplicateError

#@markdown You can see your username on replicate in the top left corner.
replicate_username = "sundai-club" #@param {type: 'string'}
#@markdown Name of your fintuned model
finetuned_mode_name = "kanch-sundai" #@param {type: 'string'}

try:
  model = replicate.models.create(
      owner=replicate_username,
      name=finetuned_mode_name,
      visibility="public",  # or "private" if you prefer
      hardware="gpu-t4",  # Replicate will override this for fine-tuned models
      description="A fine-tuned sdxl model"
  )
  print(f"Model created: {model.name}")
except ReplicateError as e:
  if "already exists" in e.detail:
    print("Model already exists, loading it.")
    model = replicate.models.get(f"{replicate_username}/{finetuned_mode_name}")
  else:
    raise e

print(f"Model URL: https://replicate.com/{model.owner}/{model.name}")

Model created: kanch-sundai
Model URL: https://replicate.com/sundai-club/kanch-sundai


In [101]:
import base64

filename = "/tanay.zip"
with open(f"{filename}", "rb") as file_input:
        encoded_data = base64.b64encode(file_input.read())


encoded_data_str = encoded_data.decode('utf-8')

In [102]:
#@title Train the model

#@markdown The dataset needs to be a zip folder, with
#dataset_url = "https://drive.google.com/file/d/1inrIw_ObaIE6laOifwwxyme-Y5uVtniW/view?usp=drive_link" #@param {type: 'string'}
trigger_word = "tanay" #@param {type: 'string'}
steps = 1000 #@param {type: 'number'}

training = replicate.trainings.create(
    version="stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc",
    input={
        "input_images": f"data:application/zip;base64,{encoded_data_str}",
        "steps": steps,
        "use_face_detection_instead": True,
        "token_string": trigger_word,
        "is_lora": "true"
    },
    destination=f"{model.owner}/{model.name}",
)

print(f"Training started: {training.status}")
print(f"Training URL: https://replicate.com/p/{training.id}")

Training started: starting
Training URL: https://replicate.com/p/fa0kjbbkt9rm80cm71pa3sdsww


In [108]:
print(model.versions.list()[0].id)

ab48229064407355e46d0b84d2a1ea58c5c62c25b9fa6ae4d75a0ff7da864754


In [109]:
gen_model = f"sundai-club/kanch-sundai:{model.versions.list()[0].id}"
output = replicate.run(
    gen_model,
    input={
        "width": 768,
        "height": 768,
        "prompt": "13 year old indian asian kid playing basketball",
        "refine": "expert_ensemble_refiner",
        "scheduler": "K_EULER",
        "lora_scale": 0.6,
        "num_outputs": 1,
        "guidance_scale": 7.5,
        "apply_watermark": False,
        "high_noise_frac": 0.8,
        "negative_prompt": "",
        "prompt_strength": 0.8,
        "num_inference_steps": 25
    }
)

print(output)
image_url = output[0].url
display(Image(url=image_url))

[<replicate.helpers.FileOutput object at 0x7f6ff14e5600>]
