In [1]:
import timm
import torch
import tqdm
import pandas as pd
import torch.nn as nn

pt = timm.list_models(pretrained=True)

In [2]:
def get_children(model: torch.nn.Module):
    # get children form model!
    children = list(model.children())
    flatt_children = []
    if children == []:
        # if model has no children; model is last child! :O
        return str(type(model).__name__)
    else:
       # look for children from children... to the last child!
       for child in children:
            c = get_children(child)
            if type(c) == type(list()):
                flatt_children.extend(get_children(child))
            else:
                flatt_children.append(get_children(child))
    return flatt_children

In [27]:
m

'tresnet_l'

In [28]:
model_meta = {}
for m in tqdm.tqdm(pt):
    if 'tresnet' in m:
        continue
    model = timm.create_model(m, num_classes=0).eval()
    config = timm.data.resolve_data_config({}, model=model)    
    config['model_type'] = type(model).__name__
    config['n_feature_params'] = sum(p.numel() for p in model.parameters())
    config['model_name'] = m
    
    # get description of model modules
    c = get_children(model)
    d = {item:c.count(item) for item in c}
    for k,v in d.items():
        config[k] = v
        
    # get feature dimension
    f_dim = model(torch.randn(config['input_size']).unsqueeze(0))
    if type(f_dim) != type(0):
        if type(f_dim) == type(list()) or type(f_dim) == type((1,2)):
            f_dim = f_dim[0].shape[1]
        else:
            f_dim = f_dim.shape[1]
                
    config['feature_dim'] = f_dim
    del model
    
    model_meta[m] = config

 87%|████████▋ | 513/592 [08:51<01:55,  1.46s/it] Removing representation layer for fine-tuning.
 87%|████████▋ | 517/592 [08:56<01:33,  1.25s/it]Removing representation layer for fine-tuning.
 88%|████████▊ | 518/592 [08:57<01:29,  1.21s/it]Removing representation layer for fine-tuning.
 88%|████████▊ | 520/592 [09:01<01:57,  1.63s/it]Removing representation layer for fine-tuning.
 89%|████████▊ | 524/592 [09:26<05:27,  4.82s/it]Removing representation layer for fine-tuning.
100%|██████████| 592/592 [10:50<00:00,  1.10s/it]


In [33]:
import json
with open('timm_model_input_sizes.json', 'r', encoding='utf-8') as f:
    sizes = json.load(f)

for k,v in sizes.items():
    for m in v:
        if 'tresnet' not in m:
            model_meta[m]['test_input_size'] = (3, int(k), int(k))
        
df = pd.DataFrame([v for k,v in model_meta.items()])
df.to_csv('timm_model_metadata.csv', index=False)    
    