# Flamingo Model Stats
notebook to check the number of trainable and frozen parameters

In [1]:
%load_ext autoreload
%autoreload 2
from flamingo_mini import FlamingoConfig, FlamingoModel
import torch

In [2]:
#config = FlamingoConfig(lm='facebook/opt-125m')
#model = FlamingoModel(config)
model = FlamingoModel.from_pretrained('dhansmair/flamingo-tiny')
model.train()
print('model loaded.')

model loaded.


In [3]:
def get_named_parameters_trainable(named_params):
    return [(name, tensor) for name, tensor in named_params if tensor.requires_grad]
    
    
def print_nicely(big_number):
    """print with thousands-blocks separated"""
    return '{:,}'.format(big_number).replace(',', ' ')
    
    
def find_redundant(l):
    dups, uniq = [], []
    seen = set()

    for x in l:
        if x not in seen:
            uniq.append(x)
            seen.add(x)
        else:
            dups.append(x)
            
    return dups, uniq
    

# General Stats

In [4]:

state_dict = list(model.flamingo.state_dict().keys())
state_dict_trainable = list(model.flamingo.state_dict_trainable().keys())

parameters = list(model.parameters())
named_parameters = list(w for w, t in model.flamingo.named_parameters())
parameters_trainable = list(model.parameters_trainable())
named_parameters_trainable = list(w for w, t in get_named_parameters_trainable(model.flamingo.named_parameters()))

In [5]:
print('length state_dict:', len(state_dict))
print('length state_dict_trainable:', len(state_dict_trainable))
print('length parameters:', len(parameters))
print('length named_parameters:', len(named_parameters))
print('length parameters_trainable:', len(parameters_trainable))
print('length named_parameters_trainable:', len(named_parameters_trainable))

length state_dict: 797
length state_dict_trainable: 209
length parameters: 795
length named_parameters: 795
length parameters_trainable: 209
length named_parameters_trainable: 209


In [6]:
[p for p in state_dict if p not in state_dict_trainable]

['vision_encoder.vision_model.embeddings.class_embedding',
 'vision_encoder.vision_model.embeddings.position_ids',
 'vision_encoder.vision_model.embeddings.patch_embedding.weight',
 'vision_encoder.vision_model.embeddings.position_embedding.weight',
 'vision_encoder.vision_model.pre_layrnorm.weight',
 'vision_encoder.vision_model.pre_layrnorm.bias',
 'vision_encoder.vision_model.encoder.layers.0.self_attn.k_proj.weight',
 'vision_encoder.vision_model.encoder.layers.0.self_attn.k_proj.bias',
 'vision_encoder.vision_model.encoder.layers.0.self_attn.v_proj.weight',
 'vision_encoder.vision_model.encoder.layers.0.self_attn.v_proj.bias',
 'vision_encoder.vision_model.encoder.layers.0.self_attn.q_proj.weight',
 'vision_encoder.vision_model.encoder.layers.0.self_attn.q_proj.bias',
 'vision_encoder.vision_model.encoder.layers.0.self_attn.out_proj.weight',
 'vision_encoder.vision_model.encoder.layers.0.self_attn.out_proj.bias',
 'vision_encoder.vision_model.encoder.layers.0.layer_norm1.weight',


In [7]:
# trainable parameters
named_parameters_trainable

['resampler.latents',
 'resampler.time_pos_emb',
 'resampler.layers.0.0.norm_media.weight',
 'resampler.layers.0.0.norm_media.bias',
 'resampler.layers.0.0.norm_latents.weight',
 'resampler.layers.0.0.norm_latents.bias',
 'resampler.layers.0.0.to_q.weight',
 'resampler.layers.0.0.to_k.weight',
 'resampler.layers.0.0.to_v.weight',
 'resampler.layers.0.0.to_out.weight',
 'resampler.layers.0.1.0.weight',
 'resampler.layers.0.1.0.bias',
 'resampler.layers.0.1.1.weight',
 'resampler.layers.0.1.3.weight',
 'resampler.layers.1.0.norm_media.weight',
 'resampler.layers.1.0.norm_media.bias',
 'resampler.layers.1.0.norm_latents.weight',
 'resampler.layers.1.0.norm_latents.bias',
 'resampler.layers.1.0.to_q.weight',
 'resampler.layers.1.0.to_k.weight',
 'resampler.layers.1.0.to_v.weight',
 'resampler.layers.1.0.to_out.weight',
 'resampler.layers.1.1.0.weight',
 'resampler.layers.1.1.0.bias',
 'resampler.layers.1.1.1.weight',
 'resampler.layers.1.1.3.weight',
 'resampler.layers.2.0.norm_media.weigh

In [8]:
# parameters that are frozen
[p for p in named_parameters if p not in named_parameters_trainable]

['vision_encoder.vision_model.embeddings.class_embedding',
 'vision_encoder.vision_model.embeddings.patch_embedding.weight',
 'vision_encoder.vision_model.embeddings.position_embedding.weight',
 'vision_encoder.vision_model.pre_layrnorm.weight',
 'vision_encoder.vision_model.pre_layrnorm.bias',
 'vision_encoder.vision_model.encoder.layers.0.self_attn.k_proj.weight',
 'vision_encoder.vision_model.encoder.layers.0.self_attn.k_proj.bias',
 'vision_encoder.vision_model.encoder.layers.0.self_attn.v_proj.weight',
 'vision_encoder.vision_model.encoder.layers.0.self_attn.v_proj.bias',
 'vision_encoder.vision_model.encoder.layers.0.self_attn.q_proj.weight',
 'vision_encoder.vision_model.encoder.layers.0.self_attn.q_proj.bias',
 'vision_encoder.vision_model.encoder.layers.0.self_attn.out_proj.weight',
 'vision_encoder.vision_model.encoder.layers.0.self_attn.out_proj.bias',
 'vision_encoder.vision_model.encoder.layers.0.layer_norm1.weight',
 'vision_encoder.vision_model.encoder.layers.0.layer_nor

# Total number of parameters in the model
note that the model does not include the vision encoder.

In [9]:
num_params_total = model.num_parameters(only_trainable=False)
num_params_trainable = model.num_parameters(only_trainable=True)

print('params trainable:', print_nicely(num_params_trainable))
print('params total:', print_nicely(num_params_total))

params trainable: 180 312 856
params total: 570 123 032


In [10]:
num_resampler_params = sum(p.numel() for p in model.flamingo.resampler.parameters()) 
print('params resampler:', print_nicely(num_resampler_params))

params resampler: 63 023 104
