<a href="https://colab.research.google.com/github/gutsssssss/7150-final/blob/main/diffusion_clean.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
import torch

model_id = "stabilityai/stable-diffusion-2-1-base"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device).eval()
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device).eval()
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(device)
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")

# 冻结 VAE 和 text encoder 权重
for p in vae.parameters():
    p.requires_grad = False
for p in text_encoder.parameters():
    p.requires_grad = False


In [None]:
from datasets import load_dataset

dataset = load_dataset("lambdalabs/naruto-blip-captions", split="train[:500]")


In [None]:
from torch.utils.data import Dataset
import torchvision.transforms as T
import torch

class NarutoDataset(Dataset):
    def __init__(self, dataset, tokenizer, text_encoder, vae, noise_scheduler, image_size=256):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.text_encoder = text_encoder
        self.vae = vae
        self.noise_scheduler = noise_scheduler

        self.image_transforms = T.Compose([
            T.Resize((image_size, image_size)),
            T.ToTensor(),
            T.Normalize([0.5], [0.5]),
        ])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = self.image_transforms(item["image"])
        caption = item["text"]

        # Tokenize + encode caption → encoder_hidden_states
        tokenized = self.tokenizer(caption, padding="max_length", max_length=77, return_tensors="pt", truncation=True)
        input_ids = tokenized.input_ids.to(self.vae.device)
        with torch.no_grad():
            encoder_hidden_states = self.text_encoder(input_ids)[0].squeeze(0).cpu()

        # Encode image to latent
        image = image.unsqueeze(0).to(self.vae.device)
        with torch.no_grad():
            latents = self.vae.encode(image * 2 - 1).latent_dist.sample() * 0.18215
        latents = latents.squeeze(0).cpu()

        # Add noise
        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps, (1,), dtype=torch.long)
        noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)

        return {
            "noisy_latents": noisy_latents,
            "timesteps": timesteps.squeeze(0),
            "encoder_hidden_states": encoder_hidden_states,
            "target_noise": noise,
        }


In [None]:
from torch.utils.data import DataLoader

train_dataset = NarutoDataset(dataset, tokenizer, text_encoder, vae, scheduler)
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)


In [None]:
from torch.amp import autocast, GradScaler
from torch import nn
from tqdm import tqdm
import gc

optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-5)
loss_fn = nn.MSELoss()
scaler = GradScaler()
unet.train()

for epoch in range(80):
    pbar = tqdm(train_dataloader)
    for step, batch in enumerate(pbar):
        noisy_latents = batch["noisy_latents"].to(device)
        timesteps = batch["timesteps"].to(device)
        encoder_hidden_states = batch["encoder_hidden_states"].to(device)
        target = batch["target_noise"].to(device)

        with autocast(device_type="cuda"):
            noise_pred = unet(
                noisy_latents,
                timesteps,
                encoder_hidden_states=encoder_hidden_states
            ).sample
            loss = loss_fn(noise_pred, target)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        # 🚿 显存清理
        torch.cuda.empty_cache()
        gc.collect()

        pbar.set_description(f"loss: {loss.item():.4f}")


In [None]:
# 保存 UNet 模型权重
save_path = "fine_tuned_unet"
unet.save_pretrained(save_path)

In [None]:
import os

print(os.listdir("/content/drive/MyDrive"))

In [None]:
import shutil
import os

shutil.copytree(
    "/content/drive/MyDrive/fine_tuned_unet",
    "/content/fine_tuned_unet",
    dirs_exist_ok=True  # ✅ 如果目标目录已存在也允许拷贝（Python 3.8+）
)

In [None]:
from diffusers import StableDiffusionPipeline, UNet2DConditionModel
import torch
import os

# 1. 设置提示词
prompt = "A girl wears red skirt"

# 2. 加载原始 SD 管线（fp16）
pipe = StableDiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1-base",
    torch_dtype=torch.float16
).to("cuda")

# 3. 替换为微调后的 UNet
fine_tuned_unet = UNet2DConditionModel.from_pretrained(
    "/content/drive/MyDrive/fine_tuned_unet"
).to("cuda").half()
pipe.unet = fine_tuned_unet

# 4. 推理生成图像
image = pipe(
    prompt,
    num_inference_steps=50,    # 越高图像越细致，推荐 30~50
    guidance_scale=8         # 越高越贴近 prompt，推荐 7.0~10.0
).images[0]

# 5. 保存图像（自动命名）
filename = prompt.replace(" ", "_").replace(",", "")[:60] + ".png"
image.save(filename)


In [None]:
from google.colab import drive
drive.mount('/content/drive')

unet.save_pretrained("/content/fine_tuned_unet")


In [None]:
import os

print(os.path.exists("/content/fine_tuned_unet"))
print(os.listdir("/content/fine_tuned_unet"))

In [None]:
import shutil
shutil.move("/content/fine_tuned_unet", "/content/drive/MyDrive/fine_tuned_unet")

In [None]:
print(os.listdir("/content/drive/MyDrive/fine_tuned_unet"))

In [None]:
shutil.copytree(
    "/content/drive/MyDrive/fine_tuned_unet",
    "/content/fine_tuned_unet",
    dirs_exist_ok=True
)
