### Script

In [None]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import rasterio
import numpy as np
import matplotlib.pyplot as plt
from torchmetrics import Accuracy
from utils import SaveFeatures, imshow_transform
from custom_model import model_4D
from torch.autograd import Variable
from skimage.transform import resize
from skimage.io import imshow
import wandb
import torch.optim.lr_scheduler as lr_scheduler

%matplotlib inline

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

run = wandb.init(project= 'VGG_CAMs', id= 'v6ax9crk', resume = 'must')
artifact = run.use_artifact('nadjaflechner/VGG_CAMs/model:v9', type='model')
artifact_dir = artifact.download()
state_dict = torch.load(f"{artifact_dir}/model.pth")
model = model_4D()
model.load_state_dict(state_dict)
model.to(device)

In [None]:
####################
# generating CAMs #

#change batch size to 1 to grab one image at a time
valid_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, 
                                          shuffle=True,
                                          num_workers=1)

# Load best model to produce CAMs with
model.load_state_dict(best_model)
model.eval()

# Save 10 palsa images
palsa_imgs = 0
for palsa_cam in range(100):
    im, lab = next(iter(valid_loader))

    #get the last convolution
    sf = SaveFeatures(model.features[-4])
    model.eval()

    if lab == 1:
        palsa_imgs+= 1
        im = Variable(im).to(device)
        outputs = model(im).to(device)
        res = torch.argmax(outputs.data).cpu().detach().numpy()

        # generate CAM
        sf.remove()
        arr = sf.features.cpu().detach().numpy()
        arr1 = arr[0]
        ans_nopalsa = np.dot(np.rollaxis(arr1,0,3), [1,0])
        ans_palsa = np.dot(np.rollaxis(arr1,0,3), [0,1])

        if res==1:
            CAM = resize(ans_palsa, (im_size*2,im_size*2))
        else:
            CAM = resize(ans_nopalsa, (im_size*2,im_size*2))

        # Plot image with CAM
        cpu_img = im.squeeze().cpu().detach().permute(1,2,0).long().numpy()

        fig, (ax1, ax2) = plt.subplots(1,2, figsize = (10,7))

        ax1.imshow(cpu_img)
        ax1.set_xticks([])
        ax1.set_yticks([])
        ax1.set_title(f'original image')

        ax2.imshow(cpu_img)
        ax2.imshow(CAM, alpha=.4, cmap='jet')
        ax2.set_xticks([])
        ax2.set_yticks([])
        ax2.set_title('image with CAM')

        plt.tight_layout()
        plt.show()

        wandb.log({'generated_CAM': fig})

    if palsa_imgs == 20:
        break
