<a href="https://colab.research.google.com/github/kelsdoerksen/giga-connectivity/blob/main/SatCLIP_Embedding_Extraction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Script for extracting feature embeddings for the lat,lon school locations to use for connectivity prediction

In [1]:
!rm -r sample_data .config # Empty current directory
!git clone https://github.com/microsoft/satclip.git . # Clone SatCLIP repository

Cloning into '.'...
remote: Enumerating objects: 250, done.[K
remote: Counting objects: 100% (155/155), done.[K
remote: Compressing objects: 100% (135/135), done.[K
remote: Total 250 (delta 61), reused 85 (delta 19), pack-reused 95[K
Receiving objects: 100% (250/250), 30.46 MiB | 12.06 MiB/s, done.
Resolving deltas: 100% (95/95), done.


In [2]:
!pip install lightning --quiet
!pip install rasterio --quiet
!pip install torchgeo --quiet
!pip install basemap --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m808.5/808.5 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m868.8/868.8 kB[0m [31m34.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m812.3/812.3 kB[0m [31m38.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.3/21.3 MB[0m [31m32.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.5/21.5 MB[0m [31m48.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m381.1/381.1 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m833.3/833.3 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━

In [3]:
# Loading required packages
import sys
import pandas as pd
sys.path.append('./satclip')

import torch
from load import get_satclip

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Automatically select device

In [4]:
# Get [lon, lat] of schools as float.64 tensor to extract embeddings for

def get_coords(df):
  """
  Function to return coords of school locations
  as 2D tensor to extract embeddings for
  in order lon, lat
  """

  total_coords = []
  for i in range(len(df)):
    coord = torch.tensor((df.loc[i]['lon'], df.loc[i]['lat']))
    total_coords.append(coord)

  locations = torch.stack(total_coords)

  return locations

SatCLIP model names:

*   satclip-resnet18-l10
*   satclip-resnet18-l40
*   satclip-resnet50-l10
*   satclip-resnet50-l40
*   satclip-vit16-l10
*   satclip-vit16-l40

In [5]:
# Processing data for locations for the embeddings to be extracted from
# Load in id info column and extract features accordingly
RWA_df = pd.read_csv('RWA_id_info.csv')

In [7]:
# Get coordinates for aoi of interest
coords = get_coords(RWA_df)

In [40]:
satclip_model = 'satclip-vit16-l40'

In [41]:
# Grab embeddings for each model type
embeddings = []

!wget 'https://satclip.z13.web.core.windows.net/satclip/satclip-vit16-l40.ckpt'
model = get_satclip('satclip-vit16-l40.ckpt', device=device)
model.eval()
with torch.no_grad():
  x  = model(coords.double().to(device)).detach().cpu()


--2024-07-22 21:14:32--  https://satclip.z13.web.core.windows.net/satclip/satclip-vit16-l40.ckpt
Resolving satclip.z13.web.core.windows.net (satclip.z13.web.core.windows.net)... 52.239.221.231
Connecting to satclip.z13.web.core.windows.net (satclip.z13.web.core.windows.net)|52.239.221.231|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 120982795 (115M) [application/zip]
Saving to: ‘satclip-vit16-l40.ckpt’


2024-07-22 21:14:34 (52.6 MB/s) - ‘satclip-vit16-l40.ckpt’ saved [120982795/120982795]

using pretrained moco vit16


Downloading: "https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/resolve/1cb683f6c14739634cdfaaceb076529adf898c74/vit_small_patch16_224_sentinel2_all_moco-67c9032d.pth" to /root/.cache/torch/hub/checkpoints/vit_small_patch16_224_sentinel2_all_moco-67c9032d.pth
100%|██████████| 86.5M/86.5M [00:01<00:00, 45.9MB/s]


In [42]:
identifying_info_df = RWA_df[['giga_id_school', 'connectivity', 'lat', 'lon', 'split', 'fid']]
emb_df = pd.DataFrame(x.numpy())

In [43]:
emb_df_labelled = pd.concat([identifying_info_df, emb_df], axis=1)

In [44]:
# Split into Train/Test/Val
emb_train = emb_df_labelled[emb_df_labelled['split'] =='Train']
emb_test = emb_df_labelled[emb_df_labelled['split'] =='Test']
emb_val = emb_df_labelled[emb_df_labelled['split'] =='Val']

In [45]:
# Export to dataframe
emb_train.to_csv('RWA_{}_embeddings_TrainingData.csv'.format(satclip_model))
emb_test.to_csv('RWA_{}_embeddings_TestingData.csv'.format(satclip_model))
emb_val.to_csv('RWA_{}_embeddings_ValData.csv'.format(satclip_model))