In [21]:
import os
from pathlib import Path
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

if Path(os.getcwd()).name != "SSL4EO_base":
    os.chdir("..")

from main import METHODS
from data import constants
from data.constants import MMEARTH_DIR, input_size
from data.mmearth_dataset import MMEarthDataset, create_MMEearth_args, get_mmearth_dataloaders

In [39]:
def load_model(method,verbose=False):
    model_path = f"/work/data/weights/{method}/50epochs.ckpt" # replace with your own path to data
    model_ckpt = torch.load(model_path, map_location="cpu") # no gpu required for running small samples
    
    hyperparameters=model_ckpt["hyper_parameters"]
    
    if verbose:
        print(hyperparameters)
    
    # intialize model from checkpoint hyper parameters
    hparams = model_ckpt["hyper_parameters"]
    model = METHODS[method]["model"](
        backbone=hparams["backbone"], 
        batch_size_per_device=hparams["batch_size_per_device"], 
        in_channels=hparams["in_channels"], 
        num_classes=hparams["num_classes"], 
        has_online_classifier=hparams["has_online_classifier"], 
        last_backbone_channel=hparams["last_backbone_channel"], 
        train_transform=METHODS[method]["transform"]
    )
    model.load_state_dict(model_ckpt["state_dict"])
    return model 

def init_mmearth(task):
    modalities = constants.INP_MODALITIES
    split = "train"
    args = create_MMEearth_args(MMEARTH_DIR, modalities, {f'{task}': constants.MODALITIES_FULL[task]})
    dataset = MMEarthDataset(args, split=split)
    return dataset

def retrieve_representations(num_samples, task):

    s2 = []
    labels = []
    for idx in range(num_samples):
        data = dataset[idx]
        s2.append(torch.from_numpy(data["sentinel2"]))
        labels.append(data[task])
    
    # stack to shape: b, c, h, w
    s2 = torch.stack(s2, 0)
    
    # concatentate to shape: b
    labels = np.stack(labels, 0)
    with torch.no_grad():
        model.eval()
        embeddings = model(s2).flatten(start_dim=1)
    
    return embeddings, labels
    

def linear_probe(embeddings, labels):
    
    X=list(embeddings)
    y=list(labels) 
    
    from sklearn.model_selection import train_test_split
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=1)  
    
    model = LogisticRegression()
    model.fit(X_train, y_train)
    
    y_pred = model.predict(X_test)
    
    #precision = precision_score(y_test, y_pred, average='weighted')  # Adjust 'average' parameter as needed
    #recall = recall_score(y_test, y_pred, average='weighted')
    f1 = f1_score(y_test, y_pred, average='weighted')
    
    return f1

In [40]:
task='biome'
method='barlowtwins'


model=load_model(method)

dataset=init_mmearth(task)

embeddings, labels=retrieve_representations(1000, task)

linear_probe(embeddings, labels)


{'backbone': 'default', 'batch_size_per_device': 512, 'in_channels': 12, 'num_classes': 14, 'has_online_classifier': True, 'last_backbone_channel': None, 'method': 'BarlowTwins'}
Using default backbone: resnet50


0.3632580373518014

In [None]:
modalities = constants.INP_MODALITIES
split = "train"
args = create_MMEearth_args(MMEARTH_DIR, modalities, {"biome": constants.MODALITIES_FULL["biome"]})
dataset = MMEarthDataset(args, split=split)

In [41]:
constants.MODALITIES_FULL

{'sentinel2_cloudmask': ['QA60'],
 'sentinel2_cloudprob': ['MSK_CLDPRB'],
 'sentinel2_scl': ['SCL'],
 'sentinel2': ['B1',
  'B2',
  'B3',
  'B4',
  'B5',
  'B6',
  'B7',
  'B8A',
  'B8',
  'B9',
  'B10',
  'B11',
  'B12'],
 'sentinel1': ['asc_VV',
  'asc_VH',
  'asc_HH',
  'asc_HV',
  'desc_VV',
  'desc_VH',
  'desc_HH',
  'desc_HV'],
 'aster': ['elevation', 'slope'],
 'canopy_height_eth': ['height', 'std'],
 'lat': ['sin', 'cos'],
 'lon': ['sin', 'cos'],
 'month': ['sin_month', 'cos_month'],
 'era5': ['prev_month_avg_temp',
  'prev_month_min_temp',
  'prev_month_max_temp',
  'prev_month_total_precip',
  'curr_month_avg_temp',
  'curr_month_min_temp',
  'curr_month_max_temp',
  'curr_month_total_precip',
  'year_avg_temp',
  'year_min_temp',
  'year_max_temp',
  'year_total_precip'],
 'esa_worldcover': ['map'],
 'dynamic_world': ['landcover'],
 'biome': ['biome'],
 'eco_region': ['eco_region']}