# Satvision-TOA Reconstruction Notebook

Version: 03.15.24

Env: `Python [conda env:ilab-pytorch]`

In [None]:
!pip install yacs timm segmentation-models-pytorch termcolor webdataset==0.2.86

In [None]:
import os
import sys
import time
import random
import datetime
import numpy as np
import logging

import torch
import torch.cuda.amp as amp

import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

import warnings
from tqdm import tqdm

warnings.filterwarnings('ignore') 

In [None]:
sys.path.append('../../pytorch-caney')

from pytorch_caney.config import get_config

from pytorch_caney.models.build import build_model

from pytorch_caney.ptc_logging import create_logger

from pytorch_caney.data.datasets.mim_modis_22m_dataset import MODIS22MDataset

from pytorch_caney.data.transforms import SimmimTransform, SimmimMaskGenerator

from pytorch_caney.config import _C, _update_config_from_file

## Configuration

### Clone model ckpt from huggingface

```bash
# On prism/explore
module load git-lfs

git lfs install

git clone git@hf.co:nasa-cisto-data-science-group/satvision-toa-huge
```

Note: If using git w/ ssh, make sure you have ssh keys enabled to clone using ssh auth. 

If experiencing ssh-related authentication issues:
```bash
eval `ssh-agent -s` # starts ssh-agent

ssh-add -l # is your ssh key added to the agent?

ssh-add ~/.ssh/id_xxxx # adds ssh ID to ssh-agent

ssh -T git@hf.co # Should return "Hi <user-id>, welcome to Hugging Face."
```

In [None]:
MODEL_PATH: str = '/explore/nobackup/people/cssprad1/projects/satvision-toa/models/huge-16-224/mp_rank_00_model_states.pt'
CONFIG_PATH: str = '/explore/nobackup/people/cssprad1/projects/satvision-toa/models/huge-16-224/mim_pretrain_swinv2_satvision_huge_224_patch16_100ep_lr3e4.yaml'

BATCH_SIZE: int = 64 # Want to report loss on every image? Change to 1.
OUTPUT: str = '.'
TAG: str = 'satvision-base-toa-reconstruction'
DATA_PATH: str = '/explore/nobackup/projects/ilab/projects/3DClouds/data/mosaic-v3/webdatasets'
DATA_PATHS: list = [DATA_PATH]

In [None]:
# Update config given configurations

config = _C.clone()
_update_config_from_file(config, CONFIG_PATH)

config.defrost()
config.MODEL.RESUME = MODEL_PATH
config.DATA.DATA_PATHS = DATA_PATHS
config.DATA.BATCH_SIZE = BATCH_SIZE
config.OUTPUT = OUTPUT
config.TAG = TAG
config.freeze()

In [None]:
# Configure logging
logging.basicConfig(
    filename='app.log',  # Specify the log file name
    level=logging.INFO,  # Set logging level to DEBUG
    format='%(asctime)s [%(levelname)s] %(message)s',  # Specify log message format
    datefmt='%Y-%m-%d %H:%M:%S'  # Specify date format
)

# Add logging to standard output
console = logging.StreamHandler()  # Create a handler for standard output
console.setLevel(logging.INFO)  # Set logging level for standard output
console.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s'))  # Set log message format for standard output
logger = logging.getLogger('')
logger.addHandler(console)

In [None]:
checkpoint = torch.load(MODEL_PATH)
model = build_model(config, pretrain=True)
model.load_state_dict(checkpoint['module'])
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"number of params: {n_parameters}")
model.cuda()
model.eval()

## Dataloader

In [None]:
outputDir = '/explore/nobackup/people/cssprad1/projects/satvision-toa/data/sampleModisToa2m'

In [None]:
import glob

In [None]:
def loadBatchApply(idx, outputDir, regex, transform, model):
    inputs = []
    outputs = []
    masks = []
    losses = []

    batches = glob.glob(os.path.join(outputDir, regex))
    print(batches)
    print(f'Found {len(batches)}')

    batchPath = batches[idx]
    print(batchPath)
    batch = np.load(batchPath)

    imgMasks = [transform(img) for img in batch]

    img = torch.stack([imgMask[0] for imgMask in imgMasks])
    mask = torch.stack([torch.from_numpy(imgMask[1]) for imgMask in imgMasks])

    img = img.cuda(non_blocking=True)
    mask = mask.cuda(non_blocking=True)

    with torch.no_grad():
        with amp.autocast(enabled=config.ENABLE_AMP):
            z = model.encoder(img, mask)
            img_recon = model.decoder(z)
            loss = model(img, mask)

    inputs.extend(img.cpu())
    masks.extend(mask.cpu())
    outputs.extend(img_recon.cpu())
    losses.append(loss.cpu())

    return inputs, outputs, masks, losses  
        

## Prediction helper functions

In [None]:
def predict(model, dataloader, num_batches=5):

    inputs = []
    outputs = []
    masks = []
    losses = []
    with tqdm(total=num_batches) as pbar:

        for idx, img_mask in enumerate(dataloader):
            
            pbar.update(1)

            if idx > num_batches:
                return inputs, outputs, masks, losses

            img_mask = img_mask[0]

            img = torch.stack([pair[0] for pair in img_mask])
            mask = torch.stack([pair[1] for pair in img_mask])

            img = img.cuda(non_blocking=True)
            mask = mask.cuda(non_blocking=True)

            with torch.no_grad():
                with amp.autocast(enabled=config.ENABLE_AMP):
                    z = model.encoder(img, mask)
                    img_recon = model.decoder(z)
                    loss = model(img, mask)

            inputs.extend(img.cpu())
            masks.extend(mask.cpu())
            outputs.extend(img_recon.cpu())
            losses.append(loss.cpu())
    
    return inputs, outputs, masks, losses


def minmax_norm(img_arr):
    arr_min = img_arr.min()
    arr_max = img_arr.max()
    img_arr_scaled = (img_arr - arr_min) / (arr_max - arr_min)
    img_arr_scaled = img_arr_scaled * 255
    img_arr_scaled = img_arr_scaled.astype(np.uint8)
    return img_arr_scaled


def process_mask(mask):
    mask = mask.unsqueeze(0)
    mask = mask.repeat_interleave(4, 1).repeat_interleave(4, 2).unsqueeze(1).contiguous()
    mask = mask[0, 0, :, :]
    mask = np.stack([mask, mask, mask], axis=-1)
    return mask


def process_prediction(image, img_recon, mask, rgb_index):

    mask = process_mask(mask)
    
    red_idx = rgb_index[0]
    blue_idx = rgb_index[1]
    green_idx = rgb_index[2]

    image = image.numpy()
    rgb_image = np.stack((image[red_idx, :, :],
                          image[blue_idx, :, :],
                          image[green_idx, :, :]),
                         axis=-1)
    rgb_image = minmax_norm(rgb_image)

    img_recon = img_recon.numpy()
    rgb_image_recon = np.stack((img_recon[red_idx, :, :],
                                img_recon[blue_idx, :, :],
                                img_recon[green_idx, :, :]),
                                axis=-1)
    rgb_image_recon = minmax_norm(rgb_image_recon)

    rgb_masked = np.where(mask == 0, rgb_image, rgb_image_recon)
    rgb_image_masked = np.where(mask == 1, 0, rgb_image)
    rgb_recon_masked = rgb_masked
    
    return rgb_image, rgb_image_masked, rgb_recon_masked, mask


def plot_export_pdf(path, num_sample, inputs, outputs, masks, rgb_index):
    random_subsample = list(random.sample(range(len(inputs)), num_sample))
    pdf_plot_obj = PdfPages(path)

    for idx in random_subsample:
        # prediction processing
        image = inputs[idx]
        img_recon = outputs[idx]
        mask = masks[idx]
        rgb_image, rgb_image_masked, rgb_recon_masked, mask = \
            process_prediction(image, img_recon, mask, rgb_index)

        # matplotlib code
        fig, (ax01, ax23) = plt.subplots(2, 2, figsize=(40, 30))
        ax0, ax1 = ax01
        ax2, ax3 = ax23
        ax2.imshow(rgb_image)
        ax2.set_title(f"Idx: {idx} MOD021KM v6.1 Bands: {rgb_index}")

        ax0.imshow(rgb_recon_masked)
        ax0.set_title(f"Idx: {idx} Model reconstruction")

        ax1.imshow(rgb_image_masked)
        ax1.set_title(f"Idx: {idx} MOD021KM Bands: {rgb_index}, masked")
        
        ax3.matshow(mask[:, :, 0])
        ax3.set_title(f"Idx: {idx} Reconstruction Mask")
        pdf_plot_obj.savefig()

    pdf_plot_obj.close()

## Predict

## Plot and write to PDF

In [None]:
rgb_index = [0, 2, 1] # Indices of [Red band, Blue band, Green band]

In [None]:
selectedBatches = {
    0: [58, 24, 17, 2, 48, 29, 41, 18, 55, 23, 43, 61, 40, 7, 63, 16],
    1: [47, 24, 50, 30, 2, 44, 21, 36, 13, 40, 46, 53, 62, 27, 8, 63, 18],
    10: [41, 18, 19, 13, 6, 34, 10, 39, 0, 32, 52, 23, 55, 58, 5, 59],
    17: [47, 59, 11, 54, 42, 25, 61, 45, 7, 20, 3, 49, 22, 51, 19, 46],
    4: [11, 51, 29, 18, 31, 13, 53, 34, 39, 52, 20, 2, 6, 26, 21, 10, 60, 3, 30],
    12: [41, 25, 14, 62, 10, 46, 26, 52, 4, 54, 23, 49, 16, 33, 30, 34, 53, 59, 63],
    9: [5, 42, 21, 53, 60, 58, 24, 55, 50, 47, 43, 3, 37, 2, 33, 18, 29, 48, 62, 45, 38, 8, 22],
    17: [52, 60, 20, 61, 14, 22, 26, 15, 25, 43, 5, 49, 19, 47, 46, 3, 27, 29],
}

In [None]:
batches = glob.glob(os.path.join(outputDir, '*.npy'))
print(f'Found {len(batches)}')
total = 0
allImages = []

for batchIdx in selectedBatches.keys():
    batch = np.load(batches[batchIdx])
    selectedImgs = np.stack([batch[bIdx] for bIdx in selectedBatches[batchIdx]])
    allImages.append(selectedImgs)
    total += selectedImgs.shape[0]
    print(selectedImgs.shape)
    

print(total)

In [None]:
allBatches = np.concatenate(allImages)
allBatches.shape

In [None]:
transform = SimmimTransform(config)

inputs = []
outputs = []
masks = []
losses = []

for idx in range(allBatches.shape[0]):

    image = allBatches[idx]

    imgMasks = [transform(image)]

    img = torch.stack([imgMask[0] for imgMask in imgMasks])
    mask = torch.stack([torch.from_numpy(imgMask[1]) for imgMask in imgMasks])

    img = img.cuda(non_blocking=True)
    mask = mask.cuda(non_blocking=True)

    with torch.no_grad():
        with amp.autocast(enabled=config.ENABLE_AMP):
            z = model.encoder(img, mask)
            img_recon = model.decoder(z)
            loss = model(img, mask)

    inputs.extend(img.cpu())
    masks.extend(mask.cpu())
    outputs.extend(img_recon.cpu())
    losses.append(loss.cpu()) 

In [None]:
def plot_export_pdf(path, inputs, outputs, masks, rgb_index):
    pdf_plot_obj = PdfPages(path)

    for idx in range(len(inputs)):
        # prediction processing
        image = inputs[idx]
        img_recon = outputs[idx]
        mask = masks[idx]
        rgb_image, rgb_image_masked, rgb_recon_masked, mask = \
            process_prediction(image, img_recon, mask, rgb_index)

        # matplotlib code
        fig, (ax01, ax23) = plt.subplots(2, 2, figsize=(40, 30))
        ax0, ax1 = ax01
        ax2, ax3 = ax23
        ax2.imshow(rgb_image)
        ax2.set_title(f"Idx: {idx} MOD021KM v6.1 Bands: {rgb_index}")

        ax0.imshow(rgb_recon_masked)
        ax0.set_title(f"Idx: {idx} Model reconstruction")

        ax1.imshow(rgb_image_masked)
        ax1.set_title(f"Idx: {idx} MOD021KM Bands: {rgb_index}, masked")
        
        ax3.matshow(mask[:, :, 0])
        ax3.set_title(f"Idx: {idx} Reconstruction Mask")
        pdf_plot_obj.savefig()

    pdf_plot_obj.close()

In [None]:
plot_export_pdf('satvision_toa_2m_224_patch16_128testchips_predictions.pdf', inputs, outputs, masks, rgb_index)