# Satvision-TOA Reconstruction Notebook

Version: 02.20.24

Env: `Python [conda env:ilab-pytorch]`

In [None]:
!pip install yacs timm segmentation-models-pytorch termcolor webdataset==0.2.86

In [None]:
import os
import sys
import time
import random
import datetime
import numpy as np
import logging

import torch
import torch.cuda.amp as amp

import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

import warnings

warnings.filterwarnings('ignore') 

In [None]:
sys.path.append('../../pytorch-caney')

from pytorch_caney.config import get_config

from pytorch_caney.training.mim_utils import load_checkpoint, load_pretrained

from pytorch_caney.models.build import build_model

from pytorch_caney.ptc_logging import create_logger

from pytorch_caney.data.datamodules import mim_webdataset_datamodule

from pytorch_caney.data.transforms import SimmimTransform, SimmimMaskGenerator

from pytorch_caney.config import _C, _update_config_from_file

## Configuration

### Clone model ckpt from huggingface

```bash
# On prism/explore
module load git-lfs

git lfs install

git clone git@hf.co:nasa-cisto-data-science-group/satvision-toa-base
```

Note: If using git w/ ssh, make sure you have ssh keys enabled to clone using ssh auth. 

In [None]:
MODEL_PATH: str = '../../satvision-toa-base/satvision-toa_84M_2M_100.pth'
CONFIG_PATH: str = '../../satvision-toa-base/mim_pretrain_swinv2_satvision-toa_base_192_window12_800ep.yaml'

BATCH_SIZE: int = 64 # Want to report loss on every image? Change to 1.
OUTPUT: str = '.'
TAG: str = 'satvision-base-toa-reconstruction'
DATA_PATH: str = '/explore/nobackup/projects/ilab/projects/3DClouds/data/mosaic-v3/webdatasets'
DATA_PATHS: list = [DATA_PATH]

In [None]:
# Update config given configurations

config = _C.clone()
_update_config_from_file(config, CONFIG_PATH)

config.defrost()
config.MODEL.RESUME = MODEL_PATH
config.DATA.DATA_PATHS = DATA_PATHS
config.DATA.BATCH_SIZE = BATCH_SIZE
config.OUTPUT = OUTPUT
config.TAG = TAG
config.freeze()

In [None]:
# Configure logging
logging.basicConfig(
    filename='app.log',  # Specify the log file name
    level=logging.INFO,  # Set logging level to DEBUG
    format='%(asctime)s [%(levelname)s] %(message)s',  # Specify log message format
    datefmt='%Y-%m-%d %H:%M:%S'  # Specify date format
)

# Add logging to standard output
console = logging.StreamHandler()  # Create a handler for standard output
console.setLevel(logging.INFO)  # Set logging level for standard output
console.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))  # Set log message format for standard output
logger = logging.getLogger('')
logger.addHandler(console)

In [None]:
checkpoint = torch.load(MODEL_PATH)
model = build_model(config, pretrain=True)
model.load_state_dict(checkpoint['model'])
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"number of params: {n_parameters}")
model.cuda()
model.eval()

## Dataloader

In [None]:
dataloader = mim_webdataset_datamodule.build_mim_dataloader(config, logger)

## Prediction helper functions

In [None]:
def predict(model, dataloader, num_batches=5):

    inputs = []
    outputs = []
    masks = []
    losses = []

    for idx, img_mask in enumerate(dataloader):
        
        if idx > num_batches:
            return inputs, outputs, masks, losses

        img_mask = img_mask[0]

        img = torch.stack([pair[0] for pair in img_mask])
        mask = torch.stack([pair[1] for pair in img_mask])

        img = img.cuda(non_blocking=True)
        mask = mask.cuda(non_blocking=True)

        with torch.no_grad():
            with amp.autocast(enabled=config.ENABLE_AMP):
                z = model.encoder(img, mask)
                img_recon = model.decoder(z)
                loss = model(img, mask)

        inputs.extend(img.cpu())
        masks.extend(mask.cpu())
        outputs.extend(img_recon.cpu())
        losses.append(losses)
    
    return inputs, outputs, masks, losses


def minmax_norm(img_arr):
    arr_min = img_arr.min()
    arr_max = img_arr.max()
    img_arr_scaled = (img_arr - arr_min) / (arr_max - arr_min)
    img_arr_scaled = img_arr_scaled * 255
    img_arr_scaled = img_arr_scaled.astype(np.uint8)
    return img_arr_scaled


def process_mask(mask):
    mask = mask.unsqueeze(0)
    mask = mask.repeat_interleave(4, 1).repeat_interleave(4, 2).unsqueeze(1).contiguous()
    mask = mask[0, 0, :, :]
    mask = np.stack([mask, mask, mask], axis=-1)
    return mask


def process_prediction(image, img_recon, mask, rgb_index):
    img_normed = minmax_norm(image.numpy())

    mask = process_mask(mask)
    
    red_idx = rgb_index[0]
    blue_idx = rgb_index[1]
    green_idx = rgb_index[2]

    rgb_image = np.stack((img_normed[red_idx, :, :],
                        img_normed[blue_idx, :, :],
                        img_normed[green_idx, :, :]),
                        axis=-1)

    img_recon = minmax_norm(img_recon.numpy())
    rgb_image_recon = np.stack((img_recon[red_idx, :, :],
                                img_recon[blue_idx, :, :],
                                img_recon[green_idx, :, :]),
                                axis=-1)

    rgb_masked = np.where(mask == 0, rgb_image, rgb_image_recon)
    rgb_image_masked = np.where(mask == 1, 0, rgb_image)
    rgb_recon_masked = rgb_masked
    
    return rgb_image, rgb_image_masked, rgb_recon_masked, mask


def plot_export_pdf(path, num_sample, inputs, outputs, masks, rgb_index):
    random_subsample = random.sample(range(len(inputs)), num_sample)
    pdf_plot_obj = PdfPages(path)

    for idx in random_subsample:
        # prediction processing
        image = inputs[idx]
        img_recon = outputs[idx]
        mask = masks[idx]
        rgb_image, rgb_image_masked, rgb_recon_masked, mask = \
            process_prediction(image, img_recon, mask, rgb_index)

        # matplotlib code
        fig, (ax01, ax23) = plt.subplots(2, 2, figsize=(40, 30))
        ax0, ax1 = ax01
        ax2, ax3 = ax23
        ax2.imshow(rgb_image)
        ax2.set_title(f"Idx: {idx} MOD021KM v6.1 Bands: {rgb_index}")

        ax0.imshow(rgb_recon_masked)
        ax0.set_title(f"Idx: {idx} Model reconstruction")

        ax1.imshow(rgb_image_masked)
        ax1.set_title(f"Idx: {idx} MOD021KM Bands: {rgb_index}, masked")
        
        ax3.matshow(mask[:, :, 0])
        ax3.set_title(f"Idx: {idx} Reconstruction Mask")
        pdf_plot_obj.savefig()

    pdf_plot_obj.close()

## Predict

In [None]:
%%time

inputs, outputs, masks, losses = predict(model, dataloader, num_batches=5)

## Plot and write to PDF

In [None]:
pdf_path = '../../satvision-toa-reconstruction-pdf-02.20.pdf'
num_samples = 10 # Number of random samples from the predictions
rgb_index = [0, 3, 2] # Indices of [Red band, Blue band, Green band]

plot_export_pdf(pdf_path, num_samples, inputs, outputs, masks, rgb_index)