In [None]:
import requests
from PIL import Image
from transformers import AutoProcessor, CLIPVisionModel

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

%load_ext autoreload
%autoreload 2

In [None]:
from transformers.models.clip.processing_clip import CLIPProcessor

LLAVA_IMAGE_ENCODER = "openai/clip-vit-large-patch14-336"
LLAVA_IMAGE_ENCODER_SMALL = "openai/clip-vit-base-patch32"

model = CLIPVisionModel.from_pretrained(LLAVA_IMAGE_ENCODER_SMALL)
processor: CLIPProcessor = AutoProcessor.from_pretrained(LLAVA_IMAGE_ENCODER_SMALL)

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)



In [None]:
type(processor)
type(processor).__bases__

processor.model_input_names
processor.image_processor_class


In [None]:
processor.image_processor
type(processor.image_processor).__bases__

In [None]:
from transformers.image_processing_utils import BaseImageProcessor
type(BaseImageProcessor).__bases__

In [None]:
inputs = processor(images=image, return_tensors="pt",)

In [None]:
import torch

device = torch.device('cuda')
pixels = inputs.pixel_values.to(device)

In [None]:
from transformers.models.clip.modeling_clip import CLIPVisionModel, CLIPVisionTransformer

In [None]:
m = model
for base in m.base_model.children():
    print(base)

In [None]:
from torch import nn

l = nn.Linear(2, 2)
net = nn.Sequential(l, l)
for idx, m in enumerate(net.named_modules()):
    print(idx, '->', m)


In [None]:
modules = dict(model.named_modules(remove_duplicate=True))
modules_dups = dict(model.named_modules(remove_duplicate=False))
params = dict(model.named_parameters())


In [None]:
len(modules), len(modules_dups)

In [None]:
"a" * 5

In [None]:
pixels.toli

In [None]:
children = dict(model.named_children())

In [None]:
dict(children['vision_model'].named_children())

In [None]:
list(s)

In [None]:
sd = model.state_dict()

In [None]:
modules = dict(model.named_modules(remove_duplicate=False))

In [None]:
print(str(model))

In [None]:
print("Modules")
modules_list = []
for i, (k, m) in enumerate(model.named_modules()):
    module_cls = m.__class__.__name__
    modules_list.append((k, module_cls))

In [None]:
modules_list[:10]

In [None]:

from collections import defaultdict

params_dict = {}
params_list = []
for i, (k, p) in enumerate(model.named_parameters()):
    parent_module = '.'.join(k.split('.')[:-1])
    module_cls = modules[parent_module].__class__.__name__
    params_list.append((k, module_cls, list(p.shape)))

In [None]:
str(modules['vision_model'].__class__)

In [None]:
from dataclasses import dataclass
from typing import Optional
import torch

@dataclass
class ParamInfo:
    path: str
    shape: list
    dtype: Optional[torch.dtype] = None
    def __repr__(self) -> str:
        s = f'Param: \"{self.path}\", Shape: {self.shape}'
        return s +  f"dtype: {self.dtype}" if self.dtype is not None else s

@dataclass
class ModulePath:
    path: str
    module_name: str

    def __repr__(self):
        return f"{self.path} = {self.module_name}"
        
    def __hash__(self):
        return hash(self.path + self.module_name)
        
@dataclass
class TensorMap:
    path: str
    module_type: str
    children: dict    
    
def modules_to_implement(m_dict):
    child_modules = []
    if isinstance(m_dict, list):
        return m_dict
    
    for modulepath, child in m_dict.items():
        print(f"{modulepath.path}: {modulepath.module_name}")
        child_modules.append(modules_to_implement(child))
    
    return child_modules

def walk_modules(module, prefix=''):
    module_dict = {}
    if len(list(module.children())) == 0:
        params = [ParamInfo(path='.'.join([prefix,k]), shape=list(v.shape)) for k,v in module.named_parameters()]
        return params
    
    for n,m in module.named_children():
        m_cls = m.__class__.__name__
        path = '.'.join([prefix, n]) if prefix else n
        module_dict[ModulePath(path, m_cls)] = walk_modules(m, prefix=path) 
    return module_dict

In [None]:
m_dict = walk_modules(model)

In [None]:
keys = list(m_dict.keys())

In [None]:
child_modules = modules_to_implement(m_dict)

In [None]:
modules = dict(model.named_modules())

In [239]:
model

CLIPVisionModel(
  (vision_model): CLIPVisionTransformer(
    (embeddings): CLIPVisionEmbeddings(
      (patch_embedding): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
      (position_embedding): Embedding(50, 768)
    )
    (pre_layrnorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
        

In [238]:
keys = list(modules.keys())
keys[:5]

['',
 'vision_model',
 'vision_model.embeddings',
 'vision_model.embeddings.patch_embedding',
 'vision_model.embeddings.position_embedding']

In [251]:

def _make_module_map(model: torch.nn.Module):
    module_map = {}

    for k,v in model.named_modules():
        module_cls = v.__class__.__name__
        if module_cls in module_map:
            module_map[module_cls] = module_map[module_cls] + [k]
        else:
            module_map[module_cls] = [k]
        
    return module_map

def get_module(name: str, model: torch.nn.Module):
    module_map = _make_module_map(model)    
    modules = dict(model.named_modules())
    
    path = module_map[name]
    
    key = path[0]    
    fmt_path = path[0] + ', ... ,' + path[-1] if len(path) > 1 else path[0]
         
    module = modules.get(key)
    print(f"{fmt_path}: {module}")
    
    return module

In [252]:
get_module('CLIPAttention', model)

vision_model.encoder.layers.0.self_attn, ... ,vision_model.encoder.layers.11.self_attn: CLIPAttention(
  (k_proj): Linear(in_features=768, out_features=768, bias=True)
  (v_proj): Linear(in_features=768, out_features=768, bias=True)
  (q_proj): Linear(in_features=768, out_features=768, bias=True)
  (out_proj): Linear(in_features=768, out_features=768, bias=True)
)


CLIPAttention(
  (k_proj): Linear(in_features=768, out_features=768, bias=True)
  (v_proj): Linear(in_features=768, out_features=768, bias=True)
  (q_proj): Linear(in_features=768, out_features=768, bias=True)
  (out_proj): Linear(in_features=768, out_features=768, bias=True)
)

In [237]:
dict(m.named_modules())

{'': CLIPVisionEmbeddings(
   (patch_embedding): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
   (position_embedding): Embedding(50, 768)
 ),
 'patch_embedding': Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False),
 'position_embedding': Embedding(50, 768)}

In [None]:
m

In [None]:
model

In [None]:
m_dict[keys[0]]

In [None]:
m_dict[]

In [None]:
params_list[0:10]

In [None]:
modules = dict(model.named_modules())
params = dict(model.named_parameters())

type(modules['vision_model.embeddings'])


ident = " " * 4
for module_name, module in model.named_modules():
    if module_name:
        print(f"MODULE {module_name}")
        
        for p_name, p in module.named_parameters():
            print(f"{ident}PARAM {p_name} {list(p.shape)}")

In [None]:

for k,v in params.items():
    depth = len(k.split('.')) - 1
    ident = depth * "  "    
    print(f"{depth} {ident}{k} {v.shape}")

In [None]:
from torchinfo import summary

tensors, stats = summary(model, input_data=[pixels], verbose=2)

In [None]:

outputs = model(**inputs)
last_hidden_state = outputs.last_hidden_state
pooled_output = outputs.pooler_output  # pooled CLS states
