# SatVision-TOA Reconstruction Example Notebook

This notebook demonstrates the reconstruction capabilities of the SatVision-TOA model, designed to process and reconstruct MODIS TOA (Top of Atmosphere) imagery using Masked Image Modeling (MIM) for Earth observation tasks.

Follow this step-by-step guide to install necessary dependencies, load model weights, transform data, make predictions, and visualize the results.

## 1. Setup and Install Dependencies

The following packages are required to run the notebook:
- `yacs` – for handling configuration
- `timm` – for Transformer and Image Models in PyTorch
- `segmentation-models-pytorch` – for segmentation utilities
- `termcolor` – for colored terminal text
- `webdataset==0.2.86` – for handling datasets from web sources
- `huggingface-hub` - for downloading Hugging Face files
- `datasets` - for running model 

#### 1.1 Install necessary modules, clone github repository

In [None]:
# !pip install yacs timm segmentation-models-pytorch termcolor webdataset==0.2.86 huggingface-hub datasets 

#### 1.2 Imports
**Some modules may require some path configurations (often installed in .local or .cache directories)**

In [None]:
import os
import cv2
import sys
import glob
import math
import time
import torch
import random
import logging
import datasets
import datetime
import warnings
import subprocess
import numpy as np
import matplotlib.pyplot as plt

warnings.filterwarnings('ignore')

from tqdm import tqdm
from huggingface_hub import hf_hub_download
from huggingface_hub import snapshot_download
from matplotlib.backends.backend_pdf import PdfPages

In [None]:
repo_dir = "satvision-toa"

if not os.path.exists(repo_dir):
    subprocess.run(["git", "clone", "https://github.com/nasa-nccs-hpda/satvision-toa"])
else:
    subprocess.run(["git", "-C", repo_dir, "pull"])

#### 1.3 Repository-Specific Imports

We load necessary modules from the pytorch-caney library, including the model, transformations, and plotting utilities.

In [None]:
# sys.path.append('../..')
sys.path.append('satvision-toa')
from satvision_toa.models.mim import build_mim_model
from satvision_toa.transforms.mim_modis_toa import MimTransform
from satvision_toa.configs.config import _C, _update_config_from_file
from satvision_toa.plotting.modis_toa import plot_export_pdf

## 2. User-defined variables

**save_to_pdf** and **pdf_path** dictate whether to save model inference images to PDF, and where to save them to.

**rgb_index** is the indices of RGB bands in model input data

In [None]:
# Whether to save files to a PDF, and where to save them 
save_to_pdf = False
pdf_path = "chip_plot.pdf" # if not saving this can be None

# Indices of RGB bands within 14-band data
rgb_index = [0, 2, 1]

# Model size to download
model_size: str = 'giant'

## 3. Downlad model, validation files from HF

### 3.1 Download model and config

In [None]:
model_options_metadata = {
    'giant': {
        'repo_id': 'nasa-cisto-data-science-group/satvision-toa-giant-patch8-window8-128',
        'model_filename': 'mp_rank_00_model_states.pt',
        'config_filename': 'mim_pretrain_swinv2_satvision_giant_128_window08_50ep.yaml'
    },
    'huge': {
        'repo_id': 'nasa-cisto-data-science-group/satvision-toa-huge-patch8-window8-128',
        'model_filename': 'mp_rank_00_model_states.pt',
        'config_filename': 'mim_pretrain_swinv2_satvision_huge_128_window8_patch8_100ep.yaml'
    }
}

In [None]:
hf_model_repo_id: str = model_options_metadata[model_size]['repo_id']
hf_model_filename: str = model_options_metadata[model_size]['model_filename']
hf_config_filename: str = model_options_metadata[model_size]['config_filename']
hf_dataset_repo_id: str = 'nasa-cisto-data-science-group/modis_toa_cloud_reconstruction_validation'

model_filename = hf_hub_download(
    repo_id=hf_model_repo_id,
    filename=hf_model_filename)
config_filename = hf_hub_download(
    repo_id=hf_model_repo_id,
    filename=hf_config_filename)

### 3.2 Download and transform validation set

In [None]:
# download the dataset
validation_tiles_dir = snapshot_download(repo_id=hf_dataset_repo_id, allow_patterns="*.npy", repo_type='dataset')
validation_tiles_regex = os.path.join(validation_tiles_dir, '*.npy')
validation_tiles_filename = next(iter(glob.glob(validation_tiles_regex)))
validation_tiles = np.load(validation_tiles_filename)

### 3.3 Load and edit model config

In [None]:
config = _C.clone()
_update_config_from_file(config, config_filename)

# Add checkpoint (MODEL.PRETRAINED), 
# validation tile dir (DATA.DATA_PATHS),
# and output dir (OUTPUT) to config file
config.defrost()
config.MODEL.PRETRAINED = model_filename
config.DATA.DATA_PATHS = validation_tiles_filename
config.OUTPUT = '.'
config.freeze()

In [None]:
# Use the Masked-Image-Modeling transform specific to MODIS TOA data
transform = MimTransform(config)

# The reconstruction evaluation set is a single numpy file
len_batch = range(validation_tiles.shape[0])

# Apply transform to each image in the batch
# A mask is auto-generated in the transform
imgMasks = [transform(validation_tiles[idx]) for idx \
    in len_batch]

# Seperate img and masks, cast masks to torch tensor
img = torch.stack([imgMask[0] for imgMask in imgMasks])
mask = torch.stack([torch.from_numpy(imgMask[1]) for \
    imgMask in imgMasks])

## 4. Build model

Model checkpoint and weights are stored in config file.

In [None]:
print('Building un-initialized model')
model = build_mim_model(config)
print('Successfully built uninitialized model')

print(f'Attempting to load checkpoint from {config.MODEL.PRETRAINED}')
checkpoint = torch.load(config.MODEL.PRETRAINED)
model.load_state_dict(checkpoint['module'])
print('Successfully applied checkpoint')
model.cuda()
model.eval();

## 5. Prediction

Run predictions on each sample and calculate reconstruction losses. Each image is processed individually to track individual losses.

In [None]:
inputs = []
outputs = []
masks = []
losses = []

# We could do this in a single batch however we
# want to report the loss per-image, in place of
# loss per-batch.
for i in tqdm(range(img.shape[0])):
    single_img = img[i].unsqueeze(0)
    single_mask = mask[i].unsqueeze(0)
    single_img = single_img.cuda(non_blocking=True)
    single_mask = single_mask.cuda(non_blocking=True)

    with torch.no_grad():
        z = model.encoder(single_img, single_mask)
        img_recon = model.decoder(z)
        loss = model(single_img, single_mask)

    inputs.extend(single_img.cpu())
    masks.extend(single_mask.cpu())
    outputs.extend(img_recon.cpu())
    losses.append(loss.cpu()) 

In [None]:
outputs[0].shape
# inputs is a 128-length list of [14, 128, 128] shape chips
# outputs is a 128-length list of [14, 128, 128] shape chips
# masks is a 128-length list of [32, 32] masks

## 6. Plot reconstruction

*Using the plot_export_pdf found in satvision_toa/plotting/modis_toa.py*

This will display model reconstruction and mask, compared with model input. It will save to a pdf file defined in pdf_path if save_to_pdf is True. 

In [None]:
plot_export_pdf(pdf_path, inputs, outputs, masks, rgb_index, save_to_pdf)