[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nbiish/fine-tune-sdxl-replicate/blob/main/fine-tune-sdxl.ipynb)

# SDXL Fine-tuning on Replicate

In [None]:
!pip install replicate
import os
import replicate
from google.colab import output
output.clear()

Authenticate by setting your token in an environment variable:

In [None]:
# get your token from https://replicate.com/account
from getpass import getpass

REPLICATE_API_TOKEN = getpass()
os.environ["REPLICATE_API_TOKEN"] = REPLICATE_API_TOKEN

## Prepare your training images

The training API expects a zip file containing your training images. A handful of images (5-6) is enough to fine-tune SDXL on a single person, but you might need more if your training subject is more complex or the images are very different. Keep the following guidelines in mind when preparing your images:

- Images can be of yourself, your pet, your favorite stuffed animal, or any unique object. For best results, your images should contain only the subject itself, with a minimum of background noise or other objects.
- Images can be in JPEG or PNG format.
- Dimensions and size don't matter.
- Filenames don't matter.
- Do not use images of other people without their consent.

## Upload you .zip file to this colab and rename it data.zip


## Create a model

You also need to create a Replicate model that will be the destination for the trained SDXL version. Go to [replicate.com/create](https://replicate.com/create) to create the model. In the example below we call it `my-name/my-model`.

You can make your model public or private. If your model is private, only you will be able to run it. If your model is public, anyone will be able to run it, but only you will be able to update it.

## Start the training

In [None]:
import requests
import json

# Set the API endpoint and headers
api_url = "https://dreambooth-api-experimental.replicate.com/v1/upload/data.zip"
headers = {"Authorization": f"Token {REPLICATE_API_TOKEN}"}

# Send a POST request to upload the data.zip file
response = requests.post(api_url, headers=headers)

# Get the upload URL and serving URL from the response
upload_url = json.loads(response.text)["upload_url"]
serving_url = json.loads(response.text)["serving_url"]

# Upload the data.zip file using the upload URL
with open("/content/data.zip", "rb") as file:
    upload_response = requests.put(upload_url, data=file, headers={"Content-Type": "application/zip"})

# Print the serving URL
print(serving_url)


### Normal SDXL Training (includes clothing)

In [None]:
import replicate

training = replicate.trainings.create(
    version="stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
    input={
        "input_images": "" #input url from last cell
    },
    destination="" #yourReplicateName/yourReplicateRepo
)

### LORA with Facetrain (excludes clothing)

In [None]:
import replicate

training = replicate.trainings.create(
    version="stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
    input={
        "input_images": "",
        "use_face_detection_instead": True,
    },
    destination=""
)

### SDXL Train Style 
* speed up the lora learning rate, this stops the training from focusing too closely on the details. Experiment with different values like 1e-4, 2e-4. Our Barbie fine-tune used 4e-4.

* use a different caption_prefix to refer to a style

In [None]:
import replicate

training = replicate.trainings.create(
    version="stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b",
    input={
        "input_images": "https://my-domain/style-images.zip",
        "lora_lr": 2e-4,
        "caption_prefix": 'In the style of TOK,',
    },
    destination=""
)

## Monitor training progress

To follow the progress of the training job, visit [replicate.com/trainings](https://replicate.com/trainings) or inspect the training programmatically:

In [None]:
training.reload()
print(training.status)
if training.status == 'processing':
  print("\n".join(training.logs.split("\n")[-10:]))

## Run the model

When the model has finished training you can run it using the GUI on replicate.com/my-name/my-model, or via the API:


In [None]:
output = replicate.run(
    destination_repo,
    input={"prompt": "a photo of TOK riding a rainbow unicorn"},
)

The trained concept is named `TOK` by default, but you can change that by setting `token_string` and `caption_prefix` inputs during the training process.