In [None]:
#@markdown ### Mount Google Drive to save your output and access the training image.
#@markdown Leave empty if you're not using a shared drive.

from google.colab import drive
drive.mount('/content/drive')

In [None]:
#@markdown ### Clone the B-LoRA repository
!git clone https://github.com/itsitgroup/B-LoRA.git

# Change directory to the cloned repository.
%cd B-LoRA

In [None]:
#@markdown ### Install the required libraries from requirements.txt
!pip install -r requirements.txt

In [None]:
#@markdown ### Upload the Training Image

import os
from google.colab import files

# Function to upload the training image and get its path.
def get_image_path():
    uploaded = files.upload()
    if uploaded:
        return list(uploaded.keys())[0]
    else:
        raise Exception("Please upload a training image.")

# Get the path to the training image from the user.
image_path = get_image_path()

In [None]:
#@markdown ### Specify the Output Path
#@markdown Enter the path to save the output (leave blank to use default `/content/B-LoRA_output`)

#@param {type:"string"}
output_path = "" #@param {type:"string"}

# Create a directory for the output if the user doesn't specify one.
default_output_path = "/content/B-LoRA_output"
os.makedirs(default_output_path, exist_ok=True)

if output_path.strip() == "":
    output_path = default_output_path

# Display the paths for confirmation.
print(f"Training image path: {image_path}")
print(f"Output path: {output_path}")

In [None]:
#@markdown ### Specify Training Parameters
#@markdown Leave blank to use default values.

#@param {type:"string"}
steps = "" #@param {type:"string"}

#@param {type:"string"}
learning_rate = "" #@param {type:"string"}

In [None]:
#@markdown ### Run the Training Script

# Construct the command to run the training script.
command = f"python train_dreambooth_b-lora_sdxl.py --image_path {image_path} --output_dir {output_path}"

if steps.strip():
    command += f" --steps {steps}"
if learning_rate.strip():
    command += f" --lr {learning_rate}"

# Execute the command.
!{command}