In [1]:
import cv2
import copy
from ViT_CX.ViT_CX import ViT_CX
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from torchvision import transforms
import torchvision.transforms as tt
from torchvision.utils import make_grid
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import gradio as gr

In [2]:
# Device selection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device

    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl:
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

Using device: cuda


In [3]:
# Load black box model for explanations
model = torch.load('./data/pwc_ViT_timm_pure_10_folds')
labels = ['ASD', 'NC']

In [92]:
def predict_image(img):
    stats = (0.5, 0.5)
    img_tfms = tt.Compose([tt.ToTensor(), tt.Resize((224, 224), antialias='True'), tt.Normalize(*stats)])
    img_t = img_tfms(img)
    xb = to_device(img_t.unsqueeze(0), device) 
    yb = model(xb)
    preds = torch.argmax(yb.data, dim=1)
    target_layer=model.blocks[-1].norm1
    cx_result = ViT_CX(model,xb,target_layer,target_category=None,distance_threshold=0.1,gpu_batch=50)
    img_path = './data/gradio/cx_result.png'
    plt.imsave(img_path, cx_result, cmap='jet')

    return labels[preds], img_path

In [93]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

In [104]:
demo = gr.Interface(fn=predict_image,
                    inputs=gr.Image(type='pil', shape=(224,224), label="PWC Image Input"),
                    outputs=[
                        gr.Textbox(label="Predicted:"),
                        gr.Image(type='filepath',label="Saliency Map")],
                    title="ViT-Base16 PWC ASD Classifier"
                   )

In [105]:
demo.launch(inline=False)

Running on local URL:  http://127.0.0.1:7886

To create a public link, set `share=True` in `launch()`.


