In [1]:
# Model
import torch
import torch.nn as nn

class CNN(nn.Module):
    def __init__(self, img_size, num_class):
        super(CNN, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3,32,3,1,1),
            nn.ReLU())
        
        self.layer2 = nn.Sequential(
            nn.Conv2d(32,64,3,1,1),
            nn.ReLU(),
            nn.MaxPool2d(2,2))
        
        self.layer3 = nn.Sequential(
            nn.Conv2d(64,128,3,1,1),
            nn.ReLU(),
            nn.MaxPool2d(2,2))
        
        self.layer4 = nn.Sequential(
            nn.Conv2d(128,256,3,1,1),
            nn.ReLU(),
            nn.MaxPool2d(2,2))
        
        self.layer5 = nn.Sequential(
            nn.Conv2d(256,512,3,1,1),
            nn.ReLU())
    
        self.gap = nn.AvgPool2d(img_size // 8)
        
        self.classifier = nn.Linear(512, num_class)
        torch.nn.init.xavier_uniform_(self.classifier.weight)


    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        feature = self.layer5(out)
        out = self.gap(feature).view(feature.size(0), -1)
        out = self.classifier(out)
        return out, feature

In [2]:
# Util
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms

def set_trainloader(dataset_name, path, img_size, batch_size):
    transform = transforms.Compose(
        [transforms.Resize(img_size),
         transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),])
    if dataset_name == 'MNIST':
      dataset = dsets.MNIST(root=path, train=True, download=True, transform=transform)
    elif dataset_name == 'CIFAR10':
      dataset = dsets.CIFAR10(root=path, train=True, download=True, transform=transform)
    else:
      dataset = dsets.ImageFolder(root=path, transform=transform)
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2), len(dataset.classes)

def set_testloader(dataset_name, path, img_size, batch_size=1):
    transform = transforms.Compose(
        [transforms.Resize(img_size),
        transforms.ToTensor(),])
    if dataset_name == 'MNIST':
      dataset = dsets.MNIST(root=path, train=False, download=True, transform=transform)
    elif dataset_name == 'CIFAR10':
      dataset = dsets.CIFAR10(root=path, train=False, download=True, transform=transform)
    else:
      dataset = dsets.ImageFolder(root=path, transform=transform)
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2), len(dataset.classes)

In [3]:
# Train with CIFAR10
import os
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable

dataset = 'CIFAR10'
data_path = './data'
model_path = './model'
model_name = './CAMnet.pth'
img_size = 32
batch_size = 128
epoch = 5
epoch_box = 100
learning_rate = 0.001

if not os.path.exists(model_path):
  os.mkdir(model_path)

train_loader, num_class = set_trainloader(dataset, data_path, img_size, batch_size)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = CNN(img_size=img_size, num_class=num_class).to(device)

criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

min_loss = 100

print("Training Starts")
for ep in range(epoch):
  running_loss = 0.0
  total_loss = 0.0
  for i, (inputs, labels) in enumerate(train_loader, 0):
    inputs, labels = inputs.to(device), labels.to(device)

    optimizer.zero_grad()

    outputs, _ = net(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()

    running_loss += loss.item()
    total_loss += loss.item()
    if i % epoch_box == epoch_box-1:
      print('Epoch [%d/%d], Iteration [%d/%d], loss: %.4f' % 
            (ep+1, epoch, i+1, len(train_loader), running_loss/epoch_box))
      running_loss = 0.0
  epoch_loss = total_loss/len(train_loader)
  print('Epoch [%d/%d], Total Loss: %.4f' % (ep+1, epoch, epoch_loss))
  
  if epoch_loss < min_loss:
    min_loss = epoch_loss
    torch.save(net.state_dict(), os.path.join(model_path, model_name))
print('Training Finished')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data
Training Starts
Epoch [1/5], Iteration [100/390], loss: 2.0521
Epoch [1/5], Iteration [200/390], loss: 1.7433
Epoch [1/5], Iteration [300/390], loss: 1.5532
Epoch [1/5], Total Loss: 1.7039
Epoch [2/5], Iteration [100/390], loss: 1.3529
Epoch [2/5], Iteration [200/390], loss: 1.2555
Epoch [2/5], Iteration [300/390], loss: 1.1685
Epoch [2/5], Total Loss: 1.2265
Epoch [3/5], Iteration [100/390], loss: 1.0502
Epoch [3/5], Iteration [200/390], loss: 1.0298
Epoch [3/5], Iteration [300/390], loss: 0.9693
Epoch [3/5], Total Loss: 0.9988
Epoch [4/5], Iteration [100/390], loss: 0.8860
Epoch [4/5], Iteration [200/390], loss: 0.8745
Epoch [4/5], Iteration [300/390], loss: 0.8574
Epoch [4/5], Total Loss: 0.8617
Epoch [5/5], Iteration [100/390], loss: 0.7844
Epoch [5/5], Iteration [200/390], loss: 0.7575
Epoch [5/5], Iteration [300/390], loss: 0.7408
Epoch [5/5], Total Loss: 0.7564
Training Finished


In [4]:
# Test with CIFAR10
import torch
import numpy as np
dataset = 'CIFAR10'
data_path = './data'
model_path = './model'
model_name = 'CAMnet.pth'
img_size = 32

test_loader, num_class = set_testloader(dataset, data_path, img_size)

device = 'cuda' if torch.cuda.is_available else 'cpu'
net = CNN(img_size=img_size, num_class=num_class).to(device)
net.load_state_dict(torch.load(os.path.join(model_path, model_name)))
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs, _ = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

Files already downloaded and verified
Accuracy of the network on the 10000 test images: 72 %


In [15]:
# CAM with CIFAR10
import os
import cv2
import torch
import numpy as np
import torch.nn.functional as F
import torchvision.transforms as transforms

dataset = 'CIFAR10'
data_path = './data'
model_path = './model'
model_name = 'CAMnet.pth'
result_path = './result'
img_size = 32
result_num = 3


if not os.path.exists(result_path):
    os.mkdir(result_path)

test_loader, num_class = set_testloader(dataset, data_path, img_size)

device = 'cuda' if torch.cuda.is_available else 'cpu'
net = CNN(img_size=img_size, num_class=num_class).to(device)
net.load_state_dict(torch.load(os.path.join(model_path, model_name)))

feature_collection = []
def get_feature(input):
  _, feature = net(input)
  feature_collection.append(feature.cpu().data.numpy())

params = list(net.parameters())
weight_for_softmax = np.squeeze(params[-2].cpu().data.numpy())

def Do_CAM(feature, weigth_for_softmax, class_id):
  upsample_size = (img_size, img_size)
  _, c, h, w = feature.shape
  cam = np.dot(weight_for_softmax[class_id],feature.reshape(c, h*w))
  cam = cam.reshape(h, w)
  cam = (cam - np.min(cam)) / np.max(cam)
  cam = np.uint8(255 * cam)
  cv2.resize(cam, upsample_size)
  return cam

for i, (image, label) in enumerate(test_loader):
  PIL_image = transforms.ToPILImage()(image[0])
  PIL_image.save(os.path.join(result_path, 'img%d.png' %(i+1)))
  image, label = image.to(device), label.to(device)
  get_feature(image)
  out, _ = net(image)
  Sc = F.softmax(out, dim=1).data.squeeze()
  prob, id = Sc.sort(0,True)
  print("GT : %d, Pred : %d, Prob : %.2f" % (label.item(), id[0].item(), prob[0].item()))
  CAM = Do_CAM(feature_collection[0], weight_for_softmax, id[0].item())
  image = cv2.imread(os.path.join(result_path, 'img%d.png' % (i+1)))
  height, width, _ = image.shape
  heatmap = cv2.applyColorMap(cv2.resize(CAM, (width, height)), cv2.COLORMAP_JET)
  result = heatmap*0.5 + image*0.5
  cv2.imwrite(os.path.join(result_path, 'cam%d.png' % (i+1)), result) 
  if i+1 == result_num:
    break
  feature_collection.clear()

Files already downloaded and verified
GT : 5, Pred : 3, Prob : 0.50
GT : 9, Pred : 9, Prob : 0.78
GT : 6, Pred : 6, Prob : 1.00


In [16]:
# CAM with my Dataset
# Images should be in './data/OWN'
import os
import cv2
import torch
import numpy as np
import torch.nn.functional as F
import torchvision.transforms as transforms

dataset = 'OWN'
data_path = './data'
model_path = './model'
model_name = 'CAMnet.pth'
result_path = './result'
# img_size = 32
img_size = 128
result_num = 3

if not os.path.exists(result_path):
    os.mkdir(result_path)

# test_loader, num_class = set_testloader(dataset, data_path, img_size)
test_loader, _ = set_testloader(dataset, data_path, img_size)
num_class = 10

device = 'cuda' if torch.cuda.is_available else 'cpu'
net = CNN(img_size=img_size, num_class=num_class).to(device)
net.load_state_dict(torch.load(os.path.join(model_path, model_name)))

feature_collection = []
def get_feature(input):
  _, feature = net(input)
  feature_collection.append(feature.cpu().data.numpy())

params = list(net.parameters())
weight_for_softmax = np.squeeze(params[-2].cpu().data.numpy())

def Do_CAM(feature, weigth_for_softmax, class_id):
  upsample_size = (img_size, img_size)
  _, c, h, w = feature.shape
  cam = np.dot(weight_for_softmax[class_id],feature.reshape(c, h*w))
  cam = cam.reshape(h, w)
  cam = (cam - np.min(cam)) 
  cam = cam / np.max(cam)
  cam = np.uint8(255 * cam)
  cam = cv2.resize(cam, upsample_size)
  return cam

for i, (image, label) in enumerate(test_loader):
  PIL_image = transforms.ToPILImage()(image[0])
  PIL_image.save(os.path.join(result_path, 'img%d.png' %(i+1)))
  image, label = image.to(device), label.to(device)
  get_feature(image)
  out, _ = net(image)
  Sc = F.softmax(out, dim=1).data.squeeze()
  prob, id = Sc.sort(0,True)
  print("GT : %d, Pred : %d, Prob : %.2f" % (label.item(), id[0].item(), prob[0].item()))
  CAM = Do_CAM(feature_collection[0], weight_for_softmax, id[0].item())
  image = cv2.imread(os.path.join(result_path, 'img%d.png' % (i+1)))
  height, width, _ = image.shape
  heatmap = cv2.applyColorMap(cv2.resize(CAM, (width, height)), cv2.COLORMAP_JET)
  result = heatmap*0.3 + image*0.5
  cv2.imwrite(os.path.join(result_path, 'cam%d.png' % (i+1)), result) 
  if i+1 == result_num:
    break
  feature_collection.clear()

GT : 1, Pred : 3, Prob : 0.39
GT : 1, Pred : 2, Prob : 0.73
GT : 1, Pred : 2, Prob : 0.81
