In [1]:
### Prepare your image data

from torch.utils.data import TensorDataset
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
from torch.utils.data import Dataset
import urllib.request
from urllib.error import HTTPError

# Generate a random set of images (TODO: replace with your images)
num_images = 600
image_height, image_width = 224, 224
random_grayscale_images = torch.rand((num_images, 1, image_height, image_width), dtype=torch.float32)
random_rgb_images = random_grayscale_images.repeat(1, 3, 1, 1)  # Shape: (N, 3, H, W)
random_labels = torch.arange(0, num_images)
image_dataset = TensorDataset(random_rgb_images, random_labels) # Shape: torch.Size([600, 3, 224, 224])

In [2]:
### Run images through a pretrained SimCLR model and extract features
class SimCLR(nn.Module):
    def __init__(self, hidden_dim=128):
        super().__init__()

        # Base ResNet18 backbone (pretrained=False, because we load custom weights later, from the SimCLR checkpoint file)
        self.convnet = torchvision.models.resnet18(pretrained=False)
        
        # This is the projection head, only needed during training. For downstream tasks it is disposed of
        # and the final linear layer output is used (Chen et al., 2020) 
        self.convnet.fc = nn.Sequential(
            nn.Linear(self.convnet.fc.in_features, 4 * hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(4 * hidden_dim, hidden_dim)
        )

        self.activations = {}
        self.num_workers = os.cpu_count()
        self.device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
        self._register_hooks()

    def _register_hooks(self):
        """
        Register forward hooks to capture activations of convolutional layers.
        """
        def hook_fn(layer_name):
            def hook(module, input, output):
                if layer_name not in self.activations:
                    self.activations[layer_name] = []
                self.activations[layer_name].append(output.detach().cpu())
            return hook

        # Register hooks for all convolutional layers
        for name, layer in self.convnet.named_modules():
            if isinstance(layer, nn.Conv2d):
                layer.register_forward_hook(hook_fn(name))

    def load_pretrained(self):
        """
        Load pretrained SimCLR weights
        """
        base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial17/"
        models_dir = "../../models"
        pretrained_simclr_filename = "SimCLR.ckpt"
        pretrained_simclr_path = os.path.join(models_dir, pretrained_simclr_filename)
        os.makedirs(models_dir, exist_ok=True)

        # Check whether the pretrained model file already exists locally. If not, try downloading it
        file_url = base_url + pretrained_simclr_filename
        if not os.path.isfile(pretrained_simclr_path):
            print(f"Downloading pretrained SimCLR model {file_url}...")
            try:
                urllib.request.urlretrieve(file_url, pretrained_simclr_path)
            except HTTPError as e:
                print("Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n", e)

        print(f"Already downloaded pretrained model: {file_url}")

        # Load pretrained model
        checkpoint = torch.load(pretrained_simclr_path, map_location=self.device)
        self.load_state_dict(checkpoint['state_dict'])
        self.to(self.device)
        self.eval()
    
    def forward(self, x):
        """
        Forward pass of the model.
        """
        return self.convnet(x)

sim_clr = SimCLR()
sim_clr.load_pretrained()

data_loader = DataLoader(image_dataset, batch_size=64, shuffle=False, drop_last=False, num_workers=sim_clr.num_workers)
sim_clr.activations = {}

# Process images in batches
with torch.no_grad():
    for imgs, _ in data_loader:
        imgs = imgs.to(sim_clr.device)
        sim_clr(imgs) # Forward pass (hooks will collect activations)

# Convert list of activations to tensors
for layer_name in sim_clr.activations:
    sim_clr.activations[layer_name] = torch.cat(sim_clr.activations[layer_name], dim=0)

# Print stored activations
for layer_name, activation in sim_clr.activations.items():
    print(f"Layer: {layer_name}, Activation Shape: {activation.shape}")

  checkpoint = torch.load(pretrained_simclr_path, map_location=self.device)


Already downloaded pretrained model: https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial17/SimCLR.ckpt
Layer: conv1, Activation Shape: torch.Size([600, 64, 112, 112])
Layer: layer1.0.conv1, Activation Shape: torch.Size([600, 64, 56, 56])
Layer: layer1.0.conv2, Activation Shape: torch.Size([600, 64, 56, 56])
Layer: layer1.1.conv1, Activation Shape: torch.Size([600, 64, 56, 56])
Layer: layer1.1.conv2, Activation Shape: torch.Size([600, 64, 56, 56])
Layer: layer2.0.conv1, Activation Shape: torch.Size([600, 128, 28, 28])
Layer: layer2.0.conv2, Activation Shape: torch.Size([600, 128, 28, 28])
Layer: layer2.0.downsample.0, Activation Shape: torch.Size([600, 128, 28, 28])
Layer: layer2.1.conv1, Activation Shape: torch.Size([600, 128, 28, 28])
Layer: layer2.1.conv2, Activation Shape: torch.Size([600, 128, 28, 28])
Layer: layer3.0.conv1, Activation Shape: torch.Size([600, 256, 14, 14])
Layer: layer3.0.conv2, Activation Shape: torch.Size([600, 256, 14, 14])
Layer: layer3.0.downs