# Training
Goals:
- get data from scratch bucket, as saved from DataPrep_Trial notebook
- Decompress (distributedly?)
- Convert to Pytorch format
- Train distributedly
- Test

In [62]:
# Common imports and settings
import os, sys, re
from pathlib import Path
from IPython.display import Markdown
import pandas as pd
pd.set_option("display.max_rows", None)
import xarray as xr
import dask
from dask.distributed import Client
from dask_gateway import Gateway
from dateutil.parser import parse
from dateutil.relativedelta import relativedelta

# Datacube
import datacube
from datacube.utils.aws import configure_s3_access
import odc.geo.xr                                  # https://github.com/opendatacube/odc-geo
from datacube.utils import masking  # https://github.com/opendatacube/datacube-core/blob/develop/datacube/utils/masking.py
from odc.algo import enum_to_bool                  # https://github.com/opendatacube/odc-tools/blob/develop/libs/algo/odc/algo/_masking.py
from dea_tools.plotting import display_map, rgb    # https://github.com/GeoscienceAustralia/dea-notebooks/tree/develop/Tools

import boto3

# Basic plots
%matplotlib inline
# import matplotlib.pyplot as plt
# plt.rcParams['figure.figsize'] = [12, 8]

# Holoviews
# https://holoviz.org/tutorial/Composing_Plots.html
# https://holoviews.org/user_guide/Composing_Elements.html
import hvplot.pandas
import hvplot.xarray
import panel as pn
import colorcet as cc
import cartopy.crs as ccrs
from datashader import reductions
from holoviews import opts
# hv.extension('bokeh', logo=False)
print("Libraries loaded successfully.")

Libraries loaded successfully.


In [63]:
# EASI defaults
# These are convenience functions so that the notebooks in this repository work in all EASI deployments

# The `git.Repo()` part returns the local directory that easi-notebooks has been cloned into
# If using the `easi-tools` functions from another path, replace `repo` with your local path to `easi-notebooks` directory
try:
    import git
    repo = git.Repo('.', search_parent_directories=True).working_tree_dir    # Path to this cloned local directory
except (ImportError, git.InvalidGitRepositoryError):
    repo = Path.home() / 'easi-notebooks'    # Reasonable default
    if not repo.is_dir():
        raise RuntimeError('To use `easi-tools` please provide the local path to `https://github.com/csiro-easi/easi-notebooks`')
if repo not in sys.path:
    sys.path.append(str(repo))    # Add the local path to `easi-notebooks` to python

from easi_tools import EasiDefaults
from easi_tools import initialize_dask, xarray_object_size, mostcommon_crs, heading
#from easi_tools.load_s2l2a import load_s2l2a_with_offset
print("EASI librariies loaded successfully.")

EASI librariies loaded successfully.


In [64]:
import torch
import pytorch_lightning as pl
from lightning.pytorch import Trainer
from torch.utils.data import DataLoader, Dataset
import s3fs, zarr, numpy as np

# TerraTorch Imports

from terratorch.tasks import SemanticSegmentationTask
print("ML libraries imported successfully")

ML libraries imported successfully


In [66]:
gateway = Gateway()

# CUDA stuff here
options = gateway.cluster_options()
options.cuda_worker = True
gpu_worker_num = 2
# options.node_selection = "all"

print("Requesting GPU Cluster via dask-cuda-worker...")
try:
    cluster = gateway.new_cluster(cluster_options=options)
    cluster.scale(gpu_worker_num)
    client = cluster.get_client()
    print("Cluster Dashboard:", client.dashboard_link)
except Exception as e:
    print("Cluster failed to start. You might not have GPU quota or the system is busy.")
    print(e)


Requesting GPU Cluster via dask-cuda-worker...


Exception ignored in: <function Gateway.__del__ at 0x7f6a0d46b920>
Traceback (most recent call last):
  File "/env/lib/python3.12/site-packages/dask_gateway/client.py", line 380, in __del__
    self.close()
  File "/env/lib/python3.12/site-packages/dask_gateway/client.py", line 353, in close
    elif self.loop.asyncio_loop.is_running():
         ^^^^^^^^^
  File "/env/lib/python3.12/site-packages/dask_gateway/client.py", line 330, in loop
    return self._loop_runner.loop
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/env/lib/python3.12/site-packages/distributed/utils.py", line 648, in loop
    raise RuntimeError(
RuntimeError: Accessing the loop property while the loop is not running is not supported


KeyboardInterrupt: 

In [None]:
gateway = Gateway()
cluster_name = "easihub.e61cd4c8bab34b9ea44a73a2adfef40d"
cluster = gateway.connect(cluster_name)
client = cluster.get_client()

In [57]:
client.restart()

In [67]:
#TODO: will this work for training?
configure_s3_access(aws_unsigned=False, requester_pays=True, client=client);

In [68]:
easi = EasiDefaults()
bucket = easi.scratch

# 2. Get User ID
userid = boto3.client('sts').get_caller_identity()['UserId']

# 3. Define your Project Name and Dataset Name
project_name = "traning_test_project"
dataset_name = "training_dataset_v1.zarr"

# 4. Construct the full S3 Path
# This is where Dask will write the files
s3_path = f"s3://{bucket}/{userid}/{project_name}/{dataset_name}"

#print(f"Target Storage Path: {s3_path}")

Successfully found configuration for deployment "csiro"


In [54]:
class ZarrDataset(Dataset):
    def __init__(self, img_path, lbl_path=None):
        self.fs = s3fs.S3FileSystem(anon=False)
        
        # Load Images
        img_store = s3fs.S3Map(root=img_path, s3=self.fs, check=False)
        self.img_root = zarr.open(img_store, mode='r')
        self.images = self.img_root[list(self.img_root.keys())[0]]
        self.num_times = self.data.shape[1]
        # Load Labels (If they exist)
        self.labels = None
        if lbl_path:
            lbl_store = s3fs.S3Map(root=lbl_path, s3=self.fs, check=False)
            self.lbl_root = zarr.open(lbl_store, mode='r')
            self.labels = self.lbl_root[list(self.lbl_root.keys())[0]]

    def __len__(self):
        return self.num_times

    def __getitem__(self, idx):
        for _ in range(10):
            # Load (6, Y, X)
            full_img = self.images[:, idx, :, :]
            _, h, w = full_img.shape
            top = np.random.randint(0, h - 224)
            left = np.random.randint(0, w - 224)
            img_crop = full_img[:, top:top+224, left:left+224]
            
            if np.count_nonzero(img_crop) < (img_crop.size * 0.5): continue 

            # Create Mask
            if self.labels:
                full_lbl = self.labels[idx, :, :]
                mask_crop = full_lbl[top:top+224, left:left+224]
                mask = torch.from_numpy(mask_crop).long()
            else:
                # Dummy Mask (NIR > 0.1)
                mask_crop = (img_crop[3] > 0.1).astype('long')
                mask = torch.from_numpy(mask_crop).long()

            img_crop = np.maximum(img_crop, 0)
            
            # CRITICAL CHANGE: Return a Dictionary!
            # TerraTorch tasks expect a dict with keys "image" and "mask"
            return {
                "image": torch.from_numpy(img_crop).float(),
                "mask": mask
            }
        
        # FallbackS
        return {
            "image": torch.zeros((6, 224, 224)).float(),
            "mask": torch.zeros((224, 224)).long()
        }

            

In [55]:
def train_terra_remote(image_path, label_path, save_path, epochs=5, batch_size=8):
    # ### FUTURE EDIT: Change "num_classes" if you have more than 2 classes
    model_args={
            # The Backbone
            "backbone": "prithvi_eo_v1_100", # Using v1 100M as standard base
            "backbone_pretrained": True,
            # Band Mapping (Matches your Data Prep order!)
            "backbone_bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
            # The Necks (Crucial for Prithvi to work as a Segmenter)
            # These reshape the 1D tokens back into 2D images for the decoder
            "necks": [
                {"name": "ReshapeTokensToImage"},
                {
                    "name": "SelectIndices", 
                    "indices": [2, 5, 8, 11] # Indices for 100M model
                },
                {"name": "LearnedInterpolateToPyramidal"}
            ],
            # The Decoder
            "decoder": "FCNDecoder",
            "num_classes": 2,
            "head_dropout": 0.1
        }

    # Wrap in a Lightning Task (Handles Loss/Optimizer)
    task = SemanticSegmentationTask(
        model_factory="EncoderDecoderFactory",
        model_args=model_args,
        loss="focal",
        optimizer="AdamW",
        lr=1e-4,
        optimizer_hparams={},
        ignore_index=-1,
    )

    # Actual training 
    dataset = ZarrDataset(image_path, label_path)
    loader = DataLoader(dataset, batch_size=batch_size, num_workers=0)

    trainer = Trainer(
        max_epochs=epochs,
        accelerator="gpu",
        devices=1,
        logger=False,
        enable_checkpointing=False 
    )

    print("Starting Training Loop...")
    trainer.fit(model=task, train_dataloaders=loader)

    # Save weigths
    print(f"Saving model to {save_path}...")
    fs = s3fs.S3FileSystem(anon=False)

    with fs.open(save_path, "wb") as f:
        # weights only, not the whole object
        torch.save(task.state_dict(), f)
        
    print("Model Saved Successfully.")

    return trainer.callback_metrics.get("train_loss", 0.0).item()

In [70]:
S3_LABEL_PATH = None
S3_IMAGE_PATH = s3_path
S3_MODEL_SAVE_PATH = f"s3://{bucket}/{userid}/{project_name}/models/prithvi_v1.pt"

EPOCHS = 5
BATCH_SIZE = 8

#print(f"Image Source: {S3_IMAGE_PATH}")
print(f"Label Source: {S3_LABEL_PATH if S3_LABEL_PATH else 'GENERATED ON FLY'}")


# Submit to GPU
future = client.submit(
    train_terra_remote,
    S3_IMAGE_PATH,
    S3_LABEL_PATH,
    S3_MODEL_SAVE_PATH,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE
)

# Wait for result
try:
    final_loss = future.result()
    print(f"\nTraining Complete! Final Loss: {final_loss}")
except Exception as e:
    print("\nTraining Failed on Worker:")
    print(e)

Label Source: GENERATED ON FLY


KeyboardInterrupt: 

In [30]:
client.close()
cluster.shutdown()