In [1]:
# %pip install --upgrade torch diffusers

In [2]:
# !pip install -r requirements.txt

In [3]:
# dbutils.library.restartPython()

In [5]:
# change working dir
%cd ..

d:\Projects\Github\UNet Diffusion


In [6]:
import os
import torch
from tqdm import tqdm

from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import CLIPTokenizer, CLIPTextModel

from utils.ema import create_ema_model
from utils.checkpoint import save_training_state, load_training_state
from utils.celeba_with_caption import CelebAloader
from utils.metrics.gpu import init_nvml, gpu_info
from omegaconf import OmegaConf
import lpips

In [7]:
if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
        torch.cuda.empty_cache()

torch.manual_seed(99)
handle = init_nvml()

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Enable mixed precision training
scaler = torch.amp.GradScaler('cuda') if device == "cuda" else None
print("Mixed precision training enabled" if scaler is not None else "Mixed precision training disabled")

Using device: cuda
Mixed precision training enabled


In [8]:
# Load configuration
config = OmegaConf.load("configs/train_config_256.yaml")
# config = OmegaConf.load("configs/train_config_512.yaml"
print(f"Configuration loaded: {OmegaConf.to_yaml(config)}")
#==================================================================

# === Load VAE from diffusers ===
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(device).eval()

# === Load DiT from diffusers ===
block_out_ch = (256, 512, 1024, 1024)

model = UNet2DConditionModel(
    sample_size=config.model.sample_size,
    in_channels=config.model.in_channels,
    out_channels=config.model.out_channels,
    down_block_types=config.model.down_block_types,
    up_block_types=config.model.up_block_types,
    block_out_channels=config.model.block_out_channels,
    layers_per_block=config.model.layers_per_block,
    cross_attention_dim=config.model.cross_attention_dim,
).to(device)


# === Load noise scheduler from diffusers ===
scheduler = DDPMScheduler(
    num_train_timesteps=config.scheduler.timesteps,
    beta_start=config.scheduler.beta_start,
    beta_end=config.scheduler.beta_end,
    beta_schedule="linear",
)
# === Load CLIP tokenizer and encoder ===
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
clip_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device).eval()

# EMA model
unet_ema_model, ema = create_ema_model(model, beta=config.training.ema_beta, step_start_ema=config.training.step_start_ema)

# Optimizer for Unet model
optimizer = torch.optim.AdamW(model.parameters(), lr=config.training.lr)

# losses
MSE_LOSS_Unet = torch.nn.MSELoss()
LPIPS_LOSS = lpips.LPIPS(net='vgg').to(device).eval()

print("Models, optimizers, losses initialized successfully.")

Configuration loaded: data:
  path: data/CelebA-HQ/images_512
  parquet_path: /dbfs/mnt/ds-space/Hitesh/Datasets/CelebA-HQ/parquet_files/CelebA-HQ.parquet
  image_size: 256
  normalize: true
  caption_path: data/CelebA-HQ/captions.csv
checkpoint:
  path: Checkpoints
  ckpt_name: UNet_ckpt_256.pth
  ema_ckpt_name: UNet_ema_ckpt_256.pth
output_dir:
  train: output/train
  test: output/test
training:
  batch_size: 8
  validation_split: 0.995
  epochs: 100
  warmup_epochs: 10
  lr: 0.0001
  grad_accum_steps: 2
  use_ema: true
  ema_beta: 0.995
  step_start_ema: 2000
  num_workers: 4
sampling:
  dir: output/samples
  num_samples: 25
  steps: 50
model:
  type: unet
  sample_size: 32
  in_channels: 4
  out_channels: 4
  block_out_channels:
  - 320
  - 640
  - 1280
  - 1280
  down_block_types:
  - CrossAttnDownBlock2D
  - CrossAttnDownBlock2D
  - DownBlock2D
  - DownBlock2D
  up_block_types:
  - UpBlock2D
  - UpBlock2D
  - CrossAttnUpBlock2D
  - CrossAttnUpBlock2D
  layers_per_block: 2
  cross



Loading model from: c:\Users\Incognito-R\miniconda3\envs\ml_env\Lib\site-packages\lpips\weights\v0.1\vgg.pth
Models, optimizers, losses initialized successfully.


In [9]:
# === Load data ===
dataloader, _ = CelebAloader(data_config=config.data, train_config=config.training)
print(f"Dataset size: {len(dataloader.dataset)} images")
print(f"Batch size: {dataloader.batch_size}")

# batch = next(iter(dataloader))
# print(f"Batch image shape: {batch['image'].shape}, Batch captions: {len(batch['caption'])}, Batch images path: {len(batch['img_path'])}")
# Dataset size: 30000 images
# Batch image shape: torch.Size([12, 3, 256, 256]), Batch captions: 12, Batch images path: 12

# === Load checkpoint ===
checkpoint_dir = config.checkpoint.path
unet_ckpt_path = os.path.join(checkpoint_dir, config.checkpoint.ckpt_name)
# ckpt_path = "checkpoints/unet_diffusion_ckpt_256.pth"
# ema_ckpt_path = "checkpoints/unet_diffusion_ema_ckpt_256.pth"

start_epoch, best_loss = load_training_state(unet_ckpt_path, model, optimizer, device)
print(f"Resuming training from epoch {start_epoch} with best loss {best_loss:.4f}")

Dataset size: 150 images
Batch size: 8
✅ Resuming from checkpoint: Checkpoints\UNet_ckpt_256.pth


RuntimeError: Error(s) in loading state_dict for UNet2DConditionModel:
	size mismatch for conv_in.weight: copying a param with shape torch.Size([256, 4, 3, 3]) from checkpoint, the shape in current model is torch.Size([320, 4, 3, 3]).
	size mismatch for conv_in.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for time_embedding.linear_1.weight: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([1280, 320]).
	size mismatch for time_embedding.linear_1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for time_embedding.linear_2.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
	size mismatch for time_embedding.linear_2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.0.attentions.0.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.0.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.0.proj_in.weight: copying a param with shape torch.Size([256, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([320, 320, 1, 1]).
	size mismatch for down_blocks.0.attentions.0.proj_in.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_q.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([256, 768]) from checkpoint, the shape in current model is torch.Size([320, 768]).
	size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([256, 768]) from checkpoint, the shape in current model is torch.Size([320, 768]).
	size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.attn2.to_out.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.norm3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.norm3.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.weight: copying a param with shape torch.Size([2048, 256]) from checkpoint, the shape in current model is torch.Size([2560, 320]).
	size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.ff.net.0.proj.bias: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([2560]).
	size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([320, 1280]).
	size mismatch for down_blocks.0.attentions.0.transformer_blocks.0.ff.net.2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.0.proj_out.weight: copying a param with shape torch.Size([256, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([320, 320, 1, 1]).
	size mismatch for down_blocks.0.attentions.0.proj_out.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.1.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.1.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.1.proj_in.weight: copying a param with shape torch.Size([256, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([320, 320, 1, 1]).
	size mismatch for down_blocks.0.attentions.1.proj_in.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.1.transformer_blocks.0.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.1.transformer_blocks.0.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_q.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_k.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_v.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for down_blocks.0.attentions.1.transformer_blocks.0.attn1.to_out.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.1.transformer_blocks.0.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.1.transformer_blocks.0.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_q.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([256, 768]) from checkpoint, the shape in current model is torch.Size([320, 768]).
	size mismatch for down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([256, 768]) from checkpoint, the shape in current model is torch.Size([320, 768]).
	size mismatch for down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for down_blocks.0.attentions.1.transformer_blocks.0.attn2.to_out.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.1.transformer_blocks.0.norm3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.1.transformer_blocks.0.norm3.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.weight: copying a param with shape torch.Size([2048, 256]) from checkpoint, the shape in current model is torch.Size([2560, 320]).
	size mismatch for down_blocks.0.attentions.1.transformer_blocks.0.ff.net.0.proj.bias: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([2560]).
	size mismatch for down_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([320, 1280]).
	size mismatch for down_blocks.0.attentions.1.transformer_blocks.0.ff.net.2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.attentions.1.proj_out.weight: copying a param with shape torch.Size([256, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([320, 320, 1, 1]).
	size mismatch for down_blocks.0.attentions.1.proj_out.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.resnets.0.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.resnets.0.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.resnets.0.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([320, 320, 3, 3]).
	size mismatch for down_blocks.0.resnets.0.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.resnets.0.time_emb_proj.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([320, 1280]).
	size mismatch for down_blocks.0.resnets.0.time_emb_proj.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.resnets.0.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.resnets.0.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.resnets.0.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([320, 320, 3, 3]).
	size mismatch for down_blocks.0.resnets.0.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.resnets.1.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.resnets.1.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.resnets.1.conv1.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([320, 320, 3, 3]).
	size mismatch for down_blocks.0.resnets.1.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.resnets.1.time_emb_proj.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([320, 1280]).
	size mismatch for down_blocks.0.resnets.1.time_emb_proj.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.resnets.1.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.resnets.1.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.resnets.1.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([320, 320, 3, 3]).
	size mismatch for down_blocks.0.resnets.1.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.0.downsamplers.0.conv.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([320, 320, 3, 3]).
	size mismatch for down_blocks.0.downsamplers.0.conv.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.1.attentions.0.norm.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.0.norm.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.0.proj_in.weight: copying a param with shape torch.Size([512, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([640, 640, 1, 1]).
	size mismatch for down_blocks.1.attentions.0.proj_in.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.0.transformer_blocks.0.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.0.transformer_blocks.0.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_q.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_v.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_out.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.0.transformer_blocks.0.norm2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.0.transformer_blocks.0.norm2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_q.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([512, 768]) from checkpoint, the shape in current model is torch.Size([640, 768]).
	size mismatch for down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([512, 768]) from checkpoint, the shape in current model is torch.Size([640, 768]).
	size mismatch for down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for down_blocks.1.attentions.0.transformer_blocks.0.attn2.to_out.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.0.transformer_blocks.0.norm3.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.0.transformer_blocks.0.norm3.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([5120, 640]).
	size mismatch for down_blocks.1.attentions.0.transformer_blocks.0.ff.net.0.proj.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([5120]).
	size mismatch for down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.weight: copying a param with shape torch.Size([512, 2048]) from checkpoint, the shape in current model is torch.Size([640, 2560]).
	size mismatch for down_blocks.1.attentions.0.transformer_blocks.0.ff.net.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.0.proj_out.weight: copying a param with shape torch.Size([512, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([640, 640, 1, 1]).
	size mismatch for down_blocks.1.attentions.0.proj_out.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.1.norm.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.1.norm.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.1.proj_in.weight: copying a param with shape torch.Size([512, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([640, 640, 1, 1]).
	size mismatch for down_blocks.1.attentions.1.proj_in.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.1.transformer_blocks.0.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.1.transformer_blocks.0.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_q.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_k.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_v.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for down_blocks.1.attentions.1.transformer_blocks.0.attn1.to_out.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.1.transformer_blocks.0.norm2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.1.transformer_blocks.0.norm2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_q.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([512, 768]) from checkpoint, the shape in current model is torch.Size([640, 768]).
	size mismatch for down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([512, 768]) from checkpoint, the shape in current model is torch.Size([640, 768]).
	size mismatch for down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for down_blocks.1.attentions.1.transformer_blocks.0.attn2.to_out.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.1.transformer_blocks.0.norm3.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.1.transformer_blocks.0.norm3.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([5120, 640]).
	size mismatch for down_blocks.1.attentions.1.transformer_blocks.0.ff.net.0.proj.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([5120]).
	size mismatch for down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.weight: copying a param with shape torch.Size([512, 2048]) from checkpoint, the shape in current model is torch.Size([640, 2560]).
	size mismatch for down_blocks.1.attentions.1.transformer_blocks.0.ff.net.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.attentions.1.proj_out.weight: copying a param with shape torch.Size([512, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([640, 640, 1, 1]).
	size mismatch for down_blocks.1.attentions.1.proj_out.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.resnets.0.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.1.resnets.0.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for down_blocks.1.resnets.0.conv1.weight: copying a param with shape torch.Size([512, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([640, 320, 3, 3]).
	size mismatch for down_blocks.1.resnets.0.conv1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.resnets.0.time_emb_proj.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([640, 1280]).
	size mismatch for down_blocks.1.resnets.0.time_emb_proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.resnets.0.norm2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.resnets.0.norm2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.resnets.0.conv2.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([640, 640, 3, 3]).
	size mismatch for down_blocks.1.resnets.0.conv2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.resnets.0.conv_shortcut.weight: copying a param with shape torch.Size([512, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([640, 320, 1, 1]).
	size mismatch for down_blocks.1.resnets.0.conv_shortcut.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.resnets.1.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.resnets.1.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.resnets.1.conv1.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([640, 640, 3, 3]).
	size mismatch for down_blocks.1.resnets.1.conv1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.resnets.1.time_emb_proj.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([640, 1280]).
	size mismatch for down_blocks.1.resnets.1.time_emb_proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.resnets.1.norm2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.resnets.1.norm2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.resnets.1.conv2.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([640, 640, 3, 3]).
	size mismatch for down_blocks.1.resnets.1.conv2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.1.downsamplers.0.conv.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([640, 640, 3, 3]).
	size mismatch for down_blocks.1.downsamplers.0.conv.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.2.resnets.0.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.2.resnets.0.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for down_blocks.2.resnets.0.conv1.weight: copying a param with shape torch.Size([1024, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 640, 3, 3]).
	size mismatch for down_blocks.2.resnets.0.conv1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.2.resnets.0.time_emb_proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
	size mismatch for down_blocks.2.resnets.0.time_emb_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.2.resnets.0.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.2.resnets.0.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.2.resnets.0.conv2.weight: copying a param with shape torch.Size([1024, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 3, 3]).
	size mismatch for down_blocks.2.resnets.0.conv2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.2.resnets.0.conv_shortcut.weight: copying a param with shape torch.Size([1024, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([1280, 640, 1, 1]).
	size mismatch for down_blocks.2.resnets.0.conv_shortcut.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.2.resnets.1.norm1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.2.resnets.1.norm1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.2.resnets.1.conv1.weight: copying a param with shape torch.Size([1024, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 3, 3]).
	size mismatch for down_blocks.2.resnets.1.conv1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.2.resnets.1.time_emb_proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
	size mismatch for down_blocks.2.resnets.1.time_emb_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.2.resnets.1.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.2.resnets.1.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.2.resnets.1.conv2.weight: copying a param with shape torch.Size([1024, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 3, 3]).
	size mismatch for down_blocks.2.resnets.1.conv2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.2.downsamplers.0.conv.weight: copying a param with shape torch.Size([1024, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 3, 3]).
	size mismatch for down_blocks.2.downsamplers.0.conv.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.3.resnets.0.norm1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.3.resnets.0.norm1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.3.resnets.0.conv1.weight: copying a param with shape torch.Size([1024, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 3, 3]).
	size mismatch for down_blocks.3.resnets.0.conv1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.3.resnets.0.time_emb_proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
	size mismatch for down_blocks.3.resnets.0.time_emb_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.3.resnets.0.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.3.resnets.0.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.3.resnets.0.conv2.weight: copying a param with shape torch.Size([1024, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 3, 3]).
	size mismatch for down_blocks.3.resnets.0.conv2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.3.resnets.1.norm1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.3.resnets.1.norm1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.3.resnets.1.conv1.weight: copying a param with shape torch.Size([1024, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 3, 3]).
	size mismatch for down_blocks.3.resnets.1.conv1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.3.resnets.1.time_emb_proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
	size mismatch for down_blocks.3.resnets.1.time_emb_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.3.resnets.1.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.3.resnets.1.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for down_blocks.3.resnets.1.conv2.weight: copying a param with shape torch.Size([1024, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 3, 3]).
	size mismatch for down_blocks.3.resnets.1.conv2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.0.resnets.0.norm1.weight: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([2560]).
	size mismatch for up_blocks.0.resnets.0.norm1.bias: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([2560]).
	size mismatch for up_blocks.0.resnets.0.conv1.weight: copying a param with shape torch.Size([1024, 2048, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 2560, 3, 3]).
	size mismatch for up_blocks.0.resnets.0.conv1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.0.resnets.0.time_emb_proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
	size mismatch for up_blocks.0.resnets.0.time_emb_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.0.resnets.0.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.0.resnets.0.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.0.resnets.0.conv2.weight: copying a param with shape torch.Size([1024, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 3, 3]).
	size mismatch for up_blocks.0.resnets.0.conv2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.0.resnets.0.conv_shortcut.weight: copying a param with shape torch.Size([1024, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([1280, 2560, 1, 1]).
	size mismatch for up_blocks.0.resnets.0.conv_shortcut.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.0.resnets.1.norm1.weight: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([2560]).
	size mismatch for up_blocks.0.resnets.1.norm1.bias: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([2560]).
	size mismatch for up_blocks.0.resnets.1.conv1.weight: copying a param with shape torch.Size([1024, 2048, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 2560, 3, 3]).
	size mismatch for up_blocks.0.resnets.1.conv1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.0.resnets.1.time_emb_proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
	size mismatch for up_blocks.0.resnets.1.time_emb_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.0.resnets.1.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.0.resnets.1.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.0.resnets.1.conv2.weight: copying a param with shape torch.Size([1024, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 3, 3]).
	size mismatch for up_blocks.0.resnets.1.conv2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.0.resnets.1.conv_shortcut.weight: copying a param with shape torch.Size([1024, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([1280, 2560, 1, 1]).
	size mismatch for up_blocks.0.resnets.1.conv_shortcut.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.0.resnets.2.norm1.weight: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([2560]).
	size mismatch for up_blocks.0.resnets.2.norm1.bias: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([2560]).
	size mismatch for up_blocks.0.resnets.2.conv1.weight: copying a param with shape torch.Size([1024, 2048, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 2560, 3, 3]).
	size mismatch for up_blocks.0.resnets.2.conv1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.0.resnets.2.time_emb_proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
	size mismatch for up_blocks.0.resnets.2.time_emb_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.0.resnets.2.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.0.resnets.2.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.0.resnets.2.conv2.weight: copying a param with shape torch.Size([1024, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 3, 3]).
	size mismatch for up_blocks.0.resnets.2.conv2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.0.resnets.2.conv_shortcut.weight: copying a param with shape torch.Size([1024, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([1280, 2560, 1, 1]).
	size mismatch for up_blocks.0.resnets.2.conv_shortcut.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.0.upsamplers.0.conv.weight: copying a param with shape torch.Size([1024, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 3, 3]).
	size mismatch for up_blocks.0.upsamplers.0.conv.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.1.resnets.0.norm1.weight: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([2560]).
	size mismatch for up_blocks.1.resnets.0.norm1.bias: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([2560]).
	size mismatch for up_blocks.1.resnets.0.conv1.weight: copying a param with shape torch.Size([1024, 2048, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 2560, 3, 3]).
	size mismatch for up_blocks.1.resnets.0.conv1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.1.resnets.0.time_emb_proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
	size mismatch for up_blocks.1.resnets.0.time_emb_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.1.resnets.0.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.1.resnets.0.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.1.resnets.0.conv2.weight: copying a param with shape torch.Size([1024, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 3, 3]).
	size mismatch for up_blocks.1.resnets.0.conv2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.1.resnets.0.conv_shortcut.weight: copying a param with shape torch.Size([1024, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([1280, 2560, 1, 1]).
	size mismatch for up_blocks.1.resnets.0.conv_shortcut.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.1.resnets.1.norm1.weight: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([2560]).
	size mismatch for up_blocks.1.resnets.1.norm1.bias: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([2560]).
	size mismatch for up_blocks.1.resnets.1.conv1.weight: copying a param with shape torch.Size([1024, 2048, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 2560, 3, 3]).
	size mismatch for up_blocks.1.resnets.1.conv1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.1.resnets.1.time_emb_proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
	size mismatch for up_blocks.1.resnets.1.time_emb_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.1.resnets.1.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.1.resnets.1.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.1.resnets.1.conv2.weight: copying a param with shape torch.Size([1024, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 3, 3]).
	size mismatch for up_blocks.1.resnets.1.conv2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.1.resnets.1.conv_shortcut.weight: copying a param with shape torch.Size([1024, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([1280, 2560, 1, 1]).
	size mismatch for up_blocks.1.resnets.1.conv_shortcut.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.1.resnets.2.norm1.weight: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([1920]).
	size mismatch for up_blocks.1.resnets.2.norm1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([1920]).
	size mismatch for up_blocks.1.resnets.2.conv1.weight: copying a param with shape torch.Size([1024, 1536, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 1920, 3, 3]).
	size mismatch for up_blocks.1.resnets.2.conv1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.1.resnets.2.time_emb_proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
	size mismatch for up_blocks.1.resnets.2.time_emb_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.1.resnets.2.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.1.resnets.2.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.1.resnets.2.conv2.weight: copying a param with shape torch.Size([1024, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 3, 3]).
	size mismatch for up_blocks.1.resnets.2.conv2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.1.resnets.2.conv_shortcut.weight: copying a param with shape torch.Size([1024, 1536, 1, 1]) from checkpoint, the shape in current model is torch.Size([1280, 1920, 1, 1]).
	size mismatch for up_blocks.1.resnets.2.conv_shortcut.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.1.upsamplers.0.conv.weight: copying a param with shape torch.Size([1024, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 3, 3]).
	size mismatch for up_blocks.1.upsamplers.0.conv.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.2.attentions.0.norm.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.0.norm.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.0.proj_in.weight: copying a param with shape torch.Size([512, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([640, 640, 1, 1]).
	size mismatch for up_blocks.2.attentions.0.proj_in.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.0.transformer_blocks.0.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.0.transformer_blocks.0.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_q.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_k.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_v.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for up_blocks.2.attentions.0.transformer_blocks.0.attn1.to_out.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.0.transformer_blocks.0.norm2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.0.transformer_blocks.0.norm2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_q.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([512, 768]) from checkpoint, the shape in current model is torch.Size([640, 768]).
	size mismatch for up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([512, 768]) from checkpoint, the shape in current model is torch.Size([640, 768]).
	size mismatch for up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for up_blocks.2.attentions.0.transformer_blocks.0.attn2.to_out.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.0.transformer_blocks.0.norm3.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.0.transformer_blocks.0.norm3.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([5120, 640]).
	size mismatch for up_blocks.2.attentions.0.transformer_blocks.0.ff.net.0.proj.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([5120]).
	size mismatch for up_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.weight: copying a param with shape torch.Size([512, 2048]) from checkpoint, the shape in current model is torch.Size([640, 2560]).
	size mismatch for up_blocks.2.attentions.0.transformer_blocks.0.ff.net.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.0.proj_out.weight: copying a param with shape torch.Size([512, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([640, 640, 1, 1]).
	size mismatch for up_blocks.2.attentions.0.proj_out.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.1.norm.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.1.norm.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.1.proj_in.weight: copying a param with shape torch.Size([512, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([640, 640, 1, 1]).
	size mismatch for up_blocks.2.attentions.1.proj_in.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.1.transformer_blocks.0.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.1.transformer_blocks.0.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_q.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_k.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_v.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for up_blocks.2.attentions.1.transformer_blocks.0.attn1.to_out.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.1.transformer_blocks.0.norm2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.1.transformer_blocks.0.norm2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_q.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([512, 768]) from checkpoint, the shape in current model is torch.Size([640, 768]).
	size mismatch for up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([512, 768]) from checkpoint, the shape in current model is torch.Size([640, 768]).
	size mismatch for up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for up_blocks.2.attentions.1.transformer_blocks.0.attn2.to_out.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.1.transformer_blocks.0.norm3.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.1.transformer_blocks.0.norm3.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([5120, 640]).
	size mismatch for up_blocks.2.attentions.1.transformer_blocks.0.ff.net.0.proj.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([5120]).
	size mismatch for up_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.weight: copying a param with shape torch.Size([512, 2048]) from checkpoint, the shape in current model is torch.Size([640, 2560]).
	size mismatch for up_blocks.2.attentions.1.transformer_blocks.0.ff.net.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.1.proj_out.weight: copying a param with shape torch.Size([512, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([640, 640, 1, 1]).
	size mismatch for up_blocks.2.attentions.1.proj_out.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.2.norm.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.2.norm.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.2.proj_in.weight: copying a param with shape torch.Size([512, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([640, 640, 1, 1]).
	size mismatch for up_blocks.2.attentions.2.proj_in.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.2.transformer_blocks.0.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.2.transformer_blocks.0.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_q.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_k.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_v.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_out.0.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for up_blocks.2.attentions.2.transformer_blocks.0.attn1.to_out.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.2.transformer_blocks.0.norm2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.2.transformer_blocks.0.norm2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_q.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([512, 768]) from checkpoint, the shape in current model is torch.Size([640, 768]).
	size mismatch for up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([512, 768]) from checkpoint, the shape in current model is torch.Size([640, 768]).
	size mismatch for up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0.weight: copying a param with shape torch.Size([512, 512]) from checkpoint, the shape in current model is torch.Size([640, 640]).
	size mismatch for up_blocks.2.attentions.2.transformer_blocks.0.attn2.to_out.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.2.transformer_blocks.0.norm3.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.2.transformer_blocks.0.norm3.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj.weight: copying a param with shape torch.Size([4096, 512]) from checkpoint, the shape in current model is torch.Size([5120, 640]).
	size mismatch for up_blocks.2.attentions.2.transformer_blocks.0.ff.net.0.proj.bias: copying a param with shape torch.Size([4096]) from checkpoint, the shape in current model is torch.Size([5120]).
	size mismatch for up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2.weight: copying a param with shape torch.Size([512, 2048]) from checkpoint, the shape in current model is torch.Size([640, 2560]).
	size mismatch for up_blocks.2.attentions.2.transformer_blocks.0.ff.net.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.attentions.2.proj_out.weight: copying a param with shape torch.Size([512, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([640, 640, 1, 1]).
	size mismatch for up_blocks.2.attentions.2.proj_out.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.resnets.0.norm1.weight: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([1920]).
	size mismatch for up_blocks.2.resnets.0.norm1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([1920]).
	size mismatch for up_blocks.2.resnets.0.conv1.weight: copying a param with shape torch.Size([512, 1536, 3, 3]) from checkpoint, the shape in current model is torch.Size([640, 1920, 3, 3]).
	size mismatch for up_blocks.2.resnets.0.conv1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.resnets.0.time_emb_proj.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([640, 1280]).
	size mismatch for up_blocks.2.resnets.0.time_emb_proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.resnets.0.norm2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.resnets.0.norm2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.resnets.0.conv2.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([640, 640, 3, 3]).
	size mismatch for up_blocks.2.resnets.0.conv2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.resnets.0.conv_shortcut.weight: copying a param with shape torch.Size([512, 1536, 1, 1]) from checkpoint, the shape in current model is torch.Size([640, 1920, 1, 1]).
	size mismatch for up_blocks.2.resnets.0.conv_shortcut.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.resnets.1.norm1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.2.resnets.1.norm1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for up_blocks.2.resnets.1.conv1.weight: copying a param with shape torch.Size([512, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([640, 1280, 3, 3]).
	size mismatch for up_blocks.2.resnets.1.conv1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.resnets.1.time_emb_proj.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([640, 1280]).
	size mismatch for up_blocks.2.resnets.1.time_emb_proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.resnets.1.norm2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.resnets.1.norm2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.resnets.1.conv2.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([640, 640, 3, 3]).
	size mismatch for up_blocks.2.resnets.1.conv2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.resnets.1.conv_shortcut.weight: copying a param with shape torch.Size([512, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([640, 1280, 1, 1]).
	size mismatch for up_blocks.2.resnets.1.conv_shortcut.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.resnets.2.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([960]).
	size mismatch for up_blocks.2.resnets.2.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([960]).
	size mismatch for up_blocks.2.resnets.2.conv1.weight: copying a param with shape torch.Size([512, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([640, 960, 3, 3]).
	size mismatch for up_blocks.2.resnets.2.conv1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.resnets.2.time_emb_proj.weight: copying a param with shape torch.Size([512, 1024]) from checkpoint, the shape in current model is torch.Size([640, 1280]).
	size mismatch for up_blocks.2.resnets.2.time_emb_proj.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.resnets.2.norm2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.resnets.2.norm2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.resnets.2.conv2.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([640, 640, 3, 3]).
	size mismatch for up_blocks.2.resnets.2.conv2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.resnets.2.conv_shortcut.weight: copying a param with shape torch.Size([512, 768, 1, 1]) from checkpoint, the shape in current model is torch.Size([640, 960, 1, 1]).
	size mismatch for up_blocks.2.resnets.2.conv_shortcut.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.2.upsamplers.0.conv.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([640, 640, 3, 3]).
	size mismatch for up_blocks.2.upsamplers.0.conv.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.3.attentions.0.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.0.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.0.proj_in.weight: copying a param with shape torch.Size([256, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([320, 320, 1, 1]).
	size mismatch for up_blocks.3.attentions.0.proj_in.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.0.transformer_blocks.0.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.0.transformer_blocks.0.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_q.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_k.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_v.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_out.0.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for up_blocks.3.attentions.0.transformer_blocks.0.attn1.to_out.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.0.transformer_blocks.0.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.0.transformer_blocks.0.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_q.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([256, 768]) from checkpoint, the shape in current model is torch.Size([320, 768]).
	size mismatch for up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([256, 768]) from checkpoint, the shape in current model is torch.Size([320, 768]).
	size mismatch for up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_out.0.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for up_blocks.3.attentions.0.transformer_blocks.0.attn2.to_out.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.0.transformer_blocks.0.norm3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.0.transformer_blocks.0.norm3.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.0.transformer_blocks.0.ff.net.0.proj.weight: copying a param with shape torch.Size([2048, 256]) from checkpoint, the shape in current model is torch.Size([2560, 320]).
	size mismatch for up_blocks.3.attentions.0.transformer_blocks.0.ff.net.0.proj.bias: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([2560]).
	size mismatch for up_blocks.3.attentions.0.transformer_blocks.0.ff.net.2.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([320, 1280]).
	size mismatch for up_blocks.3.attentions.0.transformer_blocks.0.ff.net.2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.0.proj_out.weight: copying a param with shape torch.Size([256, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([320, 320, 1, 1]).
	size mismatch for up_blocks.3.attentions.0.proj_out.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.1.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.1.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.1.proj_in.weight: copying a param with shape torch.Size([256, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([320, 320, 1, 1]).
	size mismatch for up_blocks.3.attentions.1.proj_in.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.1.transformer_blocks.0.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.1.transformer_blocks.0.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_q.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_k.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_v.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_out.0.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for up_blocks.3.attentions.1.transformer_blocks.0.attn1.to_out.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.1.transformer_blocks.0.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.1.transformer_blocks.0.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_q.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([256, 768]) from checkpoint, the shape in current model is torch.Size([320, 768]).
	size mismatch for up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([256, 768]) from checkpoint, the shape in current model is torch.Size([320, 768]).
	size mismatch for up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_out.0.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for up_blocks.3.attentions.1.transformer_blocks.0.attn2.to_out.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.1.transformer_blocks.0.norm3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.1.transformer_blocks.0.norm3.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.1.transformer_blocks.0.ff.net.0.proj.weight: copying a param with shape torch.Size([2048, 256]) from checkpoint, the shape in current model is torch.Size([2560, 320]).
	size mismatch for up_blocks.3.attentions.1.transformer_blocks.0.ff.net.0.proj.bias: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([2560]).
	size mismatch for up_blocks.3.attentions.1.transformer_blocks.0.ff.net.2.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([320, 1280]).
	size mismatch for up_blocks.3.attentions.1.transformer_blocks.0.ff.net.2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.1.proj_out.weight: copying a param with shape torch.Size([256, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([320, 320, 1, 1]).
	size mismatch for up_blocks.3.attentions.1.proj_out.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.2.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.2.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.2.proj_in.weight: copying a param with shape torch.Size([256, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([320, 320, 1, 1]).
	size mismatch for up_blocks.3.attentions.2.proj_in.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.2.transformer_blocks.0.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.2.transformer_blocks.0.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_q.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_k.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_v.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_out.0.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for up_blocks.3.attentions.2.transformer_blocks.0.attn1.to_out.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.2.transformer_blocks.0.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.2.transformer_blocks.0.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_q.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([256, 768]) from checkpoint, the shape in current model is torch.Size([320, 768]).
	size mismatch for up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([256, 768]) from checkpoint, the shape in current model is torch.Size([320, 768]).
	size mismatch for up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_out.0.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([320, 320]).
	size mismatch for up_blocks.3.attentions.2.transformer_blocks.0.attn2.to_out.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.2.transformer_blocks.0.norm3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.2.transformer_blocks.0.norm3.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.2.transformer_blocks.0.ff.net.0.proj.weight: copying a param with shape torch.Size([2048, 256]) from checkpoint, the shape in current model is torch.Size([2560, 320]).
	size mismatch for up_blocks.3.attentions.2.transformer_blocks.0.ff.net.0.proj.bias: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([2560]).
	size mismatch for up_blocks.3.attentions.2.transformer_blocks.0.ff.net.2.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([320, 1280]).
	size mismatch for up_blocks.3.attentions.2.transformer_blocks.0.ff.net.2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.attentions.2.proj_out.weight: copying a param with shape torch.Size([256, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([320, 320, 1, 1]).
	size mismatch for up_blocks.3.attentions.2.proj_out.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.resnets.0.norm1.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([960]).
	size mismatch for up_blocks.3.resnets.0.norm1.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([960]).
	size mismatch for up_blocks.3.resnets.0.conv1.weight: copying a param with shape torch.Size([256, 768, 3, 3]) from checkpoint, the shape in current model is torch.Size([320, 960, 3, 3]).
	size mismatch for up_blocks.3.resnets.0.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.resnets.0.time_emb_proj.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([320, 1280]).
	size mismatch for up_blocks.3.resnets.0.time_emb_proj.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.resnets.0.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.resnets.0.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.resnets.0.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([320, 320, 3, 3]).
	size mismatch for up_blocks.3.resnets.0.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.resnets.0.conv_shortcut.weight: copying a param with shape torch.Size([256, 768, 1, 1]) from checkpoint, the shape in current model is torch.Size([320, 960, 1, 1]).
	size mismatch for up_blocks.3.resnets.0.conv_shortcut.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.resnets.1.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.3.resnets.1.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.3.resnets.1.conv1.weight: copying a param with shape torch.Size([256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([320, 640, 3, 3]).
	size mismatch for up_blocks.3.resnets.1.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.resnets.1.time_emb_proj.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([320, 1280]).
	size mismatch for up_blocks.3.resnets.1.time_emb_proj.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.resnets.1.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.resnets.1.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.resnets.1.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([320, 320, 3, 3]).
	size mismatch for up_blocks.3.resnets.1.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.resnets.1.conv_shortcut.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([320, 640, 1, 1]).
	size mismatch for up_blocks.3.resnets.1.conv_shortcut.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.resnets.2.norm1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.3.resnets.2.norm1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([640]).
	size mismatch for up_blocks.3.resnets.2.conv1.weight: copying a param with shape torch.Size([256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([320, 640, 3, 3]).
	size mismatch for up_blocks.3.resnets.2.conv1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.resnets.2.time_emb_proj.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([320, 1280]).
	size mismatch for up_blocks.3.resnets.2.time_emb_proj.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.resnets.2.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.resnets.2.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.resnets.2.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([320, 320, 3, 3]).
	size mismatch for up_blocks.3.resnets.2.conv2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for up_blocks.3.resnets.2.conv_shortcut.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([320, 640, 1, 1]).
	size mismatch for up_blocks.3.resnets.2.conv_shortcut.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for mid_block.attentions.0.norm.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.attentions.0.norm.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.attentions.0.proj_in.weight: copying a param with shape torch.Size([1024, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 1, 1]).
	size mismatch for mid_block.attentions.0.proj_in.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.attentions.0.transformer_blocks.0.norm1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.attentions.0.transformer_blocks.0.norm1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.attentions.0.transformer_blocks.0.attn1.to_q.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
	size mismatch for mid_block.attentions.0.transformer_blocks.0.attn1.to_k.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
	size mismatch for mid_block.attentions.0.transformer_blocks.0.attn1.to_v.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
	size mismatch for mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
	size mismatch for mid_block.attentions.0.transformer_blocks.0.attn1.to_out.0.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.attentions.0.transformer_blocks.0.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.attentions.0.transformer_blocks.0.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.attentions.0.transformer_blocks.0.attn2.to_q.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
	size mismatch for mid_block.attentions.0.transformer_blocks.0.attn2.to_k.weight: copying a param with shape torch.Size([1024, 768]) from checkpoint, the shape in current model is torch.Size([1280, 768]).
	size mismatch for mid_block.attentions.0.transformer_blocks.0.attn2.to_v.weight: copying a param with shape torch.Size([1024, 768]) from checkpoint, the shape in current model is torch.Size([1280, 768]).
	size mismatch for mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
	size mismatch for mid_block.attentions.0.transformer_blocks.0.attn2.to_out.0.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.attentions.0.transformer_blocks.0.norm3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.attentions.0.transformer_blocks.0.norm3.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.weight: copying a param with shape torch.Size([8192, 1024]) from checkpoint, the shape in current model is torch.Size([10240, 1280]).
	size mismatch for mid_block.attentions.0.transformer_blocks.0.ff.net.0.proj.bias: copying a param with shape torch.Size([8192]) from checkpoint, the shape in current model is torch.Size([10240]).
	size mismatch for mid_block.attentions.0.transformer_blocks.0.ff.net.2.weight: copying a param with shape torch.Size([1024, 4096]) from checkpoint, the shape in current model is torch.Size([1280, 5120]).
	size mismatch for mid_block.attentions.0.transformer_blocks.0.ff.net.2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.attentions.0.proj_out.weight: copying a param with shape torch.Size([1024, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 1, 1]).
	size mismatch for mid_block.attentions.0.proj_out.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.resnets.0.norm1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.resnets.0.norm1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.resnets.0.conv1.weight: copying a param with shape torch.Size([1024, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 3, 3]).
	size mismatch for mid_block.resnets.0.conv1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.resnets.0.time_emb_proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
	size mismatch for mid_block.resnets.0.time_emb_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.resnets.0.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.resnets.0.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.resnets.0.conv2.weight: copying a param with shape torch.Size([1024, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 3, 3]).
	size mismatch for mid_block.resnets.0.conv2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.resnets.1.norm1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.resnets.1.norm1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.resnets.1.conv1.weight: copying a param with shape torch.Size([1024, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 3, 3]).
	size mismatch for mid_block.resnets.1.conv1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.resnets.1.time_emb_proj.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1280, 1280]).
	size mismatch for mid_block.resnets.1.time_emb_proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.resnets.1.norm2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.resnets.1.norm2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for mid_block.resnets.1.conv2.weight: copying a param with shape torch.Size([1024, 1024, 3, 3]) from checkpoint, the shape in current model is torch.Size([1280, 1280, 3, 3]).
	size mismatch for mid_block.resnets.1.conv2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1280]).
	size mismatch for conv_norm_out.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for conv_norm_out.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([320]).
	size mismatch for conv_out.weight: copying a param with shape torch.Size([4, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([4, 320, 3, 3]).

In [None]:
# ===== Training Loop =====
warmup_ep = config.training.warmup_epochs
for epoch in range(start_epoch, config.training.epochs):
    # ---- Memory Management ----
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

    model.train()
    running_loss = 0

    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch+1}/{config.training.epochs}")

    for batch_idx, batch in pbar:
        if batch_idx % config.training.grad_accum_steps == 0:
            optimizer.zero_grad(set_to_none=True)

        images = batch['image'].to(device).float()

        # === Text Conditioning with CLIP ===
        captions = batch["caption"]
        text_inputs = clip_tokenizer(
            captions,
            padding="max_length",
            truncation=True,
            max_length=77,
            return_tensors="pt"
        ).to(device)

        with torch.no_grad():
            text_outputs = clip_encoder(**text_inputs)
            text_embeddings = text_outputs.last_hidden_state   # (B, 77, 768)

        # ---- VAE Encoding ----
        with torch.no_grad():
            latents = vae.encode(images).latent_dist.sample() * 0.18215

        # ---- Forward Diffusion ----
        t = torch.randint(0, scheduler.config.num_train_timesteps, (latents.size(0),), device=device)

        noise = torch.randn_like(latents)
        x_t = scheduler.add_noise(latents, noise, t)

        # ---- Noise Prediction ----
        with torch.amp.autocast('cuda', enabled=(scaler is not None)):
        
            noise_pred = model(x_t, timestep=t, encoder_hidden_states=text_embeddings).sample # here -----
            
            # ==== Loss Calculation ===
            mse_loss = MSE_LOSS_Unet(noise_pred, noise) / config.training.grad_accum_steps

            # -----
            if epoch+1 < warmup_ep: # No other loss
                lpips_loss = 0.0
                lpips_weight = 0.0

            else: # Compute LPIPS loss or other losses
                pred_x0 = scheduler.step(noise_pred, t[0].item(), x_t).pred_original_sample
                pred_rgb = vae.decode(pred_x0 / 0.18215).sample.clamp(-1, 1)

                lpips_loss = LPIPS_LOSS(pred_rgb, images).mean()
                
                # gradually increase loss weights
                if epoch+1 < 50:
                    lpips_weight = 0.05 * (epoch+1 - warmup_ep) / (30 - warmup_ep)
                else:
                    lpips_weight = 0.1
            # -----
    
            # Total loss
            total_loss= mse_loss + lpips_weight * lpips_loss
            # ======================
            
        loss = total_loss

        # ---- Backward Pass ----
        scaler.scale(loss).backward()  

        # ---- Gradient Accumulation ----
        if (batch_idx + 1) % config.training.grad_accum_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            ema.step_ema(unet_ema_model, model)

        # ---- Progress Tracking ----
        running_loss += loss.item() * config.training.grad_accum_steps
        avg_loss = running_loss / (batch_idx + 1)
        
        if avg_loss < best_loss:
            best_loss = avg_loss
    
        pbar.set_postfix(avg_loss=avg_loss, mem=gpu_info(handle))

    # ---- Checkpointing ----
    save_training_state(
        checkpoint_path=unet_ckpt_path,
        epoch=epoch,
        model=model,
        optimizer=optimizer,
        avg_loss=avg_loss,
        best_loss=best_loss,
    )
    unet_ema_ckpt_path = os.path.join(checkpoint_dir, config.checkpoint.ema_ckpt_name)
    torch.save(unet_ema_model.state_dict(), unet_ema_ckpt_path)
    print(f"Epoch {epoch+1} completed. Avg Loss: {avg_loss:.4f}")

    if torch.cuda.is_available():
        torch.cuda.empty_cache()

print("Training Completed!")

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: c:\Users\Incognito-R\miniconda3\envs\ml_env\Lib\site-packages\lpips\weights\v0.1\vgg.pth
Models, optimizers, losses initialized successfully.
