In [None]:
import torch
import timm
import numpy as np
import json, io, base64
from torchvision import transforms
from PIL import Image
from IPython.display import display, HTML, Javascript
from src.utils.analysis.clmim_hook import ActivationCache
import ipywidgets as widgets
from datasets import load_dataset
import logging
import time

logging.basicConfig(level=logging.INFO)

In [1]:
def load_partmae():
    return timm.create_model(
        "vit_base_patch16_224",
        pretrained=True,
        pretrained_cfg_overlay={
            "file": "../../artifacts/model-2knf0d16:v0/backbone.ckpt"
        },
        pretrained_strict=False,
    ).cuda().eval()

In [2]:
def load_vit():
    return timm.create_model(
        "vit_base_patch16_224",
        pretrained=True
    ).cuda().eval()

In [3]:
def load_mae():
    return timm.create_model('vit_base_patch16_224.mae', pretrained=True).cuda().eval()

def load_dino():
    return timm.create_model('vit_base_patch16_224.dino', pretrained=True).cuda().eval()


In [4]:
def load_dora():
    return (
        torch.hub.load(
            "dgcnz/DoRA_ICLR24", model="vit_small_patch16_224_dora_wt_venice_ep100"
        )
        .cuda()
        .eval()
    )

In [5]:
# model = load_partmae()
# model = timm.create_model('vit_base_patch16_224', pretrained=True).eval().cuda()

dataset = load_dataset("frgfm/imagenette" , split="validation", name="160px")
batch = dataset[:4]

NameError: name 'load_dataset' is not defined

In [None]:
def generate_interactive_html(all_attn_maps, image_base64, num_layers, num_heads):
    all_attns_json = json.dumps(all_attn_maps)
    html_code = f'''
    <h3>Hover over the top image to update attention maps</h3>
    <canvas id="input-image" width="224" height="224" style="border:1px solid #000;"></canvas>
    <div id="attention-grid" style="display: grid; grid-template-columns: repeat({num_heads}, auto); grid-gap: 5px; margin-top: 10px;"></div>
    
    <script>
    var allAttentionMaps = {all_attns_json}; 
    var numLayers = {num_layers};
    var numHeads = {num_heads};
    var canvasSize = 224;
    var gridSize = 14;
    var patchSize = canvasSize / gridSize;
    
    // Define viridis colormap using key points.
    function getViridisColor(t) {{
        t = Math.min(Math.max(t, 0), 1);
        var viridis = [
            {{t: 0.0, color: [68, 1, 84]}},
            {{t: 0.125, color: [71, 44, 122]}},
            {{t: 0.25, color: [59, 81, 139]}},
            {{t: 0.375, color: [44, 113, 142]}},
            {{t: 0.5, color: [33, 144, 141]}},
            {{t: 0.625, color: [39, 173, 129]}},
            {{t: 0.75, color: [92, 200, 99]}},
            {{t: 0.875, color: [170, 220, 50]}},
            {{t: 1.0, color: [253, 231, 37]}}
        ];
        for (var i = 0; i < viridis.length - 1; i++) {{
            if (t >= viridis[i].t && t <= viridis[i+1].t) {{
                var ratio = (t - viridis[i].t) / (viridis[i+1].t - viridis[i].t);
                var r = Math.floor(viridis[i].color[0] + ratio * (viridis[i+1].color[0] - viridis[i].color[0]));
                var g = Math.floor(viridis[i].color[1] + ratio * (viridis[i+1].color[1] - viridis[i].color[1]));
                var b = Math.floor(viridis[i].color[2] + ratio * (viridis[i+1].color[2] - viridis[i].color[2]));
                return [r, g, b];
            }}
        }}
        return viridis[viridis.length - 1].color;
    }}
    
    function drawTopImage() {{
        var canvas = document.getElementById("input-image");
        var ctx = canvas.getContext("2d");
        var img = new Image();
        img.onload = function() {{
            ctx.drawImage(img, 0, 0, canvasSize, canvasSize);
        }};
        img.src = "data:image/jpeg;base64,{image_base64}";
    }}
    
    function createGrid() {{
        var gridDiv = document.getElementById("attention-grid");
        gridDiv.innerHTML = "";
        for(var l=0; l<numLayers; l++) {{
            for(var h=0; h<numHeads; h++) {{
                var canvas = document.createElement("canvas");
                canvas.id = "attmap-" + l + "-" + h;
                canvas.width = gridSize;
                canvas.height = gridSize;
                canvas.style.width = (canvasSize/4) + "px";
                canvas.style.height = (canvasSize/4) + "px";
                canvas.style.border = "1px solid #000";
                gridDiv.appendChild(canvas);
            }}
        }}
    }}
    
    function reshapeToMatrix(arr, size) {{
        var matrix = [];
        for(var i=0; i<size; i++) {{
            matrix.push(arr.slice(i*size, (i+1)*size));
        }}
        return matrix;
    }}
    
    function drawAttentionHeatmap(canvasId, dataMatrix) {{
        var canvas = document.getElementById(canvasId);
        var ctx = canvas.getContext("2d");
        var size = dataMatrix.length;
        var flat = dataMatrix.flat();
        var minVal = Math.min(...flat);
        var maxVal = Math.max(...flat);
        var imgData = ctx.createImageData(size, size);
        
        for(var i=0; i<size; i++) {{
            for(var j=0; j<size; j++) {{
                var value = dataMatrix[i][j];
                var normVal = (maxVal - minVal) ? (value - minVal) / (maxVal - minVal) : 0;
                var rgb = getViridisColor(normVal);
                var index = (i * size + j) * 4;
                imgData.data[index] = rgb[0];
                imgData.data[index+1] = rgb[1];
                imgData.data[index+2] = rgb[2];
                imgData.data[index+3] = 255;
            }}
        }}
        ctx.putImageData(imgData, 0, 0);
    }}
    
    function updateAttentionMaps(queryPatch) {{
        var tokenIndex = queryPatch + 1;
        for(var l=0; l<numLayers; l++) {{
            for(var h=0; h<numHeads; h++) {{
                var attnVector = allAttentionMaps[l][h][tokenIndex].slice(1);
                var heatmap = reshapeToMatrix(attnVector, gridSize);
                drawAttentionHeatmap("attmap-" + l + "-" + h, heatmap);
            }}
        }}
    }}
    
    document.getElementById("input-image").addEventListener("mousemove", function(event) {{
        var rect = this.getBoundingClientRect();
        var x = Math.floor((event.clientX - rect.left) / patchSize);
        var y = Math.floor((event.clientY - rect.top) / patchSize);
        var patchIndex = y * gridSize + x;
        updateAttentionMaps(patchIndex);
    }});
    
    drawTopImage();
    createGrid();
    updateAttentionMaps(0);
    </script>
    '''
    display(HTML(html_code))


In [None]:
# Function to get attention maps via a hook; each returned tensor has shape (1, heads, tokens, tokens)
def get_attention_maps(model, img_tensor):
    cache = ActivationCache()
    cache.hook(model)
    with torch.no_grad():
        _ = model(img_tensor)
    attn = cache.get_attns()
    attn_list = []
    for layer_attn in attn:
        # Remove batch dimension and convert tensor to nested lists for JSON serialization
        attn_layer = layer_attn.squeeze(0).cpu().tolist()
        attn_list.append(attn_layer)
    return attn_list

# Preprocess a PIL image: returns tensor, bytes, and the original PIL image.
def preprocess_image(pil_image):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
    img_tensor = transform(pil_image).unsqueeze(0)
    with io.BytesIO() as buf:
        pil_image.save(buf, format="JPEG")
        image_bytes = buf.getvalue()
    return img_tensor, image_bytes, pil_image

# Generate interactive HTML.

# Instead of a file uploader, use a dropdown to select one of the batch images.
dropdown = widgets.Dropdown(options=[(f"Image {i}", i) for i in range(len(batch))],
                            description="Select image:")

models = {
    "vit": load_vit,
    "partmae": load_partmae,
    "dora": load_dora,
    "dino": load_dino,
    "mae": load_mae,
}
model = models["vit"]()
image_id = 0

model_dropdown = widgets.Dropdown(options=[(name, name) for name in models.keys()],
                                  description="Select model:")

def process_image(index):
    # Assuming the key for images in the dataset is "image"
    pil_image = batch["image"][index]
    logging.info("[Image Preprocessing] START")
    t0 = time.time()
    img_tensor, image_bytes, _ = preprocess_image(pil_image)
    t1 = time.time()
    logging.info(f"[Image Preprocessing] END (time: {t1 - t0:.2f}s)")
    logging.info(f"[Attention Maps] START")
    attn_maps = get_attention_maps(model, img_tensor.cuda())
    t2 = time.time()
    logging.info(f"[Attention Maps] END (time: {t2 - t1:.2f}s)")
    logging.info("[Interactive HTML] START")
    image_base64 = base64.b64encode(image_bytes).decode("utf-8")
    num_layers = len(attn_maps)
    num_heads = len(attn_maps[0]) if num_layers > 0 else 0
    generate_interactive_html(attn_maps, image_base64, num_layers, num_heads)
    t3 = time.time()
    logging.info(f"[Interactive HTML] END (time: {t3 - t2:.2f}s)")

def dropdown_changed(change):
    if change["type"] == "change" and change["name"] == "value":
        global image_id
        image_id = change["new"]
        logging.info(f"Processing image {image_id}")
        process_image(image_id)
        logging.info(f"Image {image_id} processed")

def model_dropdown_changed(change):
    if change["type"] == "change" and change["name"] == "value":
        global model
        logging.info(f"[Model Loading] START")
        t0 = time.time()
        logging.info(f"Loading model {change['new']}")
        model = models[change["new"]]()
        t1 = time.time()
        logging.info(f"[Model Loading] END (time: {t1 - t0:.2f}s)")
        logging.info(f"[Image Processing] START")
        t2 = time.time()
        process_image(image_id)
        t3 = time.time()
        logging.info(f"[Image Processing] END (time: {t3 - t2:.2f}s)")

display(dropdown)
display(model_dropdown)

dropdown.observe(dropdown_changed, names="value")
model_dropdown.observe(model_dropdown_changed, names="value")
# Process the initially selected image.
process_image(dropdown.value)
