In [None]:
!git clone https://github.com/explainingai-code/StableDiffusion-PyTorch.git
%cd StableDiffusion-PyTorch

In [None]:
!pip install -r requirements.txt

In [None]:
!mkdir -p models/weights/v0.1
!wget -O models/weights/v0.1/vgg.pth https://github.com/richzhang/PerceptualSimilarity/raw/master/lpips/weights/v0.1/vgg.pth

In [None]:
# 创建数据集类文件
!mkdir -p dataset

with open('dataset/doppler_dataset.py', 'w') as f:
    f.write("""import glob
import os
import random
import torchvision
from PIL import Image
from tqdm import tqdm
from utils.diffusion_utils import load_latents
from torch.utils.data.dataset import Dataset

class DopplerDataset(Dataset):
    def __init__(self, split, im_path, im_size, im_channels,
                 use_latents=False, latent_path=None, condition_config=None,
                 train_ratio=0.8, seed=42):
        self.split = split
        self.im_size = im_size
        self.im_channels = im_channels
        self.latent_maps = None
        self.use_latents = False
        self.condition_types = [] if condition_config is None else condition_config['condition_types']
        self.image_to_class = {}
        
        # 加载所有图像路径
        all_images = self.load_nested_images(im_path)
        
        # 划分数据集(8:2)
        self.images = self.split_dataset(all_images, train_ratio, seed)
        
        # 加载潜在表示
        if use_latents and latent_path is not None:
            latent_maps = load_latents(latent_path)
            if len(latent_maps) > 0:
                self.use_latents = True
                self.latent_maps = latent_maps
                print(f'找到{len(self.latent_maps)}个潜在表示')
            else:
                print('未找到潜在表示')
    
    def load_nested_images(self, dataset_root):
        assert os.path.exists(dataset_root), f"数据集目录{dataset_root}不存在"
        all_images = []
        
        # 获取所有ID子目录
        id_folders = [f for f in os.listdir(dataset_root) if os.path.isdir(os.path.join(dataset_root, f)) and f.startswith('ID_')]
        id_folders.sort()
        
        # 构建ID到类别索引的映射
        id_to_class = {folder: idx for idx, folder in enumerate(id_folders)}
        
        for id_folder in tqdm(id_folders):
            folder_path = os.path.join(dataset_root, id_folder)
            
            # 支持多种图像格式
            for ext in ['*.png', '*.jpg', '*.jpeg']:
                image_paths = glob.glob(os.path.join(folder_path, ext))
                
                for img_path in image_paths:
                    self.image_to_class[img_path] = id_to_class[id_folder]
                    all_images.append(img_path)
        
        print(f'共找到{len(all_images)}张图像，来自{len(id_folders)}个用户ID')
        
        # 保存类别映射信息 - 修改为使用Kaggle可写目录
        os.makedirs(os.path.join('/kaggle/working', 'metadata'), exist_ok=True)
        import json
        with open(os.path.join('/kaggle/working', 'metadata', 'class_mapping.json'), 'w') as f:
            json.dump({
                'id_to_class': id_to_class,
                'num_classes': len(id_folders)
            }, f)
        
        return all_images
    
    def split_dataset(self, all_images, train_ratio, seed):
        random.seed(seed)
        random.shuffle(all_images)
        
        train_size = int(len(all_images) * train_ratio)
        
        if self.split == 'train':
            images = all_images[:train_size]
        else:  # val
            images = all_images[train_size:]
            
        print(f'{self.split}集包含{len(images)}张图像')
        
        if self.split == 'train':
            split_info = {
                'train': all_images[:train_size],
                'val': all_images[train_size:]
            }
            
            # 修改为使用Kaggle可写目录
            os.makedirs(os.path.join('/kaggle/working', 'metadata'), exist_ok=True)
            import pickle
            with open(os.path.join('/kaggle/working', 'metadata', 'split_info.pkl'), 'wb') as f:
                pickle.dump(split_info, f)
            
            print(f"数据集划分完成: 训练集={len(split_info['train'])}张, 验证集={len(split_info['val'])}张")
            
        return images
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        if self.use_latents:
            latent = self.latent_maps[self.images[index]]
            return latent
        else:
            im = Image.open(self.images[index])
            im = im.resize((self.im_size, self.im_size))
            im_tensor = torchvision.transforms.ToTensor()(im)
            
            # 转换为-1到1范围
            im_tensor = (2 * im_tensor) - 1
            return im_tensor
""")

print("数据集类文件已创建")

In [None]:
with open('config/doppler.yaml', 'w') as f:
    f.write("""dataset_params:
  im_path: '/kaggle/input/dataset'
  im_channels: 3
  im_size: 256  # 保持原分辨率
  name: 'doppler'

dataset_params:
  im_path: '/kaggle/input/dataset'
  im_channels: 3
  im_size: 256
  name: 'doppler'

diffusion_params:
  num_timesteps: 1000
  beta_start: 0.0015
  beta_end: 0.0195

ldm_params:
  down_channels: [128, 256, 256, 256]  # 调整为能被8整除
  mid_channels: [256, 256]  # 调整为能被8整除
  down_sample: [True, True, False]
  attn_down: [False, False, True]
  time_emb_dim: 256  # 调整为与通道数匹配
  norm_channels: 64  # 调整为能被8整除
  num_heads: 8  # 从4增加到8
  conv_out_channels: 128  # 调整为能被8整除
  num_down_layers: 1
  num_mid_layers: 1
  num_up_layers: 1

autoencoder_params:
  z_channels: 3
  codebook_size: 512  # 已提高到512
  down_channels: [64, 128, 128]  # 调整为能被8整除
  mid_channels: [128, 128]  # 调整为能被8整除
  down_sample: [True, True]
  attn_down: [False, False]
  norm_channels: 64  # 调整为能被8整除
  num_heads: 8  # 从4增加到8
  num_down_layers: 1
  num_mid_layers: 1
  num_up_layers: 1

train_params:
  # 其他参数保持不变
  seed: 1111
  task_name: '/kaggle/working/doppler'
  ldm_batch_size: 4
  autoencoder_batch_size: 4
  disc_start: 500
  disc_weight: 0.8
  codebook_weight: 2.0
  commitment_beta: 0.5
  perceptual_weight: 1.0
  kl_weight: 0.0
  ldm_epochs: 120
  autoencoder_epochs: 40
  num_samples: 8
  num_grid_rows: 2
  ldm_lr: 0.00001
  autoencoder_lr: 0.0001
  autoencoder_acc_steps: 8
  autoencoder_img_save_steps: 8
  save_latents: True
  vae_latent_dir_name: 'vae_latents'
  vqvae_latent_dir_name: 'vqvae_latents'
  ldm_ckpt_name: 'ddpm_ckpt.pth'
  vqvae_autoencoder_ckpt_name: 'vqvae_autoencoder_ckpt.pth'
  vae_autoencoder_ckpt_name: 'vae_autoencoder_ckpt.pth'
  vqvae_discriminator_ckpt_name: 'vqvae_discriminator_ckpt.pth'
  vae_discriminator_ckpt_name: 'vae_discriminator_ckpt.pth'
""")

In [None]:
# 注册数据集类 - 修改train_vqvae.py
with open('tools/train_vqvae.py', 'r') as file:
    content = file.read()

# 替换数据集字典
content = content.replace(
    "'celebhq': CelebDataset,", 
    "'celebhq': CelebDataset,\n        'doppler': DopplerDataset,"
)

# 添加导入语句
content = content.replace(
    "from dataset.mnist_dataset import MnistDataset",
    "from dataset.mnist_dataset import MnistDataset\nfrom dataset.doppler_dataset import DopplerDataset"
)

with open('tools/train_vqvae.py', 'w') as file:
    file.write(content)

# 对其他文件执行相同操作
for file_path in ['tools/train_ddpm_vqvae.py', 'tools/infer_vqvae.py']:
    with open(file_path, 'r') as file:
        content = file.read()
    
    content = content.replace(
        "'celebhq': CelebDataset,", 
        "'celebhq': CelebDataset,\n        'doppler': DopplerDataset,"
    )
    
    content = content.replace(
        "from dataset.mnist_dataset import MnistDataset",
        "from dataset.mnist_dataset import MnistDataset\nfrom dataset.doppler_dataset import DopplerDataset"
    )
    
    with open(file_path, 'w') as file:
        file.write(content)

print("已修改所有训练脚本")

In [None]:
# 准备工作目录
!mkdir -p /kaggle/working/doppler
!mkdir -p /kaggle/working/metadata
#1. 训练VQ-VAE自编码器
!python -m tools.train_vqvae --config config/doppler.yaml

In [None]:
!rm -rf /kaggle/working/doppler/vqvae_latents/*

In [None]:
# 2. 生成潜在表示
!python -m tools.infer_vqvae --config config/doppler.yaml

In [None]:
# 3. 训练LDM
!python -m tools.train_ddpm_vqvae --config config/doppler.yaml

In [None]:
# 4. 生成样本
!python -m tools.sample_ddpm_vqvae --config config/doppler.yaml