In [99]:
from torchvision import models, transforms
import torch
import numpy as np
from PIL import Image
import os

In [100]:
model = models.alexnet(pretrained = True)
# print(model)

In [101]:
input_size = 224
preprocess = transforms.Compose(
    [
        transforms.ToPILImage(),
        transforms.Resize([input_size]),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

In [102]:
activation = {}
def get_activ(name):
    def hook(model, input, output):
        activation[name] = output.detach()

    return hook


def layer_activation(model, layer_fun, layer_str, input, text=None):
    layer_fun.register_forward_hook(get_activ(layer_str))
    model.eval()
    output = model(input)
    return activation[layer_str]

In [120]:
def read_image(image_path, preprocess):

    Images = np.array(Image.open(image_path).convert("RGB"))
    
    return preprocess(Images).unsqueeze(axis = 0)

In [115]:
layers = [
    model.features[2],
    model.features[5],
    model.features[12],
    model.classifier[1],
    model.classifier[4],
    model.classifier[6],
]
layers_name = [str(i) for i in np.arange(len(layers))]

In [116]:
image_path = './trialscenes/trialscenes'
triplets = os.listdir(image_path)

one_triplet = os.path.join(image_path, triplets[2])
configs = os.listdir(one_triplet)

one_config = os.path.join(image_path, triplets[2], configs[0])
images = os.listdir(one_config)

one_image = os.path.join(image_path, triplets[2], configs[0], images[0])

In [117]:
Images = read_image(one_image, preprocess)

In [118]:
cnn_data = []
for layer, layer_name in zip(layers, layers_name):
    activs = (
        layer_activation(model, layer, layer_name, Images.float())
        .cpu()
        .detach()
        .numpy()
    )
    # Activations are flattened to retun a n*m matrix where n is the number of images and m is the overall number of features.
    flatten_activs = np.array([activ.flatten() for activ in activs])
    cnn_data.append(flatten_activs)

In [119]:
cnn_data[0].shape

(1, 62208)