# Notebook Purpose

Implement TCAV using Pytorch for CLIP

# Load Dependencies

In [1]:
#https://github.com/openai/CLIP
# authors Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings), nerdyrodent
# authors vivian
# The original BigGAN+CLIP method was by https://twitter.com/advadnoun
import threading
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

from torch.autograd import Variable

from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import math
import random
from urllib.request import urlopen
from tqdm import tqdm
import sys
import os
sys.path.append('taming-transformers')
from omegaconf import OmegaConf
from taming.models import cond_transformer, vqgan
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from torch.cuda import get_device_properties
torch.backends.cudnn.benchmark = False
from torch_optimizer import DiffGrad, AdamP, RAdam
from CLIP import clip
import kornia.augmentation as K
import imageio
from PIL import ImageFile, Image, PngImagePlugin, ImageChops
ImageFile.LOAD_TRUNCATED_IMAGES = True
from subprocess import Popen, PIPE
import re

In [13]:
clip.available_models()

['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']

In [14]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load CLIP

In [15]:
model, preprocess = clip.load('ViT-B/32', device)

In [5]:
class Hook:
    """Attaches to a module and records its activations and gradients."""

    def __init__(self, module: nn.Module):
        self.data = None
        self.hook = module.register_forward_hook(self.save_grad)
        
    def save_grad(self, module, input, output):
        self.data = output
        output.requires_grad_(True)
        output.retain_grad()
        
    def __enter__(self):
        return self
    
    def __exit__(self, exc_type, exc_value, exc_traceback):
        self.hook.remove()
        
    @property
    def activation(self) -> torch.Tensor:
        return self.data
    
    @property
    def gradient(self) -> torch.Tensor:
        return self.data.grad

In [6]:
class EarlyStopping():
    """
    Early stopping to stop the training when the loss does not improve after
    certain epochs.
    """
    def __init__(self, patience=5, min_delta=0):
        """
        :param patience: how many epochs to wait before stopping when loss is
               not improving
        :param min_delta: minimum difference between new loss and old loss for
               new loss to be considered as an improvement
        """
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
    def __call__(self, val_loss):
        if self.best_loss == None:
            self.best_loss = val_loss
        elif self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            # reset counter if validation loss improves
            self.counter = 0
        elif self.best_loss - val_loss < self.min_delta:
            self.counter += 1
            print(f"INFO: Early stopping counter {self.counter} of {self.patience}")
            if self.counter >= self.patience:
                print('INFO: Early stopping')
                self.early_stop = True

# Register hooks

In [22]:
# assist from https://web.stanford.edu/~nanbhas/blog/forward-hooks-pytorch/
activations = {}
gradients = {}
def getActivation(name):
    # the hook signature 
    def hook(model, input, output):
        
        output.requires_grad_(True)
        output.retain_grad()
        gradients[name] = output.grad
        activations[name] = output.detach()
    return hook

In [23]:
hooks = []
layers = np.concatenate([[model.visual.conv1], model.visual.transformer.resblocks[1::2]])
layernames = np.concatenate([['layer0'], [f'layer{i}' for i in range(1,13,2)]], dtype=str)
for l, n in zip(layers, layernames):
    hooks.append(l.register_forward_hook(getActivation(n)))

# Image Encoding

# Text Encoding

In [25]:
def get_img_tensors(img_filename, img_dir=""):
    image = preprocess(Image.open(img_dir + img_filename)).unsqueeze(0).to(device)
    return image

In [26]:
def encode_images(img_filename, img_dir=""):
    image = preprocess(Image.open(img_dir + img_filename)).unsqueeze(0).to(device)

    image_features = model.encode_image(image.cuda())
    return image_features

Load an example image

# Define Linear Classifiers

In [2]:
class LinearClassifier(torch.nn.Module):
    
    def __init__(self, num_features):
        super().__init__()
        
        self.linear1 = torch.nn.Linear(num_features, 1)
        

    def forward(self, input_x):
        x = self.linear1(input_x)

        return x



In [3]:
import numpy as np

In [19]:
linear_classifiers_sets = np.load("classifiers_perclass_perlayer_smeared_dotted_knitted_spiralled_chequered.npy", allow_pickle=True)

In [20]:
def get_orthogonal_vector(classifier, classifier_size):
    weight, bias = [param for param in classifier.parameters()]
    weight_vector = weight.squeeze().cpu().detach().numpy()
    orthonormal_vector = np.random.randn(classifier_size)  # take a random vector
    orthonormal_vector -= orthonormal_vector.dot(weight_vector) * weight_vector / np.linalg.norm(weight_vector)**2
    orthonormal_vector /= np.linalg.norm(orthonormal_vector) 
    return orthonormal_vector, weight_vector

In [23]:
cavs = {}

concepts = ["smeared", "dotted", "knitted", "spiralled", "chequered"]

for linear_classifier_set, concept in zip(linear_classifiers_sets, concepts):
    orthonormal_vector, weight_vector = get_orthogonal_vector(linear_classifier_set[-1], linear_classifier_set[-1].linear1.in_features)
    cavs[concept] = orthonormal_vector
    
    

In [22]:
linear_classifiers[0][-1].linear1.in_features

512

In [25]:
import pickle

with open("concept_cavs.pkl", "wb") as f:
    pickle.dump(cavs,f)