## Sentinel 1 - Softcon embeddings 

Experimental notebook to gather Sentinel-1 data from Earth Engine and run foundation 
model inference. 

Depends on the external repo https://github.com/zhu-xlab/softcon and the out-of-repo model backbone that is linked on their README page. A local path to the cloned repo is specified in the cell below. 

The softcon model is trained on the SSL4EO-S12 dataset. Dataset statistics for normalization come from: https://arxiv.org/abs/2211.07044, App. 1, p. 8.

Note that GEE data comes down in shape (h, w, bands), whereas rasterio, torch, etc. use (bands, h, w). For consistency we will assume the latter order after data download.

In [None]:
from datetime import datetime
import glob
import os
import sys

import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import rasterio
from rasterio.transform import Affine
import torch
from torchvision.transforms import transforms
from tqdm import tqdm

import gee
import utils

SOFTCON_PATH = 'softcon/'
sys.path.append(SOFTCON_PATH)
from models.dinov2 import vision_transformer 

SSL4EO_S1_STATS = {  
    'VV': {'mean': -12.59, 'std': 5.26},
    'VH': {'mean': -20.26, 'std': 5.91}
}

%load_ext autoreload
%autoreload 2

In [None]:
def normalize(band, mean, std):
    """Normalize image data bandwise."""
    min_value = mean - 2 * std
    max_value = mean + 2 * std
    band = (band - min_value) / (max_value - min_value) * 255.0
    band = np.clip(band, 0, 255).astype(np.float32)  # Zhu lab suggest uint8 here but the model requires float32
    return band

# We split a tile into geographic patches, or chips, with a size selected for object-oriented detection
# applications. Experience indicates a chip size of order hundreds of meters rather than kms. Later, chips  
# will be resized to match the input dimension expected by the model. It works. Why? It's a mystery.

def cut_chips(tile_pixels, tile_info, geo_chip_size=32, stride_frac=2):
    """Split a large geographic tile into patches to be embedded."""
    stride = geo_chip_size // stride_frac
    chips, chip_geoms = utils.chips_from_tile(np.moveaxis(tile_pixels, 0, -1), tile_info, geo_chip_size, stride)
    chips = np.array(chips)
    chips = np.moveaxis(chips, -1, 1)
    chip_geoms.to_crs("EPSG:4326", inplace=True)
    return chips, chip_geoms

def describe(arr):
    """Compute summary stats akin to pandas df.describe()."""
    summary = {
        "count": arr.size,
        "mean": np.mean(arr),
        "std": np.std(arr, ddof=1), 
        "min": np.min(arr),
        "25%": np.percentile(arr, 25),
        "50% (median)": np.median(arr),
        "75%": np.percentile(arr, 75),
        "max": np.max(arr),
    }
    return summary

### Tiling an AOI

In [None]:
region_name = 'tapajos_test_region'
region = gpd.read_file(f'../data/boundaries/{region_name}.geojson').geometry[0].__geo_interface__

tilesize = 1344 # previously 576, which was around the max size allowed for GEE export for 12-band imagery 
padding = 24

start_date = datetime(2024, 12, 1)
end_date = datetime(2024, 12, 31)

In [None]:
tiles = utils.create_tiles(region, tilesize=tilesize, padding=padding)
print(f"Created {len(tiles):,} tiles")
print(f'Sample tile data:\n{tiles[0]}')

### GEE S1 data download

In [None]:
data_pipeline = gee.GEE_Data_Extractor(
    tiles, 
    start_date, 
    end_date, 
    batch_size=500,
    collection='S1'
    )

In [None]:
data_pipeline.composite.bandNames().getInfo()

In [None]:
data_dir = 'S1datav2'
if not os.path.exists(data_dir):
    os.mkdir(data_dir)

In [None]:
save_visual = True  # Save bandwise, uint8 copies of the data for easier visualization

for tile in tqdm(tiles[20:]):
    img, tile_info = data_pipeline.get_tile_data(tile)
    img = utils.pad_patch(img, tile_info.tilesize)
    
    assert tile.tilesize == img.shape[0]
    assert tile.tilesize == img.shape[1]
    
    img = img.astype('float32')
    path = os.path.join(data_dir, f"{region_name}S1_{tile.key}.tif")
    
    # DLTile uses a pseudo-UTM system with only UTM North CRSs. Fix this. 
    if tile.bounds[1] < 0 and tile.crs.upper().startswith("EPSG:326"):
        utm_zone = tile.crs.split(":")[1][-2:] 
        crs = f"EPSG:327{utm_zone}"
        geotrans = list(tile.geotrans)
        geotrans[3] = geotrans[3] + 10000000
        
    else:
        crs = tile.crs
        geotrans = tile.geotrans
        

    profile = {
        'count': img.shape[-1],
        'height': img.shape[0],
        'width': img.shape[1],
        'crs': crs,
        'transform': Affine.from_gdal(*geotrans),
        'dtype': img.dtype
    }
    
    with rasterio.open(path, 'w', **profile) as f:
        for band in range(2):
            f.write(img[:, :, band], band + 1)
            
    if save_visual:
        profile.update({'count': 1, 'dtype': 'uint8'})
        for i, (band, stats) in enumerate(SSL4EO_S1_STATS.items()):
            with rasterio.open(path.split('.tif')[0] + f'{band}.tif', 'w', **profile) as f:
                raster = img[:, :, i].reshape(1, *img[:, :, i].shape)
                raster = normalize(raster, stats['mean'], stats['std'])
                f.write(raster)
                print(describe(raster))

    print(f"Saved {path}")


#### Bulk reload data from disk for inspection. For inference, tiles are loaded one by one to save RAM. See below.

In [None]:
data_dir = 'S1datav2'

In [None]:
available_paths = glob.glob(f'{data_dir}/{region_name}*.tif')
paths = []
for tile in tiles[:2]:
    for path in available_paths:
        if tile.key in path and 'VV' not in path and 'VH' not in path:
            paths.append(path)
paths

In [None]:
pixels = []
for path in paths:
    with rasterio.open(path, 'r') as f:
        S1image = f.read()
        pixels.append(S1image)
    
pixels = np.array(pixels)
pixels.shape

In [None]:
for p in pixels:
    for arr in p:
        print(describe(arr))

In [None]:
to_view = pixels[:3]

fig, axes = plt.subplots(len(to_view), 2, figsize=(10, 10*len(to_view)))

if axes.ndim == 1:
    axes = axes[np.newaxis, :]  

for row,img in zip(axes, to_view):
    for (ax, band, band_name) in zip(row, img, data_pipeline.bandIds):
        ax.imshow(band)
        ax.set_title(band_name)
        ax.axis("off") 

# plt.savefig(f'{data_dir}/{region_name}_S1to{end_date.date().isoformat()}.png', bbox_inches='tight')

### Inference

In [None]:
# Model

model_chip_size = 224

# For running on Mac Mx chip
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu") 

print(f'Device: {device}')

model = vision_transformer.__dict__['vit_small'](
    img_size=model_chip_size,
    patch_size=14,
    in_chans=2,
    block_chunks=0,
    init_values=1e-5,
    num_register_tokens=0,
)

model_name = 'B2_vits14_softcon.pth'
ckpt_vits14 = torch.load(os.path.join(SOFTCON_PATH, f'pretrained/{model_name}'))
model.load_state_dict(ckpt_vits14)

model.to(device)
model.eval()

In [None]:
# Inference

geo_chip_size = 32
batch_size = 128 
feature_columns = [f"vit-dino-patch14_{i}" for i in range(features.shape[-1])] 

gdfs = []
for tile in tqdm(tiles):
    path = os.path.join(data_dir, f"{region_name}S1_{tile.key}.tif")
    with rasterio.open(path, 'r') as f:
        pixels = f.read()
    
    normed = [normalize(band, stats['mean'], stats['std']) for band, stats in zip(pixels, SSL4EO_S1_STATS.values())]
    normed = np.array(normed)
    
    chips, chip_geoms = cut_chips(normed, tile, geo_chip_size=geo_chip_size)
    tensor = torch.from_numpy(chips)
    if geo_chip_size != model_chip_size:
        tensor = transforms.Resize((model_chip_size, model_chip_size), antialias=False).__call__(tensor)

    print(f'Input tensor shape {tensor.shape}')
    tensor = tensor.to(device)
    
    batch_outputs = []
    for i in tqdm(range(0, len(tensor), batch_size)):
        batch = tensor[i : i + batch_size]
        with torch.no_grad():
            batch_output = model(batch)
        batch_outputs.append(batch_output)
    batch_outputs = torch.cat(batch_outputs).cpu().numpy()    
    
    features_df = gpd.pd.DataFrame(batch_outputs, columns=feature_columns)
    gdf = gpd.pd.concat([chip_geoms, features_df], axis=1)
    gdfs.append(gdf)
    
gdf = gpd.pd.concat(gdfs).reset_index(drop=True)

In [None]:
if not os.path.exists(region_name):
    os.mkdir(region_name)

gdf.to_parquet(f"{region_name}/{region_name}_{model_name.split('.pth')[0]}_{geo_chip_size}chip_S1to{end_date.date().isoformat()}.parquet", index=False)

#### Optional embedding quantization to save memory

In [None]:
# Inspect and adjust upper / lower bound to ensure sufficient variance after quantization.
# (In principle the bounds should be set once across all S1 embeddings.)
gdf.describe()

In [None]:
def quantize(embeddings, lower_bound=-5, upper_bound=5):
    clipped = np.clip(embeddings, lower_bound, upper_bound)
    normalized = (clipped - lower_bound) / (upper_bound - lower_bound)
    scaled = normalized * 255
    return scaled.astype(np.uint8)

quantized = quantize(gdf.drop(columns='geometry').to_numpy())
features_df = gpd.pd.DataFrame(quantized, columns=feature_columns)
q_gdf = gpd.pd.concat([gdf['geometry'], features_df], axis=1)
q_gdf.head()

In [None]:
q_gdf.describe()

In [None]:
q_gdf.to_parquet(f"{region_name}/{region_name}_{model_name.split('.pth')[0]}_{geo_chip_size}chip_S1to{end_date.date().isoformat()}_quant.parquet", index=False)