In [None]:
# run the script initnotebook.py in the cuurent folder
%run initnotebook.py

In [None]:
import torch
from  torch.utils.data import DataLoader
import torch.nn as nn
from models.builder import EncoderDecoder as segmodel
from dataloader.cfg_defaults import get_cfg_defaults
from config_cityscapes import *
import os
from dataloader.cityscapes_dataloader import CityscapesDataset
from val_segformer_rgbonly import val_cityscape
import torch.nn.functional as F
from utils.visualize import unnormalize_img_numpy
import matplotlib.pyplot as plt
from visualizer.visualizer import *

In [None]:
config_path = 'dataloader/cityscapes_rgbd_config.yaml'
config_path = os.path.join(projectFolder, config_path)

cfg = get_cfg_defaults()
cfg.merge_from_file(config_path)
cfg.freeze()

data_mean = [0.291,  0.329,  0.291]
data_std = [0.190,  0.190,  0.185]

In [None]:
cityscapes_test = CityscapesDataset(cfg, split='train')
test_loader = DataLoader(cityscapes_test, batch_size=1, shuffle=False, num_workers=4) # batchsize?
print(f'total test sample: {len(cityscapes_test)} v_iteration:{len(test_loader)}')


In [None]:
def find_index_of_image(test_loader, image_name):
    files = test_loader.dataset.files
    for idx, path in enumerate(files['train']):
        if image_name in path:
            return idx
    return -1

In [None]:
# img_path = '/home/abjawad/Documents/GitHub/local-attention-model/data/Cityscapes/leftImg8bit/train/cologne/cologne_000008_000019_leftImg8bit.png'
# img = cv2.imread(img_path)
# img = torch.from_numpy(img)
# img = img.permute(2, 0, 1)

image_name = 'cologne_000008_000019_leftImg8bit.png'
index = find_index_of_image(test_loader, image_name)
print(index)

In [None]:
pretrained_model_path = './pretrained/model_400.pth'
criterion = nn.CrossEntropyLoss(reduction='mean', ignore_index=config.background)

model = segmodel(cfg=config, criterion=criterion, norm_layer=nn.BatchNorm2d, test=True)
model = nn.DataParallel(model, device_ids = config.device_ids)
model.to(f'cuda:{model.device_ids[0]}', non_blocking=True)

In [None]:
saved_model_path = os.path.join(projectFolder, pretrained_model_path)
print(saved_model_path)
# exit()
state_dict = torch.load(saved_model_path)
model.load_state_dict(state_dict['model'], strict=False)
print(f'model loaded')

In [None]:
model.eval()
with torch.no_grad():
    for idx, sample in enumerate(test_loader):
        if idx != index:
            continue
        imgs = sample['image']      #B, 3, 1024, 2048
        gts = sample['label']       #B, 1024, 2048
        imgs = imgs.to(f'cuda:{model.device_ids[0]}', non_blocking=True)
        gts = gts.to(f'cuda:{model.device_ids[0]}', non_blocking=True)

        img = imgs[:, :, :, 1024:]
        gt = gts[:, :, :1024]
        loss, out, atten = model(img, gt, visualize=True, attention=True)
        print('loss = ', loss.shape)
        print('out = ', out.shape)
        print('atten = ', len(atten))
        # plot_output(img, gt, out)

        print(img.shape, gt.shape)
        # print(len(out), out[0].shape, out[1].shape)
        break

In [None]:
def get_attention_matrix(attention, layer, head):
    # sanity check
    layer = layer - 1
    head = head - 1
    if layer > len(attention):
        print('layer index out of range')
        return None
    if head > len(attention[layer]):
        print('head index out of range')
        return None
    atten = attention[layer][head]
    atten = atten.cpu().numpy()
    if len(atten.shape) == 4:
        atten = atten[0]
    return atten



In [None]:
def plot_attention(img, pixel, attention, layer, head, target_size, alpha):
    factor = 4
    layer_factor = 2 ** (layer - 1)
    downsample_factor = factor * layer_factor

    attention = get_attention_matrix(attention, layer, head)
    print(attention.shape)

    patch_size = int(np.sqrt(attention.shape[1]))
    unnormalized_image = unnormalize_img_numpy(img)
    rescaled_image_layer = cv2.resize(unnormalized_image, target_size)
    attention_map_layer = np.zeros(target_size)

    downsized_pixel = (pixel[0] // downsample_factor, pixel[1] // downsample_factor)
    downsized_image = unnormalized_image[::downsample_factor, ::downsample_factor]

    array_shape = downsized_image.shape
    patch_idx = (pixel[1] // downsample_factor // patch_size) * (array_shape[1] // patch_size) + (pixel[0] // downsample_factor // patch_size)

    pixel_inside_patch = (((pixel[0] // downsample_factor) % patch_size), ((pixel[1] // downsample_factor) % patch_size))
    pixel_idx_inside_patch = pixel_inside_patch[1] * patch_size + pixel_inside_patch[0]
    attention_patch = attention[patch_idx]
    attention_pixel = attention_patch[pixel_idx_inside_patch].reshape(patch_size, patch_size)

    upscaled_attention_pixel = cv2.resize(attention_pixel, (layer_factor * patch_size, layer_factor * patch_size))
    patch_start = (patch_idx // int(np.sqrt(attention.shape[0])) * patch_size * layer_factor, patch_idx % int(np.sqrt(attention.shape[0])) * patch_size * layer_factor)
    attention_map_layer[patch_start[0]:patch_start[0] + upscaled_attention_pixel.shape[0], patch_start[1]:patch_start[1] + upscaled_attention_pixel.shape[1]] = upscaled_attention_pixel
    plt.figure(figsize=(7, 7))
    plt.imshow(rescaled_image_layer)
    plt.imshow(attention_map_layer, alpha=alpha, cmap='viridis')
    plt.show()

In [None]:
pixel = (250, 420)
target_size = (256, 256)

In [None]:
# Convert the tensor to a NumPy array
# image_array = image[0].permute(1, 2, 0).cpu().numpy()
unnormalized_image = unnormalize_img_numpy(img)
image_array = unnormalized_image

plt.figure(figsize=(10, 10))
plt.imshow(image_array)

# Highlight the pixel at (100, 100) in red
plt.scatter(pixel[0], pixel[1], c='red', marker='o')

plt.title('Input Image with Highlighted Pixel')
plt.show()

In [None]:
for l in range(1, len(atten)+1):
    for h in range(1, len(atten[l-1])+1):
        print('layer', l, 'head', h)
        plot_attention(img, pixel, atten, l, h, target_size, 0.5)

    