In [34]:
import os
import sys
import random
import argparse
from PIL import Image
import numpy as np
from omegaconf import OmegaConf
import torch 

from mvdream.camera_utils import get_camera
from mvdream.ldm.util import instantiate_from_config
from mvdream.ldm.models.diffusion.ddim import DDIMSampler
from mvdream.model_zoo import build_model
from diffusers.models.attention_processor import LoRAAttnProcessor

from peft import LoraConfig,get_peft_model

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    
    Args:
    model: The model to inspect.
    """
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Trainable parameters: {trainable_params}")
    print(f"Total parameters: {total_params}")
    print(f"Percentage of trainable parameters: {100 * trainable_params / total_params:.2f}%")

In [7]:

args_dict = {'model_name': 'sd-v2.1-base-4view', 'config_path': None, 'ckpt_path': None, 'text': 'an astronaut riding a horse', 'suffix': ', 3d asset', 'size': 256, 'num_frames': 4, 'use_camera': 1, 'camera_elev': 15, 'camera_azim': 90, 'camera_azim_span': 360, 'seed': 23, 'fp16': False, 'device': 'cuda:2'}
args = argparse.Namespace(**args_dict)
dtype = torch.float16 if args.fp16 else torch.float32
device = args.device
batch_size = max(4, args.num_frames)

In [8]:
print("load t2i model ... ")
if args.config_path is None:
    model = build_model(args.model_name, ckpt_path=args.ckpt_path)
else:
    assert args.ckpt_path is not None, "ckpt_path must be specified!"
    config = OmegaConf.load(args.config_path)
    model = instantiate_from_config(config.model)
    model.load_state_dict(torch.load(args.ckpt_path, map_location=device))
model.device = device
model.to(device)
model.eval()

import re
model_modules = str(model.modules)


load t2i model ... 
Loading model from config: sd-v2-base.yaml
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention

In [44]:
model.named_parameters

<bound method Module.named_parameters of LatentDiffusionInterface(
  (model): DiffusionWrapper(
    (diffusion_model): MultiViewUNetModel(
      (time_embed): Sequential(
        (0): Linear(in_features=320, out_features=1280, bias=True)
        (1): SiLU()
        (2): lora.Linear(
          (base_layer): Linear(in_features=1280, out_features=1280, bias=True)
          (lora_dropout): ModuleDict(
            (default): Dropout(p=0.1, inplace=False)
          )
          (lora_A): ModuleDict(
            (default): Linear(in_features=1280, out_features=32, bias=False)
          )
          (lora_B): ModuleDict(
            (default): Linear(in_features=32, out_features=1280, bias=False)
          )
          (lora_embedding_A): ParameterDict()
          (lora_embedding_B): ParameterDict()
        )
      )
      (camera_embed): Sequential(
        (0): Linear(in_features=16, out_features=1280, bias=True)
        (1): SiLU()
        (2): lora.Linear(
          (base_layer): Linear(in_fe

In [46]:
config = LoraConfig(
r=20,
lora_alpha=32,
# target_modules=['proj_out', 'to_k', 'c_proj', 'to_q', 'proj_in', 'to_v', 'c_fc', 'proj'],
lora_dropout=0.1,
bias="lora_only",
modules_to_save=["decode_head"],
)
lora_model = get_peft_model(model, config)
print_trainable_parameters(lora_model)

Trainable parameters: 11870144
Total parameters: 1316877228
Percentage of trainable parameters: 0.90%
