# Satvision-TOA Reconstruction Notebook

Version: 04.30.24

Env: `Python [conda env:ilab-pytorch]`

In [None]:
!pip install yacs timm segmentation-models-pytorch termcolor webdataset==0.2.86

In [None]:
import os
import sys
import time
import random
import datetime
from tqdm import tqdm
import numpy as np
import logging

import torch
import torch.cuda.amp as amp

import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

import warnings

warnings.filterwarnings('ignore') 

In [None]:
sys.path.append('../../pytorch-caney')

from pytorch_caney.config import get_config

from pytorch_caney.training.mim_utils import load_checkpoint, load_pretrained

from pytorch_caney.models.build import build_model

from pytorch_caney.ptc_logging import create_logger

from pytorch_caney.data.datamodules import mim_webdataset_datamodule

from pytorch_caney.data.transforms import SimmimTransform, SimmimMaskGenerator

from pytorch_caney.config import _C, _update_config_from_file

## 1. Configuration

### Clone model ckpt from huggingface

```bash
# On prism/explore
module load git-lfs

git lfs install

git clone git clone git@hf.co:nasa-cisto-data-science-group/satvision-toa-huge-patch8-window12-192
```

Note: If using git w/ ssh, make sure you have ssh keys enabled to clone using ssh auth.
https://huggingface.co/docs/hub/security-git-ssh

```bash
eval $(ssh-agent)

# If this outputs as anon, follow the next steps.
ssh -T git@hf.co

# Check if ssh-agent is using the proper key
ssh-add -l

# If not
ssh-add ~/.ssh/your-key

# Or if you want to use the default id_* key, just do
ssh-add

```

In [None]:
MODEL_PATH: str = '../../satvision-toa-huge-patch8-window12-192/mp_rank_00_model_states.pt'
CONFIG_PATH: str = '../../satvision-toa-huge-patch8-window12-192/mim_pretrain_swinv2_satvision_huge_192_window12_100ep.yaml'

OUTPUT: str = '.'
TAG: str = 'satvision-huge-toa-reconstruction'
DATA_PATH: str = '/explore/nobackup/projects/ilab/projects/3DClouds/data/validation/sv_toa_128_chip_validation_04_24.npy'
DATA_PATHS: list = [DATA_PATH]

In [None]:
# Update config given configurations

config = _C.clone()
_update_config_from_file(config, CONFIG_PATH)

config.defrost()
config.MODEL.RESUME = MODEL_PATH
config.DATA.DATA_PATHS = DATA_PATHS
config.OUTPUT = OUTPUT
config.TAG = TAG
config.freeze()

In [None]:
# Configure logging
logging.basicConfig(
    filename='app.log',  # Specify the log file name
    level=logging.INFO,  # Set logging level to DEBUG
    format='%(asctime)s [%(levelname)s] %(message)s',  # Specify log message format
    datefmt='%Y-%m-%d %H:%M:%S'  # Specify date format
)

# Add logging to standard output
console = logging.StreamHandler()  # Create a handler for standard output
console.setLevel(logging.INFO)  # Set logging level for standard output
console.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))  # Set log message format for standard output
logger = logging.getLogger('')
logger.addHandler(console)

## 2. Load model weights from checkpoint

In [None]:
checkpoint = torch.load(MODEL_PATH)
model = build_model(config, pretrain=True)
model.load_state_dict(checkpoint['module']) # If 'module' not working, try 'model'
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"number of params: {n_parameters}")
model.cuda()
model.eval()

## 3. Load evaluation set (from numpy file)

In [None]:
# Use the Masked-Image-Modeling transform
transform = SimmimTransform(config)

# The reconstruction evaluation set is a single numpy file
validation_dataset_path = config.DATA.DATA_PATHS[0]
validation_dataset = np.load(validation_dataset_path)
len_batch = range(validation_dataset.shape[0])

# Apply transform to each image in the batch
# A mask is auto-generated in the transform
imgMasks = [transform(validation_dataset[idx]) for idx \
    in len_batch]

# Seperate img and masks, cast masks to torch tensor
img = torch.stack([imgMask[0] for imgMask in imgMasks])
mask = torch.stack([torch.from_numpy(imgMask[1]) for \
    imgMask in imgMasks])

## 4. Prediction helper functions

In [None]:
def predict(model, dataloader, num_batches=5):

    inputs = []
    outputs = []
    masks = []
    losses = []
    with tqdm(total=num_batches) as pbar:

        for idx, img_mask in enumerate(dataloader):
            
            pbar.update(1)

            if idx > num_batches:
                return inputs, outputs, masks, losses

            img_mask = img_mask[0]

            img = torch.stack([pair[0] for pair in img_mask])
            mask = torch.stack([pair[1] for pair in img_mask])

            img = img.cuda(non_blocking=True)
            mask = mask.cuda(non_blocking=True)

            with torch.no_grad():
                with amp.autocast(enabled=config.ENABLE_AMP):
                    z = model.encoder(img, mask)
                    img_recon = model.decoder(z)
                    loss = model(img, mask)

            inputs.extend(img.cpu())
            masks.extend(mask.cpu())
            outputs.extend(img_recon.cpu())
            losses.append(loss.cpu())
    
    return inputs, outputs, masks, losses


def minmax_norm(img_arr):
    arr_min = img_arr.min()
    arr_max = img_arr.max()
    img_arr_scaled = (img_arr - arr_min) / (arr_max - arr_min)
    img_arr_scaled = img_arr_scaled * 255
    img_arr_scaled = img_arr_scaled.astype(np.uint8)
    return img_arr_scaled


def process_mask(mask):
    mask_img = mask.unsqueeze(0)
    mask_img = mask_img.repeat_interleave(4, 1).repeat_interleave(4, 2).unsqueeze(1).contiguous()
    mask_img = mask_img[0, 0, :, :]
    mask_img = np.stack([mask_img, mask_img, mask_img], axis=-1)
    return mask_img


def process_prediction(image, img_recon, mask, rgb_index):

    mask = process_mask(mask)
    
    red_idx = rgb_index[0]
    blue_idx = rgb_index[1]
    green_idx = rgb_index[2]

    image = image.numpy()
    rgb_image = np.stack((image[red_idx, :, :],
                          image[blue_idx, :, :],
                          image[green_idx, :, :]),
                         axis=-1)
    rgb_image = minmax_norm(rgb_image)

    img_recon = img_recon.numpy()
    rgb_image_recon = np.stack((img_recon[red_idx, :, :],
                                img_recon[blue_idx, :, :],
                                img_recon[green_idx, :, :]),
                                axis=-1)
    rgb_image_recon = minmax_norm(rgb_image_recon)

    rgb_masked = np.where(mask == 0, rgb_image, rgb_image_recon)
    rgb_image_masked = np.where(mask == 1, 0, rgb_image)
    rgb_recon_masked = rgb_masked
    
    return rgb_image, rgb_image_masked, rgb_recon_masked, mask


def plot_export_pdf(path, inputs, outputs, masks, rgb_index):
    pdf_plot_obj = PdfPages(path)

    for idx in range(len(inputs)):
        # prediction processing
        image = inputs[idx]
        img_recon = outputs[idx]
        mask = masks[idx]
        rgb_image, rgb_image_masked, rgb_recon_masked, mask = \
            process_prediction(image, img_recon, mask, rgb_index)

        # matplotlib code
        fig, (ax01, ax23) = plt.subplots(2, 2, figsize=(40, 30))
        ax0, ax1 = ax01
        ax2, ax3 = ax23
        ax2.imshow(rgb_image)
        ax2.set_title(f"Idx: {idx} MOD021KM v6.1 Bands: {rgb_index}")

        ax0.imshow(rgb_recon_masked)
        ax0.set_title(f"Idx: {idx} Model reconstruction")

        ax1.imshow(rgb_image_masked)
        ax1.set_title(f"Idx: {idx} MOD021KM Bands: {rgb_index}, masked")
        
        ax3.matshow(mask[:, :, 0])
        ax3.set_title(f"Idx: {idx} Reconstruction Mask")
        pdf_plot_obj.savefig()

    pdf_plot_obj.close()

## 5. Predict

In [None]:
inputs = []
outputs = []
masks = []
losses = []

# We could do this in a single batch however we
# want to report the loss per-image, in place of
# loss per-batch.
for i in tqdm(range(img.shape[0])):
    single_img = img[i].unsqueeze(0)
    single_mask = mask[i].unsqueeze(0)
    single_img = single_img.cuda(non_blocking=True)
    single_mask = single_mask.cuda(non_blocking=True)

    with torch.no_grad():
        z = model.encoder(single_img, single_mask)
        img_recon = model.decoder(z)
        loss = model(single_img, single_mask)

    inputs.extend(single_img.cpu())
    masks.extend(single_mask.cpu())
    outputs.extend(img_recon.cpu())
    losses.append(loss.cpu()) 

## 6. Plot and write to PDF

Writes out all of the predictions to a PDF file

In [None]:
pdf_path = '../../satvision-toa-reconstruction-pdf-huge-patch-8-04.30.pdf'
rgb_index = [0, 2, 1] # Indices of [Red band, Blue band, Green band]

plot_export_pdf(pdf_path, inputs, outputs, masks, rgb_index)