In [2]:
from pytorch_grad_cam import FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

import matplotlib.pyplot as plt
from PIL import Image
import torch.nn as nn
import numpy as np
import os, json
from glob import glob
from tqdm.notebook import tqdm
from joblib import Parallel, delayed

import torch
from torchvision import models, transforms
from torch.autograd import Variable
import torch.nn.functional as F

import timm

In [3]:
def get_input_transform():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])       
    transf = transforms.Compose([
        transforms.Resize((600, 600)),
        transforms.CenterCrop(600),
        transforms.ToTensor(),
        normalize
    ])    

    return transf

def get_input_tensors(img):
    transf = get_input_transform()
    return transf(img).unsqueeze(0)

In [4]:
ckpt_fn = "models/resolution/20220722-163041-sequencer2d_l-600/last.pth.tar"
im_path = "data/orig_o"
out_path = "resolution/600_s2d"
model_type = "sequencer2d_l"

fl_dict = {}
for idx, fl in enumerate(glob(f"{im_path}/*")):
    bn = os.path.basename(fl)
    fl_dict[bn] = idx
    
images = []
for im_fn in glob(f"{im_path}/**/*.*", recursive=True):
    bn = os.path.basename(im_fn)
    images.append(im_fn)
    
print(len(images))

4577


In [5]:
model = timm.create_model(
        model_type,
        num_classes=300,
        in_chans=3,
        checkpoint_path=ckpt_fn)

In [7]:
for im_fn in tqdm(images, total=len(list(images))):
    curr_dir = dir()
    cam_list = [["fullgrad", FullGrad(model=model, target_layers=[])]]
    
    bn = os.path.basename(im_fn)
    cat = fl_dict[im_fn.split("/")[-2]]
    
    img = Image.open(im_fn)
    img_t = get_input_tensors(img)
    test_img = np.array(img.resize((600, 600)), dtype='float32')
    test_img /= 255

    targets = [ClassifierOutputTarget(cat)]
    
    for cam_str, cam_method in cam_list:
        grayscale_cam = cam_method(input_tensor=img_t, targets=targets)
        grayscale_cam = grayscale_cam[0, :]

        visualization = show_cam_on_image(test_img, grayscale_cam, use_rgb=True)
        plt.imshow(visualization)
        if not os.path.exists(f"outs/{out_path}/{cam_str}/{bn}"):
            try:

                plt.savefig(f"outs/{out_path}/{cam_str}/{bn}")
            except FileNotFoundError:
                os.makedirs(f"outs/{out_path}/{cam_str}/")
            plt.close()
        
        del grayscale_cam
        
    for name in dir():
        if name not in curr_dir and name != "curr_dir":
            del globals()[name]

  0%|          | 0/4577 [00:00<?, ?it/s]


KeyboardInterrupt

