## 의류 이미지 분류와 Grad-CAM을 이용한 영역 모델 해석

> 이미지 데이터를 어떻게 처리하고 딥러닝을 해석할 것인지 고민할 필요가 있다.


In [None]:
%matplotlib inline  

import matplotlib as mpl 
import matplotlib.pyplot as plt 
import matplotlib.font_manager as fm  

!apt-get update -qq
!apt-get install fonts-nanum* -qq

path = '/usr/share/fonts/truetype/nanum/NanumBarunGothic.ttf' 
font_name = fm.FontProperties(fname=path, size=10).get_name()
print(font_name)
plt.rc('font', family=font_name)

fm._rebuild()
mpl.rcParams['axes.unicode_minus'] = False

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import numpy as np
import pandas as pd
from glob import glob
import random

import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib
from matplotlib import font_manager, rc

from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchsummary import summary
import torchvision

from torchvision import transforms
import albumentations as A
from albumentations import Compose
from albumentations.pytorch import ToTensorV2
import cv2
from google.colab.patches import cv2_imshow
from PIL import Image
import re
from tqdm import tqdm

import gc
import torch

from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
from sklearn.metrics import f1_score, precision_score

import warnings  
warnings.filterwarnings(action = 'ignore')

print(A.__version__)
!nvidia-smi

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device 

In [None]:
seed = 42

os.environ['PYTHONHASHSEED'] = str(seed)

random.seed(seed)
np.random.seed(seed)

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

In [None]:
df = pd.read_csv("/content/drive/MyDrive/data/musinsa_high/musinsa.csv", index_col = 0)
df.info()

In [None]:
df_small = df.sample(frac=0.01)
df_small.info()

In [None]:
# 의류 이름 컬럼 생성
def re_h(x):
    x = re.compile('[a-zA-Z0-9/_.]').sub('', x).strip()
    return x

def cut_label(x):
    x = x[:int(len(x)/2)]
    return x

In [None]:
df['name'] = df['path'].apply(re_h).apply(cut_label)
df['name'].value_counts()

In [None]:
# fig = plt.figure(figsize = (15, 15)),
# label_count = df['label'].value_counts()
# label_count
# label_count = df['name'].value_counts()

# sns.barplot(x = df['name'], y = df['name'].value_counts())

In [None]:
# fig = plt.figure(figsize = (15, 15))

# sns.barplot(x = df_small['name'], y = df_small['label'])

In [None]:
# custom_dataset 생성

class Custom_Dataset(Dataset):
    def __init__(self, x, y = None, transforms = None):
        self.x = x
        self.y = y
        self.transforms = transforms
    
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):

        img = Image.open(self.x[idx]).convert("RGB")
        img = np.array(img)
        label = self.y[idx]

        if self.transforms:
            img = self.transforms(image = img)
            img = img['image']
            
        return img, label

In [None]:
train_val_transform = A.Compose([
    A.Resize(256, 256),
    # A.CenterCrop(224, 224, always_apply=True, p=1.0), 
    A.Normalize(mean = (0.48, 0.48, 0.48), std = (0.48, 0.48, 0.48)),
    ToTensorV2()
])

test_transform = A.Compose([
    A.Resize(256, 256),
    # A.CenterCrop(224, 224, always_apply=True, p=1.0),
    A.Normalize(mean = (0.48, 0.48, 0.48), std = (0.48, 0.48, 0.48)),
    ToTensorV2()
])

transform = A.Compose([
    A.Resize(256, 256),
    # A.CenterCrop(224, 224, always_apply=True, p=1.0),
    A.Normalize(mean = (0.48, 0.48, 0.48), std = (0.48, 0.48, 0.48)),
    ToTensorV2()
])

### train-val-test(8:1:1) 나누기

In [None]:
x_train, x_test, y_train, y_test = train_test_split(df['path'], df['label'],
                                                  test_size = 0.3, shuffle = True,
                                                  stratify = df['label'], random_state = 42)

x_val, x_test, y_val, y_test = train_test_split(x_test, y_test,
                                                  test_size = 0.5, shuffle = True,
                                                  stratify = y_test, random_state = 42)

x_train = list(x_train)
x_val = list(x_val)
y_train = list(y_train)
y_val = list(y_val)
x_test = list(x_test)
y_test = list(y_test)

In [None]:
# x_train, x_test, y_train, y_test = train_test_split(df_small['path'], df_small['label'],
#                                                   test_size = 0.3, shuffle = True,
#                                                   stratify = df_small['label'], random_state = 42)

# x_val, x_test, y_val, y_test = train_test_split(x_test, y_test,
#                                                   test_size = 0.5, shuffle = True,
#                                                   random_state = 42)

# x_train = list(x_train)
# x_val = list(x_val)
# y_train = list(y_train)
# y_val = list(y_val)
# x_test = list(x_test)
# y_test = list(y_test)

In [None]:
print(f'x_train: {len(x_train)}, y_train: {len(y_train)}\nx_val: {len(x_val)}, y_val: {len(y_val)}\nx_test: {len(x_test)}, y_test: {len(y_test)}')

In [None]:
# DataLoader
# 배치 사이즈: 64
# train만 shuffle 허용

train_dataset = Custom_Dataset(x_train, y_train, transforms = transform)
train_loader = DataLoader(train_dataset, batch_size = 64, shuffle = True, drop_last = True, num_workers=6)

val_dataset = Custom_Dataset(x_val, y_val, transforms = transform)
val_loader = DataLoader(val_dataset, batch_size = 64, shuffle = False, drop_last = True, num_workers=6)

test_dataset = Custom_Dataset(x_test, y_test, transforms = test_transform)
test_loader = DataLoader(test_dataset, batch_size = 64, shuffle = False, drop_last = True, num_workers=0)

In [None]:
# 원본 이미지와 transform 된 이미지 비교

fig = plt.figure(figsize=(15, 15))
def custom_imshow(img_transform, img):
    img_transform = img_transform.numpy()
    print(img_transform.shape)
    print(img)
    # img_transform = (img_transform * 255).astype(np.uint8)

    ax1 = fig.add_subplot(1, 2, 1)
    plt.imshow(img)
    ax1.set_title('Original_Image', size = 20)
    ax1.axis("off")

    ax2 = fig.add_subplot(1, 2, 2)
    plt.imshow(np.transpose(img_transform, (1, 2, 0))) # (channel
    ax2.set_title('Transform_Image', size = 20)
    ax2.axis("off")
    plt.show()

img = Image.open(df.path[323]).convert('RGB')
img = np.array(img)
img_transform = transform(image = img)
img_transform = img_transform['image']

custom_imshow(img_transform, img)

In [None]:
class EarlyStopping:
    def __init__(self, patience, verbose = False, delta = 0):
        self.patience = patience
        self.verbose = verbose
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.counter = 0
        self.delta = delta
        self.path = '/content/drive/MyDrive/data/save_data/best_model.pth' # pth pt

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0
    
    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).\n Saving model ...\n')
            torch.save(model.state_dict(), self.path)
            self.val_loss_min = val_loss
            

In [None]:
def train(model, optimizer, epoch, scheduler, device):
    model = model.to(device)
    early_stopping = EarlyStopping(patience = 5, verbose = True) 

    train_losses, val_losses = [], [] # train, val loss 리스트로 저장

    for epoch in range(epoch):
        print(f"\n------ {epoch} epoch -------\n")
        model.train()
        train_loss_list = []
        train_loss = 0.0
        
        for img, label in tqdm(iter(train_loader)):
            img, label = img.to(device), label.to(device)

            optimizer.zero_grad()
            pred = model(img)
            loss = criterion(pred, label)

            loss.backward()
            optimizer.step()
            # train_loss += loss.item()

            train_loss_list.append(loss.item())

        if scheduler is not None:
            scheduler.step()

        model.eval()

        val_loss = 0.0
        val_loss_list = []
        model_preds = []
        true_labels = []

        correct = 0
        
        with torch.no_grad():
            for img, label in tqdm(iter(val_loader)):
                img, label = img.to(device), label.to(device)
                
                val_pred = model(img)
                v_loss = criterion(val_pred, label)

                # val_loss += v_loss.item()

                val_loss_list.append(v_loss.item()) 

                model_preds += val_pred.argmax(1).detach().cpu().numpy().tolist()
                true_labels += label.detach().cpu().numpy().tolist()

                pred = val_pred.argmax(dim=1, keepdim=True)
                correct += pred.eq(label.view_as(pred)).sum().item()
        
        train_lossed = np.mean(train_loss_list)
        val_lossed = np.mean(val_loss_list)

        train_losses.append(train_lossed)
        val_losses.append(val_lossed)

        val_accuracy = 100 * correct / len(val_loader.dataset)
        precision = precision_score(true_labels, model_preds, average = 'micro')
        f1 = f1_score(true_labels, model_preds, average = "weighted")


        print(f"\nTrain loss: {train_lossed:.4f}")
        print(f"Val Loss: {val_lossed:.4f}")
        print(f"precsion ------> {precision:.5f}")
        print(f"f1_score ------> {f1:.5f}")
        print(f"{correct} / {len(val_loader.dataset)}, Accuracy: {val_accuracy:.3f}%\n")

        early_stopping(val_lossed, model)
        if early_stopping.early_stop:
            print('Early Stopping')
            break

        model.load_state_dict(torch.load('/content/drive/MyDrive/data/save_data/best_model.pth'))
    return train_losses, val_losses
   

In [None]:
def predict(model, test_loader, device):
    model.load_state_dict(torch.load('/content/drive/MyDrive/data/save_data/best_model.pth'), strict = False)

    model.to(device)
    model.eval()  # test eval 
    model_pred = []
    true_labels = []
    correct = 0
    
    with torch.no_grad():
        for img, label in tqdm(iter(test_loader)):
            img, label = img.to(device), label.to(device)
            
            test_pred = model(img)

            model_pred += test_pred.argmax(1).detach().cpu().numpy().tolist()
            true_labels += label.detach().cpu().numpy().tolist()

            pred = test_pred.argmax(dim=1, keepdim=True)
            correct += pred.eq(label.view_as(pred)).sum().item()
        
    test_accuracy = 100 * correct / len(test_loader.dataset)
    precision = precision_score(true_labels, model_pred, average = 'micro')
    f1 = f1_score(true_labels, model_pred, average = "weighted")

    print(f"\nprecsion ------> {precision:.5f}")
    print(f"f1_score ------> {f1:.5f}")
    print(f"{correct} / {len(test_loader.dataset)}, Accuracy: {test_accuracy:.3f}%\n")


In [None]:
class ResNet34(nn.Module):
    def __init__(self):
        super(ResNet34, self).__init__()
        self.model = torchvision.models.resnet50(pretrained = True)
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs, 29)
    
    def forward(self, x):
        return self.model(x)

class MobilenetV2(nn.Module):
    def __init__(self):
        super(MobilenetV2, self).__init__()
        self.model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
        self.drop = nn.Dropout(p = 0.3)
        self.model.classifier[1] = nn.Linear(1280, 29)

    def forward(self, x):
        x = self.drop(x)
        return self.model(x)


class Efficientnet(nn.Module):
    def __init__(self):
        super(Efficientnet, self).__init__()
        self.model = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_efficientnet_b0', pretrained=True)
        self.drop = nn.Dropout(p = 0.3)
        self.model.classifier.fc = nn.Linear(1280, 29)
    
    def forward(self, x):
        x = self.drop(x)
        return self.model(x)

In [None]:
# 메모리 캐시 정리

gc.collect()
torch.cuda.empty_cache()

In [None]:
criterion = nn.CrossEntropyLoss()

# 1 epoch 기준 
resnet = ResNet34() # 7분 42초
mobilenet = MobilenetV2() # 3분 13초
efficientnet = Efficientnet() # 4분 5초

In [None]:
resnet_optimizer = optim.Adam(params = resnet.parameters(), lr = 0.001)
resnet_scheduler = optim.lr_scheduler.LambdaLR(resnet_optimizer, lr_lambda = lambda epoch:0.95**epoch,
                                        last_epoch = -1, verbose = False)
    
train_loss, val_loss = train(resnet, resnet_optimizer, 20, resnet_scheduler, device)

In [None]:
mobilenet_optimizer = optim.Adam(params = mobilenet.parameters(), lr = 0.001)
scheduler = optim.lr_scheduler.LambdaLR(mobilenet_optimizer, lr_lambda = lambda epoch:0.95**epoch,
                                         last_epoch = -1, verbose = False)

train_loss, val_loss = train(mobilenet, mobilenet_optimizer, 20, scheduler, device)

In [None]:
efficientnet_optimizer = optim.Adam(params = efficientnet.parameters(), lr = 0.001)
efficientnet_scheduler = optim.lr_scheduler.LambdaLR(efficientnet_optimizer, lr_lambda = lambda epoch:0.95**epoch,
                                                     last_epoch = -1, verbose = False)
    
train_loss, val_loss = train(efficientnet, efficientnet_optimizer, 20, efficientnet_scheduler, device)

In [None]:
# train , val loss 시각화

fig = plt.figure(figsize = (12, 12))
plt.plot(train_loss, marker = 'o', label='train_loss')
plt.plot(val_loss, marker = 'o', label='val_loss')
plt.legend()
plt.title('train vs valid')

In [None]:
predict(mobilenet, test_loader, device)

In [None]:
# 마지막 layer
# 
final_conv = mobilenet.model.features[18]
print(final_conv)
fc_params = list(mobilenet.model.classifier.parameters())

# 현재 Grad-CAM 시각화 진행 중

In [None]:
class SaveFeatures():
    """ Extract pretrained activations"""
    features = None
    def __init__(self, m):
        self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output):
        self.features = ((output.cpu()).data).numpy()
    def remove(self):
        self.hook.remove()


# def getCAM(feature_conv, weight_fc, class_idx, cur_images):
#     _, nc, h, w = feature_conv.shape

#     cam = weight_fc[class_idx].dot(feature_conv[0, :, :, ].reshape((nc, h*w)))
#     # print(cam)
#     fig = plt.figure(figsize=(30, 30))
#     # cam = cam[0, :].reshape(h, w)
#     # cam = cam - np.min(cam)
#     # cam_img = cam / np.max(cam)
#     # print(cam_img.shape)
#     for i in range(10):
#         cam = weight_fc[class_idx].dot(feature_conv[0, :, :, ].reshape((nc, h*w)))
#         cam = cam[i, :].reshape(h, w)
#         cam = cam - np.min(cam)
#         cam_img = cam / np.max(cam)
#         # cam_img = cam_img.astype(np.uint16)
        
#         ax = fig.add_subplot(1, 10, i+1, xticks=[], yticks=[])
#         plt.imshow(cv2.cvtColor(cur_images[i], cv2.COLOR_BGR2RGB))
#         plt.imshow(cv2.resize(cam_img, (224, 224), interpolation=cv2.INTER_LINEAR), alpha=0.4, cmap='jet')
#         ax.set_title('Label:%d, Predict:%d' % (target, pred_idx), fontsize=14)
#         # return cam_img



def plotGradCAM(model, final_conv, fc_params, train_loader, 
                row=1, col=10, img_size=224, device='cpu', original=False):
    for param in model.parameters():
        param.requires_grad = False
    
    model.load_state_dict(torch.load('/content/drive/MyDrive/data/save_data/best_model.pth'), strict = False)
    model.to(device)
    model.eval()
    # save activated_features from conv
    activated_features = SaveFeatures(final_conv)
    # save weight from fc
    weight = np.squeeze(fc_params[0].cpu().data.numpy())
    # original images
    # heatmap images

    # fig = plt.figure(figsize=(10, 10))
    for i, (img, target) in enumerate(test_loader):
        output = model(img.to(device))
        pred_idx = output.to('cpu').numpy().argmax(1)
        cur_images = img.cpu().numpy().transpose((0, 2, 3, 1))

        # getCAM(activated_features.features, weight, pred_idx, cur_images)

        _, nc, h, w = activated_features.features.shape # chanel, height, width 언패킹

        # cam = weight[pred_idx].dot(activated_features.features[0, :, :, ].reshape((nc, h*w)))
        # print(cam)
        fig = plt.figure(figsize=(30, 30))
        # cam = cam[0, :].reshape(h, w)
        # cam = cam - np.min(cam)
        # cam_img = cam / np.max(cam)
        # print(cam_img.shape)
        for i in range(len(img)):
            cam = weight[pred_idx].dot(activated_features.features[0, :, :, ].reshape((nc, h*w)))
            cam = cam[i, :].reshape(h, w)
            cam = cam - np.min(cam)
            cam_img = cam / np.max(cam)
            # cam_img = cam_img.astype(np.uint16)
            ax = fig.add_subplot(1, 5, i+1, xticks=[], yticks=[])
            plt.imshow(cv2.cvtColor(cur_images[i], cv2.COLOR_BGR2RGB))
            plt.imshow(cv2.resize(cam_img, (224, 224), interpolation=cv2.INTER_LINEAR), alpha=0.4, cmap='jet')
            # ax.set_title('Label:%s, Predict:%s'% (df['name'].unique()[target[i]], df['name'].unique()[pred_idx[i]]), fontsize=14)
            ax.set_title('Label:%d, Predict:%d'% (target[i], pred_idx[i]), fontsize=14)
        # heatmap = getCAM(activated_features.features, weight, pred_idx)
        # ax = fig.add_subplot(row, col, i+1, xticks=[], yticks=[])

        # plt.imshow(cv2.cvtColor(cur_images[0], cv2.COLOR_BGR2RGB))
        # plt.imshow(cv2.resize(heatmap, (img_size, img_size), interpolation=cv2.INTER_LINEAR), alpha=0.4, cmap='jet')
        # ax.set_title('Label:%d, Predict:%d' % (target, pred_idx), fontsize=14)
        if i == row*col-1:
            break
    plt.show()

In [None]:
class SaveFeatures():
    """ Extract pretrained activations"""
    features = None
    def __init__(self, m):
        self.hook = m.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output):
        self.features = ((output.cpu()).data).numpy()
    def remove(self):
        self.hook.remove()



def getCAM(feature_conv, weight_fc, class_idx, cur_images):
    _, nc, h, w = feature_conv.shape

    cam = weight_fc[class_idx].dot(feature_conv[0, :, :, ].reshape((nc, h*w)))
    # print(cam)
    fig = plt.figure(figsize=(30, 30))
    # cam = cam[0, :].reshape(h, w)
    # cam = cam - np.min(cam)
    # cam_img = cam / np.max(cam)
    # print(cam_img.shape)
    for i in range(10):
        cam = weight_fc[class_idx].dot(feature_conv[0, :, :, ].reshape((nc, h*w)))
        cam = cam[i, :].reshape(h, w)
        cam = cam - np.min(cam)
        cam_img = cam / np.max(cam)
        # cam_img = cam_img.astype(np.uint16)
        
        ax = fig.add_subplot(1, 10, i+1, xticks=[], yticks=[])
        plt.imshow(cv2.cvtColor(cur_images[i], cv2.COLOR_BGR2RGB))
        plt.imshow(cv2.resize(cam_img, (224, 224), interpolation=cv2.INTER_LINEAR), alpha=0.4, cmap='jet')
        ax.set_title('Label:%d, Predict:%d' % (target, pred_idx), fontsize=14)
        # return cam_img



def plotGradCAM(model, final_conv, fc_params, train_loader, 
                row=1, col=10, img_size=224, device='cpu', original=False):
    for param in model.parameters():
        param.requires_grad = False
    
    model.load_state_dict(torch.load('/content/drive/MyDrive/data/save_data/best_model.pth'), strict = False)
    model.to(device)

    model_pred = []
    true_labels = []
    correct = 0
    model.eval()
    # save activated_features from conv
    activated_features = SaveFeatures(final_conv)
    # save weight from fc
    weight = np.squeeze(fc_params[0].cpu().data.numpy())
    # original images
    # heatmap images

    # fig = plt.figure(figsize=(10, 10))
    for img, target in tqdm(iter(test_loader)):
        output = model(img.to(device))
        pred_idx = output.to('cpu').numpy().argmax(1)
        cur_images = img.cpu().numpy().transpose((0, 2, 3, 1))

        # getCAM(activated_features.features, weight, pred_idx, cur_images)

        _, nc, h, w = activated_features.features.shape # chanel, height, width 언패킹
        fig = plt.figure(figsize=(30, 30))

        for i in range(5):
            cam = weight[pred_idx].dot(activated_features.features[0, :, :, ].reshape((nc, h*w)))
            cam = cam[i, :].reshape(h, w)
            cam = cam - np.min(cam)
            cam_img = cam / np.max(cam)

            ax = fig.add_subplot(10, 10, i+1, xticks=[], yticks=[])
            plt.imshow(cv2.cvtColor(cur_images[i], cv2.COLOR_BGR2RGB))
            plt.imshow(cv2.resize(cam_img, (224, 224), interpolation=cv2.INTER_LINEAR), alpha=0.4, cmap='jet')
            # ax.set_title('Label:%s, Predict:%s'% (df['name'].unique()[target[i]], df['name'].unique()[pred_idx[i]]), fontsize=14)
            ax.set_title('Label:%d, Predict:%d'% (target[i], pred_idx[i]), fontsize=14)
        # heatmap = getCAM(activated_features.features, weight, pred_idx)
        # ax = fig.add_subplot(row, col, i+1, xticks=[], yticks=[])

        # plt.imshow(cv2.cvtColor(cur_images[0], cv2.COLOR_BGR2RGB))
        # plt.imshow(cv2.resize(heatmap, (img_size, img_size), interpolation=cv2.INTER_LINEAR), alpha=0.4, cmap='jet')
        # ax.set_title('Label:%d, Predict:%d' % (target, pred_idx), fontsize=14)
            if i == row*col-1:
                break

        model_pred += output.argmax(1).detach().cpu().numpy().tolist()
        true_labels += target.detach().cpu().numpy().tolist()

        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        

    test_accuracy = 100 * correct / len(test_loader.dataset)
    precision = precision_score(true_labels, model_pred, average = 'micro')
    f1 = f1_score(true_labels, model_pred, average = "weighted")

    print(f"\nprecsion ------> {precision:.5f}")
    print(f"f1_score ------> {f1:.5f}")
    print(f"{correct} / {len(test_loader.dataset)}, Accuracy: {test_accuracy:.3f}%\n")


    plt.show()

In [None]:
import warnings
warnings.simplefilter(action='ignore')

In [None]:
# class SaveFeatures():
#     features = None
#     def __init__(self, m):
#         self.hook = m.register_forward_hook(self.hook_fn)

#     def hook_fn(self, module, input, output):
#         self.features = ((output.cpu()).data).numpy()

#     def remove(self):
#         self.hook.remove()

# def getCAM(feature_conv, weight_fc, class_idx):
#     _, nc, h, w = feature_conv.shape

#     cam = weight_fc[class_idx].dot(feature_conv[0, :, :, ].reshape((nc, h*w)))
    
#     cam = cam[0, :].reshape(h, w)

#     print(cam)
#     cam = cam - np.min(cam)
#     cam_img = cam / np.max(cam)
#     return cam_img

# def plotGradCAM(model, final_conv, fc_params, test_loader, row = 1, col = 8,
#                 img_size = 224, device = device, original = False):
#     model.to(device)
#     model.eval()

#     model_pred = []
#     true_labels = []

#     activated_features = SaveFeatures(final_conv)
#     weight = np.squeeze(fc_params[0].cpu().data.numpy())


#     for i, (img, label) in enumerate(test_loader):
#         img, label = img.to(device), label.to(device)

#         test_pred = model(img)
#         pred_idx = test_pred.detach().cpu().numpy().argmax(1)
#         cur_images = img.detach().cpu().numpy()
#         cur_images = img.detach().cpu().numpy().transpose((0, 2, 3, 1))
#         heatmap = getCAM(activated_features.features, weight, pred_idx)
#         print(heatmap)
            

In [None]:
plotGradCAM(mobilenet, final_conv, fc_params, test_loader)