In [1]:
import torch
from accelerate import init_empty_weights
from accelerate.utils import (
    calculate_maximum_sizes,
    convert_bytes
)
from accelerate.commands.estimate import check_has_model, create_empty_model
import transformers
from transformers import AutoConfig, AutoModel
from prettytable import PrettyTable

In [19]:
from google.colab import userdata
import os
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')

In [23]:
def create_empty_model(model_name: str, trust_remote_code: bool = False):
    auto_map = False
    config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
    print(config)

    with init_empty_weights():
        # remote code could specify a specific `AutoModel` class in the `auto_map`
        constructor = AutoModel
        if isinstance(auto_map, dict):
            value = None
            for key in auto_map.keys():
                if key.startswith("AutoModelFor"):
                    value = key
                    break
            if value is not None:
                constructor = getattr(transformers, value)
        model = constructor.from_config(config, trust_remote_code=trust_remote_code)

    return model

In [24]:
def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    extra_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad:
            continue
        params = parameter.numel()
        if 'embed_tokens' in name.lower():
          extra_params += params
        table.add_row([name, params])
        total_params += params
    print(table)
    return total_params + extra_params

In [25]:
def calc(model_name):
    model = create_empty_model(model_name, trust_remote_code=False)
    # print(model)
    data_row = [None, None, None, None]
    # memory_req = calculate_memory(model)
    total_params = count_parameters(model)
    print(f"Total Trainable Params: {total_params}")

In [None]:
calc('meta-llama/Llama-3.1-8B')

In [None]:
calc('Qwen/Qwen2.5-7B')