In [None]:
# This mounts your Google Drive to the Colab VM.
from google.colab import drive
import os
import sys

drive.mount('/content/drive', force_remount=True)

# TODO: Enter the foldername in your Drive where you have saved the unzipped
# assignment folder, e.g. 'cs231n/assignments/assignment3/'
FOLDERNAME = "cs231n/project/"
assert FOLDERNAME is not None, "[!] Enter the foldername."
PROJECT_PATH = f"/content/drive/My Drive/{FOLDERNAME}"
sys.path.append(PROJECT_PATH)

# Change working directory
os.chdir(PROJECT_PATH)

# Confirm
print("✅ Current working directory:", os.getcwd())
print("📁 Contents:", os.listdir('.'))

# 💡 LoRA Training on Stable Diffusion
This notebook trains custom LoRA adapters on Stable Diffusion using your own image-caption pairs.


In [None]:
!pip install -q diffusers transformers accelerate torchvision safetensors kornia wandb

## Data Preparation

In [None]:
import os

import torch
from torch.utils.data import Dataset, DataLoader
from dataset import ImageTextDataset

dataset = ImageTextDataset("data")
loader = DataLoader(dataset, batch_size=1, shuffle=True)

# sanity check for datasets
print(f"Number of image-caption pairs: {len(dataset)}")

In [None]:
# preview a couple of images.
from matplotlib import pyplot as plt
import warnings
import re

# This regex will match messages containing "Glyph" and "missing from font"
warnings.filterwarnings("ignore", category=UserWarning, message=r"Glyph .* missing from font\(s\) DejaVu Sans\.")

for i in range(min(2, len(dataset))):  # Show up to 3 examples
    image, caption = dataset[i]
    plt.imshow(image.permute(1, 2, 0).numpy() * 0.5 + 0.5)  # Undo normalization
    plt.axis("off")
    plt.title(caption)
    plt.show()


## Training Logic for Lora Finetuning

In [None]:
import wandb

import os
os.environ["WANDB_API_KEY"] = "20e7c4be307028d246fbff111508b75d9eaab1ee"
wandb.login()

wandb.init(
    project="stable-diffusion-calligraphy",
    name="米芾_褚遂良1_10epock_agumentNone_EnglishData",  # give each experiment a unique name
    config={
        "lora_rank": 8,
        "lora_alpha": 8,
        "lr": 1e-4,
        "epochs": 2,
        "augmentations": False,
        "conv_lora": True,
    }
)

In [None]:
from lora import LoRALinear, LoRAConv2d
from patch_unet import patch_unet_with_lora, conv_filter

device = "cuda"
from diffusers import StableDiffusionPipeline
import torch.nn as nn


pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float32).to(device)

pipe.vae.requires_grad_(False)
pipe.text_encoder.requires_grad_(False)
pipe.unet.requires_grad_(False)

# ADD LORA
patch_unet_with_lora(pipe.unet, r=8, alpha=8, dropout=0.0, conv_filter=None)

pipe.unet.to(device)  # Move after patching

# unfreeze lora weights
for module in pipe.unet.modules():
    if isinstance(module, (LoRALinear, LoRAConv2d)):
        for p in module.parameters():
            p.requires_grad = True

optimizer = torch.optim.Adam([p for p in pipe.unet.parameters() if p.requires_grad], lr=1e-4)


In [None]:
# patch_unet_with_lora(
#     pipe.unet,
#     r=4,
#     alpha=1.0,
#     dropout=0.1,
#     enable_linear=False,
#     enable_conv=True,
#     conv_filter=lambda name, mod: "down_blocks" in name and mod.kernel_size == (3, 3)
# )

In [None]:
# Loads the CLIP model and processor
from clip import calculate_clip_score
from transformers import CLIPProcessor, CLIPModel

clip_model_name = "openai/clip-vit-base-patch32"
clip_model = CLIPModel.from_pretrained(clip_model_name).to(device)
clip_processor = CLIPProcessor.from_pretrained(clip_model_name)

In [None]:
def generate_and_log_image(pipe, clip_model, clip_processor, device, prompt, epoch):
    """
    Generate an image from a prompt using the pipeline, calculate CLIP score,
    and log the result to Weights & Biases.

    Args:
        pipe: The diffusion pipeline.
        clip_model: The CLIP model for scoring.
        clip_processor: The processor for CLIP input.
        device: torch.device to use.
        prompt (str): Text prompt to generate image from.
        epoch (int): Current training epoch, for logging.
    """
    with torch.no_grad():
        image = pipe(prompt, num_inference_steps=30).images[0]

    # Calculate CLIP score
    clip_score = calculate_clip_score(image, prompt, clip_model, clip_processor, device)

    # Log image, score, and prompt to wandb
    wandb.log({
        "epoch": epoch,
        "generated_image": wandb.Image(image, caption=f"{prompt} | score: {clip_score:.3f}"),
        "clip_score": clip_score,
        "generation_prompt": prompt,
    })

    return image, clip_score  # optionally return for other use

In [None]:
import torch
from augment import augment
from patch_unet import save_lora_weights, load_lora_weights
import torch.nn as nn

# training
for epoch in range(3):
    for i, (images, captions) in enumerate(loader):
        # Keep images in float32
        images = images.to(device)

        images = augment(images)

        text_input = pipe.tokenizer(captions, padding="max_length", max_length=77, return_tensors="pt").to(device)
        text_embeds = pipe.text_encoder(**text_input).last_hidden_state.to(device)  # stays float32

        with torch.no_grad():
            vae_output = pipe.vae.encode(images)
        latents = vae_output.latent_dist.sample().to(device) * 0.18215  # float32

        noise = torch.randn_like(latents).to(device)  # float32 noise
        timesteps = torch.randint(0, 1000, (latents.shape[0],), device=device).long()

        noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)

        noise_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeds).sample  # float32

        loss = nn.MSELoss()(noise_pred, noise)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        print(f"Epoch {epoch} Step {i} Loss: {loss.item():.4f}")
        wandb.log({
            "epoch": epoch,
            "step": i,
            "loss": loss.item(),
            "lr": optimizer.param_groups[0]["lr"],
        })

    # Log one of the input images
    sample_image = (images[0].cpu().numpy() * 0.5 + 0.5).transpose(1, 2, 0)
    wandb.log({
        "input_image": wandb.Image(sample_image, caption=captions[0])
    })

    # Generate and log output image + clip score
    gen_prompt = "chinese calligraphy"
    generate_and_log_image(pipe, clip_model, clip_processor, device, gen_prompt, epoch)


save_lora_weights(pipe.unet, path="lora_weights.pth")

## Generate images with Trained LoRA

In [None]:
import os
import torch
import wandb

# Load LoRA weights and patch UNet
state_dict = torch.load("lora_weights.pth", map_location=device)
patch_unet_with_lora(pipe.unet, r=4, alpha=1.0)
pipe.unet.load_state_dict(state_dict, strict=False)
pipe.unet.eval()

# Setup output
output_dir = "generated_images"
os.makedirs(output_dir, exist_ok=True)

# Define prompts
chinese_prompts = [
    ("米芾 书法", "米芾_书法_generated"),
    ("褚遂良 书法", "褚遂良_书法_generated"),
    ("书法", "书法_generated")
]

# Define prompts
english_prompts = [
    ("mi fu chinese caligraphy", "米芾_caligraphy_generated"),
    ("chu suiliang chinese caligraphy", "褚遂良_caligraphy_generated"),
    ("deng shiru chinese caligraphy", "邓石如_caligraphy_generated"),
    ("chinese caligraphy", "caligraphy_generated")
]


def generate_and_log_images(prompt, filename_prefix):
    for i in range(3):
        image = pipe(prompt, num_inference_steps=30).images[0]
        clip_score = calculate_clip_score(
            image, prompt, clip_model, clip_processor, device
        )

        wandb.log({
            "prompt": prompt,
            "clip_score": clip_score,
            "output image": wandb.Image(image, caption=f"{prompt} | score: {clip_score:.3f}")
        })

        image_path = os.path.join(output_dir, f"{filename_prefix}_{i}.png")
        image.save(image_path)
        image.show()
        display(image)  # for inline notebook display

# Run generation for all prompts
for prompt_text, filename_prefix in english_prompts:
    generate_and_log_images(prompt_text, filename_prefix)


In [None]:
wandb.finish()