In [21]:
import os
import sys
import random
import argparse
from PIL import Image
import numpy as np
from omegaconf import OmegaConf
import torch 

from mvdream.camera_utils import get_camera
from mvdream.ldm.util import instantiate_from_config
from mvdream.ldm.models.diffusion.ddim import DDIMSampler
from mvdream.model_zoo import build_model
from diffusers.models.attention_processor import LoRAAttnProcessor

from peft import LoraConfig,get_peft_model

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def print_trainable_parameters(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    linear_params = 0
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):# and name in target_modules:
            linear_params += sum(p.numel() for p in module.parameters() if p.requires_grad)
    print(f"Linear Trainable parameters: {linear_params}")
    print(f"LORA Trainable parameters: {trainable_params}")
    print(f"LORA Aprrox Percentage : {100*trainable_params/linear_params:.2f}")
    print(f"Total parameters: {total_params}")
    print(f"Percentage of trainable parameters: {100 * trainable_params / total_params:.2f}%")

In [2]:

args_dict = {'model_name': 'sd-v2.1-base-4view', 'config_path': None, 'ckpt_path': None, 'text': 'an astronaut riding a horse', 'suffix': ', 3d asset', 'size': 256, 'num_frames': 4, 'use_camera': 1, 'camera_elev': 15, 'camera_azim': 90, 'camera_azim_span': 360, 'seed': 23, 'fp16': False, 'device': 'cuda:2'}
args = argparse.Namespace(**args_dict)
dtype = torch.float16 if args.fp16 else torch.float32
device = args.device
batch_size = max(4, args.num_frames)

In [30]:
print("load t2i model ... ")
if args.config_path is None:
    model = build_model(args.model_name, ckpt_path=args.ckpt_path)
else:
    assert args.ckpt_path is not None, "ckpt_path must be specified!"
    config = OmegaConf.load(args.config_path)
    model = instantiate_from_config(config.model)
    model.load_state_dict(torch.load(args.ckpt_path, map_location=device))
model.device = device
model.to(device)
model.eval()

import re
model_modules = str(model.modules)


load t2i model ... 
Loading model from config: sd-v2-base.yaml
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention

In [32]:
import re
model_modules = str(model.modules)
pattern = r'\((\w+)\): Linear'
linear_layer_names = re.findall(pattern, model_modules)
pattern = r'\((\w+)\): Conv2d'
conv_layer_names = re.findall(pattern, model_modules)
names = []
# Print the names of the Linear layers
for name in linear_layer_names:
    names.append(name)
for name in conv_layer_names:
    names.append(name)
target_modules = list(set(names))
print(target_modules)


import torch.nn as nn


# Function to get the number of parameters in a module
def get_num_params(module):
    return sum(p.numel() for p in module.parameters() if p.requires_grad)

# Calculate total number of parameters in the model
total_params = get_num_params(model)

# Calculate number of parameters in the target modules (Linear layers)
linear_params = 0
for name, module in model.named_modules():
    if isinstance(module, nn.Linear):# and name in target_modules:
        linear_params += get_num_params(module)

# Calculate the percentage of parameters in Linear layers
percentage_linear_params = (linear_params / total_params) * 100

print(f"Total Parameters: {total_params}")
print(f"Linear Layer Parameters: {linear_params}")
print(f"Percentage of Parameters in Linear Layers: {percentage_linear_params:.2f}%")



['2', 'proj_out', '1', 'proj_in', 'conv_out', 'conv2', '0', '3', 'conv_in', 'v', 'post_quant_conv', 'to_v', 'q', 'quant_conv', 'skip_connection', 'to_k', 'conv', 'to_q', 'op', 'nin_shortcut', 'proj', 'c_fc', 'c_proj', 'conv1', 'k']
Total Parameters: 951226027
Linear Layer Parameters: 303066240
Percentage of Parameters in Linear Layers: 31.86%


In [33]:
config = LoraConfig(
r=20,
lora_alpha=32,
target_modules=['proj_out', 'to_k', 'c_proj', 'to_q', 'proj_in', 'to_v', 'c_fc', 'proj','conv_out', 'conv2','conv_in','conv'],
lora_dropout=0.1,
bias="lora_only",
modules_to_save=["decode_head"],
)
lora_model = get_peft_model(model, config)
print_trainable_parameters(lora_model)

Linear Trainable parameters: 11559040
LORA Trainable parameters: 14638419
LORA Aprrox Percentage : 126.64
Total parameters: 1319633268
Percentage of trainable parameters: 1.11%
