In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from visionlab.datasets import StreamingDataset
from visionlab.project_demo.utils.feature_extractor_recurrent import FeatureExtractorRecurrent, get_layer_names
from litdata import StreamingDataLoader
from torchvision import models, transforms

In [None]:
model = models.alexnet(weights='DEFAULT')
model

In [None]:
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [None]:
s3_path = "s3://visionlab-datasets/exploring-objects-images"
dataset = StreamingDataset(s3_path, decode_images=True, to_pil=True, pipelines=dict(image=transform))
dataset

In [None]:
sample = dataset[0]
sample['image']

In [None]:
dataloader = StreamingDataLoader(dataset, batch_size=72)
dataloader

In [None]:
batch = next(iter(dataloader))
batch.keys()

In [None]:
batch['image'].shape

In [None]:
layer_names = get_layer_names(model)
layer_names = layer_names[4:]
layer_names

In [None]:
import torch
from tqdm import tqdm
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()
with torch.no_grad():
    images = batch['image'].to(device)
    with FeatureExtractorRecurrent(model, layer_names) as extractor:
        feats = extractor(images)

In [None]:
feats.keys()

In [None]:
for k,v in feats.items():
    print(k, v[0].shape)

In [None]:
import numpy as np

RDMS = {}
for layer_name, act in feats.items():
    act = act[0].flatten(1)
    rdm = np.corrcoef(act)
    RDMS[layer_name] = rdm

In [None]:
import seaborn as sns

sns.heatmap(RDMS['features.4'])

In [None]:
sns.heatmap(RDMS['classifier.2'])

In [None]:
act = feats['features.4'][0].flatten(1)
act.shape

In [None]:
threshold = 1e-6
mask = (act.abs() > threshold).any(dim=0)

# Apply the mask
act_filtered = act[:, mask]
rdm = np.corrcoef(act_filtered)

In [None]:
sns.heatmap(rdm)