In [17]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageNet
import lpips
import einops
from lucent.optvis import render, param, transform, objectives
from lucent.modelzoo import inceptionv1
import scipy.stats
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt

### Initial steps in automating the identification of polysemantic neuons in image models

Galen Pogoncheff

___
#### Feature Visualizations and Dataset Example Analysis for Neurons in InceptionV1

Create dataloaders for ImageNet validation dataset and load pre-trained InceptionV1 model.

In [11]:
batch_size = 128
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

img_ch_means = [0.485, 0.456, 0.406]
img_ch_stds = [0.229, 0.224, 0.225]

validation_data_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=img_ch_means,
        std=img_ch_stds
    )
])
validation_data = ImageNet('./data/imagenet', split='val', transform=validation_data_transform)
validation_dataloader = DataLoader(validation_data, batch_size=batch_size, shuffle=False, num_workers=12, pin_memory=True)

model = inceptionv1(pretrained=True)
_ = model.to(device).eval()

Randomly sample 128 neuron channels from the 832 channel Mixed4e layer on InceptionV1.  For each image in the ImageNet validation dataset, compute intermediate activations of each sample channel at the Mixed4e layer.  Finally, for each neuron channel, get 25 dataset images that maximized average activation across all neurons in the channel (sorted in descending order of activation).

In [9]:
activations = {}
def get_activation(name):
    # Hook for recording intermediate layer activations
    def hook(model, input, output):
        if name not in activations:
            activations[name] = []
        activations[name].append(output[:,sampled_neurons,:,:].detach().clone().cpu().numpy())
    return hook

n_neurons = 128
sampled_neurons = np.random.permutation(np.arange(832))[:n_neurons]
activation_hook = model.mixed4e.register_forward_hook(get_activation('4e'))
for input, target in tqdm(validation_dataloader):
    input = input.to(device)
    target = target.to(device)
    _ = model(input)

activation_hook.remove()

activations['4e'] = np.concatenate(activations['4e'])
mean_ch_activations = einops.reduce(activations['4e'], 'n c h w -> n c', 'mean')

n_dataset_examples = 25
top_n_img_inds = np.empty((n_neurons, n_dataset_examples), dtype=int)
for i in range(n_neurons):
    neuron_activations = mean_ch_activations[:, i]
    top_n_img_inds[i] = np.argsort(neuron_activations)[-n_dataset_examples:][::-1]

100%|██████████| 391/391 [00:30<00:00, 12.79it/s]


For each sampled neuron channel, save dataset examples and feature visualization 

In [15]:
def plot_img_grid(dataset, img_idxs, img_ch_means=None, img_ch_stds=None, n_rows=5, n_cols=5, fname=None):
    '''
    Plots a grid of images from a given dataset.

    Input:
        dataset: torch.utils.data.Dataset of (image, label) pairs
        img_means: normalization means to unnormalize the images
        img_stds: normalization stds to unnormalize the images
        n_rows: number of rows in the grid
        n_cols: number of columns in the grid
        fname: if not None, saves the figure to the given path

    Output:
        None
    '''
    img_idxs = img_idxs[:n_rows*n_cols]
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols, n_rows))
    for img_idx, ax in zip(img_idxs, axes.ravel()):
        img, label = dataset[img_idx]
        img = einops.rearrange(img, 'c h w -> h w c')
        img = (img * torch.tensor(img_ch_stds).view(1, 1, 3)) + torch.tensor(img_ch_means).view(1, 1, 3)
        ax.imshow(img.numpy())
        ax.set_title(label)
        ax.axis('off')
    plt.tight_layout()
    if fname is not None:
        plt.savefig(fname, transparent=True)
        plt.close(fig)
    else:
        plt.show()

for i, neuron in tqdm(enumerate(sampled_neurons)):
    plot_img_grid(validation_data, top_n_img_inds[i], img_ch_means, img_ch_stds, 5, 5, f'./test/{neuron}.png')

0it [00:00, ?it/s]


For each sampled neuron channel, compute and save 5 diverse feature visualizations using Lucent

In [None]:
for neuron in sampled_neurons:
    batch_param_f = lambda: param.image(128, batch=5)
    obj = objectives.channel("mixed4e", neuron) - 1e2 * objectives.diversity("mixed4e")
    output = render.render_vis(model, obj, batch_param_f, show_inline=False, save_image=True, image_name=f'./figures/{neuron}_featurevis.png')
    try:
        np.save(f'./data/feature_vis/{neuron}.npy', torch.tensor(output).numpy())
    except:
        print(f'Failed to save data for neuron {neuron}')

Manual labelings of sampled neuron channels post-manual analysis of dataset examples and feature visualizations.

In [None]:
# Manually labelings of sampled neuron channels
monos_neurons = [296, 28, 176, 254, 547, 780, 392, 337, 523, 636, \
                 399, 695, 477, 80, 705, 280, 657, 513, 662, 430, \
                 743, 689, 30, 68, 457, 471, 327, 453, 44, 817, 171, \
                 691, 286, 213, 504, 297, 420, 727, 425, 626, 21, 139, \
                 746, 522, 236, 273, 794, 749, 451, 713, 692, 113, 789, \
                 363, 548, 14, 552, 791, 445, 715, 157, 386, 556, 771, \
                 19, 805, 336, 728, 700, 539, 305, 204, 54, 595, 690, \
                 535, 777, 65, 117, 6, 604, 533, 172]

polys_neurons = [623, 132, 55, 452, 591, 608, 24, 129, 289, 459, 159, 274, \
                 134, 501, 206, 648, 682, 578, 405, 460, 228, 415, 606, 308, \
                 119, 570, 614, 644, 584, 156]

Compute pairwise perceptual image similarity losses between among dataset examples for each neuron.

In [16]:
neuron_dissimilarities = torch.zeros((n_neurons, int((n_dataset_examples*(n_dataset_examples-1))/2)))

lpips_loss = lpips.LPIPS(net='alex').to(device)

for neuron_i in tqdm(range(len(sampled_neurons))):
    img_inds = top_n_img_inds[neuron_i]
    imgs = torch.stack([validation_data[ind][0] for ind in img_inds])
    perceptual_dissimilarity = torch.zeros((n_dataset_examples, n_dataset_examples))
    for i in range(perceptual_dissimilarity.shape[0]):
        src_img = einops.repeat(imgs[i], 'c h w -> n c h w', n=n_dataset_examples)
        src_img = src_img.to(device)
        cmp_imgs = imgs.to(device)
        perceptual_dissimilarity[i, :] = lpips_loss(src_img, cmp_imgs).view(-1).detach().clone().cpu()
        perceptual_dissimilarity[:, i] = perceptual_dissimilarity[i, :]
    mask = torch.triu(torch.ones_like(perceptual_dissimilarity), diagonal=1)
    neuron_dissimilarities[neuron_i] = perceptual_dissimilarity[mask==1]

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /home/gpogoncheff/anaconda3/envs/ps_neurons/lib/python3.10/site-packages/lpips/weights/v0.1/alex.pth


100%|██████████| 128/128 [00:36<00:00,  3.53it/s]


Perform Kruskal-Wallis test to compare pairwise dataset example LPIPS distributions and plot KDE plots for manually labeled monosemantic and polysemantic neurons.

In [None]:
monos_neuron_idxs = np.array([np.where(sampled_neurons == neuron)[0][0] for neuron in monos_neurons])
polys_neuron_idxs = np.array([np.where(sampled_neurons == neuron)[0][0] for neuron in polys_neurons])

mean_dissimilarities = torch.mean(neuron_dissimilarities, dim=1).numpy()


print('Kruskal-Wallis Test for Comparison of Distribution Medians')
print(scipy.stats.kstest(mean_dissimilarities[monos_neuron_idxs], mean_dissimilarities[polys_neuron_idxs]))

# Mean pairwise dataset examples LPIPS KDEs
fig, ax = plt.subplots(figsize=(5, 5))
sns.kdeplot([mean_dissimilarities[monos_neuron_idxs], mean_dissimilarities[polys_neuron_idxs]], shade=True, common_norm=False, ax=ax)
ax.set_xlabel('Mean LPIPS Loss', size=12)
ax.set_ylabel('Density', size=12)
plt.legend(loc='upper right', labels=['Polysemantic', 'Monosemantic'], fontsize=11)
ax.set_title('Perceptual Differences Among Dataset Examples', fontsize=13)
plt.savefig('/figures/lpips_means_dataset_examples.png')
plt.show()

Perform Kruskal-Wallis test to compare pairwise feature visualization LPIPS distributions and plot KDE plots for manually labeled monosemantic and polysemantic neurons.

In [21]:
lpips_img_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

n_fv_examples = 5

neuron_fv_dissimilarities = torch.zeros((n_neurons, int((n_fv_examples*(n_fv_examples-1))/2)))

lpips_loss = lpips.LPIPS(net='alex').to(device)

for neuron_i, neuron in tqdm(enumerate(sampled_neurons)):
    fvs = np.load(f'./data/feature_vis/{neuron}.npy')[0]
    fvs = einops.rearrange(fvs, 'n h w c -> n c h w')
    fv_imgs = lpips_img_transform(torch.Tensor(fvs))
    perceptual_dissimilarity = torch.zeros((n_fv_examples, n_fv_examples))
    for i in range(perceptual_dissimilarity.shape[0]):
        src_img = einops.repeat(fv_imgs[i], 'c h w -> n c h w', n=n_fv_examples)
        src_img = src_img.to(device)
        cmp_imgs = fv_imgs.to(device)
        perceptual_dissimilarity[i, :] = lpips_loss(src_img, cmp_imgs).view(-1).detach().clone().cpu()
        perceptual_dissimilarity[:, i] = perceptual_dissimilarity[i, :]
    mask = torch.triu(torch.ones_like(perceptual_dissimilarity), diagonal=1)
    neuron_fv_dissimilarities[neuron_i] = perceptual_dissimilarity[mask==1]

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /home/gpogoncheff/anaconda3/envs/ps_neurons/lib/python3.10/site-packages/lpips/weights/v0.1/alex.pth


256it [00:04, 59.11it/s]


In [None]:
mean_neuron_fv_dissimilarities = torch.mean(neuron_fv_dissimilarities, dim=1).numpy()

print('Kruskal-Wallis Test for Comparison of Distribution Medians')
print(scipy.stats.kstest(mean_neuron_fv_dissimilarities[monos_neuron_idxs], mean_neuron_fv_dissimilarities[polys_neuron_idxs]))

fig, ax = plt.subplots(figsize=(5, 5))
sns.kdeplot([mean_neuron_fv_dissimilarities[monos_neuron_idxs], mean_neuron_fv_dissimilarities[polys_neuron_idxs]], shade=True, common_norm=False, ax=ax)
ax.set_xlabel('Mean LPIPS Loss', size=12)
ax.set_ylabel('Density', size=12)
plt.legend(loc='upper left', labels=['Polysemantic', 'Monosemantic'], fontsize=11)
ax.set_title('Perceptual Differences Among Feature Visualizations', fontsize=13)
plt.savefig('./figures/lpips_means_fvs.png')
plt.show()

___
#### Macaque IT Ridge Regression

In [18]:
from macaque_neural.it_data import MajajHong2015Dataset
from macaque_neural.train_convex import compute_layer

Create dataloaders for Majajhong2015 dataset

In [22]:
it_img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

img_fpaths_fpath = './data/majajhong2015_data/img_paths.npy'
response_fpath = './data/majajhong2015_data/neuron_responses.npy'

it_dataset = MajajHong2015Dataset(img_fpaths_fpath, response_fpath, img_transform=it_img_transform)
it_train_set, it_val_set = torch.utils.data.random_split(it_dataset, [0.7, 0.3])

it_train_dataloader = DataLoader(it_train_set, batch_size=32, shuffle=True)
it_val_dataloader = DataLoader(it_val_set, batch_size=32, shuffle=False)

Get pairs of training data of the form (Mixed4e activations, biological neuron firing rates) for each image stimulus

In [23]:
artificial_activations = {}
def get_it_activation(name):
    def hook(model, input, output):
        artificial_activations[name] = output[:,sampled_neurons,:,:].detach().clone().cpu()
    return hook

it_hook = model.mixed4e.register_forward_hook(get_it_activation('4e'))

X, Y, X_report, Y_report = [], [], [], []
for img, response in it_train_dataloader:
    img = img.to(device)
    _ = model(img)
    X.append(artificial_activations['4e'])
    Y.append(response)

for img, response in it_val_dataloader:
    img = img.to(device)
    _ = model(img)
    X_report.append(artificial_activations['4e'])
    Y_report.append(response)

it_hook.remove()

X = torch.concat(X)
Y = torch.concat(Y)
X_report = torch.concat(X_report)
Y_report = torch.concat(Y_report)

Learn linear mapping from artificial neuronal activations to biological neuron firing rates (using [code derived from Patrick Mineault's YHIT repository](https://github.com/patrickmineault/your-head-is-there-to-move-you-around))

In [24]:
regression_results, weights = compute_layer(X.view(X.size(0), -1), Y, X_report.view(X_report.size(0), -1), Y_report, pca=-1, method='ridge', device='cuda:0')
print(regression_results['corrs_report_mean'])

tensor([0.6943, 0.7952, 0.9137, 0.9415, 0.9494, 0.9571, 0.9621, 0.9615, 0.9571,
        0.9486], device='cuda:0')
0.55835795


Analyze weights of monosemantic and polysemantic neurons

In [None]:
W = weights['W']

# Each artificial neuron has a weight for each biological neuron, comput maximum weight for each artificial neuron
max_weights = np.max(np.abs(W), axis=1)

mixed4e_h, mixed4e_w = 14, 14

monos_neuron_weights = max_weights[:len(monos_neurons)*mixed4e_h*mixed4e_w]
polys_neuron_weights = max_weights[len(monos_neurons)*mixed4e_h*mixed4e_w:]

print('Kuskall-Wallis Test for Comparison of Distribution Medians of Neuron Weight Magnitudes')
print(scipy.stats.kruskal(monos_neuron_weights, polys_neuron_weights))

# Weight histgrams
fig, ax = plt.subplots(figsize=(6,4))
ax.hist(monos_neuron_weights, density=True, alpha=0.5, label='Monosemantic', bins=20)
ax.hist(polys_neuron_weights, density=True, alpha=0.5, label='Polysemantic', bins=20)
ax.set_ylabel('Density')
ax.set_xlabel('Magnitude of Neuron Weight')
ax.set_title('Ridge Regression Weights')
plt.legend(loc='upper right')
plt.savefig('./figures/weights.png')
plt.show()