<a href="https://colab.research.google.com/github/lzichi/Thin-Materials-ML/blob/main/RUN_2d_grad_cam.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# CAM codes

In [None]:
import torch
import argparse
import cv2
import numpy as np
import torch
from torch.autograd import Function
from torchvision import models, transforms

In [None]:
class FeatureExtractor():
    """ Class for extracting activations and
    registering gradients from targetted intermediate layers """

    def __init__(self, model, target_layers):
        self.model = model
        self.target_layers = target_layers
        self.gradients = []

    def save_gradient(self, grad):
        self.gradients.append(grad)

    def __call__(self, x):
        outputs = []
        self.gradients = []
        for name, module in self.model._modules.items():
            x = module(x)
            if name in self.target_layers:
                x.register_hook(self.save_gradient)
                outputs += [x]
        return outputs, x

class ModelOutputs():
    """ Class for making a forward pass, and getting:
    1. The network output.
    2. Activations from intermeddiate targetted layers.
    3. Gradients from intermeddiate targetted layers. """

    def __init__(self, model, feature_module, target_layers):
        self.model = model
        self.feature_module = feature_module
        self.feature_extractor = FeatureExtractor(self.feature_module, target_layers)

    def get_gradients(self):
        return self.feature_extractor.gradients

    def __call__(self, x):
        target_activations = []
        for name, module in self.model._modules.items():
            if module == self.feature_module:
                target_activations, x = self.feature_extractor(x)
            elif "avgpool" in name.lower():
                x = module(x)
                x = x.view(x.size(0),-1)
            else:
                x = module(x)

        return target_activations, x

def preprocess_image(img):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
    preprocessing = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])
    return preprocessing(img.copy()).unsqueeze(0)

def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return np.uint8(255 * cam)

class GradCam:
    def __init__(self, model, feature_module, target_layer_names, use_cuda):
        self.model = model
        self.feature_module = feature_module
        self.model.eval()
        self.cuda = use_cuda
        if self.cuda:
            self.model = model.cuda()

        self.extractor = ModelOutputs(self.model, self.feature_module, target_layer_names)

    def forward(self, input_img):
        return self.model(input_img)

    def __call__(self, input_img, target_category=None):
        if self.cuda:
            input_img = input_img.cuda()

        features, output = self.extractor(input_img)

        if target_category == None:
            target_category = np.argmax(output.cpu().data.numpy())

        one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
        one_hot[0][target_category] = 1
        one_hot = torch.from_numpy(one_hot).requires_grad_(True)
        if self.cuda:
            one_hot = one_hot.cuda()
        
        one_hot = torch.sum(one_hot * output)

        self.feature_module.zero_grad()
        self.model.zero_grad()
        one_hot.backward(retain_graph=True)

        grads_val = self.extractor.get_gradients()[-1].cpu().data.numpy()

        target = features[-1]
        target = target.cpu().data.numpy()[0, :]

        weights = np.mean(grads_val, axis=(2, 3))[0, :]
        cam = np.zeros(target.shape[1:], dtype=np.float32)

        for i, w in enumerate(weights):
            cam += w * target[i, :, :]

        cam = np.maximum(cam, 0)
        cam = cv2.resize(cam, input_img.shape[2:])
        cam = cam - np.min(cam)
        cam = cam / np.max(cam)
        return cam


class GuidedBackpropReLU(Function):
    @staticmethod
    def forward(self, input_img):
        positive_mask = (input_img > 0).type_as(input_img)
        output = torch.addcmul(torch.zeros(input_img.size()).type_as(input_img), input_img, positive_mask)
        self.save_for_backward(input_img, output)
        return output

    @staticmethod
    def backward(self, grad_output):
        input_img, output = self.saved_tensors
        grad_input = None

        positive_mask_1 = (input_img > 0).type_as(grad_output)
        positive_mask_2 = (grad_output > 0).type_as(grad_output)
        grad_input = torch.addcmul(torch.zeros(input_img.size()).type_as(input_img),
                                   torch.addcmul(torch.zeros(input_img.size()).type_as(input_img), grad_output,
                                                 positive_mask_1), positive_mask_2)
        return grad_input


class GuidedBackpropReLUModel:
    def __init__(self, model, use_cuda):
        self.model = model
        self.model.eval()
        self.cuda = use_cuda
        if self.cuda:
            self.model = model.cuda()

        def recursive_relu_apply(module_top):
            for idx, module in module_top._modules.items():
                recursive_relu_apply(module)
                if module.__class__.__name__ == 'ReLU':
                    module_top._modules[idx] = GuidedBackpropReLU.apply

        # replace ReLU with GuidedBackpropReLU
        recursive_relu_apply(self.model)

    def forward(self, input_img):
        return self.model(input_img)

    def __call__(self, input_img, target_category=None):
        if self.cuda:
            input_img = input_img.cuda()

        input_img = input_img.requires_grad_(True)

        output = self.forward(input_img)

        if target_category == None:
            target_category = np.argmax(output.cpu().data.numpy())

        one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
        one_hot[0][target_category] = 1
        one_hot = torch.from_numpy(one_hot).requires_grad_(True)
        if self.cuda:
            one_hot = one_hot.cuda()

        one_hot = torch.sum(one_hot * output)
        one_hot.backward(retain_graph=True)

        output = input_img.grad.cpu().data.numpy()
        output = output[0, :, :, :]

        return output

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--use-cuda', action='store_true', default=False,
                        help='Use NVIDIA GPU acceleration')
    parser.add_argument('--image-path', type=str, default='./examples/both.png',
                        help='Input image path')
    args = parser.parse_args()
    args.use_cuda = args.use_cuda and torch.cuda.is_available()
    if args.use_cuda:
        print("Using GPU for acceleration")
    else:
        print("Using CPU for computation")

    return args

# def deprocess_image(img):
#     """ see https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65 """
#     img = img - np.mean(img)
#     img = img / (np.std(img) + 1e-5)
#     img = img * 0.1
#     img = img + 0.5
#     img = np.clip(img, 0, 1)
#     return np.uint8(img*255)

def deprocess_image(img):
    return np.uint8(img*255)


# 2d Codes

In [None]:
import os, argparse, time, random
from functools import partial
from shutil import copyfile

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, confusion_matrix

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import torch.optim as optim
import torchvision.models as models
from PIL import Image

from tqdm.notebook import tqdm

In [None]:
import os

def makedirs(*dirnames):
    for dirname in dirnames:
        if not os.path.exists(dirname):
            os.makedirs(dirname)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

random.seed(41)
np.random.seed(41)
torch.manual_seed(41)
torch.cuda.manual_seed_all(41)


In [None]:
class FlakeDataset(Dataset):
    def __init__(self, df, raw_only, transform=None, material=None):
        paths, labels, materials, dat_type = [], [], [], []
        for idx, path in enumerate(df['paths']):
            file = path.split('/')[-1].split('.')[0].split('-')[-1]
            if '_crop' in file:
                file = file.split('_crop')[0]
            
            if material:
                _is_target = [file in mat for mat in material]
                if max(_is_target) is False:
                    continue
                # for mat in material:
                #     if file in mat: 
                #         break
                    # continue
            # else:
            if 'aug_' in path:
                dat_type_i = 'augment'
                if not raw_only:
                    dat_type.append(dat_type_i)
                    paths.append(path)
                    labels.append(df.labels[idx])
                    materials.append(file)
            else:
                dat_type_i = 'raw'
                dat_type.append(dat_type_i)
                paths.append(path)
                labels.append(df.labels[idx])
                materials.append(file)
            
            # print(valid)
            # raise NotADirectoryError()
        self.path = paths
        self.labels = torch.tensor(labels).float()
        self.materials = np.array(materials)
        self.dat_type = np.array(dat_type)
        
        if not transform:
            transform = transforms.Compose([transforms.ToTensor()])
        self.transform = transform

    def __len__(self):
        return len(self.path)
    
    def __getitem__(self, i):
        if torch.is_tensor(i):
            i = i.tolist()
            
        img = Image.open(self.path[i]).convert("RGB")
        img = self.transform(img)
        label = self.labels[[i]]
        material = self.materials[i]
        dat_type = self.dat_type[i]
        return img.float(), label, material, dat_type
    

In [None]:
result_path = '/content/drive/Shared drives/2d/results/to_compare'

makedirs(result_path)

In [None]:
def extract_filename(path):
    path = path.split('/')[-1]
    path = path.split('.')[0]
    return path


def train_test_split(data, train_portion=0.75, seed=1):
    np.random.seed(seed)
    assert 'names' in data.columns, f'`names` column is not found in df.'

    full_data = data.copy()
    full_data['original_name'] = full_data.names.apply(lambda x: '_'.join(x.split('_')[int('aug' in x or 'raw' in x) + int('aug' in x):]))
    original_imgs = np.unique(full_data['original_name'])
    assert original_imgs.shape[0] == 332
    train_img = set(np.random.choice(original_imgs, 
                                 int(train_portion * len(original_imgs)),
                                 False))
    train_idx = full_data.original_name.apply(lambda x: x in train_img)
    train_data = full_data[train_idx].drop('original_name', 1).reset_index()
    test_data = full_data[~train_idx].drop('original_name', 1).reset_index()
    
    return train_data, test_data


def load_train_test(data):

    train, test = train_test_split(data)

    transform = transforms.Compose([
        transforms.Resize(size=(224, 224)),
        transforms.ToTensor()])

    materials = ['MoSe2_on_Si', 'MoSe2_on_si_PDMS', 'MoSe2_on_PDMS']

    trainset = FlakeDataset(train, raw_only=False,
                        transform=transform, 
                        material=materials)

    testset = FlakeDataset(test, raw_only=False,
                        transform=transform, 
                        material=materials)

    bsz = 4
    train_loader = DataLoader(trainset, batch_size=bsz, shuffle=True, pin_memory=True)
    test_loader = DataLoader(testset, batch_size=bsz, shuffle=True, pin_memory=True)

    return train_loader, test_loader


def load_model(path):
    net = models.resnet18(pretrained=False)
    fc_features = net.fc.in_features
    net.fc = nn.Linear(fc_features, 1)
    net = net.to(device)
    net.load_state_dict(
        torch.load(path, map_location=device))
    
    return net

# Only correct predictions

In [None]:
def is_correct_prediction(net, input, label, material, dat_type):
    net.eval()
    with torch.no_grad():
        input = input.to(device).unsqueeze(0)
        label = label.to(device)
        output = net(input)
        pred = (output > 0).float()

    net.train()
    return pred == label

In [None]:
for file in os.listdir(result_path):

    # saved results
    
    # if file in ['quantized10']:
    if file in ['resnet18.torch', 'quantized20', 'quantized5']:
        continue

    else:

        if file == 'resnet18.torch':
            
            # original
            
            k = str(-1)
            data = pd.read_pickle(os.path.join('/content/drive/Shared drives/2d/data', 
                                               'pad_augment_data_final.pkl'))
            data['paths'] = data['paths'].apply(
                lambda x: '/content/drive/Shared drives/2d/data/pad_augment_data_final/' + 
                x.split('/')[-1])
            
            net_path = os.path.join(result_path, file)
            cam_fold = os.path.join(result_path, 'cam_correct', 'original')
        else:
            
            # quantized
            
            k = file.replace('quantized', '')
            data = pd.read_pickle(os.path.join('/content/drive/Shared drives/2d/data', 
                                               f'quantized{k}_pad_augment_data.pkl'))
            data['paths'] = data['paths'].apply(
                lambda x: f'/content/drive/Shared drives/2d/data/quantized{k}_pad_augment_data/' + 
                x.split('/')[-1])
            
            net_path = os.path.join(result_path, file, 'resnet18.torch')
            cam_fold = os.path.join(result_path, 'cam_correct', file)

        net = load_model(net_path)
        train, test = load_train_test(data)
        trainset = train.dataset
        testset = test.dataset

        cam_train = os.path.join(cam_fold, 'train')
        cam_test = os.path.join(cam_fold, 'test')
        makedirs(cam_fold, cam_train, cam_test)

        target_category = 0
        for i in range(len(trainset)):
            is_correct = is_correct_prediction(net, *trainset[i])
            if is_correct:
                file_name = extract_filename(trainset.path[i])
                img = trainset[i][0]
                
                input_img = img.unsqueeze(0)
                img = img.permute((1, 2, 0))
                net.zero_grad()
                grad_cam = GradCam(model=net, feature_module=net.layer4, \
                                target_layer_names=["1"], 
                                use_cuda=(device==torch.device('cuda')))
            
                grayscale_cam = grad_cam(input_img, target_category)
                grayscale_cam = cv2.resize(grayscale_cam, (img.shape[1], img.shape[0]))
                cam = show_cam_on_image(img, grayscale_cam)

                cv2.imwrite(os.path.join(cam_train, 'cam_' + file_name + '.jpg'), cam)
            
        print(f'finish trainset: {cam_fold}')

        for i in range(len(testset)):
            is_correct = is_correct_prediction(net, *testset[i])
            if is_correct:
                file_name = extract_filename(testset.path[i])
                img = testset[i][0]
                
                input_img = img.unsqueeze(0)
                img = img.permute((1, 2, 0))
                net.zero_grad()
                grad_cam = GradCam(model=net, feature_module=net.layer4, \
                                target_layer_names=["1"], 
                                use_cuda=(device==torch.device('cuda')))
            
                grayscale_cam = grad_cam(input_img, target_category)
                grayscale_cam = cv2.resize(grayscale_cam, (img.shape[1], img.shape[0]))
                cam = show_cam_on_image(img, grayscale_cam)

                cv2.imwrite(os.path.join(cam_test, 'cam_' + file_name + '.jpg'), cam)
        print(f'finish testset: {cam_fold}')




finish trainset: /content/drive/Shared drives/2d/results/to_compare/cam_correct/original
finish testset: /content/drive/Shared drives/2d/results/to_compare/cam_correct/original




finish trainset: /content/drive/Shared drives/2d/results/to_compare/cam_correct/quantized5
finish testset: /content/drive/Shared drives/2d/results/to_compare/cam_correct/quantized5




finish trainset: /content/drive/Shared drives/2d/results/to_compare/cam_correct/quantized10
finish testset: /content/drive/Shared drives/2d/results/to_compare/cam_correct/quantized10




# New result: 07/13

In [None]:
for file in os.listdir(result_path):
    # saved results
    # if file in ['resnet18.torch', 'quantized5', 'quantized10']:
    if file != 'quantized20':
        continue

    else:
        if file == 'resnet18.torch':
            # original
            k = str(-1)
            data = pd.read_pickle(os.path.join('/content/drive/Shared drives/2d/data', 
                                               'pad_augment_data_final.pkl'))
            data['paths'] = data['paths'].apply(
                lambda x: '/content/drive/Shared drives/2d/data/pad_augment_data_final/' + 
                x.split('/')[-1])
            
            net_path = os.path.join(result_path, file)
            cam_fold = os.path.join(result_path, 'cam', 'original')
        else:
            # quantized
            k = file.replace('quantized', '')
            data = pd.read_pickle(os.path.join('/content/drive/Shared drives/2d/data', 
                                               f'quantized{k}_pad_augment_data.pkl'))
            data['paths'] = data['paths'].apply(
                lambda x: f'/content/drive/Shared drives/2d/data/quantized{k}_pad_augment_data/' + 
                x.split('/')[-1])
            
            net_path = os.path.join(result_path, file, 'resnet18.torch')
            cam_fold = os.path.join(result_path, 'cam', file)

        net = load_model(net_path)
        train, test = load_train_test(data)
        trainset = train.dataset
        testset = test.dataset

        cam_train = os.path.join(cam_fold, 'train')
        cam_test = os.path.join(cam_fold, 'test')
        makedirs(cam_fold, cam_train, cam_test)

        target_category = 0
        for i in range(len(trainset)):
            file_name = extract_filename(trainset.path[i])
            img = trainset[i][0]
            
            input_img = img.unsqueeze(0)
            img = img.permute((1, 2, 0))

            net.zero_grad()
            # net = load_model(net_path)

            grad_cam = GradCam(model=net, feature_module=net.layer4, \
                            target_layer_names=["1"], 
                            use_cuda=(device==torch.device('cuda')))
        
            grayscale_cam = grad_cam(input_img, target_category)
            grayscale_cam = cv2.resize(grayscale_cam, (img.shape[1], img.shape[0]))
            cam = show_cam_on_image(img, grayscale_cam)

            cv2.imwrite(os.path.join(cam_train, 'cam_' + file_name + '.jpg'), cam)
        
        print(f'finish trainset: {cam_fold}')

        for i in range(len(testset)):
            file_name = extract_filename(testset.path[i])
            img = testset[i][0]
            
            input_img = img.unsqueeze(0)
            img = img.permute((1, 2, 0))
            net.zero_grad()
            grad_cam = GradCam(model=net, feature_module=net.layer4, \
                            target_layer_names=["1"], 
                            use_cuda=(device==torch.device('cuda')))
        
            grayscale_cam = grad_cam(input_img, target_category)
            grayscale_cam = cv2.resize(grayscale_cam, (img.shape[1], img.shape[0]))
            cam = show_cam_on_image(img, grayscale_cam)

            cv2.imwrite(os.path.join(cam_test, 'cam_' + file_name + '.jpg'), cam)
        print(f'finish testset: {cam_fold}')


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


finish trainset: /content/drive/Shared drives/2d/results/to_compare/cam/quantized20
finish testset: /content/drive/Shared drives/2d/results/to_compare/cam/quantized20


# Old results

In [None]:
def torch_dict_path(exp_path, train_material):
    if os.path.exists(os.path.join(exp_path,
                                    '&'.join(train_material) + 
                                    '.torch')):
        path = os.path.join(exp_path,
                                    '&'.join(train_material) + 
                                    '.torch')
    else:
        path = os.path.join(exp_path,
                                    '&'.join(train_material[::-1]) + 
                                    '.torch')
    return path

In [None]:
for test_mat in materials:
    train_material = [mat for mat in materials if mat != test_mat]
    trainset = FlakeDataset(data, raw_only=True,
                        transform=transform, 
                        material=train_material)
    test_material = [test_mat]
    testset = FlakeDataset(data, raw_only=True,
                        transform=transform, 
                        material=test_material)

    cam_fold = os.path.join(exp_path, '&'.join(train_material))
    cam_train = os.path.join(cam_fold, 'train')
    cam_test = os.path.join(cam_fold, 'test')
    makedirs(cam_train, cam_test)

    target_category = 0
    for i in range(len(trainset)):
        file_name = extract_filename(trainset.path[i])
        img = trainset[i][0]
        
        input_img = img.unsqueeze(0)
        img = img.permute((1, 2, 0))

        net = models.resnet18(pretrained=False)
        fc_features = net.fc.in_features
        net.fc = nn.Linear(fc_features, 1)
        net = net.to(device)
        net.load_state_dict(
            torch.load(torch_dict_path(exp_path, train_material), map_location=device))

        grad_cam = GradCam(model=net, feature_module=net.layer4, \
                        target_layer_names=["1"], 
                        use_cuda=(device==torch.device('cuda')))
    
        grayscale_cam = grad_cam(input_img, target_category)
        grayscale_cam = cv2.resize(grayscale_cam, (img.shape[1], img.shape[0]))
        cam = show_cam_on_image(img, grayscale_cam)

        cv2.imwrite(os.path.join(cam_train, 'cam_' + file_name + '.jpg'), cam)


    for i in range(len(testset)):
        file_name = extract_filename(testset.path[i])
        img = testset[i][0]
        
        input_img = img.unsqueeze(0)
        img = img.permute((1, 2, 0))

        net = models.resnet18(pretrained=False)
        fc_features = net.fc.in_features
        net.fc = nn.Linear(fc_features, 1)
        net = net.to(device)
        net.load_state_dict(
            torch.load(torch_dict_path(exp_path, train_material), map_location=device))

        grad_cam = GradCam(model=net, feature_module=net.layer4, \
                        target_layer_names=["1"], 
                        use_cuda=(device==torch.device('cuda')))
    
        grayscale_cam = grad_cam(input_img, target_category)
        grayscale_cam = cv2.resize(grayscale_cam, (img.shape[1], img.shape[0]))
        cam = show_cam_on_image(img, grayscale_cam)

        cv2.imwrite(os.path.join(cam_test, 'cam_' + file_name + '.jpg'), cam)




# With quantization

In [None]:
exp_path = os.path.join(root_path, 'quant')

In [None]:
bsz = 4
data = pd.read_pickle(os.path.join('/content/drive/Shared drives/2d/data', 'quantized_pad_augment_data.pkl'))
data['paths'] = data['paths'].apply(
    lambda x: '/content/drive/Shared drives/2d/data/quantized_pad_augment_data/' + x.split('/')[-1])

In [None]:
transform = transforms.Compose([
    transforms.Resize(size=(224, 224)),
    transforms.ToTensor()])

materials = ['MoSe2_on_Si', 'MoSe2_on_si_PDMS', 'MoSe2_on_PDMS']

In [None]:
for test_mat in materials:
    train_material = [mat for mat in materials if mat != test_mat]
    trainset = FlakeDataset(data, raw_only=True,
                        transform=transform, 
                        material=train_material)
    test_material = [test_mat]
    testset = FlakeDataset(data, raw_only=True,
                        transform=transform, 
                        material=test_material)

    cam_fold = os.path.join(exp_path, '&'.join(train_material))
    cam_train = os.path.join(cam_fold, 'train')
    cam_test = os.path.join(cam_fold, 'test')
    makedirs(cam_train, cam_test)

    target_category = 0
    for i in range(len(trainset)):
        file_name = extract_filename(trainset.path[i])
        img = trainset[i][0]
        
        input_img = img.unsqueeze(0)
        img = img.permute((1, 2, 0))

        net = models.resnet18(pretrained=False)
        fc_features = net.fc.in_features
        net.fc = nn.Linear(fc_features, 1)
        net = net.to(device)
        net.load_state_dict(
            torch.load(torch_dict_path(exp_path, train_material), map_location=device))

        grad_cam = GradCam(model=net, feature_module=net.layer4, \
                        target_layer_names=["1"], 
                        use_cuda=(device==torch.device('cuda')))
    
        grayscale_cam = grad_cam(input_img, target_category)
        grayscale_cam = cv2.resize(grayscale_cam, (img.shape[1], img.shape[0]))
        cam = show_cam_on_image(img, grayscale_cam)

        cv2.imwrite(os.path.join(cam_train, 'cam_' + file_name + '.jpg'), cam)


    for i in range(len(testset)):
        file_name = extract_filename(testset.path[i])
        img = testset[i][0]
        
        input_img = img.unsqueeze(0)
        img = img.permute((1, 2, 0))

        net = models.resnet18(pretrained=False)
        fc_features = net.fc.in_features
        net.fc = nn.Linear(fc_features, 1)
        net = net.to(device)
        net.load_state_dict(
            torch.load(torch_dict_path(exp_path, train_material), map_location=device))

        grad_cam = GradCam(model=net, feature_module=net.layer4, \
                        target_layer_names=["1"], 
                        use_cuda=(device==torch.device('cuda')))
    
        grayscale_cam = grad_cam(input_img, target_category)
        grayscale_cam = cv2.resize(grayscale_cam, (img.shape[1], img.shape[0]))
        cam = show_cam_on_image(img, grayscale_cam)

        cv2.imwrite(os.path.join(cam_test, 'cam_' + file_name + '.jpg'), cam)


