In [None]:
!pip install -q diffusers==0.19.3 transformers accelerate safetensors huggingface_hub


# Mount data from google drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')  # authorize when prompted

# Example paths (adjust if your data is stored elsewhere in Drive)
# If you uploaded to Drive:
DATA_ROOT = "/content/drive/MyDrive/sd_dataset"   # <-- change this path
# Or, if you uploaded directly into the session:
# DATA_ROOT = "/content/dataset"

# For this notebook we expect:
TRAIN_IMAGES = f"{DATA_ROOT}/train/images"
CAPTIONS_JSON = f"{DATA_ROOT}/captions.json"

print("Train images folder:", TRAIN_IMAGES)
print("Captions JSON:", CAPTIONS_JSON)


# Login to Hugging face

In [None]:
from huggingface_hub import login
hf_token = input("Paste your Hugging Face token (read access): ").strip()
login(hf_token)


# Run training script

In [None]:
# Example run (execute in a code cell)
!python train_lora_colab.py \
  --images_dir "$TRAIN_IMAGES" \
  --captions_json "$CAPTIONS_JSON" \
  --output_dir "/content/outputs/lora_event" \
  --resolution 512 \
  --train_batch_size 1 \
  --learning_rate 1e-4 \
  --max_train_steps 1200 \
  --lora_rank 4 \
  --lora_alpha 16.0 \
  --save_every 300 \
  --log_every 20 \
  --merge_lora


# Inference: load merged model and generate images

In [None]:
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL
from transformers import CLIPTokenizer, CLIPTextModel
import torch
from PIL import Image
import os

device = "cuda" if torch.cuda.is_available() else "cpu"
base_model = "runwayml/stable-diffusion-v1-5"
merged_unet_dir = "/content/outputs/lora_event/merged_unet"  # produced by --merge_lora

# Load base pipeline then replace UNet with merged one
pipe = StableDiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.float16 if device=='cuda' else torch.float32)
merged_unet = UNet2DConditionModel.from_pretrained(merged_unet_dir).to(device)
pipe.unet = merged_unet.to(device)
pipe = pipe.to(device)

prompt = "A modern tech conference with people networking in a large exhibition hall, cinematic lighting"
generator = torch.Generator(device=device).manual_seed(42)
out = pipe(prompt, num_inference_steps=25, guidance_scale=7.5, generator=generator, height=512, width=512)
img = out.images[0]
out_path = "/content/lora_inference.png"
img.save(out_path)
print("Saved inference image to:", out_path)
