Before running inference:
1. Run [1.0_Model_download.ipynb](1.0_Model_download.ipynb) to download models locally.

#### Necessary imports

In [1]:
import os
import cv2
import rootutils
from PIL import Image
from IPython.display import display

import torch
import torchvision.transforms as transforms

# adding root to python path
rootutils.setup_root(
    os.path.abspath(''), indicator=['.git', 'pyproject.toml'], pythonpath=True
)

from helpers.processing import display_img_with_map
from src.models.components.cnn_cam_multihead import CNNCAMMultihead
from src.models.components.vit_rollout_multihead import VitRolloutMultihead
from src.models.components.nn_utils import weight_load

#### Load sample images

In [None]:
image = cv2.imread('../docs/sample_data/01_short_04_1926_1070.png', cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
display(Image.fromarray(image))
label = cv2.imread(
    '../docs/sample_data/01_short_04_1926_1070_label.png', cv2.IMREAD_GRAYSCALE
)
display(Image.fromarray(label))

#### Device setup

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

#### Transforms

In [4]:
transform = transforms.Compose([transforms.ToTensor()])
image_tensor = transform(image).unsqueeze(0).to(device)

#### CNN with cam output

In [None]:
model = CNNCAMMultihead(
    backbone='torchvision.models/efficientnet_v2_s',
    return_node='features.6.0.block.0',
    multi_head=True,
).to(device)
weights = weight_load(
    ckpt_path='../trained_models/models--DeepVisionXplain--efficientnet_v2_s_downscaled_pcb/',
    weights_only=True,
)
model.load_state_dict(weights)
model.eval()

In [None]:
%%timeit
with torch.no_grad():
    model(image_tensor)

In [None]:
with torch.no_grad():
    out, map = model(image_tensor)
display_img_with_map(out.cpu().numpy(), map.cpu().numpy(), image)

#### ViT with attention rollout output

In [None]:
model = VitRolloutMultihead(
    backbone='timm/vit_tiny_patch16_224.augreg_in21k_ft_in1k', multi_head=True
).to(device)
weights = weight_load(
    ckpt_path='../trained_models/models--DeepVisionXplain--vit_tiny_patch16_224.augreg_in21k_ft_in1k_pcb/',
    weights_only=True,
)
model.load_state_dict(weights)
model.eval()

In [None]:
%%timeit
with torch.no_grad():
    model(image_tensor)

In [None]:
with torch.no_grad():
    out, map = model(image_tensor)
display_img_with_map(out.cpu().numpy(), map.cpu().numpy(), image)