In [None]:
%pip install torchgeo

1. Visit https://torchgeo.readthedocs.io/en/stable/api/models.html#pretrained-weights

2. Under Sentinel-2, choose the correct model based on number of channels

In [None]:
import torch
import torch.nn as nn
from torchvision.transforms import functional as TF
from torchgeo.models import resnet50, ResNet50_Weights

from PIL import Image
import argparse
import os

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def load_model(device):
    weights = ResNet50_Weights.SENTINEL2_RGB_SECO #REPLACE WITH CHOSEN WEIGHTS
    model = resnet50(weights=weights)
    model.fc = nn.Identity()  #remove classification layer
    model.to(device)
    model.eval()
    transform = weights.transforms()
    return model, transform

In [None]:
model, transform = load_model(device)

In [None]:
from PIL import Image
import numpy as np
import rasterio
import os

def load_image(path):
    if path.endswith(".tif"):
        with rasterio.open(path) as src:
            img = src.read()  #shape [C, H, W]
            img = torch.from_numpy(img).float() 
            #img /= 1000.0 #add normalization if needed
    elif path.endswith(".npy"):
        img = np.load(path)
        img = torch.from_numpy(img).float()
    else:
        img = Image.open(path).convert("RGB")
        img = transform(img)
        return img.unsqueeze(0).to(device)

    return img

In [None]:
from torchvision.transforms import Resize

def get_embedding(img_tensor):
    img_tensor = Resize((224, 224))(img_tensor)
    img_tensor = transform(img_tensor.permute(1, 2, 0).cpu().numpy())
    img_tensor = img_tensor.unsqueeze(0).to(device)
    with torch.no_grad():
        embedding = model(img_tensor)  #shape [1, 2048] or depending on model...
    return embedding.squeeze(0).cpu()

In [None]:
from pathlib import Path
import tqdm
import pandas as pd

def process_images_to_feather(image_paths, output_feather):
    records = []

    for path in tqdm(image_paths):
        try:
            img = load_image(path)
            #may need other processing, or to select certain bands
            embedding = get_embedding(img)
            record = {"image": os.path.basename(path)}
            record.update({f"f{i}": v for i, v in enumerate(embedding)})
            records.append(record)
        except Exception as e:
            print(f"Failed: {path} — {e}")

    df = pd.DataFrame(records)
    df.to_feather(output_feather)
    print(f"Saved to {output_feather}")


### Example

In [None]:
image_paths = list() #insert image path directories
output_feather = "sentinel2_embeddings.feather"

process_images_to_feather(image_paths, output_feather)
