# Visualization

## Visualizing Image and Steering angle

In [57]:
def show_sample(sample,pred_angle):
    r""" Helper function for (batch) sample visualization
    Args:
        sample: Dictionary
    """
    image_dims = len(sample['image'].shape)
    assert image_dims <= 5, "Unsupported image shape: {}".format(sample['image'].shape)
    if image_dims == 4:
        error = abs(pred_angle - sample['angle'])
        n = sample['image'].shape[0]
        sample['image'] = torch.Tensor(resize(sample['image'], (n,3,480,640),anti_aliasing=True))
        images = sample['image'].permute(0,2,3,1)
        fig = plt.figure(figsize=(155, 135))
        for i in range(n):
            ax = fig.add_subplot(10,5,i+1)
            ax.imshow(images[i])
            ax.axis('off')
            ax.set_title("t={}".format(sample['timestamp'][i]))
            ax.text(10, 30, sample['frame_id'][i], color='red')
            ax.text(10, 390, "error {:.3}".format(error[i].item()), color='red')
            ax.text(10, 430, "man-angle {:.3}".format(sample['angle'][i].item()), color='red')
            ax.text(10, 470, "pred-angle {:.3}".format(pred_angle[i].item()), color='red')
    else:
        #error = abs(pred_angle - sample['angle'])
        sample['image'] = sample['image'].permute(0,2,1,3,4)
        batch_size,seq_len,channel = sample['image'].shape[0],sample['image'].shape[1],sample['image'].shape[2]
        sample['image'] = torch.Tensor(resize(sample['image'], (batch_size,seq_len,3,480,640),anti_aliasing=True))
        n0 = sample['image'].shape[0]
        n1 = sample['image'].shape[1] if image_dims == 5 else 1
        images_flattened = torch.flatten(sample['image'], end_dim=-4)
        fig, ax = plt.subplots(n0, n1, figsize=(25, 15))
        for i1 in range(n1):
            for i0 in range(n0):
                image = images_flattened[i0 * n1 + i1]
                axis = ax[i0, i1]
                axis.imshow(image.permute(1,2,0))
                axis.axis('off')
                axis.set_title("t={}".format(sample['timestamp'][i0][i1]))
                axis.text(10, 30, sample['frame_id'][i0][i1], color='red')
                #axis.text(10, 390, "error {:.3}".format(error[i0][i1].item()), color='red')
                axis.text(10, 430, "man-angle {:.3}".format(sample['angle'][i0][i1].item()), color='red')
                axis.text(10, 470, "pred-angle {:.3}".format(pred_angle[i0][i1].item()), color='red')
    

## Visualizing CNN filters

In [49]:
import matplotlib.pyplot as plt
import numpy as np

def visualize_cnn(cnn, max_row=10, max_col=10, select_channel=None, tb_writer=None):
    assert isinstance(cnn, nn.Conv3d) or isinstance(cnn, nn.Conv2d)
    
    filters = cnn.weight.cpu().detach().numpy() # output_ch x input_ch x D x H x W
    output_ch = np.minimum(cnn.weight.shape[0], max_col)
    input_ch = np.minimum(cnn.weight.shape[1], max_row)
    print(filters.shape)
    
    if isinstance(cnn, nn.Conv3d):
        if select_channel:
            assert isinstance(select_channel, int)
            filters = filters[:, :, select_channel, :, :]
        else:
            filters = np.mean(filters[:, :, :, :, :], axis=2)
    
    plt_idx = 0
    fig = plt.figure(figsize=(output_ch, input_ch))
    plt.xlabel("Output Channel")
    plt.ylabel("Input Channel")
    frame1 = plt.gca()
    frame1.axes.xaxis.set_ticklabels([])
    frame1.axes.yaxis.set_ticklabels([])
    for o in range(output_ch):
        for i in range(input_ch):
            image = filters[o, i, :, :]
            plt_idx += 1
            ax = fig.add_subplot(input_ch, output_ch, plt_idx)
            ax.imshow(image)
            ax.axis('off')
#     plt.tight_layout()
    plt.show()

## GradCAM (class activation mapping)

In [58]:
from skimage.transform import resize

class CamExtractor():
    """
        Extracts cam features from the model
    """
    def __init__(self, model, register_hooks=True):
        self.model = model
        self.grad = None
        self.conv_output = None
        
        if register_hooks:
            self.register_hooks()
        
    def gradient_hook(self, model, grad_input, grad_output):
        self.grad = grad_output[0].cpu().detach().numpy()
        
    def conv_output_hook(self, model, input, output):
        self.conv_output = output.cpu().detach().numpy()
        
    def register_hooks(self):
        raise NotImplementedError("You should implement this method for your own model!")
        
    def forward(self, x):
        raise NotImplementedError("You should implement this method for your own model!")
        
    def to_image(self, height=None, width=None):
        assert self.grad is not None and self.conv_output is not None, "You should perform both forward pass and backward propagation first!"
        # both grad and conv_output should have the same dimension of: (*, channel, H, W)
        # we produce image(s) of shape: (*, H, W)
        channel_weight = np.mean(self.grad, axis=(-2, -1)) # *, channel
        conv_permuted = np.moveaxis(self.conv_output, [-2, -1], [0, 1]) # H, W, *, channel
        cam_image_permuted = channel_weight * conv_permuted # H, W, *, channel
        cam_image_permuted = np.mean(cam_image_permuted, axis=-1) # H, W, *
        cam_image = np.moveaxis(cam_image_permuted, [0, 1], [-2, -1]) # *, H, W
        
        if height is not None and width is not None:
            image_shape = list(cam_image.shape)
            image_shape[-2] = height
            image_shape[-1] = width
            cam_image = resize(cam_image, image_shape)
        return cam_image
        
    
class CamExtractorTLModel(CamExtractor):
    
    def register_hooks(self):
        self.model.ResNet.layer4.register_forward_hook(self.conv_output_hook)
        self.model.ResNet.layer4.register_backward_hook(self.gradient_hook)

class CamExtractorTLModel_regnetx(CamExtractor):
    
    def register_hooks(self):
        self.model.pretrained.trunk_output.block4.register_forward_hook(self.conv_output_hook)
        self.model.pretrained.trunk_output.block4.register_backward_hook(self.gradient_hook)

class CamExtractorTLModel_VGG(CamExtractor):
    
    def register_hooks(self):
        self.model.pretrained.features[19:30].register_forward_hook(self.conv_output_hook)
        self.model.pretrained.features[19:30].register_backward_hook(self.gradient_hook)

class CamExtractorTLModel_EffNetB7(CamExtractor):
    
    def register_hooks(self):
        self.model.pretrained.features[-1].register_forward_hook(self.conv_output_hook)
        self.model.pretrained.features[-1].register_backward_hook(self.gradient_hook)


class CamExtractorTLModel_wideresnet(CamExtractor):    
    def register_hooks(self):
        self.model.pretrained.layer4.register_forward_hook(self.conv_output_hook)
        self.model.pretrained.layer4.register_backward_hook(self.gradient_hook)

class CamExtractor3DCNN(CamExtractor):
    
    def gradient_hook(self, model, grad_input, grad_output):
        grad = grad_output[0].cpu().detach().numpy()
        self.grad = np.moveaxis(grad, 1, 2) # restore old dimension (batch x seq_len x channel x H x W)
        
    def conv_output_hook(self, model, input, output):
        conv_output = output.cpu().detach().numpy()
        self.conv_output = np.moveaxis(conv_output, 1, 2) # restore old dimension (batch x seq_len x channel x H x W)
    
    def register_hooks(self):
        self.model.Convolution6.register_forward_hook(self.conv_output_hook)
        self.model.Convolution6.register_backward_hook(self.gradient_hook)
    