https://huggingface.co/ibm-nasa-geospatial

https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M

In [1]:
%load_ext autoreload
%autoreload 2

The model accepts remote sensing data in a video format (B, C, T, H, W). he model can also handle static imagery which can be fed into the model with T=1. 


The model was pre-trained with NASA's HLS V2 L30 product (30m granularity) from the contiguous United States. The bands that were used are the following:

- Blue
- Green
- Red
- Narrow NIR
- SWIR 1
- SWIR 2

Image size is 224x224. 

In [2]:
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
import rasterio
import yaml
from prithvi.Prithvi import MaskedAutoencoderViT

NO_DATA = -9999
NO_DATA_FLOAT = 0.0001
PERCENTILES = (0.1, 99.9)

In [3]:
# load weights
weights_path = "./prithvi/Prithvi_100M.pt"
checkpoint = torch.load(weights_path, map_location="cpu")

# read model config
model_cfg_path = "./prithvi/Prithvi_100M_config.yaml"
with open(model_cfg_path) as f:
    model_config = yaml.safe_load(f)

model_args, train_args = model_config["model_args"], model_config["train_params"]

# let us use only 1 frame for now (the model was trained on 3 frames)
model_args["num_frames"] = 1

# instantiate model
model = MaskedAutoencoderViT(**model_args)
model.eval()

# load weights into model
# strict=false since we are loading with only 1 frame, but the warning is expected
del checkpoint['pos_embed']
del checkpoint['decoder_pos_embed']
_ = model.load_state_dict(checkpoint, strict=False)

In [4]:
output = model(torch.rand(1, 6, 1, 224, 224))
output

(tensor(0.0943, grad_fn=<DivBackward0>),
 tensor([[[0.4565, 0.4553, 0.4768,  ..., 0.4566, 0.4865, 0.4737],
          [0.4606, 0.4514, 0.4598,  ..., 0.4075, 0.5058, 0.5207],
          [0.4671, 0.4578, 0.4693,  ..., 0.4039, 0.5093, 0.5265],
          ...,
          [0.4267, 0.3978, 0.3769,  ..., 0.7087, 0.7390, 0.6877],
          [0.2289, 0.1958, 0.1634,  ..., 0.2549, 0.2275, 0.2487],
          [0.3055, 0.3261, 0.3436,  ..., 0.3255, 0.6125, 0.6397]]],
        grad_fn=<SliceBackward0>),
 tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1.,
          1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 0.,
          1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 1., 1.,
          1., 1., 1., 0., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1., 1., 0., 1., 1.,
          0., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 0., 1., 1., 1., 1.,
          1., 1., 1

The model returns a tuple with:

- loss
- reconstructed image
- mask used

Images are normalized

In [5]:
# statistics used to normalize images before passing to the model
means = train_args["data_mean"]
stds = train_args["data_std"]

means, stds

([775.2290211032589,
  1080.992780391705,
  1228.5855250417867,
  2497.2022620507532,
  2204.2139147975554,
  1610.8324823273745],
 [1281.526139861424,
  1270.0297974547493,
  1399.4802505642526,
  1368.3446143747644,
  1291.6764008585435,
  1154.505683480695])

To finetune, you can now write a PyTorch loop as usual to train on your dataset. Simply extract the backbone from the model with some surgery and run only the model features forward, with no masking!

In general some reccomendations are:

- At least in the beggining, experiment with freezing the backbone. This will give you much faster iteration through experiments.
- Err on the side of a smaller learning rate
- With an unfrozen encoder, regularization is your friend! (Weight decay, dropout, batchnorm...)

In [6]:
features, _, _ = model.forward_encoder(torch.rand(1, 6, 1, 224, 224), mask_ratio=0)
features.shape

torch.Size([1, 197, 768])

These are the standard output of a ViT.

- Dim 1: Batch size
- Dim 2: [cls_token] + tokens representing flattened image
- Dim 3: embedding dimension

In [9]:
features.flatten().shape

torch.Size([151296])

In [13]:
(14**2 + 1)*768

151296

In [7]:
model_args

{'decoder_depth': 8,
 'decoder_embed_dim': 512,
 'decoder_num_heads': 16,
 'depth': 12,
 'embed_dim': 768,
 'img_size': 224,
 'in_chans': 6,
 'num_frames': 1,
 'num_heads': 12,
 'patch_size': 16,
 'tubelet_size': 1}

In [8]:
print(f"Encoder features have shape {features.shape}")

# drop cls token
reshaped_features = features[:, 1:, :]

# reshape
feature_img_side_length = int(np.sqrt(reshaped_features.shape[1]))
reshaped_features = reshaped_features.view(-1, feature_img_side_length, feature_img_side_length, model_args["embed_dim"])
# channels first
reshaped_features = reshaped_features.permute(0, 3, 1, 2)
print(f"Encoder features have new shape {reshaped_features.shape}")

Encoder features have shape torch.Size([1, 197, 768])
Encoder features have new shape torch.Size([1, 768, 14, 14])
