### First attempt to continue Presto pre-training on WorldCereal data

Most code taken from https://github.com/nasaharvest/presto/blob/main/train.py, but only retained core parts to be able to test how Presto can eat the WorldCereal data

In [3]:
import json
import logging
import os
import warnings
from pathlib import Path
from typing import List, Tuple, cast
import pandas as pd

import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdm

from presto import Presto
from presto.dataops import BANDS_GROUPS_IDX, MASK_STRATEGIES, MaskParams


from presto.model import LossWrapper, adjust_learning_rate, param_groups_weight_decay
from presto.utils import (
    DEFAULT_SEED,
    config_dir,
    device,
    initialize_logging,
    seed_everything,
    timestamp_dirname,
    update_data_dir,
)

logger = logging.getLogger("__main__")

In [15]:
model_name = 'presto_worldcereal'
seed = DEFAULT_SEED
seed_everything(seed)
output_parent_dir = Path('.')
run_id = None

logging_dir = output_parent_dir / "output" / timestamp_dirname(run_id)
logging_dir.mkdir(exist_ok=True, parents=True)
initialize_logging(logging_dir)
logger.info("Using output dir: %s" % logging_dir)

# Taken the defaults for now
num_epochs = 20
val_per_n_steps = 1000
dynamic_world_loss_weight = 2  # Set to 0 if we don't have DW?
max_learning_rate = 0.0001  # 0.001 is default, for finetuning max should be lower?
min_learning_rate = 0
warmup_epochs = 2
weight_decay = 0.05
batch_size = 1  # default 4096

# Default mask strategies and mask_ratio
mask_strategies = MASK_STRATEGIES
mask_ratio: float = 0.75

path_to_config = config_dir / "default.json"
model_kwargs = json.load(Path(path_to_config).open("r"))

25-10-2023 18:42:15 - INFO - Initialized logging to output/2023_10_25_18_42_15_093321/console-output.log
25-10-2023 18:42:15 - INFO - Initialized logging to output/2023_10_25_18_42_15_093321/console-output.log
25-10-2023 18:42:15 - INFO - Using output dir: output/2023_10_25_18_42_15_093321
25-10-2023 18:42:15 - INFO - Using output dir: output/2023_10_25_18_42_15_093321


### Load a (very) small test dataframe from WorldCereal

It's been reprocessed to match better Presto requirements, most notably:
- monthly instead of 10-day compositing
- outputing as 'features' the composited time series of all required bands instead of expert features

In [4]:
df = pd.read_parquet('worldcereal_testdf.parquet')
df.head()

Unnamed: 0,OPTICAL-B02-ts0-10m,OPTICAL-B02-ts1-10m,OPTICAL-B02-ts2-10m,OPTICAL-B02-ts3-10m,OPTICAL-B02-ts4-10m,OPTICAL-B02-ts5-10m,OPTICAL-B02-ts6-10m,OPTICAL-B02-ts7-10m,OPTICAL-B02-ts8-10m,OPTICAL-B02-ts9-10m,...,lat,lon,CT,OUTPUT,IRR,location_id,ref_id,start_date,end_date,aez_zoneid
0,468,0,581,534,670,381,329,458,653,1261,...,50.897915,2.668585,1110,11,0,0000280664EDE418,2018_BE_LPIS-Flanders,2017-10-27,2018-10-26,46172
1,741,0,267,354,567,275,249,425,712,958,...,50.897903,2.664141,1510,11,0,0000280664EDE418,2018_BE_LPIS-Flanders,2017-10-27,2018-10-26,46172
2,359,5651,339,436,609,289,389,509,663,1100,...,50.893848,2.665414,1110,11,0,0000280664EDE418,2018_BE_LPIS-Flanders,2017-10-27,2018-10-26,46172
3,510,7468,518,516,781,1158,1008,1236,540,456,...,50.892136,2.660255,9520,12,0,0000280664EDE418,2018_BE_LPIS-Flanders,2017-10-27,2018-10-26,46172
4,358,6283,389,453,719,1141,1141,758,587,484,...,50.892136,2.660255,9520,12,0,0000280664EDE418,2018_BE_LPIS-Flanders,2017-10-27,2018-10-26,46172


### Making WorldCereal data compatible for Presto

Uses a `WorldCerealDataset` pytorch dataset that upon requesting an item performs the required conversions to Presto inputs.
If `mask_params` is provided, inputs will get the shape of what Presto pretraining normally requires, including a generated `mask`.

In [16]:
from ewoc_presto import WorldCerealDataset
from torch.utils.data import DataLoader

logger.info("Setting up dataloaders")

# Load the mask parameters
mask_params = MaskParams(mask_strategies, mask_ratio)

# Create the WorldCereal dataset
ds = WorldCerealDataset(df, mask_params=mask_params)

# Create DataLoaders from the dataset. For now, without shame using same data for train and val
# we're just testing functionality ;-)
train_dataloader = DataLoader(ds, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(ds, batch_size=batch_size, shuffle=False)


25-10-2023 18:42:18 - INFO - Setting up dataloaders
25-10-2023 18:42:18 - INFO - Setting up dataloaders


Check what an item from this dataset looks like:

In [8]:
ds.__getitem__(0)

MaskedExample(mask_eo=array([[ True,  True,  True,  True,  True, False, False, False,  True,
         True, False, False,  True,  True,  True,  True, False],
       [ True,  True,  True,  True,  True, False, False, False,  True,
         True, False, False,  True,  True,  True,  True, False],
       [ True,  True,  True,  True,  True, False, False, False,  True,
         True, False, False,  True,  True,  True,  True, False],
       [ True,  True,  True,  True,  True, False, False, False,  True,
         True, False, False,  True,  True,  True,  True, False],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True, False],
       [ True,  True,  True,  True,  True, False, False, False,  True,
         True,  True,  True,  True,  True,  True,  True, False],
       [ True,  True,  True,  True,  True, False, False, False,  True,
         True, False, False,  True,  True,  True,  True, False],
       [ True,  True,  Tru

### Setup Presto model and load pretrained

In [17]:
logger.info("Setting up model")
model = Presto.load_pretrained()
model.to(device)

25-10-2023 18:42:24 - INFO - Setting up model
25-10-2023 18:42:24 - INFO - Setting up model


Presto(
  (encoder): Encoder(
    (eo_patch_embed): ModuleDict(
      (S1): Linear(in_features=2, out_features=128, bias=True)
      (S2_RGB): Linear(in_features=3, out_features=128, bias=True)
      (S2_Red_Edge): Linear(in_features=3, out_features=128, bias=True)
      (S2_NIR_10m): Linear(in_features=1, out_features=128, bias=True)
      (S2_NIR_20m): Linear(in_features=1, out_features=128, bias=True)
      (S2_SWIR): Linear(in_features=2, out_features=128, bias=True)
      (ERA5): Linear(in_features=2, out_features=128, bias=True)
      (SRTM): Linear(in_features=2, out_features=128, bias=True)
      (NDVI): Linear(in_features=1, out_features=128, bias=True)
    )
    (dw_embed): Embedding(10, 128)
    (latlon_embed): Linear(in_features=3, out_features=128, bias=True)
    (blocks): ModuleList(
      (0-1): 2 x Block(
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=128, out_features=384, bias=True)

In [18]:
# Model hyperparameters: keep unchanged for now
param_groups = param_groups_weight_decay(model, weight_decay)
optimizer = optim.AdamW(param_groups, lr=max_learning_rate, betas=(0.9, 0.95))
mse = LossWrapper(nn.MSELoss())
ce = LossWrapper(nn.CrossEntropyLoss())

training_config = {
    "model": model.__class__,
    "encoder": model.encoder.__class__,
    "decoder": model.decoder.__class__,
    "optimizer": optimizer.__class__.__name__,
    "eo_loss": mse.loss.__class__.__name__,
    "dynamic_world_loss": ce.loss.__class__.__name__,
    "device": device,
    "logging_dir": logging_dir,
    # **args,
    # **model_kwargs,
}

### Training loop

Copy from original code the relevant parts

In [19]:
lowest_validation_loss = None
best_val_epoch = 0
training_step = 0
num_validations = 0
dataloader_length = df.shape[0]

with tqdm(range(num_epochs), desc="Epoch") as tqdm_epoch:
    for epoch in tqdm_epoch:
        # ------------------------ Training ----------------------------------------
        total_train_loss = 0.0
        total_eo_train_loss = 0.0
        total_dw_train_loss = 0.0
        total_num_eo_values_masked = 0
        total_num_dw_values_masked = 0
        num_updates_being_captured = 0
        train_size = 0
        model.train()
        for epoch_step, b in enumerate(tqdm(train_dataloader, desc="Train", leave=False)):
            mask, x, y, start_month = b[0].to(device), b[2].to(device), b[3].to(device), b[6]
            dw_mask, x_dw, y_dw = b[1].to(device), b[4].to(device).long(), b[5].to(device).long()
            latlons = b[7].to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
            lr = adjust_learning_rate(
                optimizer,
                epoch_step / dataloader_length + epoch,
                warmup_epochs,
                num_epochs,
                max_learning_rate,
                min_learning_rate,
            )
            # Get model outputs and calculate loss
            y_pred, dw_pred = model(
                x, mask=mask, dynamic_world=x_dw, latlons=latlons, month=start_month
            )
            # set all SRTM timesteps except the first one to unmasked, so that
            # they will get ignored by the loss function even if the SRTM
            # value was masked
            mask[:, 1:, BANDS_GROUPS_IDX["SRTM"]] = False
            loss = mse(y_pred[mask], y[mask])
            dw_loss = ce(dw_pred[dw_mask], y_dw[dw_mask])
            num_eo_masked, num_dw_masked = len(y_pred[mask]), len(dw_pred[dw_mask])
            with torch.no_grad():
                ratio = num_dw_masked / max(num_eo_masked, 1)
                # weight shouldn't be > 1
                weight = min(1, dynamic_world_loss_weight * ratio)

            total_loss = loss + weight * dw_loss
            total_loss.backward()
            optimizer.step()

            current_batch_size = len(x)
            total_train_loss += total_loss.item()
            total_eo_train_loss += loss.item() * num_eo_masked
            total_dw_train_loss += dw_loss.item() * num_dw_masked
            total_num_eo_values_masked += num_eo_masked
            total_num_dw_values_masked += num_dw_masked
            num_updates_being_captured += 1
            train_size += current_batch_size
            training_step += 1

            # ------------------------ Validation --------------------------------------
            if training_step % val_per_n_steps == 0:
                total_val_loss = 0.0
                total_eo_val_loss = 0.0
                total_dw_val_loss = 0.0
                total_val_num_eo_values_masked = 0
                total_val_num_dw_values_masked = 0
                num_val_updates_captured = 0
                val_size = 0
                model.eval()
                with torch.no_grad():
                    for b in tqdm(val_dataloader, desc="Validate"):
                        mask, x, y, start_month = (
                            b[0].to(device),
                            b[2].to(device),
                            b[3].to(device),
                            b[6],
                        )
                        dw_mask, x_dw = b[1].to(device), b[4].to(device).long()
                        y_dw, latlons = b[5].to(device).long(), b[7].to(device)
                        # Get model outputs and calculate loss
                        y_pred, dw_pred = model(
                            x, mask=mask, dynamic_world=x_dw, latlons=latlons, month=start_month
                        )
                        # set all SRTM timesteps except the first one to unmasked, so that
                        # they will get ignored by the loss function even if the SRTM
                        # value was masked
                        mask[:, 1:, BANDS_GROUPS_IDX["SRTM"]] = False
                        loss = mse(y_pred[mask], y[mask])
                        dw_loss = ce(dw_pred[dw_mask], y_dw[dw_mask])
                        num_eo_masked, num_dw_masked = len(y_pred[mask]), len(dw_pred[dw_mask])
                        with torch.no_grad():
                            ratio = num_dw_masked / max(num_eo_masked, 1)
                            # weight shouldn't be > 1
                            weight = min(1, dynamic_world_loss_weight * ratio)
                        total_loss = loss + weight * dw_loss
                        current_batch_size = len(x)
                        val_size += current_batch_size
                        total_val_loss += total_loss.item()
                        total_eo_val_loss += loss.item() * num_eo_masked
                        total_dw_val_loss += dw_loss.item() * num_dw_masked
                        total_val_num_eo_values_masked += num_eo_masked
                        total_val_num_dw_values_masked += num_dw_masked
                        num_val_updates_captured += 1

                # ------------------------ Metrics + Logging -------------------------------
                # train_loss now reflects the value against which we calculate gradients
                train_loss = total_train_loss / num_updates_being_captured
                train_eo_loss = total_eo_train_loss / max(total_num_eo_values_masked, 1)
                train_dw_loss = total_dw_train_loss / max(total_num_dw_values_masked, 1)

                val_loss = total_val_loss / num_val_updates_captured
                val_eo_loss = total_eo_val_loss / max(total_val_num_eo_values_masked, 1)
                val_dw_loss = total_dw_val_loss / max(total_val_num_dw_values_masked, 1)

                if "train_size" not in training_config and "val_size" not in training_config:
                    training_config["train_size"] = train_size
                    training_config["val_size"] = val_size
                    if wandb_enabled:
                        wandb.config.update(training_config)

                to_log = {
                    "train_loss": train_loss,
                    "val_loss": val_loss,
                    "train_eo_loss": train_eo_loss,
                    "val_eo_loss": val_eo_loss,
                    "train_dynamic_world_loss": train_dw_loss,
                    "val_dynamic_world_loss": val_dw_loss,
                    "training_step": training_step,
                    "epoch": epoch,
                    "lr": lr,
                }
                tqdm_epoch.set_postfix(loss=val_loss)

                if lowest_validation_loss is None or val_loss < lowest_validation_loss:
                    lowest_validation_loss = val_loss
                    best_val_epoch = epoch

                    model_path = logging_dir / Path("models")
                    model_path.mkdir(exist_ok=True, parents=True)

                    best_model_path = model_path / f"{model_name}{epoch}.pt"
                    logger.info(f"Saving best model to: {best_model_path}")
                    torch.save(model.state_dict(), best_model_path)

                # reset training logging
                total_train_loss = 0.0
                total_eo_train_loss = 0.0
                total_dw_train_loss = 0.0
                total_num_eo_values_masked = 0
                total_num_dw_values_masked = 0
                num_updates_being_captured = 0
                train_size = 0
                num_validations += 1

                # if wandb_enabled:
                #     model.eval()
                #     for title, plot in plot_predictions(model):
                #         to_log[title] = plot
                #     wandb.log(to_log)
                #     plt.close("all")

                model.train()

logger.info(f"Done training, best model saved to {best_model_path}")

Epoch:   0%|          | 0/20 [00:00<?, ?it/s]
Train:   0%|          | 0/500 [00:00<?, ?it/s][A
Epoch:   0%|          | 0/20 [00:01<?, ?it/s] [A


NameError: name 'BANDS_GROUPS_IDX' is not defined