# Using the torchcam algorithm library

## Importing the Toolkit

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from PIL import Image
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

In [None]:
!rm -rf torch-cam # if you have.

In [None]:
# Install torch-cam
!git clone https://github.com/frgfm/torch-cam.git 
!pip install -e torch-cam/.

In [None]:
#Restart the kernel

In [None]:
import torchcam

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
from PIL import Image
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

In [None]:
model = torch.load('checkpoint/(18-1-91).pth')#Import the trained classification model.
model = model.eval().to(device)

## Import CAM methods

In [None]:
from torchcam.methods import SSCAM
cam_extractor = SSCAM(model)

In [None]:
from torchvision import transforms  #Image preprocessing: scaling, cropping, conversion to Tensor, normalisation
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

## Generate heatmaps

In [None]:
import os
from torchcam.utils import overlay_mask

In [None]:
!find . -iname '.ipynb_checkpoints'#Delete redundant documents

In [None]:
!for i in `find . -iname '.ipynb_checkpoints'`; do rm -rf $i;done

In [None]:
!find . -iname '.ipynb_checkpoints'#Look again for redundant files.

In [None]:
img_path = 'test'#The location where the test images are stored.
for files in os.listdir(img_path):
    imgpath = img_path + '/' + files
    img_pil = Image.open(imgpath)
    input_tensor = test_transform(img_pil).unsqueeze(0).to(device)
    pred_logits = model(input_tensor)
    pred_id = torch.topk(pred_logits, 1)[1].detach().cpu().numpy().squeeze().item()
    activation_map = cam_extractor(pred_id, pred_logits)
    activation_map = activation_map[0][0].detach().cpu().numpy()
    result = overlay_mask(img_pil, Image.fromarray(activation_map), alpha=0.001)
    path2 = '228sscam' + '/' + files #The location where the post-prediction results are stored.
    result.save(path2)