In [1]:
import pandas as pd
import os
import numpy as np
import pickle as pkl
import json
import yaml
from torch.utils.data import Dataset, DataLoader
from vit_foundry.perceiver import Perceiver, PerceiverConfig
import torch
from pathlib import Path
import collections
import shutil
from util_xgb import xgb_process_data, xgb_train_and_infer
from tqdm import tqdm

In [2]:
run = 'lr2e-06_cl32_cscscscscscscscsssss_NEE_lhid64_iemb64_nf10_causal_aux-0.2_drop0.3_ws_v3_MILA'
checkpoint = 'checkpoint-12.pth'

DATA_DIR = Path(os.path.join('batch', 'data', 'processed', 'v3'))
ALL_RUN_DIR = Path('batch/runs')
RUN_DIR = ALL_RUN_DIR / Path(run)
with open(RUN_DIR / Path('train_sites.txt'), 'r') as f:
    TRAIN_SITES = f.read().split('\n')
with open(RUN_DIR / Path('val_sites.txt'), 'r') as f:
    VAL_SITES = f.read().split('\n')
CHECKPOINT_PATH = RUN_DIR / Path(checkpoint)

In [3]:
# Visualize how many of each site type were in each set
def site_configuration():
    site_meta = pd.read_csv('processed_site_meta.csv')
    igbp_values = list(site_meta['IGBP'].unique())
    train_val_igbp = {i: [0,0] for i in igbp_values}
    for site in TRAIN_SITES:
        igbp = site_meta.loc[site_meta['SITE_ID'] == site, 'IGBP'].values[0]
        train_val_igbp[igbp][0] += 1
    for site in VAL_SITES:
        igbp = site_meta.loc[site_meta['SITE_ID'] == site, 'IGBP'].values[0]
        train_val_igbp[igbp][1] += 1

    site_type_distribution = pd.DataFrame(data=train_val_igbp).T.rename(columns={0: 'train', 1: 'val'})
    site_type_distribution.to_csv(os.path.join(RUN_DIR, 'site_type_distribution.csv'))

In [4]:
# Check for other runs with the same distribution
def find_identical_run():
    for d in os.listdir(ALL_RUN_DIR):
        if d == run:
            continue
        other_run_dir = os.path.join(ALL_RUN_DIR , d)
        if not os.path.exists(os.path.join(other_run_dir, 'train_sites.txt')):
            continue
        with open(os.path.join(other_run_dir, 'train_sites.txt'), 'r') as f:
            other_train_sites = f.read().split('\n')
        with open(os.path.join(other_run_dir, 'val_sites.txt'), 'r') as f:
            other_val_sites = f.read().split('\n')
        if collections.Counter(TRAIN_SITES) == collections.Counter(other_train_sites) \
                and collections.Counter(VAL_SITES) == collections.Counter(other_val_sites) \
                and os.path.exists(os.path.join(other_run_dir, 'xgb_inference.csv')):
            return other_run_dir
    return None
identical_run = find_identical_run()

if os.path.exists(os.path.join(RUN_DIR, 'xgb_inference.csv')):
    print('All files found')
elif identical_run is not None:
    print('Copying files')
    items_to_copy = ['xgb.pkl', 'site_type_distribution.csv', 'xgb_inference.csv']
    for i in items_to_copy:
        shutil.copy(os.path.join(identical_run, i), os.path.join(RUN_DIR, i))
else:
    site_configuration()
    xgb_process_data(DATA_DIR, TRAIN_SITES, VAL_SITES, RUN_DIR)
    xgb_train_and_infer(RUN_DIR, n_iter=20)

Processing data for XGBoost...
  train data complete
  val data complete


Parameters: { "njobs" } are not used.



MSE: 12.873177479536645 with params: {'colsample_bytree': 0.7578115700148477, 'gamma': 0.08287634566961916, 'min_child_weight': 3, 'learning_rate': 0.12536528639802086, 'max_depth': 12, 'n_estimators': 147, 'subsample': 0.6583160158021772}


Parameters: { "njobs" } are not used.



MSE: 13.772317715454664 with params: {'colsample_bytree': 0.5428623482935462, 'gamma': 0.42753182623476804, 'min_child_weight': 4, 'learning_rate': 0.195335025424085, 'max_depth': 13, 'n_estimators': 148, 'subsample': 0.868490383137451}


Parameters: { "njobs" } are not used.



MSE: 13.620876725890229 with params: {'colsample_bytree': 0.9373933366506375, 'gamma': 0.334650068726291, 'min_child_weight': 7, 'learning_rate': 0.09471102815068018, 'max_depth': 17, 'n_estimators': 96, 'subsample': 0.8472767261811002}


Parameters: { "njobs" } are not used.



MSE: 13.681062090348542 with params: {'colsample_bytree': 0.6659707166963071, 'gamma': 0.2574097229916544, 'min_child_weight': 6, 'learning_rate': 0.17118258588458526, 'max_depth': 15, 'n_estimators': 103, 'subsample': 0.78634716525238}


Parameters: { "njobs" } are not used.



MSE: 13.214840766176026 with params: {'colsample_bytree': 0.8595014774580252, 'gamma': 0.053164449712578166, 'min_child_weight': 4, 'learning_rate': 0.16358737399035927, 'max_depth': 14, 'n_estimators': 75, 'subsample': 0.8466152640066291}


Parameters: { "njobs" } are not used.



MSE: 13.65329293082133 with params: {'colsample_bytree': 0.50596164616668, 'gamma': 0.3058749990163182, 'min_child_weight': 6, 'learning_rate': 0.18867793487557463, 'max_depth': 13, 'n_estimators': 145, 'subsample': 0.9626349228192372}


Parameters: { "njobs" } are not used.



MSE: 15.651814947808404 with params: {'colsample_bytree': 0.5171735273002764, 'gamma': 0.33004558578977117, 'min_child_weight': 3, 'learning_rate': 0.020043046357217372, 'max_depth': 14, 'n_estimators': 57, 'subsample': 0.954149019488095}


Parameters: { "njobs" } are not used.



MSE: 14.032901964572883 with params: {'colsample_bytree': 0.9366261649717827, 'gamma': 0.19355643265313816, 'min_child_weight': 5, 'learning_rate': 0.1778782013132973, 'max_depth': 17, 'n_estimators': 83, 'subsample': 0.5556243971940721}


Parameters: { "njobs" } are not used.



MSE: 11.951121637402274 with params: {'colsample_bytree': 0.6271489401347604, 'gamma': 0.4961522349121672, 'min_child_weight': 2, 'learning_rate': 0.08566470358150552, 'max_depth': 9, 'n_estimators': 57, 'subsample': 0.8714300786930459}


Parameters: { "njobs" } are not used.



MSE: 12.1268062112446 with params: {'colsample_bytree': 0.6571866577372808, 'gamma': 0.4402882371131698, 'min_child_weight': 3, 'learning_rate': 0.15426342518584454, 'max_depth': 9, 'n_estimators': 101, 'subsample': 0.5060748218042318}


Parameters: { "njobs" } are not used.



MSE: 12.05635378322994 with params: {'colsample_bytree': 0.6821607533776404, 'gamma': 0.24691354252333325, 'min_child_weight': 2, 'learning_rate': 0.1915258815131692, 'max_depth': 8, 'n_estimators': 110, 'subsample': 0.7095200827485078}


Parameters: { "njobs" } are not used.



MSE: 12.081574038290078 with params: {'colsample_bytree': 0.5611967864367868, 'gamma': 0.231501413267994, 'min_child_weight': 3, 'learning_rate': 0.11701695906283917, 'max_depth': 10, 'n_estimators': 55, 'subsample': 0.8614767210375238}


Parameters: { "njobs" } are not used.



MSE: 12.660364552532704 with params: {'colsample_bytree': 0.9663087518105566, 'gamma': 0.06685774059254646, 'min_child_weight': 5, 'learning_rate': 0.027107788566750335, 'max_depth': 12, 'n_estimators': 104, 'subsample': 0.8439261859395983}


Parameters: { "njobs" } are not used.



MSE: 12.154337502082008 with params: {'colsample_bytree': 0.8125189953952469, 'gamma': 0.03862050533710559, 'min_child_weight': 2, 'learning_rate': 0.055639964011472, 'max_depth': 10, 'n_estimators': 74, 'subsample': 0.7726089732475734}


Parameters: { "njobs" } are not used.



MSE: 11.754383225817286 with params: {'colsample_bytree': 0.7665240684286052, 'gamma': 0.23449997699795766, 'min_child_weight': 4, 'learning_rate': 0.11148863493502222, 'max_depth': 8, 'n_estimators': 76, 'subsample': 0.6986365847441489}


Parameters: { "njobs" } are not used.



MSE: 12.50922013807316 with params: {'colsample_bytree': 0.7011816322561479, 'gamma': 0.3926079853515963, 'min_child_weight': 3, 'learning_rate': 0.10529490985894574, 'max_depth': 11, 'n_estimators': 143, 'subsample': 0.721472783762807}


Parameters: { "njobs" } are not used.



MSE: 12.502539248238422 with params: {'colsample_bytree': 0.840790932505333, 'gamma': 0.0640727083447552, 'min_child_weight': 2, 'learning_rate': 0.11469778002816683, 'max_depth': 12, 'n_estimators': 110, 'subsample': 0.5248909726002511}


Parameters: { "njobs" } are not used.



MSE: 12.142246894947673 with params: {'colsample_bytree': 0.9087844085436133, 'gamma': 0.16338809903779045, 'min_child_weight': 5, 'learning_rate': 0.12121242588273946, 'max_depth': 10, 'n_estimators': 98, 'subsample': 0.8123753247613568}


Parameters: { "njobs" } are not used.



MSE: 13.60687250054473 with params: {'colsample_bytree': 0.7744824680442506, 'gamma': 0.0066344322575006, 'min_child_weight': 3, 'learning_rate': 0.01942403530860936, 'max_depth': 17, 'n_estimators': 122, 'subsample': 0.8168795835596803}


Parameters: { "njobs" } are not used.



MSE: 15.574862764036558 with params: {'colsample_bytree': 0.8433546259061147, 'gamma': 0.02807806202543872, 'min_child_weight': 5, 'learning_rate': 0.011351177781153942, 'max_depth': 9, 'n_estimators': 96, 'subsample': 0.5355049023195102}
Best MSE: 11.754383225817286
Best params: {'colsample_bytree': 0.7665240684286052, 'gamma': 0.23449997699795766, 'min_child_weight': 4, 'learning_rate': 0.11148863493502222, 'max_depth': 8, 'n_estimators': 76, 'subsample': 0.6986365847441489}


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  val_final['XGBoost'] = best_predictions


In [5]:
# New dataset / dataloader specifically for analysis
# Includes site ID and timestamp with every sample

class FluxDataset(Dataset):
    def __init__(self, data_dir, sites, context_length=48, targets=['GPP_NT_VUT_REF']):
        self.data_dir = data_dir
        self.sites = sites
        self.data = []
        self.context_length = context_length
        self.targets = targets
        self.remove_columns = ['timestamp', 'NEE_VUT_REF', 'GPP_NT_VUT_REF', 'RECO_NT_VUT_REF']
        
        for root, _, files in os.walk(self.data_dir):
            in_sites = False
            for site in sites:
                if site in root:
                    in_sites = True
            if not in_sites:
                continue
            

            if 'data.csv' in files:
                df = pd.read_csv(os.path.join(root, 'data.csv'))
                float_cols = [c for c in df.columns if c != 'timestamp']
                df[float_cols] = df[float_cols].astype(np.float32)
                df['timestamp'] = pd.to_datetime(df['timestamp'])
                with open(os.path.join(root, 'modis.pkl'), 'rb') as f:
                    modis_data = pkl.load(f)
                with open(os.path.join(root, 'meta.json'), 'r') as f:
                    meta = json.load(f)

                self.data.append((meta, df, modis_data))
        
        self.lookup_table = []
        for i, d in enumerate(self.data):
            _, df, _ = d
            for r in range(self.context_length, len(df)+1):
                self.lookup_table.append((i,r))
        
        col_df = self.data[0][1].drop(columns=self.remove_columns)
        self.tabular_columns = list(col_df.columns)
        self.modis_bands = max([v.shape[0] for v in list(self.data[0][2].values())])

    def num_channels(self):
        _, _, modis = self.data[0]
        return modis[list(modis.keys())[0]].shape[0]

    def __len__(self):
        return len(self.lookup_table)

    def __getitem__(self, idx):
        site_num, row_max = self.lookup_table[idx]
        row_min = row_max - (self.context_length)

        meta, df, modis = self.data[site_num]
        rows = df.iloc[row_min:row_max]

        rows = rows.reset_index(drop=True)
        modis_data = []
        timestamps = list(rows['timestamp'])
        for i, ts in enumerate(timestamps):
            pixels = modis.get(ts, None)
            if pixels is not None:
                modis_data.append((i, torch.tensor(pixels[:,1:9,1:9], dtype=torch.float32)))
        
        targets = torch.tensor(rows[self.targets].values)
        row_values = torch.tensor(rows.drop(columns=self.remove_columns).values)
        mask = row_values.isnan()
        row_values = row_values.nan_to_num(-1.0) # just needs a numeric value, doesn't matter what

        ### Analysis variables
        timestamp = timestamps[-1]
        site_id = meta['SITE_ID']
        return row_values, mask, modis_data, targets, timestamp, site_id


def custom_collate_fn(batch):
    row_values, mask, modis_data, targets, timestamps, site_ids = zip(*batch)

    # Normal attributes
    row_values = torch.stack(row_values, dim=0)
    mask = torch.stack(mask, dim=0)
    targets = torch.stack(targets, dim=0)

    # List of modis data. Tuples of (batch, timestep, data)
    modis_list = []
    for b, batch in enumerate(modis_data):
        for t, data in batch:
            modis_list.append((b, t, data))

    return row_values, mask, modis_list, targets, list(timestamps), list(site_ids)

def FluxDataLoader(data_dir, sites, context_length = 32, targets=['NEE_VUT_REF'], **kwargs):
    ds = FluxDataset(data_dir, sites, context_length=context_length, targets=targets)
    return DataLoader(ds, collate_fn=custom_collate_fn, **kwargs)


In [6]:
dl = FluxDataLoader(DATA_DIR, VAL_SITES, num_workers=16, batch_size=256, shuffle=True)

In [7]:
with open(os.path.join(RUN_DIR, 'config.yml'), 'r') as file:
    config = yaml.safe_load(file)
inference_df = pd.read_csv(os.path.join(RUN_DIR, 'xgb_inference.csv'))
inference_df['timestamp'] = pd.to_datetime(inference_df['timestamp'])
inference_df.set_index(['SITE_ID', 'timestamp'], drop=True, inplace=True)
inference_df['Deep Model'] = np.nan
inference_df.drop(columns=['XGBoost'], inplace=True)
config['model']['spectral_data_channels'] = dl.dataset.num_channels()

device = torch.device('cuda')
model = Perceiver(PerceiverConfig(**config['model'])) 
checkpoint = torch.load(CHECKPOINT_PATH)
model.load_state_dict(checkpoint['model'])
model.to(device)

model.eval()
for row_values, mask, modis_data, targets, timestamp, site_id in tqdm(dl):
    with torch.cuda.amp.autocast(dtype=torch.bfloat16):
        op = model(row_values, mask, modis_data, targets)
        outputs = op['logits'][:,-1].cpu().tolist()

        # Update inference df
        idx = pd.MultiIndex.from_tuples(zip(site_id, timestamp), names=['SITE_ID', 'timestamp'])
        inference_df.update(pd.DataFrame(outputs, columns=['Deep Model'], index=idx))


  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 22866/22866 [1:20:01<00:00,  4.76it/s]


In [8]:
inference_df.to_csv(os.path.join(RUN_DIR, 'deep_inference.csv'))