In [2]:
import os
import re
import subprocess
from collections import defaultdict
import numpy as np
import rasterio
from rasterio.merge import merge
from rasterio.warp import reproject, Resampling
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import glob
import json
from tqdm import tqdm
import yaml

#import geopandas as gpd
import pandas as pd

import torch

import sys
sys.path.append('../')

import utils.basics as bsc 
import utils.plotting as pt
import utils.processing as proc
import utils.eval_pipe as eval

import utils.model_loader as md
import utils.data_loader as dt
import utils.config_loader as cf

%load_ext autoreload
%autoreload 2


## main run of experiments

In [None]:
# SELECT EXPERIMENTAL CONFIG
with open('../configs/experiments.yaml', 'r') as f:
    experiments = yaml.safe_load(f)
    experiment_names = list(experiments.keys())
    experiment_names = experiment_names[1:8]  # Select the first 6 experiment names

    
with open('../configs/normparams.yaml', 'r') as f:
    normparams = yaml.safe_load(f)
joint_normparams = normparams['chm']['_111']
#print(experiment_names)
experiment_names
combos = ["110","101","011"] 

In [None]:
import random

global_config = md.global_config
seed = global_config['seed'] 
run_id_base = "251025_GEN_" # customize as needed
repetitions = range(5)
combos = ["110","101","011"] #1 is training data, 0 test. logic is LSB: 001 -> SITE1 is training data. 

for i in repetitions: # Run 10 experiments with different seeds

    for combo in combos:
        #run_id = md.generate_run_id()
        run_id = run_id_base + f"_{combo}_{i}"
        seed = seed + i  # Different seed for each experiment
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

        for exp_name in experiment_names:
            
            sites, cfg = cf.get_config(exp_name)  
            cfg.update(global_config)  # Ensure cfg has the latest global_config
            cfg.update({"combo": combo})
            #print("=== NEW EXPERIMENT ===")
            print(f" --> Name: {exp_name}, Combo: {combo}, Run ID: {run_id}, Seed: {seed}")

            # Build dataset
            X, Y, sitenums = dt.build_patched_dataset(cfg, sites, patch_size=32, nan_percent_allowed=20)
            X_patch_train, Y_patch_train, X_test, Y_test = dt.build_patched_dataset_generalization(cfg, sites, combo,patch_size=32, nan_percent_allowed=20)
            
            # now split the train data into train (80%) and val (20%)
            X_train, X_val, Y_train, Y_val = train_test_split(
                X_patch_train, Y_patch_train, test_size=0.2, random_state=seed
            )


            # Syntax: X_train, X_val, X_test, y_train, y_val, y_test 
            train_dataset = md.S2CanopyHeightDataset(X_train, Y_train, cfg)
            val_dataset = md.S2CanopyHeightDataset(X_val, Y_val, cfg)
            test_dataset = md.S2CanopyHeightDataset(X_test, Y_test, cfg)

            train_loader = md.DataLoader(
                train_dataset,
                batch_size=global_config['batch_size'],
                shuffle=True,
                num_workers=0
            )
            val_loader = md.DataLoader(val_dataset, batch_size=global_config['batch_size'])
            test_loader = md.DataLoader(test_dataset, batch_size=global_config['batch_size'])
            
            # build model depending on in out channels, defined by the dataloaders
            model = md.build_unet(in_channels=X_train[0].shape[0], out_channels=Y_train[0].shape[0],cfg=cfg)
            #train model depending on config. 
            model, logs = md.train_model(model, train_loader, val_loader, cfg)

            # evaluate model on val and test set, save results
            # get normparams for CHMmmn
            combo_key = f"_{combo}"
            normparams_used = normparams['chm'][combo_key]

            md.save_results(model, val_loader, test_loader, normparams_used, logs, cfg, run_id=run_id)

            #print("=================")

        #print("DONE WITH COMBO ", combo)
    print("DONE WITH ALL EXPERIMENTS for iteration:  ", i)
    