In [1]:
import torch
from torchvision import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import create_transform
from timm.data.transforms import _pil_interp
from classification.focalnet import FocalNet, build_transforms, build_transforms4display

In [6]:
'''
build model
'''
img_size = 224

# multi-scale FocalNets
model = FocalNet(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3]).cuda()

# isotropic FocalNets
# model = FocalNet(depths=[12], patch_size=16, embed_dim=768, focal_levels=[3], focal_windows=[3], use_layerscale=True, use_postln=True,).cuda()

In [7]:
'''
build data transform
'''
eval_transforms = build_transforms(img_size, center_crop=False)
display_transforms = build_transforms4display(img_size, center_crop=False)

In [None]:
'''
load checkpoint
'''
ckpt_path = "focalnet_base_lrf.pth"
ckpt = torch.load(ckpt_path)
model.load_state_dict(ckpt['model']).eva()

In [None]:
import os
import numpy as np
import torch.nn as nn
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.image as image

In [None]:
# visualize modulator 
upsampler = nn.Upsample(scale_factor=4, mode='bilinear')

img_folder = "./figures"
img_paths = os.listdir(img_folder)
for img_path in img_paths:
    img = Image.open(img_folder + img_path)
    img_t = eval_transforms(img) 
    img_d = display_transforms(img)
    out = model(img_t.unsqueeze(0).cuda())    

    fig=plt.figure(figsize=(16, 8))
    
    fig.add_subplot(1, 2, 1)       
    img2d = img_d.permute(1, 2, 0).cpu().detach().contiguous().numpy()
    x = plt.imshow(img_d.permute(1, 2, 0).cpu().detach().contiguous().numpy())     
    plt.axis('off')
    x.axes.get_xaxis().set_visible(False)
    x.axes.get_yaxis().set_visible(False)    
    plt.subplots_adjust(wspace=None, hspace=None)

    fig.add_subplot(1, 2, 2)    
    modulator = torch.abs((model.layers[layer].blocks[-1].modulation.modulator)).mean(1, keepdim=True)
    modulator = upsampler(modulator)
    x = plt.imshow((modulator.squeeze(1)).permute(1, 2, 0).cpu().detach().contiguous().numpy())    
    plt.axis('off')
    x.axes.get_xaxis().set_visible(False)
    x.axes.get_yaxis().set_visible(False)    

    plt.subplots_adjust(wspace=0, hspace=0)    


In [None]:
# visualize gating maps 
upsampler = nn.Upsample(scale_factor=4, mode='bilinear')

img_folder = "./figures"
img_paths = os.listdir(img_folder)
for img_path in img_paths:
    img = Image.open(img_folder + img_path)
    img_t = eval_transforms(img) 
    img_d = display_transforms(img)
    out = model(img_t.unsqueeze(0).cuda())    

    fig=plt.figure(figsize=(16, 8))
    
    fig.add_subplot(1, 5, 1)       
    img2d = img_d.permute(1, 2, 0).cpu().detach().contiguous().numpy()
    x = plt.imshow(img_d.permute(1, 2, 0).cpu().detach().contiguous().numpy())     
    plt.axis('off')
    x.axes.get_xaxis().set_visible(False)
    x.axes.get_yaxis().set_visible(False)    

    gates = (model.layers[-1].blocks[-1].modulation.gates)
    for i in range(4):
        fig.add_subplot(1, 5, i+2)        
        gates_i = (upsampler(gates[:, i:i+1])).cpu().detach()
        plt.imshow(gates_i.permute(1,2,0).numpy())        
        plt.axis('off')
        x.axes.get_xaxis().set_visible(False)
        x.axes.get_yaxis().set_visible(False)    

In [None]:
# display learned focal kernel weights

fig=plt.figure(figsize=(8, 8))  
for id_layer in range(4):
    fig.add_subplot(1, 4, id_layer+1)
    x = plt.imshow(model.layers[id_layer].blocks[-1].modulation.focal_layers[0][0].weight.data.mean(0).cpu().permute(1, 2, 0).contiguous().numpy())    
    plt.axis('off')
    x.axes.get_xaxis().set_visible(False)
    x.axes.get_yaxis().set_visible(False)
    
fig=plt.figure(figsize=(8, 8))   
for id_layer in range(4):
    fig.add_subplot(1, 4, id_layer+1)
    plt.imshow(model.layers[id_layer].blocks[-1].modulation.focal_layers[1][0].weight.data.mean(0).cpu().permute(1, 2, 0).contiguous().numpy())
    plt.axis('off')
    x.axes.get_xaxis().set_visible(False)
    x.axes.get_yaxis().set_visible(False)
    
fig=plt.figure(figsize=(8, 8))   
for id_layer in range(4):
    fig.add_subplot(1, 4, id_layer+1)
    plt.imshow(model.layers[id_layer].blocks[-1].modulation.focal_layers[2][0].weight.data.mean(0).cpu().permute(1, 2, 0).contiguous().numpy())    
    plt.axis('off')
    x.axes.get_xaxis().set_visible(False)
    x.axes.get_yaxis().set_visible(False)    