This guide shows how to train and infer [Clay Foundation Model](https://clay-foundation.github.io/model/index.html) for segmentation task ([Chesapeake Land Cover
](https://lila.science/datasets/chesapeakelandcover)) using our API.

<a target="_blank" href="https://colab.research.google.com/github/bluesightai/docs/blob/main/guides/train-a-land-segmentation.ipynb">
  <img noZoom src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [2]:
API_URL = "https://api.bluesight.ai"
API_URL = "http://localhost:8000"

GSD = 1.0
CHIP_SIZE = 224
COLLECTION = "naip"
BANDS = ["red", "green", "blue", "nir"]

DATA_DIR = "./data"
N_TRAIN_SAMPLES = 2
N_TEST_SAMPLES = 100

RANDOM_SEED = 42
HEADERS = {"Content-Type": "application/json"}

## Loading data



## Processing data

https://github.com/Clay-foundation/model/blob/main/finetune/segment/preprocess_data.py

In [3]:
import sys
import gc
from pathlib import Path

import numpy as np
import rasterio as rio

import os
import re
from tqdm import tqdm

In [4]:
DATA_DIR = Path(DATA_DIR)
OUTPUT_DIR = DATA_DIR / "output"

def read_and_chip(file_path, chip_size, output_dir):
    """
    Reads a GeoTIFF file, creates chips of specified size, and saves them as
    numpy arrays.

    Args:
        file_path (str or Path): Path to the GeoTIFF file.
        chip_size (int): Size of the square chips.
        output_dir (str or Path): Directory to save the chips.
    """
    os.makedirs(output_dir, exist_ok=True)

    with rio.open(file_path) as src:
        data = src.read()

        n_chips_x = src.width // chip_size
        n_chips_y = src.height // chip_size

        chip_number = 0
        for i in range(n_chips_x):
            for j in range(n_chips_y):
                x1, y1 = i * chip_size, j * chip_size
                x2, y2 = x1 + chip_size, y1 + chip_size

                chip = data[:, y1:y2, x1:x2]
                chip_path = os.path.join(
                    output_dir,
                    f"{Path(file_path).stem}_chip_{chip_number}.npy",
                )
                np.save(chip_path, chip)
                chip_number += 1


def process_files(file_paths, output_dir, chip_size):
    """
    Processes a list of files, creating chips and saving them.

    Args:
        file_paths (list of Path): List of paths to the GeoTIFF files.
        output_dir (str or Path): Directory to save the chips.
        chip_size (int): Size of the square chips.
    """
    for file_path in tqdm(file_paths):
        read_and_chip(file_path, chip_size, output_dir)
        

train_image_paths = list((DATA_DIR / "train").glob("*_naip-new.tif"))
train_label_paths = list((DATA_DIR / "train").glob("*_lc.tif"))
val_image_paths = list((DATA_DIR / "val").glob("*_naip-new.tif"))
val_label_paths = list((DATA_DIR / "val").glob("*_lc.tif"))

process_files(train_image_paths, OUTPUT_DIR / "train/chips", CHIP_SIZE)
process_files(train_label_paths, OUTPUT_DIR / "train/labels", CHIP_SIZE)
process_files(val_image_paths, OUTPUT_DIR / "val/chips", CHIP_SIZE)
process_files(val_label_paths, OUTPUT_DIR / "val/labels", CHIP_SIZE)



  0%|                                                                                                                                                                                       | 0/100 [00:02<?, ?it/s]


KeyboardInterrupt: 

In [5]:
train_chip_names = [chip_path.name for chip_path in list((OUTPUT_DIR / "train/chips").glob("*.npy"))]
train_label_names = [re.sub("_naip-new_", "_lc_", chip) for chip in train_chip_names]

val_chip_names = [chip_path.name for chip_path in list((OUTPUT_DIR / "val/chips").glob("*.npy"))]
val_label_names = [re.sub("_naip-new_", "_lc_", chip) for chip in val_chip_names]

In [23]:
label_mapping = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4, 6: 5, 15: 6}

X_train = [{
    "gsd": GSD,
    "bands": BANDS,
    "pixels": np.load(OUTPUT_DIR / "train/chips" / chip_name).tolist(),
    "platform": COLLECTION
} for chip_name in tqdm(train_chip_names[:N_TRAIN_SAMPLES])]
y_train = [{"label": np.vectorize(label_mapping.get)(np.load(OUTPUT_DIR / "train/labels" / chip_name).squeeze()).tolist()} for chip_name in tqdm(train_label_names[:N_TRAIN_SAMPLES])]

X_test = [{
    "gsd": GSD,
    "bands": BANDS,
    "pixels": np.load(OUTPUT_DIR / "val/chips" / chip_name).tolist(),
    "platform": COLLECTION
} for chip_name in tqdm(val_chip_names[:N_TEST_SAMPLES])]
y_test = [{"label": np.vectorize(label_mapping.get)(np.load(OUTPUT_DIR / "val/labels" / chip_name).squeeze()).tolist()} for chip_name in tqdm(val_label_names[:N_TEST_SAMPLES])]

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 320.65it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 206.13it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 180.41it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 237.22it/s]


In [24]:
import matplotlib.pyplot as plt
import numpy as np

def visualize_image_mask_pair(image, mask, title="Image and Mask Pair"):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
    
    ax1.imshow(image)
    ax1.set_title("Original Image")
    ax1.axis('off')
    
    ax2.imshow(mask, cmap='viridis')
    ax2.set_title("Segmentation Mask")
    ax2.axis('off')
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

random_indices = np.random.choice(len(X_train), 5, replace=False)

for i, idx in enumerate(random_indices):
    image = np.array(X_train[idx]['pixels']).transpose(1, 2, 0)
    mask = np.array(y_train[idx]['labels'])
    
    if image.shape[2] > 3:
        image = image[:, :, :3]
    
    visualize_image_mask_pair(image, mask, f"Random Sample {i+1}")

ValueError: Cannot take a larger sample than population when 'replace=False'

## Training model

Now we need to load data to the format which is accepted by our API. Essentially, it is raw pixels data from all bands with some metadata.

You can check detailed endpoint specification [in the docs](https://docs.bluesight.ai/api-reference/train/train-segmentation-model).

In [32]:
import requests

url = API_URL + "/train/segmentation"
payload = {"images": X_train, "labels": y_train}
response = requests.request("POST", url, json=payload, headers=HEADERS)
if response.status_code != 200:
    print(f"{response.status_code}: {response.json()}")
else:
    data = response.json()
    model_id, train_details = data["model_id"], data["train_details"]
    print(f"Model id: '{model_id}'")

Model id: '123'


## Run inference

For inference we need only images.

You can check detailed endpoint specification [in the docs](https://docs.bluesight.ai/api-reference/inference/infer-segmentation-model).

In [7]:
url = API_URL + "/inference/classification"
payload = {"images": X_test, "model_id": model_id}
response = requests.request("POST", url, json=payload, headers=HEADERS)
if response.status_code != 200:
    print(f"{response.status_code}: {response.json()}")
else:
    y_pred = response.json()["labels"]

In [8]:
from sklearn.metrics import classification_report, accuracy_score

match = np.sum(np.array(y_test) == np.array(y_pred))
print(f"Matched {match} out of {len(y_test)} correctly")

print(f"Accuracy: {accuracy_score(y_test, y_pred)}")
print(classification_report(y_test, y_pred))

Matched 778 out of 800 correctly
Accuracy: 0.9725
              precision    recall  f1-score   support

           0       0.98      0.98      0.98       600
           1       0.94      0.95      0.95       200

    accuracy                           0.97       800
   macro avg       0.96      0.96      0.96       800
weighted avg       0.97      0.97      0.97       800

