![](https://production-media.paperswithcode.com/methods/Screen_Shot_2021-01-26_at_9.43.31_PM_uI4jjMq.png)

In [None]:
!pip install transformers

In [2]:
import numpy as np
import pandas as pd
import os
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
from colorsys import hls_to_rgb
import plotly.graph_objects as go

In [3]:
model_name = 'google/vit-base-patch16-224'
feature_extractor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(model_name).cuda()

layers = list(model.vit.encoder.layer.children())

In [4]:
def set_num_layers(n):
    for _ in range(len(model.vit.encoder.layer)):
        model.vit.encoder.layer.pop(0)
    model.vit.encoder.layer.extend(layers[:n])
    

def predict_image(image):
    inputs = feature_extractor(images=image, return_tensors='pt')
    inputs = {k: v.cuda() for k, v in inputs.items()}  # move tensors to GPU
    outputs = model(**inputs)
    predicted_class_idx = outputs.logits.argmax(-1).item()
    return model.config.id2label[predicted_class_idx]


def predict_image_by_all_layers(image):
    res = []
    for layer in range(len(layers)):
        set_num_layers(layer)
        res.append(predict_image(image))
    return res

In [5]:
base_path = 'images_hammerhead_shark'
flows = []

for path in os.listdir(base_path):
    image = Image.open(f'{base_path}/{path}')
    prediction_by_layers = predict_image_by_all_layers(image)
    flows.extend([{
        'path': path,
        'source': f'L{idx}:{s}',
        'target': f'L{idx + 1}:{t}',
    } for idx , (s, t) in enumerate(zip(prediction_by_layers[:-1],
                                        prediction_by_layers[1:]))])

In [6]:
links = (
    pd.DataFrame(flows)
    [['source', 'target']]
    .value_counts()
    .rename('count')
    .reset_index()
)

In [7]:
nodes = pd.concat([links['source'],
                   links['target']]).drop_duplicates().tolist()
node_classes = [n.split(':')[1] for n in nodes]

In [8]:
links['source_node_index'] = links['source'].apply(nodes.index)
links['target_node_index'] = links['target'].apply(nodes.index)

In [9]:
def get_distinct_colors(n):
    colors = []
    for i in np.arange(0., 360., 360. / n):
        h = i / 360.
        l = (50 + np.random.rand() * 10) / 100.
        s = (90 + np.random.rand() * 10) / 100.
        colors.append(hls_to_rgb(h, l, s))

    return colors


node_classes_unique = list(set(node_classes))
distinct_colors = get_distinct_colors(len(node_classes_unique))
class_colors = [f'rgb{i}' for i in distinct_colors]
node_colors = [class_colors[node_classes_unique.index(node_class)]
               for node_class in node_classes]

In [18]:
np.random.seed(1122)
class_colors = [f'rgb{i}' for i in distinct_colors]
np.random.shuffle(class_colors)
node_colors = [class_colors[node_classes_unique.index(node_class)]
               for node_class in node_classes]

fig = go.Figure(data=[go.Sankey(
    node=dict(
      pad=15,
      thickness=15,
      line=dict(color='black', width=0.5),
      label=node_classes,
      color=node_colors
    ),
    link=dict(
      source=links['source_node_index'],
      target=links['target_node_index'],
      value=links['count'],
      label=links['count']
))])
fig.show()