In [2]:
import os
from pathlib import Path
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

if Path(os.getcwd()).name != "SSL4EO_base":
    os.chdir("..")

from main import METHODS
from data import constants
from data.constants import MMEARTH_DIR, input_size
from data.mmearth_dataset import MMEarthDataset, create_MMEearth_args, get_mmearth_dataloaders

  from .autonotebook import tqdm as notebook_tqdm


In [13]:
def load_model(method):
    model_path = f"/work/data/weights/{method}/50epochs.ckpt" # replace with your own path to data
    model_ckpt = torch.load(model_path, map_location="cpu") # no gpu required for running small samples
    
    hyperparameters=model_ckpt["hyper_parameters"]
    print(hyperparameters)
    # intialize model from checkpoint hyper parameters
    hparams = model_ckpt["hyper_parameters"]
    model = METHODS[method]["model"](
        backbone=hparams["backbone"], 
        batch_size_per_device=hparams["batch_size_per_device"], 
        in_channels=hparams["in_channels"], 
        num_classes=hparams["num_classes"], 
        has_online_classifier=hparams["has_online_classifier"], 
        last_backbone_channel=hparams["last_backbone_channel"], 
        train_transform=METHODS[method]["transform"]
    )
    model.load_state_dict(model_ckpt["state_dict"])
    return model 

def init_mmearth():
    modalities = constants.INP_MODALITIES
    split = "train"
    args = create_MMEearth_args(MMEARTH_DIR, modalities, {"biome": constants.MODALITIES_FULL["biome"]})
    dataset = MMEarthDataset(args, split=split)
    return dataset

def retrieve_representations():
    num_samples = 1000

    s2 = []
    biome = []
    for idx in range(num_samples):
        data = dataset[idx]
        s2.append(torch.from_numpy(data["sentinel2"]))
        biome.append(data["biome"])
    
    # stack to shape: b, c, h, w
    s2 = torch.stack(s2, 0)
    
    # concatentate to shape: b
    biome = np.stack(biome, 0)
    with torch.no_grad():
        model.eval()
        z = model(s2).flatten(start_dim=1)
    
    return z
    
    

In [12]:
model=load_model('barlowtwins')

dataset=init_mmearth()



{'backbone': 'default', 'batch_size_per_device': 512, 'in_channels': 12, 'num_classes': 14, 'has_online_classifier': True, 'last_backbone_channel': None, 'method': 'BarlowTwins'}
Using default backbone: resnet50
