In [1]:
import torch
import torch.fx as fx
from transformers import ViTImageProcessor, ViTModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
inputs = processor(images=image, return_tensors="pt")

In [3]:
try:
    gm = fx.symbolic_trace(model)
except fx.proxy.TraceError as err:
    print(err)

symbolically traced variables cannot be used as inputs to control flow


In [4]:
from copy import deepcopy

In [5]:
class GraphCollector:

    def __init__(self):
        self._subgraphs = []

    def reset(self):
        self._subgraphs.clear()

    def __call__(self, gm: fx.GraphModule, sameple_inputs):
        self._subgraphs.append(deepcopy(gm.graph))
        return gm.forward
        
    @property
    def subgraphs(self):
        return self._subgraphs[:]

In [6]:
collector = GraphCollector()
model_ = torch.compile(model, backend=collector).to("cuda:0")

In [7]:
collector.reset()
_ = model_(**{k: v.to("cuda:0") for k, v in inputs.items()})

In [8]:
graph = collector.subgraphs[0]

In [9]:
graph

<torch.fx.graph.Graph at 0x79d2b0d34880>

In [44]:
nodes = [n for n in graph.nodes] # nodes in topological ordering

In [73]:
node = nodes[0]

In [50]:
node.op, node.all_input_nodes, node.users

('placeholder', [], {conv2d: None})

In [74]:
node.name

'l_pixel_values_'

In [75]:
node.target

'L_pixel_values_'

In [52]:
type(node).__module__

'torch.fx.node'

In [65]:
node.meta

{'stack_trace': '  File "/home/dboy/open_source/play_torch.fx/.venv/lib/python3.11/site-packages/transformers/models/vit/modeling_vit.py", line 620, in forward\n    if pixel_values is None:\n',
 'example_value': FakeTensor(..., device='cuda:0', size=(1, 3, 224, 224)),
 'tensor_dict': {},
 'grapharg': GraphArg(source=LocalSource(local_name='pixel_values', cell_or_freevar=False), _example=<torch.utils.weak.TensorWeakRef object at 0x7969d4079b90>, pass_arg_as_tensor=False, fake_tensor=FakeTensor(..., device='cuda:0', size=(1, 3, 224, 224)), is_tensor=True, example_strong_ref=None)}

In [66]:
graph.create_node?

[0;31mSignature:[0m
[0mgraph[0m[0;34m.[0m[0mcreate_node[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mop[0m[0;34m:[0m [0mstr[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtarget[0m[0;34m:[0m [0;34m'Target'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0margs[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mTuple[0m[0;34m[[0m[0mForwardRef[0m[0;34m([0m[0;34m'Argument'[0m[0;34m)[0m[0;34m,[0m [0;34m...[0m[0;34m][0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mkwargs[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mDict[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mForwardRef[0m[0;34m([0m[0;34m'Argument'[0m[0;34m)[0m[0;34m][0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mname[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mstr[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtype_expr[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mAny[0m[0;34

In [69]:
node.target

'L_pixel_values_'

In [71]:
node = graph.create_node("placeholder", target="my_ph_value", name="my_ph")

In [72]:
node.is_impure()

True

In [84]:
set((n.target, n.name) for n in graph.nodes)

{('L_pixel_values_', 'l_pixel_values_'),
 ('L_self_modules_embeddings_modules_patch_embeddings_modules_projection_parameters_bias_',
  'l_self_modules_embeddings_modules_patch_embeddings_modules_projection_parameters_bias_'),
 ('L_self_modules_embeddings_modules_patch_embeddings_modules_projection_parameters_weight_',
  'l_self_modules_embeddings_modules_patch_embeddings_modules_projection_parameters_weight_'),
 ('L_self_modules_embeddings_parameters_cls_token_',
  'l_self_modules_embeddings_parameters_cls_token_'),
 ('L_self_modules_embeddings_parameters_position_embeddings_',
  'l_self_modules_embeddings_parameters_position_embeddings_'),
 ('L_self_modules_encoder_modules_layer_modules_0_modules_attention_modules_attention_modules_key_parameters_bias_',
  'l_self_modules_encoder_modules_layer_modules_0_modules_attention_modules_attention_modules_key_parameters_bias_'),
 ('L_self_modules_encoder_modules_layer_modules_0_modules_attention_modules_attention_modules_key_parameters_weight_

In [85]:
name2node_map = {
    n.name: n
    for n in graph.nodes
}

In [88]:
node = name2node_map["linear_8"]

In [97]:
node.target?

[0;31mDocstring:[0m
linear(input, weight, bias=None) -> Tensor

Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.

This operation supports 2-D :attr:`weight` with :ref:`sparse layout<sparse-docs>`


    Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported,
    or may not have autograd support. If you notice missing functionality please
    open a feature request.

This operator supports :ref:`TensorFloat32<tf32_on_ampere>`.

Shape:

    - Input: :math:`(*, in\_features)` where `*` means any number of
      additional dimensions, including none
    - Weight: :math:`(out\_features, in\_features)` or :math:`(in\_features)`
    - Bias: :math:`(out\_features)` or :math:`()`
    - Output: :math:`(*, out\_features)` or :math:`(*)`, based on the shape of the weight
[0;31mType:[0m      builtin_function_or_method

In [96]:
node.args

(layer_norm_2,
 l_self_modules_encoder_modules_layer_modules_1_modules_attention_modules_attention_modules_value_parameters_weight_,
 l_self_modules_encoder_modules_layer_modules_1_modules_attention_modules_attention_modules_value_parameters_bias_)

In [94]:
name2node_map["x_4"]

x_4

In [99]:
ph = name2node_map["l_self_modules_encoder_modules_layer_modules_1_modules_attention_modules_attention_modules_value_parameters_weight_"]

In [103]:
ph.meta["example_value"].shape

torch.Size([768, 768])

In [106]:
node.meta["example_value"].shape

torch.Size([1, 197, 768])