In [1]:
import cv2
import copy
from ViT_CX.ViT_CX import ViT_CX
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from torchvision import transforms
import torchvision.transforms as tt
from torchvision.utils import make_grid
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import gradio as gr

In [2]:
# Device selection
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device

    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl:
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

Using device: cuda


In [3]:
# Load black box model for explanations
model = torch.load('./data/pwc_ViT_timm_pure_5_folds')
labels = ['ASD', 'NC']

In [4]:
def predict_image(img):
    img_t = tt.ToTensor()(img).unsqueeze(0)
    # Convert to a batch of 1
    xb = to_device(img_t, device) 
    # Get predictions from model
    yb = model(xb)
    # Pick index with highest probability
    _, preds  = torch.max(yb, dim=1)
    # Define ViT-CX target layer
    target_layer=model.blocks[-1].norm1
    # Get ViT-CX map
    cx_result = ViT_CX(model,img_t,target_layer,target_category=None,distance_threshold=0.1,gpu_batch=50)
    # Retrieve the class label and ViT-CX map
    return labels[preds], cx_result

In [5]:
demo = gr.Interface(fn=predict_image, inputs=gr.Image(shape=(224,224)), outputs=["text",gr.Image(type="pil", image_mode="L")])

In [6]:
# figure out how to debug this thing, test with filepath first and make sure prediction is correct

In [8]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

In [9]:
demo.launch()

Rerunning server... use `close()` to stop if you need to change `launch()` parameters.
----
Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


