In [1]:
from pathlib import Path
from glob import glob

from tqdm.notebook import tqdm
import tifffile as tiff
import matplotlib.pyplot as plt
import numpy as np
from skimage import exposure
import pandas as pd

: 

In [6]:
data_path = Path('/local/home/lhauptmann/CV4A/data')
data_raw_path = data_path / 'raw'
data_processed_path = data_path / 'processed'
data_split_path = data_raw_path / 'FieldIds.csv'
azcopy_path = '/local/home/lhauptmann/CV4A//local/home/lhauptmann/CV4A/data/azcopy_linux_amd64_10.28.0'

In [93]:
def plot_sentinel2_image(image_data, title="", ax = None):
    """
    Plots Sentinel-2 image data (RGB composite or individual bands).

    Parameters:
        - image_data (numpy array): The image data (RGB composite or individual band).
        - title (str): The title of the plot.
        - show_colorbar (bool): Whether to show the colorbar. Default is True.
    """
    if not ax:
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    ax.imshow(image_data, cmap='viridis')
    

    ax.set_title(title, fontsize=14)
    ax.axis('off')  # Turn off axes for better presentation


In [123]:
from torchgeo import datasets
import os
os.environ["PATH"] = f"{azcopy_path}:" + os.environ["PATH"]

dataset = datasets.CV4AKenyaCropType(root=data_path/"raw", download=True, chip_size = 224, stride=16)

split = pd.read_csv(data_split_path)
test_id = split["test"].dropna().values
train_id = split["train"].dropna().values

In [133]:
sample = dataset[300]
print(sample.keys(), sample["image"].shape)
print(sample["field_ids"].unique(), sample["mask"].unique(), sample["tile_index"].unique())

dict_keys(['image', 'mask', 'field_ids', 'tile_index', 'x', 'y']) torch.Size([13, 13, 224, 224])
tensor([   0, 1469, 1470, 1689, 1748, 1749, 1750, 1808, 2072, 2345, 2928, 2929,
        3180, 3220, 3340, 4034, 4035, 4743, 4775], dtype=torch.int32) tensor([0, 1, 2, 4, 5], dtype=torch.uint8) tensor([0])


In [134]:
# Image consists of [n_timepoints, n_bands, height, width]
print(sample["image"].shape)
fig, axes = plt.subplots(1, 2, figsize=(20, 10))
timepoint = 3
plot_sentinel2_image(sample["image"][timepoint,0:-1,...].mean(dim=0), ax = axes[0])
plot_sentinel2_image(sample["mask"], ax = axes[1])
plt.show()

torch.Size([13, 13, 224, 224])


In [119]:
from torchgeo import models
model = models.resnet50(weights=models.ResNet50_Weights.SENTINEL2_ALL_DECUR)

In [None]:
sample.keys()

dict_keys(['image', 'mask', 'field_ids', 'tile_index', 'x', 'y'])


In [None]:
labels, field_ids, split = [], [], []
for sample in dataset:
    labels.append(sample["mask"].unique().numpy())
    fids = sample["field_ids"].unique().numpy()
    field_ids.append(fids)
    if all([fid in np.append(test_id, [0]) for fid in fids]) and len(fids) > 1:
        split.append("test")
    elif all([fid in np.append(train_id, [0]) for fid in fids]) and len(fids) > 1:
        split.append("train")
    else:
        split.append("none")
        

In [None]:
# make labels one_hot encoded
labels_onehot = []
for l in labels:
    onehot = np.zeros(8)
    onehot[l] = 1
    labels_onehot.append(onehot)

In [None]:
split = np.array(split)

In [None]:
split[split == "train"], np.array(labels_onehot)[split == "train"]

(array(['train', 'train', 'train', ..., 'train', 'train', 'train'],
      shape=(1186,), dtype='<U5'), array([[1., 1., 0., ..., 0., 0., 0.],
       [1., 1., 0., ..., 0., 0., 0.],
       [1., 1., 0., ..., 0., 0., 0.],
       ...,
       [1., 0., 0., ..., 0., 0., 0.],
       [1., 1., 0., ..., 0., 0., 0.],
       [1., 0., 1., ..., 1., 0., 0.]], shape=(1186, 8)))


In [None]:
from skmultilearn.model_selection import IterativeStratification
splits = {}
k_fold = IterativeStratification(n_splits=5, order=1)
for i, (train, test) in enumerate(k_fold.split(split[split == "train"], np.array(labels_onehot)[split == "train"])):
    splits[i] = {"train": train, "test": test}

In [None]:
test_id = np.array(test_id).astype(int).tolist()
train_id = np.array(splits[0].get("train")).astype(int).tolist()
val_id = np.array(splits[0].get("test")).astype(int).tolist()


Traceback (most recent call last):
  File "/local/home/lhauptmann/.vscode-server/extensions/ms-python.python-2025.0.0-linux-x64/python_files/python_server.py", line 133, in exec_user_input
    retval = callable_(user_input, user_globals)
  File "<string>", line 1, in <module>
AttributeError: 'list' object has no attribute 'astype'



In [None]:
train_samples = [dataset[i] for i in train_id]
test_samples = [dataset[i] for i in test_id]
val_samples = [dataset[i] for i in val_id]

for val_sample in val_samples:
    val_sample["image"] = process_images(val_sample["image"])

for train_sample in train_samples:
    train_sample["image"] = process_images(train_sample["image"])

for test_sample in test_samples:
    test_sample["image"] = process_images(test_sample["image"])


# save as torch dataset
import torch

torch.save(train_samples, data_processed_path / "train_samples.pt")
torch.save(test_samples, data_processed_path / "test_samples.pt")
torch.save(val_samples, data_processed_path / "val_samples.pt")


In [None]:
def process_images(image):
    image[:,-1,...] = image[:,-1,...] / 100
    image_band = image[:,:-1,...]
    image_bands_mean, image_band_std = image_band.mean(dim=(0,2,3),  keepdims=True), image_band.std(dim=(0,2,3), keepdims=True)
    image_band_min = image_bands_mean - image_band_std * 2
    image_band_max = image_bands_mean + image_band_std * 2
    image_band_norm = (image_band - image_band_min) / (image_band_max - image_band_min)
    image_band_norm = image_band_norm.clamp(0, 1)
    return image

torch.Size([13, 13, 224, 224])


tensor(2.0000e-14)


In [None]:
masks[8000]

[0]
