In [1]:
#https://github.com/openai/CLIP

import threading
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image
from matplotlib.colors import LinearSegmentedColormap


In [2]:
# Originally made by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings)
# The original BigGAN+CLIP method was by https://twitter.com/advadnoun

import argparse
import math
import random
# from email.policy import default
from urllib.request import urlopen
from tqdm import tqdm
import sys
import os

sys.path.append('taming-transformers')

from omegaconf import OmegaConf
from taming.models import cond_transformer, vqgan
#import taming.modules 

import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from torch.cuda import get_device_properties
torch.backends.cudnn.benchmark = False		# NR: True is a bit faster, but can lead to OOM. False is more deterministic.
#torch.use_deterministic_algorithms(True)	# NR: grid_sampler_2d_backward_cuda does not have a deterministic implementation

from torch_optimizer import DiffGrad, AdamP, RAdam

from CLIP import clip
import kornia.augmentation as K
import numpy as np
import imageio

from PIL import ImageFile, Image, PngImagePlugin, ImageChops
ImageFile.LOAD_TRUNCATED_IMAGES = True

from subprocess import Popen, PIPE
import re

# Check for GPU and reduce the default image size if low VRAM
default_image_size = 256  # >8GB VRAM

class Prompt(nn.Module):
    def __init__(self, embed, weight=1., stop=float('-inf')):
        super().__init__()
        self.register_buffer('embed', embed)
        self.register_buffer('weight', torch.as_tensor(weight))
        self.register_buffer('stop', torch.as_tensor(stop))

    def forward(self, input):
        input_normed = F.normalize(input.unsqueeze(1), dim=2)
        embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
        dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
        dists = dists * self.weight.sign()
        return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()

#NR: Split prompts and weights
def split_prompt(prompt):
    vals = prompt.rsplit(':', 2)
    vals = vals + ['', '1', '-inf'][len(vals):]
    return vals[0], float(vals[1]), float(vals[2])

def resize_image(image, out_size):
    ratio = image.size[0] / image.size[1]
    area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
    size = round((area * ratio)**0.5), round((area / ratio)**0.5)
    return image.resize(size, Image.LANCZOS)



In [3]:
from captum.concept import TCAV
from captum.attr import (
    IntegratedGradients,
    LayerIntegratedGradients,
    TokenReferenceBase,
    configure_interpretable_embedding_layer,
    remove_interpretable_embedding_layer,
    visualization
)
from captum.attr._utils.input_layer_wrapper import ModelInputWrapper


In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load CLIP

In [5]:
model, preprocess = clip.load('ViT-B/32', device)


# Image Encoding

In [8]:
image = preprocess(Image.open("square.jpg")).unsqueeze(0).to(device)
# image_input = preprocess(image).unsqueeze(0).to(device)
image_features = model.encode_image(image.cuda())

# Text Encoding

In [10]:
prompt = "square shaped cat"
txt, weight, stop = split_prompt(prompt)

text_features = model.encode_text(clip.tokenize(txt).to(device)).float()

# Running Inference from CLIP (from CLIP repo)

In [27]:
attr = IntegratedGradients(model)

text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

interpretable_embedding = configure_interpretable_embedding_layer = model.token_embedding

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    inputs = (image_features, text_features)
    baselines = (image_features * 0.0, text_features * 0.0)

    logits_per_image, logits_per_text = model(image, text)
    
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()
    probs_max = probs.max()
    progs_argmax = probs.argmax()

#     attributions = attr.attribute(inputs=(image, text),
#                                     baselines=(image * 0, text * 0),
#                                     target=progs_argmax,
#                                     n_steps=30)

In [28]:
# inputs = (image_features, text_features)
# baselines = (image_features * 0.0, text_features * 0.0)

# Code that we shall need to try

In [None]:
layers=['inception4c', 'inception4d', 'inception4e']

tcav = TCAV(model=model,
              layers=layers,
              layer_attr_method = LayerIntegratedGradients(
                model, None, multiply_by_inputs=False))