In [1]:
import os
from pathlib import Path
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
# loading python packages and files from repo root
if Path(os.getcwd()).name != "SSL4EO_base":
    os.chdir("..")

from main import METHODS
from data import constants
from data.constants import MMEARTH_DIR, input_size
from data.mmearth_dataset import MMEarthDataset, create_MMEearth_args, get_mmearth_dataloaders

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def get_pretrained_model(model_path, method, device):

    model_ckpt = torch.load(model_path, map_location=device) # no gpu required for running small samples

    # intialize model from checkpoint hyper parameters
    hparams = model_ckpt["hyper_parameters"]
    model = METHODS[method]["model"](
        backbone=hparams["backbone"], 
        batch_size_per_device=hparams["batch_size_per_device"], 
        in_channels=hparams["in_channels"], 
        num_classes=hparams["num_classes"], 
        has_online_classifier=hparams["has_online_classifier"], 
        last_backbone_channel=hparams["last_backbone_channel"], 
        train_transform=METHODS[method]["transform"]
    )

    # Load weights
    model.load_state_dict(model_ckpt["state_dict"])
    return model

False

In [6]:
method = "barlowtwins"
model_path = "/work/data/weights/barlowtwins/100epochs.ckpt" # replace with your own path to data
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = get_pretrained_model(model_path, method, device)

Using default backbone: resnet50


In [7]:
from torchinfo import summary
summary(model)

Layer (type:depth-idx)                                       Param #
BarlowTwins                                                  --
├─BackboneExpander: 1-1                                      --
│    └─FeatureListNet: 2-1                                   --
│    │    └─Conv2d: 3-1                                      37,632
│    │    └─BatchNorm2d: 3-2                                 128
│    │    └─ReLU: 3-3                                        --
│    │    └─MaxPool2d: 3-4                                   --
│    │    └─Sequential: 3-5                                  215,808
│    │    └─Sequential: 3-6                                  1,219,584
│    │    └─Sequential: 3-7                                  7,098,368
│    │    └─Sequential: 3-8                                  14,964,736
│    └─Identity: 2-2                                         --
├─AdaptiveAvgPool2d: 1-2                                     --
├─OnlineLinearClassifier: 1-3                                --
│  

In [8]:
print(model)

BarlowTwins(
  (backbone): FeatureListNet and BackboneExpander(
    (backbone): FeatureListNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act1): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (drop_block): Identity()
          (act2): ReLU(inplace=True)
          (aa): Identity()
          (conv3): Conv2d(64, 256

In [9]:
import torch.nn as nn

def get_layer_from_name(model, layer_name):
    layer = model
    for name in layer_name.split('.'):
        if name.isdigit():
            layer = layer[int(name)]
        else:
            layer = layer.__getattr__(name)
    return layer
    
class Model2Embeddings:

    def __init__(self, model, layers_to_save, flatten_output=True):

        self.model = model
        self.layers_to_save = layers_to_save
        self.flatten_output = flatten_output
        if flatten_output:
            self.global_pool = nn.AdaptiveAvgPool2d(1)

        self.d_embeddings = {}
        for layer_to_save in layers_to_save:
            self.register_forward_hook(model, layer_to_save)
        print('Forward hooks registered')

    def get_activation(self, name):
        def hook(model, input, output):
            if self.flatten_output:
                self.d_embeddings[name] = self.global_pool(output.detach().cpu()).squeeze().numpy()
            else:
                self.d_embeddings[name] = output.detach().cpu().squeeze().numpy()
        return hook

    def register_forward_hook(self, model, layer_name):
        layer = get_layer_from_name(model, layer_name)
        layer.register_forward_hook(self.get_activation(layer_name))

    def forward_pass(self, data):
        self.model(data)
        return self.d_embeddings
        


layers_to_save = [
    'backbone.backbone.layer1.2.act3',
    'backbone.backbone.layer2.3.act3',
    'backbone.backbone.layer3.5.act3',
    'backbone.backbone.layer4.0.act3',
    'backbone.backbone.layer4.1.act3',
    'backbone.backbone.layer4.2.act3',
]
model2emb = Model2Embeddings(model, layers_to_save, flatten_output=True)

d_embeddings = model2emb.forward_pass(torch.randn(1,12,224,224))
for name, output in d_embeddings.items():
    print(name, output.shape)

Forward hooks registered
backbone.backbone.layer1.2.act3 (256,)
backbone.backbone.layer2.3.act3 (512,)
backbone.backbone.layer3.5.act3 (1024,)
backbone.backbone.layer4.0.act3 (2048,)
backbone.backbone.layer4.1.act3 (2048,)
backbone.backbone.layer4.2.act3 (2048,)
