In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from util import ActivationDataset
from learned_dict import TiedSAE
from transformers import ViTForImageClassification, ViTImageProcessor
from feature_vis import *
from collections import Counter
from datasets import load_dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
augment = nn.Sequential(
    RepeatBatch(8),
    ColorJitter(8),
    GausianNoise(8),
    Tile(1), Jitter()
)

In [None]:
model = ViTForImageClassification.from_pretrained("nateraw/vit-base-patch16-224-cifar10")
processor = ViTImageProcessor.from_pretrained("nateraw/vit-base-patch16-224-cifar10")
_ = model.to(device).eval()

In [None]:
dataset = load_dataset("cifar10", split="test")
images = dataset[:10]["img"]
labels = dataset[:10]["label"]

In [None]:
fig, axes = plt.subplots(2, 5, figsize=(10, 4))
for i in range(2):
    for j in range(5):
        axes[i,j].set_title(labels[i*5+j])
        axes[i,j].imshow(images[i*5+j])

plt.tight_layout()
plt.show()

In [None]:
inputs = processor(images=images, return_tensors="pt").pixel_values.to(device)

In [None]:
model(inputs, output_hidden_states=True)['hidden_states'][-1].cpu().detach().numpy().mean(1).argmax(axis=1)
# feature 187 is always the most activated in last layer

In [None]:
activations = model(inputs, output_hidden_states=True)['hidden_states'][-1]

In [None]:

neuron_index = 187 #most activated feature for class 3 cat
input_size = 224
optimized_input = feature_vis(model, 11,neuron_index, input_size, num_iterations=1000, lr=0.1, device=device, augment=augment, lambda_tv = 0.0005)
plt.imshow(optimized_input[0].permute(1, 2, 0).cpu().numpy())
plt.show()

In [None]:

neuron_index = 154 #most activated feature after 187 for airplane
input_size = 224
optimized_input = feature_vis(model, 11,neuron_index, input_size, num_iterations=1000, lr=0.01, device=device, augment=augment, lambda_tv = 0.0005, show_intermediate=True)
plt.imshow(optimized_input[0].permute(1, 2, 0).cpu().numpy())
plt.show()

In [None]:

neuron_index = 432 #most activated feature after 187 for class truck
input_size = 224
optimized_input = feature_vis(model, 11,neuron_index, input_size, num_iterations=1000, lr=0.01, device=device, augment=augment, lambda_tv = 0.0005)
plt.imshow(optimized_input[0].permute(1, 2, 0).cpu().numpy())
plt.show()

In [None]:

neuron_index = 86 #most activated feature after 187 for class dog
input_size = 224
optimized_input = feature_vis(model, 11,neuron_index, input_size, num_iterations=1000, lr=0.01, device=device, augment=augment, lambda_tv = 0.0005)
plt.imshow(optimized_input[0].permute(1, 2, 0).cpu().numpy())
plt.show()

### initialize SAE

In [None]:
# initialize sae
encoder = torch.randn((7680, 768), device = device) # encoder
nn.init.xavier_uniform_(encoder)
encoder_bias = torch.zeros(7680, device = device) # encoder bias
sae = TiedSAE(encoder, encoder_bias)
sae.load_state_dict(torch.load('SAE_models/cifar10/SAE_ratio10_epoch100_lr0.0001.pth'))
sae.to_device(device)


def top_act(act_path:str, target_layer:int, cls:int, sae,k=30):
    datasets = []

    dataset = ActivationDataset(f'{act_path}/cifar10_activations_{cls}.h5',f'vit.encoder.layer.{target_layer}.output')
    datasets.append(dataset)
    dataset = torch.utils.data.ConcatDataset(datasets)
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4)
    sae.to_device(device)
    sae.encoder.requires_grad = False
    sae.encoder_bias.requires_grad = False

    top_activations = []
    with torch.no_grad():
        for i, (image, activations) in enumerate(data_loader):
            activations = activations.to(device)
            c = sae.encode(activations)
            top_indices = torch.topk(c.mean(dim=1), k, largest=True).indices
            top_activations.append(top_indices)
    top_activations = torch.cat(top_activations, dim=0)

    return top_activations

In [None]:
for i in activations:
    print(sae.encode(i.unsqueeze(0)).squeeze().mean(0).argmax())

In [None]:
import seaborn as sns

fig, axes = plt.subplots(2, 5, figsize=(12, 4))
for i in range(2):
    for j in range(5):
        sns.heatmap(activations[i*5+j][1:,187].reshape(14,14).cpu().detach().numpy(), ax=axes[i,j])
        axes[i,j].set_title(labels[i*5+j])
plt.tight_layout()
plt.show()

# implement to ViT-SAE model

## for features most activated for final layers

In [None]:
encode = sae_encoder()
encode.encode.weight = nn.Parameter(sae.encoder.data, requires_grad=False)
encode.encode.bias = nn.Parameter(sae.encoder_bias.data, requires_grad=False)

decode = sae_decoder()
decode.decode.weight = nn.Parameter(sae.get_learned_dict().data.T, requires_grad=False)

In [None]:
model.vit.encoder.layer[11] = Added_layer(model.vit.encoder.layer[11], encode, decode)


In [None]:

neuron_index = 6775 #most activated feature for class 3 cat
input_size = 224
optimized_input = sae_feature_vis(model, neuron_index, input_size, num_iterations=1000, lr=0.1, device=device, augment=augment, lambda_tv = 0.0005)
plt.imshow(optimized_input[0].permute(1, 2, 0).cpu().numpy())
plt.show()

In [None]:
neuron_index = 5115 #most activated feature for class 0 airplane
input_size = 224
optimized_input = sae_feature_vis(model, neuron_index, input_size, num_iterations=1000, lr=0.1, device=device, augment=augment, lambda_tv = 0.0005, show_intermediate=False)
plt.imshow(optimized_input[0].permute(1, 2, 0).cpu().numpy())
plt.show()

In [None]:

neuron_index = 1454
input_size = 224
optimized_input = sae_feature_vis(model, neuron_index, input_size, num_iterations=1000, lr=0.1, device=device, augment=augment, lambda_tv = 0.0005, show_intermediate=True)
plt.imshow(optimized_input[0].permute(1, 2, 0).cpu().numpy())
plt.show()

In [None]:

neuron_index = 2525 #most activated feature for class dog
input_size = 224
optimized_input = sae_feature_vis(model, neuron_index, input_size, num_iterations=1000, lr=0.1, device=device, augment=augment, lambda_tv = 0.0005, show_intermediate=True)
plt.imshow(optimized_input[0].permute(1, 2, 0).cpu().numpy())
plt.show()

In [None]:
lambdas=[0, 0.0001, 0.0005, 0.001]
for i in lambdas:
    neuron_index = 2525 #most activated feature for class dog
    input_size = 224
    optimized_input = sae_feature_vis(model, neuron_index, input_size, num_iterations=1000, lr=0.1, device=device, augment=augment, lambda_tv = i, show_intermediate=False)
    print(f'lambda: {i}')
    plt.imshow(optimized_input[0].permute(1, 2, 0).cpu().numpy())
    plt.show()

In [None]:

neuron_index = 4753
input_size = 224
optimized_input = sae_feature_vis(model, neuron_index, input_size, num_iterations=1000, lr=0.1, device=device, augment=augment, lambda_tv = 0.0005, show_intermediate=True)
plt.imshow(optimized_input[0].permute(1, 2, 0).cpu().numpy())
plt.show()

In [None]:

neuron_index = 880
input_size = 224
optimized_input = sae_feature_vis(model, neuron_index, input_size, num_iterations=1000, lr=0.1, device=device, augment=augment, lambda_tv = 0.0005, show_intermediate=True)
plt.imshow(optimized_input[0].permute(1, 2, 0).cpu().numpy())
plt.show()

In [None]:
most_by_class = [4775,5266,2792,6775,1189,2525,1454,599,319,3673 ]
for n in most_by_class:
    neuron_index = n # most for airplane
    input_size = 224
    optimized_input = sae_feature_vis(model, neuron_index, input_size, num_iterations=1000, lr=0.1, device=device, augment=augment, lambda_tv = 0.0005, show_intermediate=False)
    print(f'feature {n}')
    plt.imshow(optimized_input[0].permute(1, 2, 0).cpu().numpy())
    plt.show()

In [None]:
#2nd most activated for automobile

neuron_index = 509
input_size = 224
optimized_input = sae_feature_vis(model, neuron_index, input_size, num_iterations=1000, lr=0.1, device=device, augment=augment, lambda_tv = 0.00005, show_intermediate=True)
plt.imshow(optimized_input[0].permute(1, 2, 0).cpu().numpy())
plt.show()

In [None]:
#2nd most activated for horse

neuron_index = 7100
input_size = 224
optimized_input = sae_feature_vis(model, neuron_index, input_size, num_iterations=1000, lr=0.1, device=device, augment=augment, lambda_tv = 0.0005, show_intermediate=False)
plt.imshow(optimized_input[0].permute(1, 2, 0).cpu().numpy())
plt.show()

## for features most activated for non-later layers

In [None]:
#for early layers
model = ViTForImageClassification.from_pretrained("nateraw/vit-base-patch16-224-cifar10")
model.eval()
model.vit.encoder.layer[0] = Added_layer(model.vit.encoder.layer[0], encode, decode)


In [None]:
act_path ='activations_cifar10_vit_b'
target_layer = 0
cls = 0
top_activations = top_act(act_path, target_layer, cls, sae)
value_counts = Counter(top_activations.cpu().numpy().flatten().tolist())

idx = list(value_counts.keys())[:30]
input_size = 224
for neuron_index in idx:
    print(f'feature {neuron_index}')
    optimized_input = sae_feature_vis(model, neuron_index, input_size, sae_layer=0, num_iterations=1000, lr=0.1, device=device, augment=augment, lambda_tv = 0.0005)
    plt.imshow(optimized_input[0].permute(1, 2, 0).cpu().numpy())
    plt.show()  

In [None]:
#for mid layers
model = ViTForImageClassification.from_pretrained("nateraw/vit-base-patch16-224-cifar10")
model.eval()
model.vit.encoder.layer[5] = Added_layer(model.vit.encoder.layer[5], encode, decode)

act_path ='activations_cifar10_vit_b'
target_layer = 5
cls = 0
top_activations = top_act(act_path, target_layer, cls, sae)
value_counts = Counter(top_activations.cpu().numpy().flatten().tolist())

idx = list(value_counts.keys())[:30]
input_size = 224
for neuron_index in idx:
    print(f'feature {neuron_index}')
    optimized_input = sae_feature_vis(model, neuron_index, input_size, sae_layer=5, num_iterations=1000, lr=0.1, device=device, augment=augment, lambda_tv = 0.0005)
    plt.imshow(optimized_input[0].permute(1, 2, 0).cpu().numpy())
    plt.show()  

# vit-dino

In [None]:
model = ViTForImageClassification.from_pretrained("facebook/dino-vitb16")
model.eval()

sae = TiedSAE(encoder, encoder_bias)
sae.load_state_dict(torch.load('SAE_models/cifar10_dino/SAE_ratio10_epoch100_lr0.0001.pth'))
sae.to_device(device)

encode = sae_encoder()
encode.encode.weight = nn.Parameter(sae.encoder.data, requires_grad=False)
encode.encode.bias = nn.Parameter(sae.encoder_bias.data, requires_grad=False)

decode = sae_decoder()
decode.decode.weight = nn.Parameter(sae.get_learned_dict().data.T, requires_grad=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
act_path ='activations_cifar10_dino_vitb16'
target_layer = 11
cls = 0
top_activations = top_act(act_path, target_layer, cls, sae)
value_counts = Counter(top_activations.cpu().numpy().flatten().tolist())

In [None]:
model.vit.encoder.layer[11] = Added_layer(model.vit.encoder.layer[11], encode, decode)


In [None]:

idx = list(value_counts.keys())[:30] #most activated feature for class 0 airplane
input_size = 224
for neuron_index in idx:
    print(f'feature {neuron_index}')
    optimized_input = sae_feature_vis(model, neuron_index, input_size, num_iterations=1000, lr=0.1, device=device, augment=augment, lambda_tv = 0.0005)
    plt.imshow(optimized_input[0].permute(1, 2, 0).cpu().numpy())
    plt.show()  

In [None]:
model = ViTForImageClassification.from_pretrained("facebook/dino-vitb16")
model.eval()
model.vit.encoder.layer[0] = Added_layer(model.vit.encoder.layer[0], encode, decode)

act_path ='activations_cifar10_dino_vitb16'
target_layer = 0
cls = 0
top_activations = top_act(act_path, target_layer, cls, sae)
value_counts = Counter(top_activations.cpu().numpy().flatten().tolist())

idx = list(value_counts.keys())[:30]
input_size = 224
for neuron_index in idx:
    print(f'feature {neuron_index}')
    optimized_input = sae_feature_vis(model, neuron_index, input_size, sae_layer=0, num_iterations=1000, lr=0.1, device=device, augment=augment, lambda_tv = 0.0005)
    plt.imshow(optimized_input[0].permute(1, 2, 0).cpu().numpy())
    plt.show()  

In [None]:
model = ViTForImageClassification.from_pretrained("facebook/dino-vitb16")
model.eval()
model.vit.encoder.layer[5] = Added_layer(model.vit.encoder.layer[5], encode, decode)

act_path ='activations_cifar10_dino_vitb16'
target_layer = 5

for cls in range(10):
    top_activations = top_act(act_path, target_layer, cls, sae)
    value_counts = Counter(top_activations.cpu().numpy().flatten().tolist())
    idx = list(value_counts.keys())[:5]
    input_size = 224
    for neuron_index in idx:
        print(f'feature {neuron_index}')
        optimized_input = sae_feature_vis(model, neuron_index, input_size, sae_layer=5, num_iterations=1000, lr=0.1, device=device, augment=augment, lambda_tv = 0.0005)
        plt.imshow(optimized_input[0].permute(1, 2, 0).cpu().numpy())
        plt.show()  

# vitmae

In [None]:
# most activated features for layer 0, class 0 airplane
model = ViTForImageClassification.from_pretrained("facebook/vit-mae-base")
model.eval()
model.vit.encoder.layer[0] = Added_layer(model.vit.encoder.layer[0], encode, decode)

act_path ='activations_cifar10_facebook_vitmae'
target_layer = 0
cls = 0
top_activations = top_act(act_path, target_layer, cls, sae)
value_counts = Counter(top_activations.cpu().numpy().flatten().tolist())

idx = list(value_counts.keys())[:30]
input_size = 224
for neuron_index in idx:
    print(f'feature {neuron_index}')
    optimized_input = sae_feature_vis(model, neuron_index, input_size, sae_layer=0, num_iterations=1000, lr=0.1, device=device, augment=augment, lambda_tv = 0.0005)
    plt.imshow(optimized_input[0].permute(1, 2, 0).cpu().numpy())
    plt.show()  

In [None]:
# most activated features for layer 5, class 0 airplane
model = ViTForImageClassification.from_pretrained("facebook/vit-mae-base")
model.eval()
model.vit.encoder.layer[5] = Added_layer(model.vit.encoder.layer[5], encode, decode)

act_path ='activations_cifar10_facebook_vitmae'
target_layer = 5
cls = 0
top_activations = top_act(act_path, target_layer, cls, sae)
value_counts = Counter(top_activations.cpu().numpy().flatten().tolist())

idx = list(value_counts.keys())[:30]
input_size = 224
for neuron_index in idx:
    print(f'feature {neuron_index}')
    optimized_input = sae_feature_vis(model, neuron_index, input_size, sae_layer=5, num_iterations=1000, lr=0.1, device=device, augment=augment, lambda_tv = 0.0005)
    plt.imshow(optimized_input[0].permute(1, 2, 0).cpu().numpy())
    plt.show()  

In [None]:
# most activated features for layer 11, class 0 airplane
model = ViTForImageClassification.from_pretrained("facebook/vit-mae-base")
model.eval()
model.vit.encoder.layer[11] = Added_layer(model.vit.encoder.layer[11], encode, decode)

act_path ='activations_cifar10_facebook_vitmae'
target_layer = 11
cls = 0
top_activations = top_act(act_path, target_layer, cls, sae)
value_counts = Counter(top_activations.cpu().numpy().flatten().tolist())

idx = list(value_counts.keys())[:30]
input_size = 224
for neuron_index in idx:
    print(f'feature {neuron_index}')
    optimized_input = sae_feature_vis(model, neuron_index, input_size, sae_layer=11, num_iterations=1000, lr=0.1, device=device, augment=augment, lambda_tv = 0.0005)
    plt.imshow(optimized_input[0].permute(1, 2, 0).cpu().numpy())
    plt.show()  