# Notebook Purpose

Implement TCAV using Pytorch for CLIP

# Load Dependencies

In [121]:
#https://github.com/openai/CLIP
# authors Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings), nerdyrodent
# authors vivian
# The original BigGAN+CLIP method was by https://twitter.com/advadnoun
import threading
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

from torch.autograd import Variable

from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import math
import random
from urllib.request import urlopen
from tqdm import tqdm
import sys
import os
sys.path.append('taming-transformers')
from omegaconf import OmegaConf
from taming.models import cond_transformer, vqgan
import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from torch.cuda import get_device_properties
torch.backends.cudnn.benchmark = False
from torch_optimizer import DiffGrad, AdamP, RAdam
from CLIP import clip
import kornia.augmentation as K
import imageio
from PIL import ImageFile, Image, PngImagePlugin, ImageChops
ImageFile.LOAD_TRUNCATED_IMAGES = True
from subprocess import Popen, PIPE
import re

In [3]:
clip.available_models()

['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load CLIP

In [5]:
model, preprocess = clip.load('ViT-B/32', device)

In [6]:
class Hook:
    """Attaches to a module and records its activations and gradients."""

    def __init__(self, module: nn.Module):
        self.data = None
        self.hook = module.register_forward_hook(self.save_grad)
        
    def save_grad(self, module, input, output):
        self.data = output
        output.requires_grad_(True)
        output.retain_grad()
        
    def __enter__(self):
        return self
    
    def __exit__(self, exc_type, exc_value, exc_traceback):
        self.hook.remove()
        
    @property
    def activation(self) -> torch.Tensor:
        return self.data
    
    @property
    def gradient(self) -> torch.Tensor:
        return self.data.grad

# Register hooks

In [7]:
# assist from https://web.stanford.edu/~nanbhas/blog/forward-hooks-pytorch/
activations = {}
gradients = {}
def getActivation(name):
    # the hook signature 
    def hook(model, input, output):
        
        output.requires_grad_(True)
        output.retain_grad()
        gradients[name] = output.grad
        activations[name] = output.detach()
    return hook

In [8]:
hooks = []
layers = np.concatenate([[model.visual.conv1], model.visual.transformer.resblocks[1::2]])
layernames = np.concatenate([['layer0'], [f'layer{i}' for i in range(1,13,2)]], dtype=str)
for l, n in zip(layers, layernames):
    hooks.append(l.register_forward_hook(getActivation(n)))

In [9]:
# possibly needed in future if using larger dataset w/ dataloader

# embedding_list = np.empty(layers.shape, dtype=object)
# for i in range(len(embedding_list)):
#     embedding_list[i] = []
# for num_layer, name  in enumerate(layernames):
#     embedding_list[num_layer].append(activations[n])

# Image Encoding

# Text Encoding

In [10]:
def get_img_tensors(img_filename, img_dir=""):
    image = preprocess(Image.open(img_dir + img_filename)).unsqueeze(0).to(device)
    return image

In [11]:
def encode_images(img_filename, img_dir=""):
    image = preprocess(Image.open(img_dir + img_filename)).unsqueeze(0).to(device)

    image_features = model.encode_image(image.cuda())
    return image_features

Load an example image

# Define Linear Classifiers

In [116]:
class LinearClassifier(torch.nn.Module):
    
    def __init__(self, num_features):
        super().__init__()
        
        self.linear1 = torch.nn.Linear(num_features, 1)
        

    def forward(self, input_x):
        x = self.linear1(input_x)

        return x



In [15]:
positive_filenames = os.listdir('tcav/concepts/striped')
negative_filenames = os.listdir('tcav/concepts/random_0')

In [16]:
positive_concept = [encode_images(filename, 'tcav/concepts/striped/') for filename in positive_filenames]
#positive_concept = torch.vstack(positive_concept)

In [17]:
negative_concept = [encode_images(filename, 'tcav/concepts/random_0/') for filename in negative_filenames]
#negative_concept = torch.vstack(negative_concept)

In [18]:
# all_concept = positive_concept + negative_concept

In [19]:

positive_concepts = torch.vstack(positive_concept)
negative_concepts = torch.vstack(negative_concept)

In [20]:
positive_img_tensors = [get_img_tensors(img,'tcav/concepts/striped/') for img in positive_filenames]
positive_img_tensors = torch.vstack(positive_img_tensors)
negative_img_tensors = [get_img_tensors(img,'tcav/concepts/random_0/') for img in negative_filenames]
negative_img_tensors = torch.vstack(negative_img_tensors)


In [21]:
all_img_tensors = torch.vstack([positive_img_tensors, negative_img_tensors])

# Collect features

In [97]:
outputs = model.encode_image(all_img_tensors)


In [98]:
text_inputs = ["zebra"] * len(positive_concept) + ["not zebra"] * len(negative_concept)


In [35]:

text_inputs = [clip.tokenize([text_input]).to(device) for text_input in text_inputs]
target = torch.vstack([model.encode_text(text_input).float() for text_input in text_inputs])

In [84]:
all_layer_gradients = {}
all_layer_activations = {}
for layer, name in zip(layers, layernames):
    layer_gradients = []
    layer_activations = []
    with Hook(layer) as hook:

        # Do a forward and backward pass.
        output = model.encode_image(all_img_tensors)
        output.backward(target)

        grad = hook.gradient.float()
        act = hook.activation.float()
        layer_gradients.append(grad)
        layer_activations.append(act)
    all_layer_gradients[name] = layer_gradients
    all_layer_activations[name] = layer_activations

# Process features

In [117]:
training_data =  []
linear_classifier_sizes = []
for key in all_layer_gradients.keys():
    training_data.append(all_layer_gradients[key][0].view(85,-1))
    linear_classifier_sizes.append(all_layer_gradients[key][0].view(85,-1).shape[-1])
    

# Assemble training data for all layers

In [89]:
positive_labels = torch.tensor(positive_concepts.shape[0] * [1])
negative_labels = torch.tensor(negative_concepts.shape[0] * [0])
# training_data = torch.vstack([positive_concepts, negative_concepts])
class_labels = torch.cat([positive_labels, negative_labels])

# Create dataloaders

In [115]:
criterion = torch.nn.BCEWithLogitsLoss()
dataloaders = []
for train in training_data:
    dataset = TensorDataset(train, class_labels)
    loader = DataLoader(dataset, batch_size=2,
                    pin_memory=False, shuffle=True)
    dataloaders.append(loader)


# Create classifiers

In [127]:
classifiers = []
for classifier_size in linear_classifier_sizes:
    
    classifiers.append(LinearClassifier(classifier_size))




In [128]:
def train_classifier(classifier, dataloader):
    clf = classifier.cuda()
    optimizer = torch.optim.SGD(clf.parameters(), lr=0.001)
    for it in range(n_epochs):
        for i, data in enumerate(dataloader,0):
            inputs, labels = data
            inputs = Variable(inputs, requires_grad=True)
            optimizer.zero_grad()
            outputs = clf(inputs.cuda().float())

            loss = criterion(outputs.cuda().float(), labels.cuda().reshape(-1,1).float())
            loss.backward()
            optimizer.step()
        if it % 10 == 0:
            print(loss)

In [129]:
dataloaders

[<torch.utils.data.dataloader.DataLoader at 0x7f3a1a534070>,
 <torch.utils.data.dataloader.DataLoader at 0x7f3a1a5344c0>,
 <torch.utils.data.dataloader.DataLoader at 0x7f3a1a513d90>,
 <torch.utils.data.dataloader.DataLoader at 0x7f3a1a513ca0>,
 <torch.utils.data.dataloader.DataLoader at 0x7f3a1a513cd0>,
 <torch.utils.data.dataloader.DataLoader at 0x7f3a1add9550>,
 <torch.utils.data.dataloader.DataLoader at 0x7f3a1add9640>]

# Training Classifiers Features

In [130]:
n_epochs = 100
for classifier, dataloader in zip(classifiers, dataloaders):
    print(classifier)
    train_classifier(classifier, dataloader)
    print("trained a classifier")


LinearClassifier(
  (linear1): Linear(in_features=37632, out_features=1, bias=True)
)
tensor(0.5984, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.4483, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.3994, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.2717, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.0681, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.1250, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.2572, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.1261, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.0337, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
tensor(0.0632, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
trained a classifier
LinearClassifier(
  (linear1): Linear(in_features=38400, out_features=1, bias=True)
)
tensor(

# Get orthogonal vector

In [142]:
def get_orthogonal_vector(classifier, classifier_size):
    weight, bias = [param for param in classifier.parameters()]
    cav_vector = weight.squeeze().cpu().detach().numpy()
    orthonormal_vector = np.random.randn(classifier_size)  # take a random vector
    orthonormal_vector -= orthonormal_vector.dot(cav_vector) * cav_vector / np.linalg.norm(cav_vector)**2
    orthonormal_vector /= np.linalg.norm(orthonormal_vector) 
    return orthonormal_vector, cav_vector

In [143]:
cavs = [get_orthogonal_vector(classifier, classifier_size) for classifier, classifier_size in zip(classifiers, linear_classifier_sizes)]

# Check orthogonality

In [145]:
[np.dot(orthonormal_vector,cav_vector) for orthonormal_vector, cav_vector in cavs]

[4.934755295366022e-12,
 5.17960492205094e-11,
 1.0493194464375466e-10,
 1.6028538392387492e-10,
 -4.232600728237834e-11,
 7.668263853760626e-11,
 5.1244056415745975e-11]

In [167]:
for i in range(85):
    print(i)
    print(torch.dot(all_layer_gradients['layer0'][0].view(85,-1)[i], torch.tensor(cavs[0][1]).cuda() ) )

0
tensor(3.8289, device='cuda:0')
1
tensor(1.6838, device='cuda:0')
2
tensor(3.1859, device='cuda:0')
3
tensor(1.9707, device='cuda:0')
4
tensor(2.5689, device='cuda:0')
5
tensor(3.0514, device='cuda:0')
6
tensor(1.7011, device='cuda:0')
7
tensor(4.2928, device='cuda:0')
8
tensor(2.4806, device='cuda:0')
9
tensor(3.2958, device='cuda:0')
10
tensor(4.3089, device='cuda:0')
11
tensor(2.1080, device='cuda:0')
12
tensor(3.6501, device='cuda:0')
13
tensor(4.1375, device='cuda:0')
14
tensor(2.1915, device='cuda:0')
15
tensor(1.9096, device='cuda:0')
16
tensor(2.8473, device='cuda:0')
17
tensor(2.8616, device='cuda:0')
18
tensor(3.5900, device='cuda:0')
19
tensor(3.6509, device='cuda:0')
20
tensor(2.1468, device='cuda:0')
21
tensor(2.3773, device='cuda:0')
22
tensor(3.5763, device='cuda:0')
23
tensor(3.9159, device='cuda:0')
24
tensor(3.1260, device='cuda:0')
25
tensor(2.9013, device='cuda:0')
26
tensor(6.1829, device='cuda:0')
27
tensor(2.2190, device='cuda:0')
28
tensor(2.3353, device='cuda

In [169]:
for i in range(85):
    print(i)
    print(torch.dot(all_layer_gradients['layer1'][0].view(85,-1)[i], torch.tensor(cavs[1][1]).cuda() ) )

0
tensor(6.7607, device='cuda:0')
1
tensor(5.4853, device='cuda:0')
2
tensor(2.3870, device='cuda:0')
3
tensor(1.7024, device='cuda:0')
4
tensor(1.4893, device='cuda:0')
5
tensor(2.0168, device='cuda:0')
6
tensor(1.2613, device='cuda:0')
7
tensor(1.8180, device='cuda:0')
8
tensor(1.6822, device='cuda:0')
9
tensor(1.2384, device='cuda:0')
10
tensor(2.0185, device='cuda:0')
11
tensor(1.6535, device='cuda:0')
12
tensor(2.5462, device='cuda:0')
13
tensor(1.9294, device='cuda:0')
14
tensor(2.2377, device='cuda:0')
15
tensor(1.2286, device='cuda:0')
16
tensor(1.0363, device='cuda:0')
17
tensor(1.2599, device='cuda:0')
18
tensor(1.2935, device='cuda:0')
19
tensor(1.6102, device='cuda:0')
20
tensor(1.4870, device='cuda:0')
21
tensor(1.5280, device='cuda:0')
22
tensor(1.5480, device='cuda:0')
23
tensor(1.5319, device='cuda:0')
24
tensor(1.8075, device='cuda:0')
25
tensor(1.9941, device='cuda:0')
26
tensor(1.7329, device='cuda:0')
27
tensor(1.2422, device='cuda:0')
28
tensor(1.3469, device='cuda

In [171]:
for i in range(85):
    print(i)
    print(torch.dot(all_layer_gradients['layer3'][0].view(85,-1)[i], torch.tensor(cavs[2][1]).cuda() ) )

0
tensor(7.4027, device='cuda:0')
1
tensor(5.5752, device='cuda:0')
2
tensor(2.2588, device='cuda:0')
3
tensor(1.4836, device='cuda:0')
4
tensor(1.3638, device='cuda:0')
5
tensor(1.6372, device='cuda:0')
6
tensor(0.9985, device='cuda:0')
7
tensor(1.5408, device='cuda:0')
8
tensor(1.3526, device='cuda:0')
9
tensor(1.3111, device='cuda:0')
10
tensor(1.6788, device='cuda:0')
11
tensor(1.1999, device='cuda:0')
12
tensor(2.1676, device='cuda:0')
13
tensor(1.3634, device='cuda:0')
14
tensor(1.7251, device='cuda:0')
15
tensor(1.0625, device='cuda:0')
16
tensor(1.0167, device='cuda:0')
17
tensor(0.9452, device='cuda:0')
18
tensor(1.1123, device='cuda:0')
19
tensor(1.1649, device='cuda:0')
20
tensor(1.2882, device='cuda:0')
21
tensor(1.2823, device='cuda:0')
22
tensor(1.1245, device='cuda:0')
23
tensor(1.2403, device='cuda:0')
24
tensor(1.4950, device='cuda:0')
25
tensor(1.1967, device='cuda:0')
26
tensor(1.5545, device='cuda:0')
27
tensor(0.8596, device='cuda:0')
28
tensor(1.2808, device='cuda

In [172]:
for i in range(85):
    print(i)
    print(torch.dot(all_layer_gradients['layer5'][0].view(85,-1)[i], torch.tensor(cavs[3][1]).cuda() ) )

0
tensor(6.4838, device='cuda:0')
1
tensor(4.9547, device='cuda:0')
2
tensor(1.8843, device='cuda:0')
3
tensor(1.1201, device='cuda:0')
4
tensor(0.9666, device='cuda:0')
5
tensor(1.2579, device='cuda:0')
6
tensor(0.7736, device='cuda:0')
7
tensor(1.1353, device='cuda:0')
8
tensor(0.9888, device='cuda:0')
9
tensor(0.8668, device='cuda:0')
10
tensor(1.2264, device='cuda:0')
11
tensor(0.9067, device='cuda:0')
12
tensor(1.7112, device='cuda:0')
13
tensor(1.0382, device='cuda:0')
14
tensor(1.2414, device='cuda:0')
15
tensor(0.6945, device='cuda:0')
16
tensor(0.7157, device='cuda:0')
17
tensor(0.5540, device='cuda:0')
18
tensor(0.8065, device='cuda:0')
19
tensor(0.8173, device='cuda:0')
20
tensor(0.8123, device='cuda:0')
21
tensor(0.9648, device='cuda:0')
22
tensor(0.7711, device='cuda:0')
23
tensor(0.9760, device='cuda:0')
24
tensor(0.9625, device='cuda:0')
25
tensor(0.8729, device='cuda:0')
26
tensor(1.0537, device='cuda:0')
27
tensor(0.6170, device='cuda:0')
28
tensor(0.8727, device='cuda

In [173]:
for i in range(85):
    print(i)
    print(torch.dot(all_layer_gradients['layer7'][0].view(85,-1)[i], torch.tensor(cavs[4][1]).cuda() ) )

0
tensor(5.8021, device='cuda:0')
1
tensor(5.1094, device='cuda:0')
2
tensor(0.7930, device='cuda:0')
3
tensor(0.4495, device='cuda:0')
4
tensor(0.4938, device='cuda:0')
5
tensor(0.4972, device='cuda:0')
6
tensor(0.4100, device='cuda:0')
7
tensor(0.5353, device='cuda:0')
8
tensor(0.4647, device='cuda:0')
9
tensor(0.4723, device='cuda:0')
10
tensor(0.5793, device='cuda:0')
11
tensor(0.4223, device='cuda:0')
12
tensor(0.8306, device='cuda:0')
13
tensor(0.4524, device='cuda:0')
14
tensor(0.6440, device='cuda:0')
15
tensor(0.5499, device='cuda:0')
16
tensor(0.4344, device='cuda:0')
17
tensor(0.3086, device='cuda:0')
18
tensor(0.4966, device='cuda:0')
19
tensor(0.4798, device='cuda:0')
20
tensor(0.5295, device='cuda:0')
21
tensor(0.5987, device='cuda:0')
22
tensor(0.3613, device='cuda:0')
23
tensor(0.6090, device='cuda:0')
24
tensor(0.5517, device='cuda:0')
25
tensor(0.4618, device='cuda:0')
26
tensor(0.6015, device='cuda:0')
27
tensor(0.3761, device='cuda:0')
28
tensor(0.5619, device='cuda

In [174]:
for i in range(85):
    print(i)
    print(torch.dot(all_layer_gradients['layer9'][0].view(85,-1)[i], torch.tensor(cavs[5][1]).cuda() ) )

0
tensor(5.1552, device='cuda:0')
1
tensor(4.5935, device='cuda:0')
2
tensor(0.1010, device='cuda:0')
3
tensor(0.0665, device='cuda:0')
4
tensor(0.0755, device='cuda:0')
5
tensor(0.0787, device='cuda:0')
6
tensor(0.0740, device='cuda:0')
7
tensor(0.0859, device='cuda:0')
8
tensor(0.1022, device='cuda:0')
9
tensor(0.0742, device='cuda:0')
10
tensor(0.1247, device='cuda:0')
11
tensor(0.0703, device='cuda:0')
12
tensor(0.1135, device='cuda:0')
13
tensor(0.0756, device='cuda:0')
14
tensor(0.0863, device='cuda:0')
15
tensor(0.0699, device='cuda:0')
16
tensor(0.0754, device='cuda:0')
17
tensor(0.0511, device='cuda:0')
18
tensor(0.1023, device='cuda:0')
19
tensor(0.1085, device='cuda:0')
20
tensor(0.1270, device='cuda:0')
21
tensor(0.1358, device='cuda:0')
22
tensor(0.0842, device='cuda:0')
23
tensor(0.1055, device='cuda:0')
24
tensor(0.0665, device='cuda:0')
25
tensor(0.0640, device='cuda:0')
26
tensor(0.1035, device='cuda:0')
27
tensor(0.0752, device='cuda:0')
28
tensor(0.1029, device='cuda

In [175]:
for i in range(85):
    print(i)
    print(torch.dot(all_layer_gradients['layer11'][0].view(85,-1)[i], torch.tensor(cavs[6][1]).cuda() ) )

0
tensor(6.1153, device='cuda:0')
1
tensor(5.0276, device='cuda:0')
2
tensor(0., device='cuda:0')
3
tensor(0., device='cuda:0')
4
tensor(0., device='cuda:0')
5
tensor(0., device='cuda:0')
6
tensor(0., device='cuda:0')
7
tensor(0., device='cuda:0')
8
tensor(0., device='cuda:0')
9
tensor(0., device='cuda:0')
10
tensor(0., device='cuda:0')
11
tensor(0., device='cuda:0')
12
tensor(0., device='cuda:0')
13
tensor(0., device='cuda:0')
14
tensor(0., device='cuda:0')
15
tensor(0., device='cuda:0')
16
tensor(0., device='cuda:0')
17
tensor(0., device='cuda:0')
18
tensor(0., device='cuda:0')
19
tensor(0., device='cuda:0')
20
tensor(0., device='cuda:0')
21
tensor(0., device='cuda:0')
22
tensor(0., device='cuda:0')
23
tensor(0., device='cuda:0')
24
tensor(0., device='cuda:0')
25
tensor(0., device='cuda:0')
26
tensor(0., device='cuda:0')
27
tensor(0., device='cuda:0')
28
tensor(0., device='cuda:0')
29
tensor(0., device='cuda:0')
30
tensor(0., device='cuda:0')
31
tensor(0., device='cuda:0')
32
tensor(

You need labels to get gradients, to get a loss to backpropagate with

# Calculate TCAV score

In [167]:
image = preprocess(Image.open('tcav/concepts/striped/striped_0086.jpg')).unsqueeze(0).to(device)
image_features = model.encode_image(image.cuda())

In [182]:
x = Variable(image.cuda(), requires_grad=True)

In [183]:
y = model.encode_image(x)

In [184]:
y.backward(torch.ones(1,1024).cuda())

In [195]:
model.visual.proj

torch.Size([768, 512])