In [1]:
import torch
import torch.nn.functional as F
from torch import nn
from torchvision import models
from torchvision.utils import make_grid

import os
import matplotlib.pyplot as plt
from utils import Wrapper
from datasets import MosquitoDL_loaders


os.environ['CUDA_VISIBLE_DEVICES'] ='0'

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
weight_path = [
    '/home/ryan/backup/CutMix-PyTorch/pretrained/R50_Mosquito_Vanilla_L5E-3_W4E-5/best_model.pth',
    '/home/ryan/backup/CutMix-PyTorch/pretrained/R50_Mosquito_MCACM_K5_P09_L5e-3_W4E-5/best_model.pth'
]

save_path = "./samples_MCACM"
dataset_path = '~/datasets/MosquitoDL'
stage_names = ['layer1','layer2','layer3','layer4']

os.makedirs(save_path, exist_ok = True)

In [3]:
train_loader, test_loader, num_classes = MosquitoDL_loaders(dataset_path, batch_size=1)
del train_loader

Dataset with length 2980


In [4]:
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = nn.DataParallel(model)
model = Wrapper(model, stage_names=stage_names)

state_dict = torch.load(weight_path[1])
model.load_state_dict(state_dict['model'])

model.to(device)


Registered forward/backward hook on 'module.layer1'
Registered forward/backward hook on 'module.layer2'
Registered forward/backward hook on 'module.layer3'
Registered forward/backward hook on 'module.layer4'


Wrapper(
  (net): DataParallel(
    (module): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=

In [5]:
model.eval()

test_loss = 0
test_acc = 0

test_n_samples = 0
test_n_corrects = 0

criterion = nn.CrossEntropyLoss()

for i, data in enumerate(test_loader):
    # print(f"- Current Batch [{i}/{len(test_loader)}]")
    
    batch, labels = data[0].to(device), data[1].to(device)

    pred = model(batch)

    pred_max = torch.argmax(pred, 1)

    loss = criterion(pred, labels)

    loss.backward()

    target_fmap = model.dict_activation['layer3']
    target_gradients = model.dict_gradients['layer3'][0]
    
    model.clear_dict()

    N, C, W_f, H_f = target_fmap.shape

    importance_weights = F.adaptive_avg_pool2d(target_gradients, 1) # [N x C x 1 x 1]

    class_activation_map = torch.mul(target_fmap, importance_weights).sum(dim=1, keepdim=True) # [N x 1 x W_f x H_f]
    class_activation_map = F.relu(class_activation_map).squeeze(dim=1) # [N x W_f x H_f]

    class_activation_map = F.interpolate(class_activation_map.unsqueeze(1), 
        size=batch.shape[-2:], mode='nearest').squeeze(0)

    input_np = batch[0].detach().cpu()
    cam_np = class_activation_map[0].detach().cpu()

    fig, ax = plt.subplots(1,2,figsize=(20,10))
    ax[0].imshow(input_np.permute([1,2,0]))
    ax[0].set_title(f"Input (Class:{labels[0]})\n(Pred:{pred_max[0]})")
    ax[0].axis('off')
    ax[1].matshow(cam_np, cmap='viridis')
    ax[1].set_title(f"CAM (Class:{labels[0]})")
    ax[1].axis('off')
    fig.savefig(os.path.join(save_path, f"Test_Batch_{i}.jpg"))
    plt.draw()
    plt.clf()
    plt.close("all")

    test_loss += loss.detach().cpu().numpy()
    test_n_samples += labels.size(0)
    test_n_corrects += torch.sum(pred_max == labels).detach().cpu().numpy()

test_loss /= len(test_loader)
test_acc = test_n_corrects/test_n_samples

In [6]:
print(f"Loss: {test_loss}, Acc.: {test_acc*100:.2f}%")

Loss: 0.07083830005651327, Acc.: 98.82%
