In [1]:
import os
from pathlib import Path
import sys
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch.utils.data as tdata
from tqdm.auto import tqdm

# loading python packages and files from repo root
if Path(os.getcwd()).name != "SSL4EO_base":
    os.chdir("..")

from main import METHODS
from data import constants, GeobenchDataset
from data.constants import MMEARTH_DIR, input_size
from data.mmearth_dataset import MMEarthDataset, create_MMEearth_args, get_mmearth_dataloaders

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_pretrained_model(model_path, method, device):

    model_ckpt = torch.load(model_path, map_location=device) # no gpu required for running small samples

    # intialize model from checkpoint hyper parameters
    hparams = model_ckpt["hyper_parameters"]
    model = METHODS[method]["model"](
        backbone=hparams["backbone"], 
        batch_size_per_device=hparams["batch_size_per_device"], 
        in_channels=hparams["in_channels"], 
        num_classes=hparams["num_classes"], 
        has_online_classifier=hparams["has_online_classifier"], 
        last_backbone_channel=hparams["last_backbone_channel"], 
        train_transform=METHODS[method]["transform"]
    )

    # Load weights
    model.load_state_dict(model_ckpt["state_dict"])
    return model



def get_layer_from_name(model, layer_name):
    layer = model
    for name in layer_name.split('.'):
        if name.isdigit():
            layer = layer[int(name)]
        else:
            layer = layer.__getattr__(name)
    return layer
    
class Model2Embeddings:

    def __init__(self, model, layers_to_save, flatten_output=True):

        self.model = model
        self.layers_to_save = layers_to_save
        self.flatten_output = flatten_output
        if flatten_output:
            self.global_pool = nn.AdaptiveAvgPool2d(1)

        self.d_embeddings = {}
        for layer_to_save in layers_to_save:
            self.register_forward_hook(model, layer_to_save)
        print('Forward hooks registered')

    def get_activation(self, name):
        def hook(model, input, output):
            if self.flatten_output:
                self.d_embeddings[name] = self.global_pool(output.detach().cpu()).squeeze().numpy()
            else:
                self.d_embeddings[name] = output.detach().cpu().squeeze().numpy()
        return hook

    def register_forward_hook(self, model, layer_name):
        layer = get_layer_from_name(model, layer_name)
        layer.register_forward_hook(self.get_activation(layer_name))

    def forward_pass(self, data):
        self.model(data)
        return self.d_embeddings

In [3]:
def create_embeddings_geobench(model_path, method, dataset_name, layers_to_save, base_output_dir=None):

    print(f'Creating embeddings for {len(layers_to_save)} layers with {method} and {dataset_name}.')
          
    # Load pre-trained model
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    model = get_pretrained_model(model_path, method, device)

    # Create class to create the embeddings for each layer
    model2emb = Model2Embeddings(model, layers_to_save, flatten_output=True)

    for split in ['test', 'train']:

        output_dir = base_output_dir / method / dataset_name / split
        
        # Load dataset from GeoBench
        dataset = GeobenchDataset(dataset_name=dataset_name, split=split, transform=None)

        # Dataloader
        batch_size = 10
        dl = tdata.DataLoader(dataset, batch_size=batch_size)

        # loop over dataset and save embeddings with labels
        for i, batch in tqdm(enumerate(dl), total=len(dl)):
            data = batch[0]
            labels = batch[1]

            # Inference
            d_embeddings = model2emb.forward_pass(data)

            for j in range(batch_size):
                idx = i*batch_size + j
                for layer_name, embeddings in d_embeddings.items():
                    embedding_layer = embeddings[j]
                    fp_embeddings = output_dir / 'embeddings' / layer_name / f'{idx}.npy'
                    fp_embeddings.parent.mkdir(parents=True, exist_ok=True)
                    np.save(fp_embeddings, embedding_layer)

                label = labels[j]
                fp_labels = output_dir / 'labels' / f'{idx}.npy'
                fp_labels.parent.mkdir(parents=True, exist_ok=True)
                np.save(fp_labels, label.detach().cpu().numpy())

        print(f'All embeddings saved for {split}')

In [4]:
method = "barlowtwins"
model_path = "/work/data/weights/barlowtwins/100epochs.ckpt"
dataset_name = "m-bigearthnet"
layers_to_save = [
    'backbone.backbone.layer1.2.act3',
    'backbone.backbone.layer2.3.act3',
    'backbone.backbone.layer3.5.act3',
    'backbone.backbone.layer4.0.act3',
    'backbone.backbone.layer4.1.act3',
    'backbone.backbone.layer4.2.act3',
]
base_output_dir = Path('/work/groupdrive/embeddings')
d_embeddings = create_embeddings_geobench(model_path, method, dataset_name, layers_to_save, base_output_dir)

Creating embeddings for 6 layers with barlowtwins and m-bigearthnet.
Using default backbone: resnet50
Forward hooks registered


100%|██████████| 100/100 [02:05<00:00,  1.26s/it]


All embeddings saved for test


100%|██████████| 2000/2000 [40:52<00:00,  1.23s/it]

All embeddings saved for train



