#  Prithvi Satellite Feature Extraction (Optoinal)

This notebook processes 6-channel satellite imagery tiles through the Prithvi-100M **pre-trained** visual transformer and extracts features from the specified layer(s). This step is optional and can be performed to combine satellite features with geospatial data when executing `05_resnet18_fine_tuning.ipynb`


## File System Structure

## Input

The input satellite tiles (for each DHS location) are located in `Satellite_Tiles` within the hierarchy below. 
<pre style="font-family: monospace;">
./GIS-Image-Stack-Processing
    /AOI/
        PK/
            Image_Tiles/
                    :
            Satellite_Tiles/
                PK_1_C-1_30m.tif
                PK_2_C-2_30m.tif
                    :
                PK_560_C-580_30m.tif
                
            Satellite_Features/
                PK_sat_features_prithvi_L6_L8.npz (generated using 04_prithvi_sat_feature_extraction.ipynb)
                PK_sat_features_resent_layer4.npz (geneerted using 04_resnet_sat_feature_extraction.ipynb)
</pre>


## Required Configurations

The following configurations are required for each execution of this notebook: the two-letter country code. Other model and feature extraction configurations are available in the Configuration section.
<pre style="font-family: monospace;">
<span style="color: blue;">country_code  = 'PK'</span>      # Set the country code to one of the available AOIs in the list below

Available AOIs: AM (Armenia)
                MA (Morocco)
                MB (Moldova)
                ML (Mali)
                MR (Mauritania)
                NI (Niger)
                PK (Pakistan)
                SN (Senegal)
                TD (Chad)
                
</pre>


In [1]:
#-------------------------------------------------------------------------
# REQUIRED CONFIGURATIONS HERE
#-------------------------------------------------------------------------
country_code  = 'ML'     # Set the country code
layer_indices = [6,8]    # Specify the layers (indices 0 through 11) to extract features from.
                         # Specifying more than one layer like [6, 8] will concatenate the features  
                         # from both layers.
#-------------------------------------------------------------------------

layer_string = "_".join([f"L{idx}" for idx in layer_indices])

sat_feature_file  = f'{country_code}_sat_features_prithvi_{layer_string}.npz'

In [2]:
import os
import sys
import re
import rasterio
import random
import numpy as np
import warnings
import json
from enum import Enum
from collections import Counter
                
import torch
import torch.nn as nn
import torchvision.transforms as T
from functools import partial
import torchvision.models as models

from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass
from omegaconf import OmegaConf

##  Environment Configuration

In [3]:
from prithvi_pytorch.encoder import *
from prithvi_pytorch import PrithviEncoder

sys.path.append('./GIS-Image-Stack-Processing') 

cache_dir = 'project_utils/__pycache__'
if os.path.exists(cache_dir):
    shutil.rmtree(cache_dir)

from project_utils.plot_utils import display_rgb_images, display_ir_band
from project_utils.aoi_configurations import *
from project_utils.satellite_dataset_utils import HLSFlexibleBandSatDataset, custom_transforms

#----------------------------------------------------------------------------------------
# *** IMPORTANT: SYSTEM PATH TO SET ***
#----------------------------------------------------------------------------------------
# The following path is required, as it contains GDAL binaries used for several 
# pre-processing functions. The pathname corresponds to the Conda virtual environment 
# created for this project (e.g., "py39-pt").
#
# Note: GDAL was adopted as a benchmark to compare the original GIS data produced by 
# another team. However, similar functionality could be implemented using the Rasterio 
# Python package. If Rasterio is used, it would eliminate the need for GDAL binaries 
# and this system path specification.
#----------------------------------------------------------------------------------------

# Adding path to gdal commands for local system
os.environ['PATH'] += ':/Users/billk/miniforge3/envs/py39-pt/bin/' 



In [4]:
# Set default num_workers
num_workers = 0

# Detect the OS name
os_name = os.popen('uname').read().strip()

# Check if the OS is Linux
if os_name == "Linux":
    
    print("Running on Linux. Setting num_workers to 64.")
    num_workers = 64
  
    print("Setting OS environment paths...")

    # Set CUDA_HOME to the conda environment prefix
    os.environ['CUDA_HOME'] = os.getenv('CONDA_PREFIX')

    # Update PATH to include the CUDA bin directory
    os.environ['PATH'] = os.path.join(os.getenv('CUDA_HOME'), 'bin') + ':' + os.getenv('PATH')

    # Update LD_LIBRARY_PATH to include the CUDA lib64 directory, handling the case where it's None
    ld_library_path = os.getenv('LD_LIBRARY_PATH')
    if ld_library_path is None:
        os.environ['LD_LIBRARY_PATH'] = os.path.join(os.getenv('CUDA_HOME'), 'lib64')
    else:
        os.environ['LD_LIBRARY_PATH'] = os.path.join(os.getenv('CUDA_HOME'), 'lib64') + ':' + ld_library_path

    # Set the environment variable for PyTorch CUDA memory allocation
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

## Confirm Number of Bands in HLS Satellite Data

In [5]:
#run_gdalinfo(f"./GIS-Image-Stack-Processing/AOI/{country_code}/Satellite_Tiles/{country_code}_1_C-1_30m.tif")

## System Configuration

In [6]:
def system_config(SEED_VALUE=42):
    """
    Configures the system environment for PyTorch-based operations.

    Args:
        SEED_VALUE (int): Seed value for random number generation. 
        package_list (str): String containing a list of additional packages to install  
        for Google Colab or Kaggle. 

    Returns:
        tuple: A tuple containing the device name as a string and a boolean indicating GPU availability.
    """

    random.seed(SEED_VALUE)
    np.random.seed(SEED_VALUE)
    torch.manual_seed(SEED_VALUE)

    def is_running_in_colab():
        return 'COLAB_GPU' in os.environ
        
    def is_running_in_kaggle():
        return 'KAGGLE_KERNEL_RUN_TYPE' in os.environ

    #--------------------------------
    # Check for availability of GPUs. 
    #--------------------------------
    if torch.cuda.is_available():
        print('Using CUDA GPU')
        
        # Set the device to the first CUDA device.
        DEVICE = torch.device('cuda')
        print("Device: ", DEVICE)
        GPU_AVAILABLE = True

        torch.cuda.manual_seed(SEED_VALUE)
        torch.cuda.manual_seed_all(SEED_VALUE)

        # Performance and deterministic behavior.
        torch.backends.cudnn.enabled = True       # Provides highly optimized primitives for DL operations.
        torch.backends.cudnn.deterministic = False 
        torch.backends.cudnn.benchmark = False    # Setting to True can cause non-deterministic behavior.
        
    else:
        
        print('Using CPU')
        DEVICE = torch.device('cpu')
        print("Device: ", DEVICE)
        GPU_AVAILABLE = False
        
        if is_running_in_colab() or is_running_in_kaggle():
            print('Installing required packages...')
            !pip install {package_list}
            print('Note: Change runtime type to GPU for better performance.')
        
        torch.use_deterministic_algorithms(True)

    return str(DEVICE), GPU_AVAILABLE

In [7]:
DEVICE, GPU_AVAILABLE = system_config()

if DEVICE == 'cuda':
    torch.cuda.empty_cache()
    !nvidia-smi

Using CPU
Device:  cpu


## Model and Data Configuration

In [8]:
@dataclass(frozen=True)
class DatasetConfig:
    COUNTRY_CODE:     str  
    IMG_HEIGHT:  int = 224
    IMG_WIDTH:   int = 224
    BATCH_SIZE:  int = 32
    NUM_WORKERS: int = 0
    GIS_ROOT:    str = './GIS-Image-Stack-Processing'
    AOI_ROOT:    str = './GIS-Image-Stack-Processing/AOI/'
    PRT_ROOT:    str = './GIS-Image-Stack-Processing/AOI/Partitions'
        
@dataclass(frozen=True)
class ModelConfig:   
    PRITHVI_WEIGHTS_PATH: str = "./prithvi_config/Prithvi_100M.pt"
    PRITHVI_CFG_PATH: str     = "./prithvi_config/Prithvi_100M_config.yaml"
    EXTRACT_FEATURES: bool    = True
        
@dataclass(frozen=True)
class TrainingConfig:
    BATCH_SIZE:      int   = 8
    NUM_WORKERS:     int   = num_workers  

In [9]:
dataset_config = DatasetConfig(COUNTRY_CODE=country_code)
model_config = ModelConfig()
train_config = TrainingConfig()

aoi_target_json_path = os.path.join(dataset_config.GIS_ROOT, f'AOI/{country_code}/Targets/targets.json')

## Load DHS Cluster Data and Target Values from AOI  `targets.json`

In [10]:
dhs_df, geospatial_df = process_aoi_target_json(aoi_target_json_path, country_code)

   cluster_id     lat     lon  fraction_dpt3_vaccinated  \
0           1  14.530 -11.324                     0.778   
1           2  14.789 -11.927                     0.231   
2           3  14.577 -11.844                     0.100   
3           4  15.105 -11.819                     0.167   
4           5  14.735 -11.114                     0.182   

   fraction_with_electricity  fraction_with_fresh_water  mean_wealth_index  \
0                      0.600                       1.00              0.750   
1                      0.680                       0.96              0.700   
2                      0.714                       1.00              0.643   
3                      0.421                       1.00              0.671   
4                      0.750                       1.00              0.625   

   fraction_with_radio  fraction_with_tv country_code  
0                0.040             0.160           ML  
1                0.040             0.160           ML  
2       

In [11]:
class CustomPrithviModel(nn.Module):
    
    def __init__(self, cfg_path, ckpt_path, 
                 num_classes=None, 
                 task_type=None, 
                 in_channels=6, 
                 img_size=(dataset_config.IMG_HEIGHT, dataset_config.IMG_WIDTH),
                 freeze_encoder=False):
        
        super(CustomPrithviModel, self).__init__()
        
        self.encoder = PrithviEncoder(
            cfg_path=cfg_path,
            ckpt_path=ckpt_path,
            num_frames=1,
            in_chans=in_channels,
            img_size=img_size
        )
        
        self.task_type = task_type
        self.num_classes = num_classes
        self.features = None  
        
        # Freeze the encoder if requested
        if freeze_encoder:
            for param in self.encoder.parameters():
                param.requires_grad = False

        # Initialize task-specific head if task_type is specified
        if task_type == 'classification' and num_classes is not None:
            self.head = nn.Linear(self.encoder.embed_dim, num_classes)
        elif task_type == 'regression':
            self.head = nn.Linear(self.encoder.embed_dim, 1)
        else:
            self.head = None  # No task-specific head initialized

    def forward(self, x):
        # Pass input through encoder
        x = self.encoder(x)
        self.features = x[:, 0]  # Save features for extraction
        
        # If head is defined, pass features through it
        if self.head:
            return self.head(self.features)
        
        return self.features  # Return features directly for extraction use case

    def register_feature_hook(self):
        # Register a hook to capture intermediate features
        def hook(module, input, output):
            self.features = output.detach()

        self.encoder.register_forward_hook(hook)

    def get_extracted_features(self):
        # Retrieve extracted features; return empty tensor if none
        return self.features if self.features is not None else torch.tensor([])


In [12]:
prithvi_model = CustomPrithviModel(cfg_path=model_config.PRITHVI_CFG_PATH,
                           ckpt_path=model_config.PRITHVI_WEIGHTS_PATH,
                           freeze_encoder=True)

if model_config.EXTRACT_FEATURES:
    prithvi_model.register_feature_hook()

##  Create `dataset` and `data_loader`

In [13]:
# Define mean and std from Prithvi configuration file
cfg = OmegaConf.load(model_config.PRITHVI_CFG_PATH)
mean = cfg['train_params']['data_mean']
std = cfg['train_params']['data_std']  

img_size = (dataset_config.IMG_HEIGHT, dataset_config.IMG_WIDTH)

# Define the transform
transform = lambda image: custom_transforms(image, mean=mean, std=std, img_size=img_size)

In [14]:
aoi_partition   = os.path.join(dataset_config.PRT_ROOT, f'{country_code}', f'{country_code}_all.json')

# Define the transform with required arguments
transform = partial(custom_transforms, mean=mean, std=std, img_size=img_size)

# Used to access AOI data for data exploration (not related to model training)
print('\n')
aoi_dataset = HLSFlexibleBandSatDataset(root_dir=dataset_config.AOI_ROOT,
                                        partition_map_path=aoi_partition, 
                                        num_channels=6,
                                        transform=transform)

print('\n')
aoi_data_loader   = DataLoader(aoi_dataset,   
                               batch_size=train_config.BATCH_SIZE, 
                               num_workers=train_config.NUM_WORKERS,
                               persistent_workers=False,
                               shuffle=False)

print("Number of samples in the aoi   data loader: ",   len(aoi_data_loader.dataset))



Processing AOI: ML, 322 clusters


Number of samples in the aoi   data loader:  322


## Display Sample Images (Optional)

In [15]:
# display_rgb_images(aoi_data_loader)

In [16]:
# display_ir_band(aoi_data_loader, band_index=3)

## Extract (HLS Satellite) Prithvi Features

In [17]:
def extract_prithvi_features_concatenated(prithvi_model, data_loader, layer_indices, device='cpu'):
    
    prithvi_model.to(device)
    prithvi_model.eval()

    cluster_ids = []
    target_values_list = []  # To store target values
    features_list = []  # List to store extracted features

    with torch.no_grad():
        for images, (cluster_id, aoi, targets) in data_loader:
            images = images.to(device)

            # Add a temporal dimension (T=1) to make it compatible with PatchEmbed
            images = images.unsqueeze(2)  # Add the temporal dimension

            # Extract intermediate layer features using get_intermediate_layers
            extracted_features = prithvi_model.encoder.get_intermediate_layers(
                images, n=layer_indices, mask_ratio=0.0, reshape=True, norm=True
            )

            # Concatenate features from the specified layers
            concatenated_features = torch.cat([feat.view(feat.size(0), -1) for feat in extracted_features], dim=1)
            features_list.append(concatenated_features)

            # Store the cluster IDs and target values
            cluster_ids.extend(cluster_id.tolist())
            target_values_list.extend(targets.tolist())

    # Concatenate all features along the first dimension
    all_features = torch.cat(features_list, dim=0)

    # Return features, cluster IDs, and target values
    return all_features, cluster_ids, target_values_list

In [18]:
features_prithvi, cluster_ids, target_values_list = extract_prithvi_features_concatenated(prithvi_model, 
                                                                                          aoi_data_loader, 
                                                                                          layer_indices,
                                                                                          device=DEVICE)
print(features_prithvi.shape)

torch.Size([322, 301056])


## Save Features to Disk

In [19]:
satellite_features_folder = f'{dataset_config.AOI_ROOT}/{country_code}/Satellite_Features'

if not os.path.exists(satellite_features_folder):
    os.makedirs(satellite_features_folder)

feature_file = f'{satellite_features_folder}/{sat_feature_file}'

np.savez(feature_file,
         features=features_prithvi,
         cluster_ids=cluster_ids)

print(f"Features saved to {feature_file}")

Features saved to ./GIS-Image-Stack-Processing/AOI//ML/Satellite_Features/ML_sat_features_prithvi_L6_L8.npz
