In [24]:
from activation_visualization import effective_receptive_field, normalizeZeroOne, multiplot
from models.model_builder import load_model
import numpy as np
import torch
from matplotlib import pyplot as plt
import os
import fnmatch
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from tqdm import tqdm
from joblib import Parallel, delayed
from torchinfo import summary

from util import fit_gabor_filter, normalize

def find_files_in_folder(folder, partial_name):
    matching_files = []
    
    for root, dirs, files in os.walk(folder):
        for filename in fnmatch.filter(files, f'*{partial_name}*'):
            matching_files.append(os.path.join(root, filename))
    matching_files.sort()
    return matching_files

In [25]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using", device)

Using cuda


In [46]:
model_base_path = "../models/rf_development/lindsey32grey"
files = find_files_in_folder(model_base_path, "e**.pth")

In [47]:
epoch_effs=[]
for i, model_path in enumerate(files[:]):
    model = load_model(model_base_path, weights_file=model_path)
    model = model.get_sequential()
    model = model.to(device)
    eff_rfs = effective_receptive_field(model[:1], n_batch=1, fill_value=0.5, device=device)
    epoch_effs.append(eff_rfs)

100%|██████████| 32/32 [00:00<00:00, 718.63it/s]
100%|██████████| 32/32 [00:00<00:00, 1029.15it/s]
100%|██████████| 32/32 [00:00<00:00, 799.98it/s]
100%|██████████| 32/32 [00:00<00:00, 805.99it/s]
100%|██████████| 32/32 [00:00<00:00, 863.36it/s]
100%|██████████| 32/32 [00:00<00:00, 903.79it/s]
100%|██████████| 32/32 [00:00<00:00, 907.17it/s]
100%|██████████| 32/32 [00:00<00:00, 1015.65it/s]
100%|██████████| 32/32 [00:00<00:00, 951.19it/s]
100%|██████████| 32/32 [00:00<00:00, 1005.26it/s]
100%|██████████| 32/32 [00:00<00:00, 825.00it/s]
100%|██████████| 32/32 [00:00<00:00, 867.72it/s]
100%|██████████| 32/32 [00:00<00:00, 807.74it/s]


In [48]:
summary(model)

Layer (type:depth-idx)                   Param #
Sequential                               --
├─Conv2d: 1-1                            2,624
├─ReLU: 1-2                              --
├─Conv2d: 1-3                            82,976
├─ReLU: 1-4                              --
├─Conv2d: 1-5                            82,976
├─ReLU: 1-6                              --
├─Conv2d: 1-7                            82,976
├─ReLU: 1-8                              --
├─Flatten: 1-9                           --
├─Linear: 1-10                           33,555,456
├─ReLU: 1-11                             --
├─Linear: 1-12                           10,250
├─Softmax: 1-13                          --
Total params: 33,817,258
Trainable params: 33,817,258
Non-trainable params: 0

In [49]:
model

Sequential(
  (0): Conv2d(1, 32, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (1): ReLU(inplace=True)
  (2): Conv2d(32, 32, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (3): ReLU(inplace=True)
  (4): Conv2d(32, 32, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (5): ReLU(inplace=True)
  (6): Conv2d(32, 32, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (7): ReLU(inplace=True)
  (8): Flatten(start_dim=1, end_dim=-1)
  (9): Linear(in_features=32768, out_features=1024, bias=True)
  (10): ReLU(inplace=True)
  (11): Linear(in_features=1024, out_features=10, bias=True)
  (12): Softmax(dim=-1)
)

In [50]:
max_plots = 64*3
num_cols=16
max_epochs=50
num_rows = max(1,min(len(epoch_effs[0]),max_plots)//num_cols)
fig, axes = plt.subplots(num_rows, num_cols, figsize=(24,num_rows*2))
imshows = []
for i, (eff_rf, ax) in enumerate(zip(epoch_effs[-1][:max_plots], axes.flat)):
    if eff_rf.shape[0] == 3:
        eff_rf = eff_rf.swapaxes(0,2)
    else:
        eff_rf = eff_rf[0]
    imshows.append(ax.imshow(normalizeZeroOne(eff_rf), cmap="grey"))
    ax.set_title(str(i))
for ax in axes.flat:
    ax.axis('off')
title = fig.suptitle("Epoch 0", fontsize=16)

def update(frame):
    for img, eff_rf in zip(imshows,epoch_effs[frame]):
        if eff_rf.shape[0] == 3:
            eff_rf = eff_rf.swapaxes(0,2)
        else:
            eff_rf = eff_rf[0]
        img.set_data(normalizeZeroOne(eff_rf))
    title.set_text(f"Epoch {frame}")

# Create the animation
ani = FuncAnimation(fig, update, frames=min(len(epoch_effs), max_epochs), interval=200)
display(HTML(ani.to_jshtml()))
# with open("../imgs/rf_development/alexnet_sgd.html", "w") as f:
#     print(ani.to_jshtml(), file=f)
plt.close()

In [51]:
aaa

NameError: name 'aaa' is not defined

In [None]:
with open("../imgs/rf_development/alexnet_sgd.html", "w") as f:
    print(ani.to_jshtml(), file=f)

In [None]:
epoch_fits = []
fit_mses=np.zeros((len(epoch_effs), len(epoch_effs[0])))

for i_epoch, effs in enumerate(tqdm(epoch_effs)):
    result_generator = Parallel(n_jobs=48)(delayed(fit_gabor_filter)(normalize(eff_rf.numpy()), wavelength=None, maxiter=200) for eff_rf in effs)
    result = list(result_generator)
    fit_mses[i_epoch] = np.array([res[1] for res in result])
    epoch_fits.append([res[0] for res in result])

: 

In [None]:
max_plots = 32
num_rows = max(1,min(len(epoch_fits[0]),max_plots)//8)
fig, axes = plt.subplots(num_rows, 8, figsize=(12,num_rows*2))
imshows = []
cmap = plt.get_cmap('winter_r')
good_fit=np.quantile(fit_mses[~np.isnan(fit_mses)], 0.5)

for i, (eff_rf, ax, fit_mse) in enumerate(zip(epoch_fits[-1][:max_plots], axes.flat,fit_mses[-1])):
    imshows.append(ax.imshow(eff_rf, cmap="gray"))
    ax.set_title(str(i), color=cmap(fit_mse/good_fit))
    ax.axis('off')
title = fig.suptitle("Epoch 0", fontsize=16)

def update(frame):
    for img, eff_rf, ax, fit_mse in zip(imshows,epoch_fits[frame], axes.flat, fit_mses[frame]):
        img.set_data(eff_rf)
        ax.title.set_color(cmap(fit_mse/good_fit))
    title.set_text(f"Epoch {frame}")

# Create the animation
ani = FuncAnimation(fig, update, frames=len(epoch_fits), interval=200)
display(HTML(ani.to_jshtml()))
plt.close()

: 

In [None]:
gaborish_rfs = np.argsort(fit_mses[-1])#np.argwhere(fit_mses[-1]<good_fit).flatten()

num_rows = max(1,np.ceil(min(len(gaborish_rfs),max_plots)/8).astype(int))
fig, axes = plt.subplots(num_rows, 8, figsize=(12,num_rows*2))

for i, eff_rf, ax, fit_mse in zip(gaborish_rfs, np.array(epoch_effs[-1])[gaborish_rfs], axes.flat,fit_mses[-1, gaborish_rfs]):
    imshows.append(ax.imshow(eff_rf, cmap="gray"))
    ax.set_title(str(i)+': {:.1e}'.format(fit_mse), color=cmap(fit_mse/good_fit))
for ax in axes.flatten():
    ax.axis('off')

title = fig.suptitle("Filters sorted by their Gaborishness", fontsize=16)
plt.show()

: 

In [None]:
plt.plot(fit_mses[:])
plt.show()

: 