# Using Presto for downstream tasks

The purpose of this notebook is to demonstrate how Presto (and utility functions in the Presto package) can be used for downstream tasks.

To demonsrate the usefulness of Presto even when the input looks very different then what Presto was pre-trained on, we will consider tree-type mapping using single-timestep images.

To do this, we will use the [TreeSat](https://essd.copernicus.org/articles/15/681/2023/) benchmark dataset. This tutorial requires the S2 data to be downloaded from [Zenodo](https://zenodo.org/record/6780578) and unzipped in the [treesat folder](data/treesat).

In [None]:
import xarray
from pyproj import Transformer
import numpy as np
from scipy import stats
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score

from tqdm import tqdm

import torch
from torch.utils.data import DataLoader, TensorDataset

# import presto

# this is to silence the xarray deprecation warning.
# Our version of xarray is pinned, but we'll need to fix this
# when we upgrade
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 

In [3]:
import Path
basepath = Path("/Net/Groups/BGI/work_1/scratch/DeepCube/earthnet2023_preprocessing//")
train_files = list(basepath.glob("train/*/*.nc"))
test_files = list(basepath.glob("test/*/*.nc")) #list(basepath.glob("test/*/*.nc"))
print("len train set: ", len(train_files))
print("len test set: ", len(test_files))

ModuleNotFoundError: No module named 'Path'

If the TreeSat data has been correctly downloaded from Zenodo (see the Markdown cell above), these assert statements should pass.

For simplicity, we will only consider classification between 2 tree species: Abies_alba and Acer_pseudoplatanus.

The TreeSatAI S2 data contains the following bands: ["B2", "B3", "B4", "B8", "B5", "B6", "B7", "B8A", "B11", "B12", "B1", "B9"]

### 1. Processing the data

The TreeSatAI data is stored in `.tif` files. We will extract 9 pixels (from the 36 total pixels in each tif file) to construct our input data.

We use `presto.construct_single_presto_input` to transform the TreeSat S2 data into the tensors expected by Presto.

In [5]:
def process_images(filenames):
    arrays, masks, latlons, image_names, labels, dynamic_worlds = [], [], [], [], [], []
    
    for filename in tqdm(filenames):
        tif_file = xarray.open_rasterio(s2_data_60m / filename.strip())
        crs = tif_file.crs.split("=")[-1]
        transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True)
        
        for x_idx in INDICES_IN_TIF_FILE:
            for y_idx in INDICES_IN_TIF_FILE:
                
                # firstly, get the latitudes and longitudes
                x, y = tif_file.x[x_idx], tif_file.y[y_idx]
                lon, lat = transformer.transform(x, y) 
                latlons.append(torch.tensor([lat, lon]))
                
                # then, get the eo_data, mask and dynamic world
                s2_data_for_pixel = torch.from_numpy(tif_file.values[:, x_idx, y_idx].astype(int)).float()
                s2_data_with_time_dimension = s2_data_for_pixel.unsqueeze(0)
                x, mask, dynamic_world = presto.construct_single_presto_input(
                    s2=s2_data_with_time_dimension, s2_bands=TREESATAI_S2_BANDS
                )
                arrays.append(x)
                masks.append(mask)
                dynamic_worlds.append(dynamic_world)
                
                labels.append(0 if filename.startswith("Abies") else 1)
                image_names.append(filename)

    return (torch.stack(arrays, axis=0),
            torch.stack(masks, axis=0),
            torch.stack(dynamic_worlds, axis=0),
            torch.stack(latlons, axis=0),
            torch.tensor(labels),
            image_names,
        )

In [6]:
train_data = process_images(train_files)
test_data = process_images(test_files)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3430/3430 [01:03<00:00, 54.26it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 375/375 [00:06<00:00, 62.25it/s]


### 2. Using Presto as a feature extractor for a random forest

One way to use Presto is as a feature extractor for a simple model (e.g. a Random Forest). We do this below.

We load the pretrained Presto model using `Presto.load_pretrained()`.

In [7]:
batch_size = 64

pretrained_model = presto.Presto.load_pretrained()
pretrained_model.eval()

Presto(
  (encoder): Encoder(
    (eo_patch_embed): ModuleDict(
      (S1): Linear(in_features=2, out_features=128, bias=True)
      (S2_RGB): Linear(in_features=3, out_features=128, bias=True)
      (S2_NIR_10m): Linear(in_features=1, out_features=128, bias=True)
      (S2_NIR_20m): Linear(in_features=1, out_features=128, bias=True)
      (S2_Red_Edge): Linear(in_features=3, out_features=128, bias=True)
      (S2_SWIR): Linear(in_features=2, out_features=128, bias=True)
      (ERA5): Linear(in_features=2, out_features=128, bias=True)
      (SRTM): Linear(in_features=2, out_features=128, bias=True)
      (NDVI): Linear(in_features=1, out_features=128, bias=True)
    )
    (dw_embed): Embedding(10, 128)
    (latlon_embed): Linear(in_features=3, out_features=128, bias=True)
    (blocks): ModuleList(
      (0-1): 2 x Block(
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=128, out_features=384, bias=True)

We will start by constructing features for the training data, and using this to train a Random Forest.

In [8]:
# the treesat AI data was collected during the summer,
# so we estimate the month to be 6 (July)
month = torch.tensor([6] * train_data[0].shape[0]).long()

dl = DataLoader(
    TensorDataset(
        train_data[0].float(),  # x
        train_data[1].bool(),  # mask
        train_data[2].long(),  # dynamic world
        train_data[3].float(),  # latlons
        month
    ),
    batch_size=batch_size,
    shuffle=False,
)

In [9]:
features_list = []
for (x, mask, dw, latlons, month) in tqdm(dl):
    with torch.no_grad():
        encodings = (
            pretrained_model.encoder(
                x, dynamic_world=dw, mask=mask, latlons=latlons, month=month
            )
            .cpu()
            .numpy()
        )
        features_list.append(encodings)
features_np = np.concatenate(features_list)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 483/483 [00:05<00:00, 92.62it/s]


We use `features_np` to train a Random Forest classifier:

In [10]:
model = RandomForestClassifier(class_weight="balanced", random_state=42)
model.fit(features_np, train_data[4].numpy())

We can then use this trained random forest to make some predictions on the test data.

In [11]:
# the treesat AI data was collected during the summer,
# so we estimate the month to be 6 (July)
month = torch.tensor([6] * test_data[0].shape[0]).long()

dl = DataLoader(
    TensorDataset(
        test_data[0].float(),  # x
        test_data[1].bool(),  # mask
        test_data[2].long(),  # dynamic world
        test_data[3].float(),  # latlons
        month
    ),
    batch_size=batch_size,
    shuffle=False,
)

In [12]:
test_preds = []
for (x, mask, dw, latlons, month) in tqdm(dl):
    with torch.no_grad():
        pretrained_model.eval()
        encodings = (pretrained_model.encoder(
            x, dynamic_world=dw, mask=mask, latlons=latlons, month=month)
            .cpu()
            .numpy()
        )
        test_preds.append(model.predict_proba(encodings))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:01<00:00, 27.22it/s]


We take the mode of the 9 pixels we processed for a single image

In [13]:
pix_per_image = len(INDICES_IN_TIF_FILE) ** 2

test_preds_np = np.concatenate(test_preds, axis=0)
test_preds_np = np.reshape(
    test_preds_np,
    (int(len(test_preds_np) / pix_per_image), pix_per_image, test_preds_np.shape[-1]),
)
# then, take the mode of the model predictions
test_preds_np_argmax = stats.mode(
    np.argmax(test_preds_np, axis=-1), axis=1, keepdims=False
)[0]

And finally, we can compute the F1 score of the test predictions.

In [14]:
target = np.reshape(test_data[4], (int(len(test_data[4]) / pix_per_image), pix_per_image))[:, 0]

f1_score(target, test_preds_np_argmax, average="weighted")

0.9681132732641782