<a href="https://colab.research.google.com/github/kelsdoerksen/giga-connectivity/blob/main/SatCLIP_Embedding_Extraction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Script for extracting feature embeddings for the lat,lon school locations to use for connectivity prediction

In [None]:
!rm -r sample_data .config # Empty current directory
!git clone https://github.com/microsoft/satclip.git . # Clone SatCLIP repository

Cloning into '.'...
remote: Enumerating objects: 189, done.[K
remote: Counting objects: 100% (94/94), done.[K
remote: Compressing objects: 100% (78/78), done.[K
remote: Total 189 (delta 44), reused 44 (delta 15), pack-reused 95[K
Receiving objects: 100% (189/189), 10.10 MiB | 20.24 MiB/s, done.
Resolving deltas: 100% (78/78), done.


In [None]:
!pip install lightning --quiet
!pip install rasterio --quiet
!pip install torchgeo --quiet
!pip install basemap --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m840.4/840.4 kB[0m [31m15.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m800.9/800.9 kB[0m [31m17.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m20.6/20.6 MB[0m [31m29.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m378.5/378.5 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m756.0/756.0 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m733.1/733.1 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━

In [None]:
# Loading required packages
import sys
import pandas as pd
sys.path.append('./satclip')

import torch
from load import get_satclip

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Automatically select device

In [None]:
# Get [lon, lat] of schools as float.64 tensor to extract embeddings for

def get_coords(df):
  """
  Function to return coords of school locations
  as 2D tensor to extract GeoCLIP embeddings for
  in order lon, lat
  """

  total_coords = []
  for i in range(len(df)):
    coord = torch.tensor((df.loc[i]['lon'], df.loc[i]['lat']))
    total_coords.append(coord)

  locations = torch.stack(total_coords)

  return locations

SatCLIP model names:

*   satclip-resnet18-l10
*   satclip-resnet18-l40
*   satclip-resnet50-l10
*   satclip-resnet50-l40
*   satclip-vit16-l10
*   satclip-vit16-l40

In [None]:
# Processing data for locations for the embeddings to be extracted from
aoi = 'RWA'
split = 'Testing'
aoi_df = pd.read_csv('{}Data_uncorrelated_fixed.csv'.format(split))

In [None]:
# Get coordinates for aoi of interest
coords = get_coords(aoi_df)

In [None]:
satclip_model = 'satclip-resnet50-l40'

In [None]:
# Grab embeddings for each model type
embeddings = []

!wget 'https://satclip.z13.web.core.windows.net/satclip/satclip-resnet50-l40.ckpt'
model = get_satclip('satclip-resnet50-l40.ckpt', device=device)
model.eval()
with torch.no_grad():
  x  = model(coords.double().to(device)).detach().cpu()


--2024-02-22 06:01:52--  https://satclip.z13.web.core.windows.net/satclip/satclip-resnet50-l40.ckpt
Resolving satclip.z13.web.core.windows.net (satclip.z13.web.core.windows.net)... 52.239.221.231
Connecting to satclip.z13.web.core.windows.net (satclip.z13.web.core.windows.net)|52.239.221.231|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 129923067 (124M) [application/zip]
Saving to: ‘satclip-resnet50-l40.ckpt’


2024-02-22 06:01:54 (86.8 MB/s) - ‘satclip-resnet50-l40.ckpt’ saved [129923067/129923067]

using pretrained moco resnet50


Downloading: "https://huggingface.co/torchgeo/resnet50_sentinel2_all_moco/resolve/main/resnet50_sentinel2_all_moco-df8b932e.pth" to /root/.cache/torch/hub/checkpoints/resnet50_sentinel2_all_moco-df8b932e.pth
100%|██████████| 90.1M/90.1M [00:00<00:00, 126MB/s]


In [None]:
identifying_info_df = aoi_df[['giga_id_school', 'connectivity', 'lat', 'lon']]
emb_df = pd.DataFrame(x.numpy())

In [None]:
emb_df_labelled = pd.concat([identifying_info_df, emb_df], axis=1)

In [None]:
emb_df_labelled['data_split'] = split

In [None]:
# Export to dataframe
emb_df_labelled.to_csv('{}_{}_embeddings_{}.csv'.format(aoi, satclip_model, split))