In [None]:
# VAE with LoRA fine-tuning script in notebook format

# ----------------------------------------------------------------------------
# This script demonstrates how to fine-tune a VAE with LoRA on the MNIST dataset
# ----------------------------------------------------------------------------

# Import libraries
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Define VAE with LoRA
# (... includes VAE architecture and LoRA injection as previously explained)

# Prepare dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5, ))
])
dataset = datasets.MNIST(root='mnist_data', transform=transform, download=True)
loader = DataLoader(dataset, batch_size=64, shuffle=True)

# Initialize VAE and LoRA
vae = VAEWithLoRA(in_features=..., out_features=...)
vae.train()

# Set up optimizer
optimizer = optim.Adam([p for p in vae.lora.parameters() if p.requires_grad], lr=1e-4)

# Train
for epoch in range(10):
    for imgs, _ in loader:
        imgs = imgs.to('cuda')
        reconstructed = vae(imgs)
        loss = compute_loss(reconstructed, imgs)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Save the trained model
torch.save(vae.state_dict(), 'vae_lora_trained.pt')
