In [None]:
# This mounts your Google Drive to the Colab VM.
from google.colab import drive
import os
import sys

drive.mount('/content/drive', force_remount=True)

# TODO: Enter the foldername in your Drive where you have saved the unzipped
# assignment folder, e.g. 'cs231n/assignments/assignment3/'
FOLDERNAME = "cs231n/project/"
assert FOLDERNAME is not None, "[!] Enter the foldername."
PROJECT_PATH = f"/content/drive/My Drive/{FOLDERNAME}"
sys.path.append(PROJECT_PATH)

# Change working directory
os.chdir(PROJECT_PATH)

# Confirm
print("✅ Current working directory:", os.getcwd())
print("📁 Contents:", os.listdir('.'))

# 💡 LoRA Training on Stable Diffusion
This notebook trains custom LoRA adapters on Stable Diffusion using your own image-caption pairs.


In [None]:
!pip install -q diffusers transformers accelerate torchvision safetensors

In [None]:
from lora import LoRALinear
from patch_unet import patch_unet_with_lora

## Data Preparation

In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
from diffusers import StableDiffusionPipeline
from patch_unet import patch_unet_with_lora

class ImageTextDataset(Dataset):
    def __init__(self, files):
        self.samples = []
        for name in files.keys():
            if name.endswith(".png") or name.endswith(".jpg"):
                txt = name.rsplit(".", 1)[0] + ".txt"
                if txt in files:
                    self.samples.append((name, txt))
        self.transform = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_file, txt_file = self.samples[idx]
        image = Image.open(img_file).convert("RGB")
        image = self.transform(image)
        caption = files[txt_file].decode()
        return image, caption


class ImageTextDataset(Dataset):
    def __init__(self, root_dir="data"):
        self.root_dir = root_dir
        self.samples = []
        for fname in os.listdir(root_dir):
            if fname.lower().endswith((".jpg", ".png")):
                basename = os.path.splitext(fname)[0]
                txt_path = os.path.join(root_dir, basename + ".txt")
                if os.path.exists(txt_path):
                    self.samples.append((os.path.join(root_dir, fname), txt_path))

        self.transform = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, txt_path = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)
        with open(txt_path, "r") as f:
            caption = f.read().strip()
        return image, caption


In [None]:
dataset = ImageTextDataset("data")
loader = DataLoader(dataset, batch_size=1, shuffle=True)
# sanity check for datasets
print(f"Number of image-caption pairs: {len(dataset)}")

In [None]:
# suppress warnings due to missing fonts.
import warnings
import re

# This regex will match messages containing "Glyph" and "missing from font"
warnings.filterwarnings("ignore", category=UserWarning, message=r"Glyph .* missing from font\(s\) DejaVu Sans\.")


In [None]:
# preview a couple of images.
from matplotlib import pyplot as plt

for i in range(min(3, len(dataset))):  # Show up to 3 examples
    image, caption = dataset[i]
    plt.imshow(image.permute(1, 2, 0).numpy() * 0.5 + 0.5)  # Undo normalization
    plt.axis("off")
    plt.title(caption)
    plt.show()


## Training Logic for Lora Finetuning

In [None]:
device = "cuda"

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(device)

pipe.vae.requires_grad_(False)
pipe.text_encoder.requires_grad_(False)
pipe.unet.requires_grad_(False)

# ADD LORA
patch_unet_with_lora(pipe.unet, r=4, alpha=1.0, dropout=0.0)
pipe.unet.to(device)  # Move after patching

# unfreeze lora weights
for module in pipe.unet.modules():
    if "LoRALinear" in str(type(module)):
        for p in module.parameters():
            p.requires_grad = True

optimizer = torch.optim.Adam([p for p in pipe.unet.parameters() if p.requires_grad], lr=1e-4)


In [None]:
# training
for epoch in range(10):
    for i, (images, captions) in enumerate(loader):
        # Keep images in float32
        images = images.to(device)

        text_input = pipe.tokenizer(captions, padding="max_length", max_length=77, return_tensors="pt").to(device)
        text_embeds = pipe.text_encoder(**text_input).last_hidden_state.to(device)  # stays float32

        with torch.no_grad():
            vae_output = pipe.vae.encode(images)
        latents = vae_output.latent_dist.sample().to(device) * 0.18215  # float32

        noise = torch.randn_like(latents).to(device)  # float32 noise
        timesteps = torch.randint(0, 1000, (latents.shape[0],), device=device).long()

        noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)

        noise_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeds).sample  # float32

        loss = nn.MSELoss()(noise_pred, noise)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        print(f"Epoch {epoch} Step {i} Loss: {loss.item():.4f}")

In [None]:
# save model parameters

torch.save(
    {k: v.cpu() for k, v in pipe.unet.state_dict().items() if "lora" in k},
    "lora_weights.pth"
)

## Generate images with Trained LoRA

In [None]:
state_dict = torch.load("lora_weights.pth", map_location=device)
patch_unet_with_lora(pipe.unet, r=4, alpha=1.0)
pipe.unet.load_state_dict(state_dict, strict=False)
pipe.unet.eval()

prompt = "米芾"
image = pipe(prompt, num_inference_steps=30).images[0]
image.save("generated.png")
image.show()
display(image)   # shows inline in the notebook
