In [None]:
# !git clone https://github.com/nick01as/CLIP-dissect
!pip install --upgrade pip
!pip install ftfy regex
!pip install diffusers
!pip install accelerate
!pip install pandas
!pip install torch==2.0

import os
home_dir = os.getcwd()
os.chdir('CLIP-dissect')

import torch
import pandas as pd
import numpy as np
import random
import statistics
from itertools import permutations
from torch.utils.data import DataLoader

In [None]:
# Change number of workers before running cell
!pip install -r requirements.txt
!pip install transformers==4.28.0
!pip install torchvision==0.15.1
!pip install scipy
!pip install matplotlib
!pip install tornado==5.1.1
import clip
import data_utils
import similarity
import utils
from transformers import pipeline
from diffusers import StableDiffusionPipeline
from diffusers import DPMSolverMultistepScheduler
import torch
from matplotlib import pyplot as plt

In [None]:
# Initialize Stable Diffusion
model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))
generator = torch.Generator(device="cuda").manual_seed(0)
pipe = pipe.to("cuda")

In [5]:
clip_name = 'ViT-B/16'
target_name = 'resnet50'
target_layer = 'layer3'
d_probe = 'broden' 
concept_set = 'data/20k.txt'

batch_size = 200
device = 'cuda'
pool_mode = 'avg'

save_dir = 'saved_activations'
similarity_fn = similarity.soft_wpmi

In [6]:
utils.save_activations(clip_name = clip_name, target_name = target_name, target_layers = [target_layer],
                       d_probe = d_probe, concept_set = concept_set, batch_size = batch_size,
                       device = device, pool_mode=pool_mode, save_dir = save_dir)

save_names = utils.get_save_names(clip_name = clip_name, target_name = target_name,
                                  target_layer = target_layer, d_probe = d_probe,
                                  concept_set = concept_set, pool_mode=pool_mode,
                                  save_dir = save_dir)

target_save_name, clip_save_name, text_save_name = save_names

In [7]:
import random
from PIL import Image

os.chdir(home_dir)
os.chdir('CLIP-dissect')

with open(concept_set, 'r') as f:
        words = f.read().split('\n')

# Vague words
discard_set = ['design','designs','visual','visuals','item','items','object','objects','imagery','image','images','element','elements']
        
# Remove vague words
def filter_word(word):
    word = word.split()
    new_word = ""
    for w in word:
        if w in discard_set:
            continue
        else:
            new_word += w + ' '
    if new_word[-1] == ' ':
        new_word = new_word[:-1]
    return new_word

# Find index of word in concept set
def get_id_for_word(word):
    with open(concept_set, 'r') as f:
        words = f.read().split('\n')
    try:
        return words.index(word)
    except:
        print("Error: Word is not in concept set, {} found expected {}".format(type(word), type('str')))

# Get pre-generated images
def get_images(word, images_to_pull, old_path, new_path, home_dir):
    
    concept_id = get_id_for_word(word)
    
    os.chdir(home_dir)
    os.chdir(new_path)
    
    random_ids = []

    # Pull random images from image set
    while len(random_ids) < images_to_pull:
        rnd_id = random.randint(0,20)
        if rnd_id not in random_ids: 
            random_ids.append(rnd_id)
    
    image_set = []
    for img_id in random_ids:
        img = Image.open(r'imagenet_labels_concept_{}_image_{}.png'.format(concept_id, img_id))
        image_set.append(img)
    
    os.chdir(old_path)
    
    return image_set

In [None]:
!pip install -U scikit-learn
import statistics
from sklearn.linear_model import LinearRegression

In [9]:
# scoring methods
mode_list = ['topk-sq-mean', 'reg', 'mean', 'median', 'sq-mean']

# remove outliers from image rankings
def rm_outliers(ranks, rm_low_outliers, rm_high_outliers):
    for label_id in ranks:
        rank_arr = np.array(ranks[label_id])
        q1 = np.quantile(rank_arr, 0.25)
        q3 = np.quantile(rank_arr, 0.75)
        iqr = q3-q1
        
        new_ranks = []
        for pos in rank_arr:
            if rm_low_outliers == True and pos < q1 - (iqr * 1.5):
                continue
            if rm_high_outliers == True and pos > q3 + (iqr * 1.5):
                continue
            else:
                new_ranks.append(pos)
        ranks[label_id] = new_ranks
    return ranks

# mean of top-k values squared
def topk_sq_mean(ranks, k = 5):
    top_vals = []
    for label_id in ranks:
        sq_sum = 0
        for i in range(k):
            sq_sum += (ranks[label_id][i] ** 2)
        top_vals.append((sq_sum / k, label_id))
    top_vals.sort()
    return top_vals

# regression + prediction
def reg(ranks, quartile = 0.25):
    top_vals = []
    X_vals = [i for i in range(len(ranks[0]))]
    p = quartile * float(len(ranks[0]))
    
    for label_id in ranks:
        regr = LinearRegression()
        regr.fit([X_vals], [ranks[label_id]])
        pred = regr.predict(np.array([p for _ in range(len(ranks[0]))]).reshape((-1,len(ranks[0]))))[0][0]
        top_vals.append((pred**3,label_id))
    top_vals.sort()
    return top_vals

def mean(ranks):
    top_vals = []
    for label_id in ranks:
        top_vals.append((sum(ranks[label_id])/len(ranks[label_id]), label_id))
    top_vals.sort()
    return top_vals

def median(ranks):
    top_vals = []
    for label_id in ranks:
        top_vals.append((stats.median(ranks[label_id]), label_id))
    top_vals.sort()
    return top_vals

# mean of squared values
def sq_mean(ranks):
    top_vals = []
    for label_id in ranks:
        top_vals.append((sum([val**2 for val in ranks[label_id]])/len(ranks[label_id]), label_id))
    top_vals.sort()
    return top_vals
    
# get score of label
def get_score(ranks, mode = 'topk-sq-mean', hyp_param = None, rm_low_outliers = False, rm_high_outliers = False):
    if mode not in mode_list:
        raise Exception("Invalid score mode '{}'",format(mode))
    
    if rm_low_outliers == True or rm_high_outliers == True:
        ranks = rm_outliers(ranks, rm_low_outliers, rm_high_outliers)
    
    if mode == 'topk-sq-mean':
        return topk_sq_mean(ranks, hyp_param)
    if mode == 'reg':
        return reg(ranks, hyp_param)
    if mode == 'mean':
        return mean(ranks)
    if mode == 'median':
        return median(ranks)
    if mode == 'sq-mean':
        return sq_mean(ranks)
    
    

In [19]:
"""
# for debuging purposes

import sys
del sys.modules['utils']
del utils
import utils
"""

In [None]:
# neuron to check + generative label
comp_words = {791:'checker', 880:'stripes', 844:'dotted backgrounds', 
              774:'bike', 658:'household spaces', 776:'pantry or store display', 
              188:'spider webs', 162:'purple-themed designs', 513:'polka dot patterns', 
              121:'dotted', 381:'dotted surfaces', 357:'polka dot patterns', 
              516:'striped clothing', 414:'knitting or crochet', 426:'corridor', 
              148:'dotted', 59:'spiral-themed elements', 772:'leaf', 
              945:'corridor', 277:'kitchens'}

# filter out vague words
for neuron_id in comp_words:
    comp_words[neuron_id] = filter_word(comp_words[neuron_id])

In [None]:
os.chdir(home_dir)
os.chdir('CLIP-dissect')

# del image_set
torch.cuda.empty_cache()

# Read labels and get D_probe
with open(concept_set, 'r') as f:
        words = f.read().split('\n')
pil_data = data_utils.get_data(d_probe)
d_probe_len = len(pil_data)

# Get directory of new saved activations
save_names = utils.get_save_names(clip_name = clip_name, target_name = target_name,
                              target_layer = target_layer, d_probe = d_probe,
                              concept_set = concept_set, pool_mode=pool_mode,
                              save_dir = save_dir, newSet = True)
new_target_save_name, new_clip_save_name, text_save_name = save_names

# Ensure all previous files are deleted
location = location = ""
if os.path.exists(new_target_save_name):
    target_path = os.path.join(location, new_target_save_name) 
    os.remove(target_path)
if os.path.exists(new_clip_save_name):
    clip_path = os.path.join(location, new_clip_save_name)
    os.remove(clip_path)
if os.path.exists(new_text_save_name):
    text_path = os.path.join(location, new_text_save_name)
    os.remove(text_path)
print('Removed files')

# Neurons to check
neurons_to_check = [i for i in comp_words]

# Block configuration = (# labels to collect, #image per label, (#scoring model, hyperparameter if required))
it_settings = [(15, 10, ('topk-sq-mean', 5))]
# it_settings = [(15, 10, ('topk-sq-mean', 5)), (10, 8, ('topk-sq-mean', 5)), (3, 15, ('topk-sq-mean', 3))]

# Main code
for list_id, orig_id in enumerate(neurons_to_check):
    
    # Add the generative word to the concept set
    words.append(comp_words[orig_id])
    
    save_names = utils.get_save_names(clip_name = clip_name, target_name = target_name,
                              target_layer = target_layer, d_probe = d_probe,
                              concept_set = concept_set, pool_mode=pool_mode,
                              save_dir = save_dir)
    target_save_name, clip_save_name, text_save_name = save_names
    
    # Make sure saved text files from previous runs are deleted
    if os.path.exists(text_save_name):
        location = location = ""
        text_path = os.path.join(location, text_save_name)
        os.remove(text_path)
        print('Removed files')

    # Save new concept set
    clip_model, clip_preprocess = clip.load(clip_name, device=device)
    text = clip.tokenize(["{}".format(word) for word in words]).to(device)
    utils.save_clip_text_features(clip_model, text, text_save_name, batch_size)

    # Get similarities and target_feats
    similarities, orig_target_feats = utils.get_similarity_from_activations(target_save_name, clip_save_name,
                                                             text_save_name, similarity_fn, device=device)
    
    # Sort labels by similarity highest -> lowest
    _, ids = torch.topk(similarities[orig_id], k=len(similarities[orig_id]), largest=True)
    
    # Initialize starting concepts
    word_list = []
    for label_id in range(it_settings[0][0]):
        print('id:{}'.format(int(ids[label_id])))
        word_list.append(words[int(ids[label_id])])
        
    # add generative label if generative label CLIP-dissect rank > # labels to collect in first block
    if comp_words[orig_id] not in word_list:
        for lb_id in range(it_settings[0][0], len(ids)):
            if words[int(ids[lb_id])] == comp_words[orig_id]:
                print('Word found at position {}'.format(lb_id))
                word_list.append(words[int(ids[lb_id])])
                break
    
    print("Neuron {}".format(orig_id))
    
    best_label = ""

    gathered = False
    
    # CLIP-dissect's best label
    print('best: {}'.format(word_list[0]))

    # For each block
    for it_num, it in enumerate(it_settings):

        # Block iteration
        print("Iteration: {}".format(it_num))
        
        # Print where generative label was found
        if comp_words[orig_id] in word_list:
            if it_num == 0: 
                gathered = True
            print('Word found at position {}'.format(word_list.index(comp_words[orig_id])))
        else: # Rank > number labels to gather for the block
            print('Not in list')
            break
        
        # Get block settings
        labels_to_check, num_images_per_prompt, mode_description = it
        mode, hyp_param = mode_description
        
        # Account for added generative label (if necessary)
        labels_to_check = max(labels_to_check, len(word_list))
        
        print(labels_to_check, num_images_per_prompt)
        
        add_im = {}
        add_im_id = {}
        labels = {}
        
        print('Gathering images...', end = "")

        # Generate images for each label
        for label_id in range(labels_to_check):
            pred_label = word_list[label_id]
            labels[label_id] = pred_label # maps label_id to label

            add_im_id[label_id] = [] # initialize image list

            # Generate images
            image_set = pipe(pred_label, generator = generator, num_images_per_prompt = num_images_per_prompt, num_inference_steps=15)

            # Use this if using pre-generated images
            # image_set = get_images(pred_label, num_images_per_prompt, old_path = os.getcwd(), home_dir = home_dir, new_path = '/expanse/lustre/scratch/nbai/temp_project/generated_images')

            for i in range(num_images_per_prompt):
                # Use this if using pre-generated images
                #image = image_set[i]

                # Rescale image
                image = image_set.images[i]
                image = image.resize([32,32])

                new_idx = len(add_im)
                add_im[new_idx] = image # Add image to list
                add_im_id[label_id].append(new_idx) # map new image indices to corresponding label_id
        print('Done')
        del image_set
        torch.cuda.empty_cache()

        # save the new concept set and d_probe
        # reuse cifar100 class to store information (because i'm lazy lol)
        utils.save_new_activations(clip_name = clip_name, target_name = target_name, target_layers = [target_layer],
                          d_probe = 'cifar100_train', new_images = add_im,
                          concept_set = concept_set, wordList = word_list, batch_size = batch_size,
                          device = device, pool_mode=pool_mode, save_dir = save_dir)

        # Get new similarity and target_feats
        save_names = utils.get_save_names(clip_name = clip_name, target_name = target_name,
                                    target_layer = target_layer, d_probe = 'cifar100_train',
                                    concept_set = concept_set, pool_mode=pool_mode,
                                    save_dir = save_dir, newSet = True)

        new_target_save_name, new_clip_save_name, new_text_save_name = save_names   

        similarity, target_feats = utils.get_similarity_from_activations(new_target_save_name, new_clip_save_name,
                                                                 new_text_save_name, similarity_fn, k=len(add_im), device=device)
        
        # Sort images based on activation
        top_vals, top_ids = torch.sort(target_feats, dim=0, descending = True)
        top_image_id = top_ids[:,orig_id]

        # Ranks: label_id -> (indicies of corresponding images in sorted target_feats)
        ranks = {label_id:[] for label_id in range(labels_to_check)}

        # Insert indices of image activations into ranks
        for label_id in range(labels_to_check):
            for i, img_id in enumerate(top_image_id):
                if img_id in add_im_id[label_id]:
                    ranks[label_id].append(i)
            ranks[label_id].sort()
        
        # Reset word_list
        word_list = []
        top_avg = []
        
        # Score labels based on ranks of corresponding generated images
        if mode != 'soft-wpmi':
            top_avg = get_score(ranks, mode, hyp_param, rm_high_outliers = True)
        else:
            new_val, new_id = torch.topk(similarity[orig_id], k=labels_to_check, largest=True)
            print(new_id.shape)
            for i in range(len(new_val)):
                top_avg.append((int(new_val[i]), int(new_id[i])))
            
        if it_num < len(it_settings) - 1: # Generate concept set for next block iteration
            for next_word in range(it_settings[it_num + 1][0]):
                word_list.append(labels[top_avg[next_word][1]])
            print("new list size: {}".format(len(word_list)))
        else: # Record best label
            best_label = labels[top_avg[0][1]]
            
        # Get position of generative label
        in_list = False
        for i in range(len(top_avg)):
            if labels[top_avg[i][1]] == comp_words[orig_id]:
                print("Generative Label found at position: {}".format(i))
                in_list = True
                break
        
        # For debugging purposes
        # if in_list == False:
        #     print("Not found in top {}".format(it_settings[0][0]))
        
        #for i in range(3):
        #   print('Rank {} ({}): {}'.format(i, labels[top_avg[i][1]], ranks[top_avg[i][1]]))
        
        # Remove files for next iteration
        location = ""
        target_path = os.path.join(location, new_target_save_name)  
        clip_path = os.path.join(location, new_clip_save_name)
        text_path = os.path.join(location, new_text_save_name)
        os.remove(target_path)
        os.remove(clip_path)
        os.remove(text_path)
    
    # Print results
    if gathered == True:
        print('------------------------------\n')
        print('Neuron {}:'.format(orig_id))
        print('CLIP Label: {}'.format(words[int(ids[0])]))
        print('New Label: {}'.format(best_label))
        print('\n------------------------------')

    # Remove added generative label
    words = words[:-1]