In [None]:
import yaml
import torch
from claymodel.module import ClayMAEModule
import numpy as np
from pathlib import Path
import datetime
from matplotlib import pyplot as plt
from pyproj import Transformer
import rasterio
from rasterio.plot import show
import copy
from torchvision.transforms import v2
import math
from estuary.util import contrast_stretch
from einops import rearrange
from sklearn import decomposition, svm
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import ConfusionMatrixDisplay, accuracy_score, precision_recall_fscore_support, f1_score, balanced_accuracy_score
from estuary.clay.data import EstuaryDataModule
from estuary.clay.module import EstuaryModule
import tqdm

In [None]:
# Load model
model = ClayMAEModule.load_from_checkpoint(
    "/Users/kyledorman/data/models/clay/clay-v1.5.ckpt",
    metadata_path="/Users/kyledorman/data/models/clay/metadata.yaml",
    mask_ratio=0.0,
    shuffle=False,
)
_ = model.eval()

In [None]:
# Load sensor metadata
with open("/Users/kyledorman/data/models/clay/metadata.yaml", "r") as f:
    metadata = yaml.safe_load(f)

channel_4_band_order = [
    'blue',
    'green',
    'red',
    'nir',
]
planetscope = metadata['planetscope-sr']
metadata['planetscope-sr-4'] = {}
metadata['planetscope-sr-4']['band_order'] = channel_4_band_order
metadata['planetscope-sr-4']['rgb_indices'] = [3, 2, 1]
metadata['planetscope-sr-4']['gsd'] = 3
bands = {}
for k, vs in planetscope['bands'].items():
    vs4 = {kk: vv for kk, vv in vs.items() if kk in channel_4_band_order}
    bands[k] = vs4
metadata['planetscope-sr-4']['bands'] = bands

metadata['planetscope-sr-4']

In [None]:
def normalize_latlon(lat, lon):
    """
    Normalize latitude and longitude to a range between -1 and 1.

    Parameters:
    lat (float): Latitude value.
    lon (float): Longitude value.

    Returns:
    tuple: Normalized latitude and longitude values.
    """
    lat = lat * np.pi / 180
    lon = lon * np.pi / 180

    return (math.sin(lat), math.cos(lat)), (math.sin(lon), math.cos(lon))

def normalize_timestamp(date):
    week = date.isocalendar().week * 2 * np.pi / 52
    hour = date.hour * 2 * np.pi / 24

    return (math.sin(week), math.cos(week)), (math.sin(hour), math.cos(hour))

def prep_datacube(image, lat, lon, date):
    """
    Prepare a data cube for model input.

    Parameters:
    image (np.array): The input image array.
    lat (float): Latitude value for the location.
    lon (float): Longitude value for the location.

    Returns:
    dict: Prepared data cube with normalized values and embeddings.
    """
    md = metadata['planetscope-sr-4']

    # Extract mean, std, and wavelengths from metadata
    mean = []
    std = []
    waves = []
    bands = md['band_order']
    for band_name in bands:
        mean.append(md['bands']['mean'][band_name])
        std.append(md['bands']['std'][band_name])
        waves.append(md['bands']['wavelength'][band_name] * 1000)

    transform = v2.Compose(
        [
            v2.Resize(size=(128, 128), interpolation=3),
            v2.Normalize(mean=mean, std=std),
        ]
    )

    # Prep datetimes embedding
    times = normalize_timestamp(date)
    week_norm = times[0]
    hour_norm = times[1]

    # Prep lat/lon embedding
    latlons = normalize_latlon(lat, lon)
    lat_norm = latlons[0]
    lon_norm = latlons[1]

    # Prep pixels
    pixels = torch.from_numpy(image.astype(np.float32))
    pixels = transform(pixels)
    pixels = pixels.unsqueeze(0)

    # Prepare additional information
    return {
        "pixels": pixels,
        "time": torch.tensor(
            np.hstack((week_norm, hour_norm)),
            dtype=torch.float32,
        ).unsqueeze(0),
        "latlon": torch.tensor(
            np.hstack((lat_norm, lon_norm)), dtype=torch.float32
        ).unsqueeze(0),
        "waves": torch.tensor(waves),
        "gsd": torch.tensor(md['gsd'] * 2).unsqueeze(0),
    }

In [None]:
datacube = prep_datacube(data, *cent_g, dt)
with torch.no_grad():
    unmsk_patch, unmsk_idx, msk_idx, msk_matrix = model.model.encoder(datacube)

In [None]:
# The first embedding is the class token, which is the overall single embedding.
unmsk_patch[:, 0, :].cpu().numpy()

In [None]:
ep_embedder_cpu = torch.export.load("/Users/kyledorman/data/models/clay/clay-v1.5-encoder_256.pt2").module()

In [None]:
%%time
datacube = prep_datacube(data, *cent_g, dt)
with torch.no_grad():
    embeddings = ep_embedder_cpu(datacube)
datacube["pixels"].shape, embeddings.shape

In [None]:
size = 32

embed = rearrange(
    unmsk_patch[:, 1:, :].detach().cpu().numpy(), "b (h w) d-> b d h w", h=size, w=size
)
embed = embed[0]
rows = 4
cols = 4
fig, axs = plt.subplots(cols, cols, figsize=(20, 20))

# idxes = np.random.choice(unmsk_patch.shape[2], rows * cols - 1, replace=False)

for idx, ax in zip(idxes, axs.flatten()[1:]):
    ax.imshow(embed[idx], cmap="bwr")
    ax.set_axis_off()
    ax.set_title(idx)

ax = axs.flatten()[0]
dd = np.log10(1 + data[[3, 2, 1]].clip(1, 2000))
dd = contrast_stretch(dd)
show(dd, ax=ax)
ax.set_axis_off()
ax.set_title(idx)

plt.tight_layout()

In [None]:
labels = pd.read_csv("/Users/kyledorman/data/estuary/label_studio/00025/labels.csv")
labels.head(5)

In [None]:
module = EstuaryModule.load_from_checkpoint(
    "/Users/kyledorman/data/results/estuary/train/20250805-205230/checkpoints/last.ckpt", 
    strict=False)
module.conf.holdout_region = None
module = module.eval()

In [None]:
dm = EstuaryDataModule(module.conf)
dm.prepare_data()
dm.setup()

In [None]:
dl = dm.train_dataloader()

preds = []
labels = []
for batch, blabel in tqdm.tqdm(dl, total=len(dl)):
    for k in batch.keys():
        batch[k] = batch[k].to(module.device)
    pred_batch = module.forward(batch)
    preds.extend(pred_batch.argmax(axis=1).detach().cpu().numpy().tolist())
    labels.extend(blabel.detach().cpu().numpy().tolist())

accuracy_score(labels, preds)

In [None]:
dl = dm.val_dataloader()

preds = []
labels = []
for batch, blabel in tqdm.tqdm(dl, total=len(dl)):
    for k in batch.keys():
        batch[k] = batch[k].to(module.device)
    pred_batch = module.forward(batch)
    preds.extend(pred_batch.argmax(axis=1).detach().cpu().numpy().tolist())
    labels.extend(blabel.detach().cpu().numpy().tolist())

accuracy_score(labels, preds)

In [None]:
dl = dm.test_dataloader()

preds = []
labels = []
for batch, blabel in tqdm.tqdm(dl, total=len(dl)):
    for k in batch.keys():
        batch[k] = batch[k].to(module.device)
    pred_batch = module.forward(batch)
    preds.extend(pred_batch.argmax(axis=1).detach().cpu().numpy().tolist())
    labels.extend(blabel.detach().cpu().numpy().tolist())

accuracy_score(labels, preds)

In [None]:
label_df = dm.test_ds.df
X_test = []
for _, row in label_df.iterrows():
    pth = Path(row.source_jpeg)
    emb_pth = pth.parent.parent / "embeddings" / f"{pth.stem}.npy"
    emb = np.load(emb_pth)
    X_test.append(emb)
X_test = np.array(X_test)
y_test = np.array(label_df.label_idx.tolist())

label_df = dm.train_ds.df
X_train = []
for _, row in label_df.iterrows():
    pth = Path(row.source_jpeg)
    emb_pth = pth.parent.parent / "embeddings" / f"{pth.stem}.npy"
    emb = np.load(emb_pth)
    X_train.append(emb)
X_train = np.array(X_train)
y_train = np.array(label_df.label_idx.tolist())

In [None]:
clf = svm.SVC()
# clf = RandomForestClassifier()
clf.fit(X_train, y_train)

# Predict classes on test set
svn_pred = clf.predict(X_test)
y_test = dm.test_ds.df.label_idx
# Perfect match for SVM
match = np.sum(y_test == svn_pred)
print(f"Matched {match} out of {len(X_test)} correctly")

_ = ConfusionMatrixDisplay.from_predictions(
    y_test, 
    svn_pred, 
    labels=list(range(len(["open", "closed"]))),
    display_labels=["open", "closed"],
)

print("Accuracy", accuracy_score(y_test, svn_pred))
prfs = precision_recall_fscore_support(y_test, svn_pred, average='macro')
print("F1", round(prfs[2], 3))
print("Precision", round(prfs[0], 3))
print("Recall", round(prfs[1], 3))

In [None]:
pred_df = pd.DataFrame(
    list(zip(
        preds,
        svn_pred,
        dm.test_ds.df.label_idx.tolist(),
        dm.test_ds.df.region.tolist(),
        dm.test_ds.df.source_jpeg.tolist(),
    )), 
    columns=['dnn', 'svn', 'label', 'region', 'source_jpeg']
)
# Define a function to compute accuracy per group
def compute_accuracy(group, pred_col):
    return balanced_accuracy_score(group['label'], group[pred_col])

# Group by region and compute accuracy
dnn_acc_by_region = pred_df.groupby('region').apply(lambda g: compute_accuracy(g, 'dnn'))
svn_acc_by_region = pred_df.groupby('region').apply(lambda g: compute_accuracy(g, 'svn'))
closed_pct = pred_df.groupby('region').apply(lambda g: g.label.sum() / g.label.count())

# Combine into a DataFrame for display
acc_df = pd.DataFrame({
    'dnn': dnn_acc_by_region,
    'svn': svn_acc_by_region,
    "closed_pct": closed_pct,
}).reset_index()

# Show the result
region_stats = acc_df.drop(columns="svn").rename(columns={"dnn": "accuracy"}).sort_values(by="region").set_index("region")
print(region_stats.round(2))
region_stats.round(2).to_csv("/Users/kyledorman/data/estuary/display/per_region_results.csv")

print("dnn", balanced_accuracy_score(pred_df.label, pred_df.dnn))
print("svn", balanced_accuracy_score(pred_df.label, pred_df.svn))

In [None]:
labels = pd.read_csv("/Users/kyledorman/data/estuary/label_studio/00025/labels.csv")
for r, g in labels.groupby("region"):
    print(r, round((g.label == "closed").sum() / len(g.label), 3))

In [None]:
from PIL import Image
aaa = pred_df[(pred_df.dnn != pred_df.label)]

rows = 3
cols = 4
fig, axs = plt.subplots(cols, cols, figsize=(15, 15))

for (idx, row), ax in zip(aaa.iterrows(), axs.flatten()):
    ax.imshow(Image.open(row.source_jpeg))
    ax.set_axis_off()
    ax.set_title(f"{row.region} pred={module.conf.classes[row.dnn]} label={module.conf.classes[row.label]}")

plt.tight_layout()

In [None]:
dm = 