# PaddockTS 'digital case study'@Milgadara

## Purpose

## Input data
- paddock polygons, manually drawn (analysis with autogenerated polygons is done in 03_paddock-ts.ipynb)
- paddock-level management annotations and yield data
- sentinel time series (ds2.pkl)

## Steps
1. Merge paddock geometries with annotation info.
2. Get paddock annotations into a usable format.
3. Load ds, get new indicies, estimate veg cover fractions, resample weekly and interpolate missing. 
4. save mp4 of RGB and veg fration
5. Generate the paddock-variable-week dataset (xarray object) and save this. (this takes a lot of effort to make)

## Outputs
- Annual heatmap time series (single variable) showing paddock annotations.
- Annual paddock clustering (multivariate) with paddock annotations.
- Multi-paddock video time series with paddock trackers.
- 2017-2024 calendar plot for each paddock with crop type printed in the margin for each year. 

### Setup

In [1]:
import geopandas as gpd
import pandas as pd
import numpy as np
import seaborn as sns
import rasterio #
import xarray as xr
import matplotlib.pyplot as plt 
import matplotlib
import rioxarray
from shapely.geometry import mapping


from joblib import Parallel, delayed
from tqdm import tqdm

%matplotlib inline

import cv2 
from matplotlib.gridspec import GridSpec
from dea_tools.plotting import display_map, rgb, xr_animation
import skimage

from tslearn.preprocessing import TimeSeriesScalerMeanVariance
from tslearn.preprocessing import TimeSeriesResampler
from tslearn.clustering import KShape, KernelKMeans, silhouette_score
from tslearn.preprocessing import TimeSeriesScalerMeanVariance
from tslearn.clustering import TimeSeriesKMeans

from IPython.display import Image
from IPython.core.display import Video

import pickle
import os
import shutil

import ffmpeg # REQUIRES a module load ffmpeg/4.3.1 (in jupyterlab, must do when setting up sesh)

# for veg cover fraction part:
import tensorflow # REQUIRES module load tensorflow/2.15.0 
from fractionalcover3 import unmix_fractional_cover
from fractionalcover3 import data

2025-12-11 02:16:20.488091: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-12-11 02:16:20.488575: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-12-11 02:16:20.564731: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
os.chdir('/home/106/jb5097/Projects/PaddockTS')

stub = 'MILG_b033_2017-24'
outdir = "/g/data/xe2/John/Data/PadSeg/" # best if output is stored in gdata

# Temporary directory for animations, frames and plots
# clear the tmp directory and create anew
tmp_dir = '/scratch/xe2/jb5097/tmp2/'+stub+'/'
shutil.rmtree(tmp_dir, ignore_errors=True)
os.makedirs(tmp_dir)

paddocks_manual = "/g/data/xe2/John/Data/PadSeg/milg_manualpaddocks2.gpkg" # hand-drawn paddock polygons with name column that MAY match with annotation data (not all rows will have annotations)
paddock_annotations = "/g/data/xe2/John/Data/PadSeg/MILG_paddocks_tmp.csv" # the latest version of paddock management annotation data (assumes format stays the same!)

### Functions

In [None]:
'''
create paddocks-variable-time xarray given polygons and xarray
'''


In [None]:
def calculate_and_add_fractional_cover(ds, band_names, i, correction=True):
    """
    Calculate the fractional cover using specified bands from an xarray Dataset 
    and add the results as new bands to the original Dataset.

    Parameters:
    ds (xarray.Dataset): The input xarray Dataset containing the satellite data.
    band_names (list): A list of 6 band names to use for the calculation.
    i (int): The integer specifying which pretrained model to use.
    correction (bool): Whether to apply correction factors to the input bands.

    Returns:
    xarray.Dataset: The updated xarray Dataset with the new fractional cover bands.
    """
    # Check if the number of band names is exactly 6
    if len(band_names) != 6:
        raise ValueError("Exactly 6 band names must be provided")
    
    # Extract the specified bands and stack them into a numpy array with shape (time, bands, x, y)
    inref = np.stack([ds[band].values for band in band_names], axis=1)
    print('Shape of input (should be time, bands, x, y):', inref.shape)  # This should now be (time, bands, x, y)

    if correction:
        print('Using correction factors that attempt to fudge S2 data to better match Landsat.. be careful?')
        # Array for correction factors 
        # This is taken from here: https://github.com/petescarth/fractionalcover/blob/main/notebooks/ApplyModel.ipynb
        # and described in a paper by Neil Flood for taking Landsat to Sentinel 2 reflectance (and visa versa).
        # NOT SURE THIS IS BEING IMPLEMENTED PROPERLY> THINK ABOUT ORDER OF OPERATION CT LINKED NOTEBOOK
        correction_factors = np.array([0.9551, 1.0582, 0.9871, 1.0187, 0.9528, 0.9688]) + \
                             np.array([-0.0022, 0.0031, 0.0064, 0.012, 0.0079, -0.0042])
    
        # Apply correction factors using broadcasting
        inref = inref * correction_factors[:, np.newaxis, np.newaxis]
    else:
        print('Not applying correction factors')
        inref = inref * 0.0001  # if not applying the correction factors

    # Initialize an array to store the fractional cover results
    fractions = np.empty((inref.shape[0], 3, inref.shape[2], inref.shape[3]))

    # Loop over each time slice and apply the unmix_fractional_cover function
    for t in range(inref.shape[0]):
        fractions[t] = unmix_fractional_cover(inref[t], fc_model=data.get_model(n=i))

    # Create DataArray for each vegetation fraction
    bg = xr.DataArray(fractions[:, 0, :, :], coords=[ds.coords['time'], ds.coords['y'], ds.coords['x']], dims=['time', 'y', 'x'])
    pv = xr.DataArray(fractions[:, 1, :, :], coords=[ds.coords['time'], ds.coords['y'], ds.coords['x']], dims=['time', 'y', 'x'])
    npv = xr.DataArray(fractions[:, 2, :, :], coords=[ds.coords['time'], ds.coords['y'], ds.coords['x']], dims=['time', 'y', 'x'])
    
    # Assign new DataArrays to the original Dataset
    ds_updated = ds.assign(bg=bg, pv=pv, npv=npv)
    
    return ds_updated

# # Example usage:
# band_names = ['nbart_blue', 'nbart_green', 'nbart_red', 'nbart_nir_2', 'nbart_swir_2', 'nbart_swir_3']
# i = 1  # or whichever model index you want to use
# ds_updated = calculate_and_add_fractional_cover(ds, band_names, i)
# print(ds_updated)

In [None]:
#Functions to add selected spectral indices to an xarray
def calculate_indices(ds, indices):
    """
    Calculate multiple indices and add them to the dataset, retaining all data variables.
    
    Parameters:
    ds (xarray.Dataset): The input xarray dataset with dimensions (paddock, variable, time)
    indices (dict): A dictionary where keys are the names of the indices to be added, 
                    and values are functions that calculate the index.
    
    Returns:
    xarray.Dataset: The dataset with the additional indices and all original data variables.
    """
    new_bands = []

    for index_name, index_func in indices.items():
        # Calculate the index
        index_data = index_func(ds)
        
        # Expand dimensions of the new index to match the original dataset structure
        index_expanded = index_data.expand_dims(variable=[index_name], axis=1)
        
        # Append the new index to the list
        new_bands.append(index_expanded)
        print(index_name, 'has shape: ', index_data.shape)
    
    # Concatenate all new indices along the variable dimension
    new_bands_concat = xr.concat([ds.pvt] + new_bands, dim='variable').to_dataset(name='pvt')

    # Add back all original data variables to the new dataset
    for var in ds.data_vars:
        if var != 'pvt':  # Avoid overwriting the 'pvt' variable
            new_bands_concat[var] = ds[var]
    
    return new_bands_concat

def calculate_ndvi(ds):
    '''NDVI, but why isnt it the same as that downloaded from DEA?
    '''
    red = ds.sel(variable='nbart_red').pvt
    nir = ds.sel(variable='nbart_nir_1').pvt
    ndvi = (nir - red) / (red + nir)
    return ndvi

def calculate_cfi(ds):
    '''Calculate CFI (Canola Flower Index)
    Tian et al 2022 Remote Sensing https://www.mdpi.com/2072-4292/14/5/1113#sec2dot4-remotesensing-14-01113'''
    ndvi = ds.sel(variable='NDVI').pvt
    red = ds.sel(variable='nbart_red').pvt
    green = ds.sel(variable='nbart_green').pvt
    blue = ds.sel(variable='nbart_blue').pvt
    
    sum_red_green = red + green
    diff_green_blue = green - blue
    
    cfi = ndvi * (sum_red_green + diff_green_blue)
    return cfi

def calculate_nirv(ds):
    '''Near Infrared Reflectance of Vegetation
    '''
    ndvi = ds.sel(variable='NDVI').pvt
    nir = ds.sel(variable='nbart_nir_1').pvt
    nirv = ndvi * nir
    return nirv

def calculate_dnirv(ds):
    '''Calculate difference in NIRv compared to previous time step
    Caution: this seems to remove one time step
    This is currently not working well'''
    nirv = calculate_nirv(ds)
    dnirv = nirv.diff(dim='time', n=1)
    #dnirv = xr.concat([xr.DataArray([0], dims='time'), dnirv], dim='time')  # Handle first time step (make dnirv equal to zero)
    return dnirv

def calculate_ndti(ds):
    """ Normalized Difference Tillage Index (NDTI).
    NDTI = (R1610−R2200)/(R1610 + R2200)
    Described here and ref within: https://www.mdpi.com/2072-4292/13/18/3718 
    """
    # Extract the SWIR1 and SWIR2 bands
    swir1 = ds.sel(variable='nbart_swir_2').pvt
    swir2 = ds.sel(variable='nbart_swir_3').pvt
    
    # Calculate the NDTI
    ndti = (swir1 - swir2) / (swir1 + swir2)
    
    return ndti

def calculate_cai(ds):
    """Cellulose Absorption Index (CAI).
    CAI = (0.5∗(R2000 +R2200))−R2100
    see https://www.mdpi.com/2072-4292/13/18/3718 
    see also for calibration/nuance with Sentinel data: https://www.spiedigitallibrary.org/conference-proceedings-of-spie/11155/2533761/Identification-of-non-photosynthetic-vegetation-areas-in-Sentinel-2-satellite/10.1117/12.2533761.full
    """
    # Extract the SWIR1, SWIR2, and NIR bands
    swir1 = ds.sel(variable='nbart_swir_2').pvt
    swir2 = ds.sel(variable='nbart_swir_3').pvt
    nir = ds.sel(variable='nbart_nir_1').pvt
    
    # Calculate the CAI
    cai = 0.5 * (swir1 + swir2) - nir
    
    return cai

# # Example usage
# indices = {
#     'CFI': calculate_cfi,
#     'NIRv': calculate_nirv
# }

# updated_ds = calculate_indices(ds_paddocks_weekly, indices)
# print(updated_ds.variable)

### Load data
(for everything except paddock-year yield analysis. For that, skip to bottom)

In [3]:
# Read in the polygons from SAMGeo (these will not neccesarily match user-provided paddocks)
pol = gpd.read_file(outdir+stub+'_filt.gpkg')

# have to set a paddock id. Preferably do this in earlier step in future... 
pol['paddock'] = range(1,len(pol)+1)
pol['paddock'] = pol.paddock.astype('category')

# Read in the array of paddocks by variables (e.g. bands) by time -- the pvt array
pvt = np.load(outdir+stub+'_pvt.npy')
# pvt = np.load(data_path+stub+'_pvt2.npy')

# get the variable names:
with open(outdir+stub+'_pvt_vars.pkl', 'rb') as handle:
    var_names = pickle.load(handle)
print('No. vars:',len(var_names))

No. vars: 23


In [None]:
# Open the satellite data stack
#year = 2023
year = None

with open(outdir+stub+'_ds2.pkl', 'rb') as handle:
    ds = pickle.load(handle)
    # Filter the data if year is not null
    if year is not None:
        ds = ds.sel(time=ds['time'].dt.year == year)

## Add veg fractions to ds
band_names = ['nbart_blue', 'nbart_green', 'nbart_red', 'nbart_nir_1', 'nbart_swir_2', 'nbart_swir_3']
i = 3  # or whichever model index you want to use
ds = calculate_and_add_fractional_cover(ds, band_names, i, correction=False)

# Resample data weekly
ds_weekly = ds.resample(time="1W").interpolate("linear")

print('No. weeks in time series:', len(ds_weekly.time.values))
ds_weekly

In [None]:
# #### ? Save time, save/read in the dsi.pkl (this script is using a weekly resampled version)

# ds2i_weekly_filename = f"{outdir}{stub}_{'ds2i-w.pkl'}"

# print(ds2i_weekly_filename)

# with open(ds2i_weekly_filename, 'wb') as f:
#     pickle.dump(ds_weekly, f, protocol=pickle.HIGHEST_PROTOCOL)

# with open(ds2i_weekly_filename, "rb") as handle:
#     ds_weekly = pickle.load(handle)

# print(ds_weekly)

# '''THIS CRASHED SO HAS NOT FULLY RUN YET'''
# # Might not be neccesary if I just instead load in the paddockTS. 

In [4]:
# Read in manual polygons and paddock annotation data. Merge and keep as a geopandas df:

# paddock annotatioun data:
pad_an = pd.read_csv(paddock_annotations)

# Load the manual drawn polygons GeoDataFrame
pad_man = gpd.read_file(paddocks_manual)

# Remove rows that have no geometry
pad_man = pad_man[pad_man.geometry.notnull()]
print(len(pad_man))

# Identify rows with invalid geometries
invalid_geometries = pad_man[~pad_man.is_valid]
print("Invalid geometries:")
print(invalid_geometries)

# Remove rows with invalid geometries
pad_man = pad_man[pad_man.is_valid]
print("------")
print(len(pad_man))

# Add a new column with unique numbers
pad_man['paddock'] = range(1, len(pad_man) + 1)

# merge manual drawn polygons with annotations
pad_manan = pd.merge(pad_man, pad_an, left_on='name', right_on='Name', how='left').drop(columns=['Name'])
print(pad_manan.crs)
print(len(pad_manan), len(pad_man), len(pad_an))

90
Invalid geometries:
   name    type                                           geometry
86  fr1  forest  POLYGON ((14325498.506 -4132100.273, 14325511....
------
89
EPSG:6933
89 89 28


In [5]:
# Add colors to plot the polygons with:
# Create the 'color' column and set it to None for all rows
pad_man['color'] = 'None'

# Define the conditions and corresponding values for 'edge_color'
conditions = [
    (pad_man['type'] == 'forest'),
    (pad_man['type'] == 'tree_row'),
    (pad_man['type'] == 'named'),
    (pad_man['type'] == 'unnamed'),
    # Add more conditions here if needed
]

values = ['green', 'yellow', 'red', 'blue']
# Add corresponding values for additional conditions here if needed

# Create the 'edge_color' column based on the conditions
pad_man['edge_color'] = np.select(conditions, values, default='other')
pad_man

Unnamed: 0,name,type,geometry,paddock,color,edge_color
0,,unnamed,"POLYGON ((14324740.000 -4132290.000, 14324760....",1,,blue
1,,unnamed,"POLYGON ((14324120.000 -4132270.000, 14324140....",2,,blue
2,,unnamed,"POLYGON ((14323570.000 -4132320.000, 14323710....",3,,blue
3,,unnamed,"POLYGON ((14324930.000 -4132490.000, 14325020....",4,,blue
4,,unnamed,"POLYGON ((14324990.000 -4132750.000, 14325110....",5,,blue
...,...,...,...,...,...,...
84,tr21,tree_row,"POLYGON ((14325892.173 -4136177.198, 14326086....",85,,yellow
85,tr22,tree_row,"POLYGON ((14325452.296 -4135098.629, 14325451....",86,,yellow
87,fr2,forest,"POLYGON ((14326017.563 -4132451.584, 14326094....",87,,green
88,fr3,forest,"POLYGON ((14327713.350 -4135076.423, 14327951....",88,,green


### Visual checks of data before analysis
- map the paddocks
- animate the RGB

In [None]:
# Colour the paddocks according to whether named, un-named, forest block, tree row. 

# num_frames = 10
# xr_animation(ds_weekly, 
#              bands = ['nbart_red', 'nbart_green', 'nbart_blue'], 
#              output_path = tmp_dir+'quick_animation.mp4', 
#              show_gdf = pad_man, 
#              gdf_kwargs={"edgecolor": pad_man['edge_color']}, 
#              #gdf_kwargs={"edgecolor": 'red'}, 
#              limit = num_frames)
# plt.close()
# Video(tmp_dir+'quick_animation.mp4', embed = True)

# This version only shows the labelled paddocks
pol = pad_man[pad_man['type'] == 'named']
num_frames = len(ds_weekly.time)
xr_animation(ds_weekly, 
             bands = ['nbart_red', 'nbart_green', 'nbart_blue'], 
             output_path = outdir+stub+'manpad_RGB.mp4', 
             show_gdf = pol, 
             #gdf_kwargs={"edgecolor": pol['edge_color']}, 
             gdf_kwargs={"edgecolor": 'white'}, 
             limit = num_frames)
plt.close()
Video(outdir+stub+'manpad_RGB.mp4', embed = True) # rename

In [6]:
list(set(pad_man.name))

['tr10',
 'No 2',
 'No 3',
 'tr17',
 'Washpool East',
 'fr3',
 'tr1',
 'Rocky South',
 'tr18',
 'Horse Paddock Big',
 'Rocky',
 'Bottom Range',
 'tr16',
 'tr9',
 'tr5',
 'tr7',
 'Contour',
 'Little Tank',
 'tr23',
 'tr3',
 'tr20',
 'Rubbish Paddock',
 'tr22',
 'tr21',
 'tr13',
 'tr12',
 'Fingerboard',
 'tr8',
 'Rocky East',
 'tr19',
 None,
 'tr6',
 'tr4',
 'Air Strip',
 'fr2',
 'fr4',
 'Ranch Paddock',
 'Pine Hill',
 '135 Acre',
 'Washpool',
 'tr14',
 'tr11',
 'Old Sheep Yard',
 'Scramble Paddock',
 "Johnny's",
 'No 4',
 'tr15',
 'tr2',
 'Tank',
 'No 1',
 'Centre Paddock',
 'Little Horse Paddock']

In [None]:
### STOP HERE IF USING Year=None, or else the next step will crash due to too much data... 
### Skip through to next plots that don't use paddockTS, e.g. calendar plots

### Paddocks x variable x time
- clustering using heirachical clustering and/or time series clustering
- make interactive heat map to hover over labels (single variable)
- make interactive pairwise dissimilarity matrix (multi variable)
- dimensionality reduction plot (k-means? tSNE?), also interactive. 

In [None]:

# Function to process each geometry row
def process_geometry(datarow, ds):
    import rioxarray
    """
    Process each geometry to extract the median band values.
    Args:
        datarow: A row from the geopandas dataframe containing the geometry.
        ds: The xarray dataset with time series satellite imagery.
    Returns:
        A numpy array of median band values for the geometry over time.
    """
    # Clip the xarray dataset to the polygon
    ds_clipped = ds.rio.clip([datarow.geometry])

    # Extract the median band value, ignoring zero values
    pol_ts = ds_clipped.where(ds_clipped > 0).median(dim=['x', 'y'])
    array = pol_ts.to_array().transpose('variable', 'time').values.astype(np.float32)

    return array[None, :]

# Use parallel processing to extract time series data for each paddock
results = Parallel(n_jobs=-1)(
    delayed(process_geometry)(datarow, ds_weekly) 
    for datarow in tqdm(pad_man.itertuples(index=True), total=len(pad_man))
)

# Combine the results into a single numpy array
pvt = np.vstack([res for res in results])

print("Processing complete")
print("pvt shape: ", pvt.shape)

## Should turn this all into a clean function that outputs the pvt. 

In [None]:
# Create ds_paddocks
def create_paddock_xarray(pol, pvt_array, ds):
    '''TO DO
    remove valid_crown_ids
    '''
    # Extract time axis from the xarray Dataset
    time_axis = ds.time

    # get variable names (bands)
    var_names = list(ds.data_vars.keys())

    # Create a DataArray for the paddock geometries
    geometry_da = xr.DataArray(pol.geometry.values, dims=["paddock"], name="geometry")

    # Create DataArrays for additional variables
    name_da = xr.DataArray(pol.name.values, dims=["paddock"], name="name")
    type_da = xr.DataArray(pol['type'].values, dims=["paddock"], name="type")

    # Create DataArray for each band over time
    pvt_da = xr.DataArray(
        pvt_array,
        dims=["paddock", "variable", "time"],
        coords={
            "paddock": pol.paddock.values,
            "variable": var_names,
            "time": time_axis
        },
        name="pvt"
    )

    # Combine into a single Dataset
    ds_paddocks = xr.Dataset({
        "geometry": geometry_da,
        "name": name_da,
        "type": type_da,
        "pvt": pvt_da
    })

    return ds_paddocks

In [7]:
### Dec 10 2025, trying to read in previously made paddockTS. 

# get the variable names:
with open(outdir+stub+'_paddockTS_raw.pkl', 'rb') as handle:
    print(handle)
    ds_weekly_paddocks = pickle.load(handle)

ds_weekly_paddocks

'''Caution:
Just realised that this *_paddockTS_raw.pkl was made in canola_paddockyear_feature_extraction.ipynb
Might not jive with rest of this notebook (which uses the pvt variable in xarray.
might be best to start a new notebook for the case study... 
'''

<_io.BufferedReader name='/g/data/xe2/John/Data/PadSeg/MILG_b033_2017-24_paddockTS_raw.pkl'>


In [None]:
ds_weekly_paddocks = create_paddock_xarray(pad_man, pvt, ds_weekly)
ds_weekly_paddocks = calculate_indices(ds_weekly_paddocks, indices)
print(ds_weekly_paddocks)

### Clustering/heatmaps
- produce interactive plots of single variable paddoc time series.


In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import matplotlib.patches as mpatches

def plot_clustermap1(ds_paddocks, variable_name, outdir, stub):
    '''
    '''
    # Check if the variable_name exists in the dataset
    if variable_name not in ds_paddocks.variable.values:
        raise ValueError(f"Variable '{variable_name}' not found in the dataset.")
    
    # Extract the time series for the specified variable for all paddocks
    pt_variable = ds_paddocks.sel(variable=variable_name).pvt.values
    
    # Interpolate missing values along the time axis (2nd axis)
    pt_variable = np.apply_along_axis(
        lambda x: pd.Series(x).interpolate(method='linear', limit_direction='both').to_numpy(), 
        axis=1, 
        arr=pt_variable
    )

    print(pt_variable.shape)
    # Count NaN values
    nan_count = np.sum(np.isnan(pt_variable))
    print(f"Number of NaN values: {nan_count}")
    print("replacing nans with 0, for now, so the clustering can work..")
    pt_variable = np.nan_to_num(pt_variable, nan=0)

    # Extract the timestamps and convert to Pandas DatetimeIndex
    time_stamps = ds_paddocks.time.values
    time_index = pd.to_datetime(time_stamps)

    # Adjust start date to include January if necessary
    start_date = time_index.min()
    if start_date.month != 1:
        start_date = pd.Timestamp(year=start_date.year, month=1, day=1)
    
    # Ensure January is included in monthly_start
    monthly_start = pd.date_range(start=start_date, end=time_index.max(), freq='MS')

    # Find the closest previous timestamps in the original time_index
    monthly_ticks = []
    for date in monthly_start:
        prior_dates = time_index[time_index <= date]
        if not prior_dates.empty:
            monthly_ticks.append(prior_dates[-1])

    monthly_ticks_str = [str(t)[:10] for t in monthly_ticks]

    # Extract paddock IDs for row names
    row_names = ds_paddocks.name.values

    # Extract species types and create a color mapping
    species_types = ds_paddocks.type.values
    unique_types = np.unique(species_types)
    palette = sns.color_palette("hsv", len(unique_types))
    type_color_map = {species: palette[i] for i, species in enumerate(unique_types)}
    species_colors = np.array([type_color_map[species] for species in species_types])

    # Plot heatmap with clustering of the rows
    plt.figure(figsize=(6, 12))  # Adjust the figure size as needed
    
    g = sns.clustermap(pt_variable, method='average', metric='euclidean', 
                       row_cluster=True, col_cluster=False, cmap='viridis',
                       row_colors=species_colors)

    # Get the order of the rows after clustering
    row_order = g.dendrogram_row.reordered_ind

    # Reorder the row names and colors according to the clustering
    ordered_row_names = [row_names[i] for i in row_order]
    ordered_species_colors = [species_colors[i] for i in row_order]

    # Customize the plot
    g.ax_heatmap.set_xlabel('Time')
    #g.ax_heatmap.set_ylabel('Paddock')

    # Set the x-tick labels to the closest previous valid timestamp of each month
    tick_positions = [time_index.get_loc(t) for t in monthly_ticks]
    g.ax_heatmap.set_xticks(tick_positions)
    g.ax_heatmap.set_xticklabels(monthly_ticks_str, rotation=45, ha='right')

    # Set the y-tick labels to the ordered paddock IDs
    g.ax_heatmap.set_yticks(np.arange(len(ordered_row_names)) + 0.5)
    g.ax_heatmap.set_yticklabels(ordered_row_names, fontsize=8, rotation=0)

    # Customize the color bar and position it to the right
    g.cax.set_position([1.1, 0.2, 0.03, 0.45])  # [left, bottom, width, height]
    g.cax.set_title(variable_name, pad=10)  # Title above the color bar

    # Create species legend
    legend_handles = [mpatches.Patch(color=palette[i], label=species) 
                      for i, species in enumerate(unique_types)]
    
    # Add the legend to the plot
    plt.legend(handles=legend_handles, title='Paddock type', bbox_to_anchor=(1.05, 1.4), loc='upper left')

    # Save the plot to results
    plt.savefig(outdir + stub + f"_pt-{variable_name}.png", bbox_inches='tight')

    # Print the number of missing pixels across all time series for the variable
    print(f'Number of missing pixels across all {variable_name} time series:', np.count_nonzero(np.isnan(pt_variable)))

def plot_clustermap2(ds_paddocks, variable_name, outdir, stub):
    '''This version only shows the named paddocks, and therefore does not color rows by 'type'.
    Also, attach crop type to the row label. 
    Put colour bar above plot to make more room for row labels...
    '''
    # Check if the variable_name exists in the dataset
    if variable_name not in ds_paddocks.variable.values:
        raise ValueError(f"Variable '{variable_name}' not found in the dataset.")
    
    # Extract the time series for the specified variable for all paddocks
    pt_variable = ds_paddocks.sel(variable=variable_name).pvt.values
    
    # Interpolate missing values along the time axis (2nd axis)
    pt_variable = np.apply_along_axis(
        lambda x: pd.Series(x).interpolate(method='linear', limit_direction='both').to_numpy(), 
        axis=1, 
        arr=pt_variable
    )

    print(pt_variable.shape)
    # Count NaN values
    nan_count = np.sum(np.isnan(pt_variable))
    print(f"Number of NaN values: {nan_count}")
    print("replacing nans with 0, for now, so the clustering can work..")
    pt_variable = np.nan_to_num(pt_variable, nan=0)

    # Extract the timestamps and convert to Pandas DatetimeIndex
    time_stamps = ds_paddocks.time.values
    time_index = pd.to_datetime(time_stamps)

    # Adjust start date to include January if necessary
    start_date = time_index.min()
    if start_date.month != 1:
        start_date = pd.Timestamp(year=start_date.year, month=1, day=1)
    
    # Ensure January is included in monthly_start
    monthly_start = pd.date_range(start=start_date, end=time_index.max(), freq='MS')

    # Find the closest previous timestamps in the original time_index
    monthly_ticks = []
    for date in monthly_start:
        prior_dates = time_index[time_index <= date]
        if not prior_dates.empty:
            monthly_ticks.append(prior_dates[-1])

    monthly_ticks_str = [str(t)[:10] for t in monthly_ticks]

    # Extract the crop type information for the specified year
    crop_col = f'{year}_Crop'
    the_crops = pad_manan[pad_manan['type'] == 'named'][crop_col].fillna('') #crops planted this year for named paddocks only. 
    the_crops = the_crops.apply(lambda x: x.strip() if isinstance(x, str) else x).replace('', '')
    the_crops

    # get paddock names
    row_names = ds_paddocks.name.values
    
    # Merge paddock names and crop types for row labels
    if len(row_names) != len(the_crops):
        raise ValueError("The two lists must be of the same length.")
        
    names_crops = [f"{r} / {c}" for r, c in zip(row_names, the_crops)]
    
    # Plot heatmap with clustering of the rows
    plt.figure(figsize=(6, 12))  # Adjust the figure size as needed

    # from matplotlib.colors import LinearSegmentedColormap
    # # Define custom colormap
    # colors = ['#8B4513', '#FFFFFF', '#008000']  # Brown, White, Green
    # n_bins = 100  # Discretizes the interpolation into bins
    # cmap_name = 'custom_diverging'
    # custom_cmap = LinearSegmentedColormap.from_list(cmap_name, colors, N=n_bins)

    # g = sns.clustermap(pt_variable, method='average', metric='euclidean', 
    #                row_cluster=True, col_cluster=False, cmap=custom_cmap)
    
    g = sns.clustermap(pt_variable, method='average', metric='euclidean', 
                       row_cluster=True, col_cluster=False, cmap='viridis')

    # Get the order of the rows after clustering
    row_order = g.dendrogram_row.reordered_ind

    # Reorder the row names according to the clustering
    ordered_row_names = [names_crops[i] for i in row_order]

    # Customize the plot
    g.ax_heatmap.set_xlabel('Time')
    #g.ax_heatmap.set_ylabel('Paddock')

    # Set the x-tick labels to the closest previous valid timestamp of each month
    tick_positions = [time_index.get_loc(t) for t in monthly_ticks]
    g.ax_heatmap.set_xticks(tick_positions)
    g.ax_heatmap.set_xticklabels(monthly_ticks_str, rotation=45, ha='right')

    # Set the y-tick labels to the ordered paddock IDs
    g.ax_heatmap.set_yticks(np.arange(len(ordered_row_names)) + 0.5)
    g.ax_heatmap.set_yticklabels(ordered_row_names, fontsize=8, rotation=0)

    # # Customize the color bar and position it above the heatmap horizontally
    # g.cax.set_position([0.25, 1.02, 0.5, 0.02])  # [left, bottom, width, height]
    # g.cax.xaxis.set_ticks_position('top')
    # g.cax.xaxis.set_label_position('top')
    g.cax.set_title(variable_name, pad=10)  # Title above the color bar


    # Save the plot to results
    plt.savefig(outdir + stub + f"_pt-{variable_name}.png", bbox_inches='tight')

    # Print the number of missing pixels across all time series for the variable
    print(f'Number of missing pixels across all {variable_name} time series:', np.count_nonzero(np.isnan(pt_variable)))



In [None]:
# Example usage:
plot_clustermap1(ds_paddocks=ds_weekly_paddocks, variable_name='NDVI', outdir=outdir, stub=stub+"_"+str(year)+'_heatmap_dendro_NDVI')

In [None]:
# Example usage:
ds_weekly_paddocks_named = ds_weekly_paddocks.where(ds_weekly_paddocks['type'] == 'named', drop=True)
the_var='CFI3'
plot_clustermap2(ds_paddocks=ds_weekly_paddocks_named, variable_name=the_var, outdir=outdir, stub=stub+"_"+str(year)+'_heatmap_dendro_'+the_var+'_namedpaddocks')


### Read in temporal climatic data
- pandas time series would be best. Should be easy to subset the same time frame as the ds, and plot in 'dashboard plot'


### Create 'Dashboard' plots

In [None]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
from plotly.offline import plot, iplot
import plotly.io as pio

def plot_time_series(ds_paddocks, bands, pad_manan, year, outdir, stub):
    ''' Plot time series curves for multiple specified bands in a multi-panel interactive plot. '''
    
    # Extract the time values
    time_values = pd.to_datetime(ds_paddocks.time.values)
    
    # Get the paddock values from ds_paddocks
    paddock_values = ds_paddocks.paddock.values
    
    # Subset pad_manan to only include rows with paddocks in ds_paddocks
    pad_manan_subset = pad_manan[pad_manan['paddock'].isin(paddock_values)]
    
    # Extract the crop type information for the specified year
    crop_col = f'{year}_Crop'
    pad_manan_subset[crop_col] = pad_manan_subset[crop_col].apply(lambda x: 'Unknown' if pd.isna(x) or x.strip() == '' else x)
    crop_types = pad_manan_subset.set_index('paddock')[crop_col].to_dict()
    
    # Identify unique crop types excluding 'Unknown'
    unique_crop_types = pad_manan_subset[crop_col].unique()
    unique_crop_types = [ct for ct in unique_crop_types if ct != 'Unknown']
    
    # Create a color map for each unique crop type
    colors = px.colors.qualitative.Plotly  # You can choose a different color palette if needed
    num_colors = len(colors)
    color_map = {crop_type: colors[i % num_colors] for i, crop_type in enumerate(unique_crop_types)}
    color_map['Unknown'] = 'lightgrey'  # Assign light grey to 'Unknown'
    
    # Create subplots, one for each specified band
    fig = make_subplots(
        rows=len(bands), cols=1, 
        shared_xaxes=True, 
        vertical_spacing=0.05,
        subplot_titles=[None] * len(bands)
    )
    
    # Iterate over each specified band to create the subplots
    for i, band in enumerate(bands):
        if band not in ds_paddocks.variable.values:
            raise ValueError(f"Band '{band}' not found in the dataset.")
        
        # Extract data for the current band
        band_data = ds_paddocks.sel(variable=band).pvt.values
        
        # Add traces for each paddock
        for paddock_index, paddock_name in enumerate(ds_paddocks.name.values):
            paddock_id = paddock_values[paddock_index]
            crop_type = crop_types.get(paddock_id, 'Unknown')
            line_color = color_map.get(crop_type, 'lightgrey')
            
            # Print for debugging
            #print(paddock_name, crop_type, dict(color=line_color))
            
            hover_text = f'Paddock: {paddock_name}<br>Crop: {crop_type}'
            fig.add_trace(
                go.Scatter(
                    x=time_values,
                    y=band_data[paddock_index],
                    mode='lines',
                    name=f'Paddock {paddock_name}',
                    line=dict(color=line_color),
                    hoverinfo='text',
                    text=hover_text
                ),
                row=i+1, col=1
            )
        
        # Update y-axis title for each subplot
        fig.update_yaxes(title_text=band, row=i+1, col=1)
    
    # Update layout
    fig.update_layout(
        height=300 * len(bands),  # Adjust height according to the number of subplots
        title_text=(f"Year: {year}. Hover over lines to see paddock name and crop planted: {unique_crop_types}"),
        showlegend=False,
        plot_bgcolor='white',
        paper_bgcolor='white',
        margin=dict(l=50, r=50, t=50, b=50)
    )
    
    # Update the layout for a more classic look
    fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
    fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=True)
    fig.update_layout(
        xaxis=dict(showgrid=False),
        yaxis=dict(showgrid=False)
    )
    
    # Display the plot in the Jupyter Notebook
    iplot(fig)
    
    # Save the plot to a file and show it
    output_file = outdir + stub + "_time_series.html"
    plot(fig, filename=output_file, auto_open=False)
    print(f"Interactive plot saved as {output_file}")
    print(unique_crop_types)
# Example usage:
bands = ['NDVI', 'NIRv', 'CFI3', 'NDTI2', 'CAI']  # Specify the bands you want to plot
plot_time_series(ds_paddocks=ds_weekly_paddocks_named, bands=bands, pad_manan=pad_manan, year=year, outdir=outdir, stub=stub+"_"+str(year))



### Create calendar plots for each paddock
- For each line (year) include text alongside describing management

1. subset the paddocks gpd

2. For each row:
   - clip the ds_weekly
   - make the RGB animation
   - make and save the calendar plot

In [None]:
named_pads = list(pad_man[pad_man['type'] == 'named'].name)
named_pads

In [None]:
outdir+stub

In [None]:
def prepare_dataset(outdir, stub):
    """Reads and prepares the satellite data (xarray) for further processing (e.g. calendar plots)."""
    
    band_names = ['nbart_blue', 'nbart_green', 'nbart_red', 'nbart_nir_1', 'nbart_swir_2', 'nbart_swir_3']

    # Load the xarray dataset from the pickle file first
    with open(outdir + stub + '_ds2.pkl', 'rb') as handle:
        ds = pickle.load(handle)[band_names]
        
    ## Add veg fractions to ds
    i = 3  # or whichever model index you want to use
    ds = calculate_and_add_fractional_cover(ds, band_names, i, correction=False)

    # Resample the dataset to weekly intervals and interpolate
    ds_weekly_allyears = ds.resample(time="1W").interpolate("linear")

    # Determine the earliest time point in ds_weekly
    earliest_time = ds_weekly_allyears.time.values[0]

    # Create new time steps in 7-day increments before the earliest time, as long as the dates are in 2017
    new_time_steps = []
    current_time = pd.Timestamp(earliest_time) - pd.Timedelta(days=7)
    while current_time.year == 2017:
        new_time_steps.append(current_time)
        current_time -= pd.Timedelta(days=7)

    # Reverse the order of new_time_steps to ensure they are in ascending order
    new_time_steps = new_time_steps[::-1]

    # Create a new xarray Dataset for these additional time steps with NaN values
    nan_data_vars = {var: (('time',) + ds_weekly_allyears[var].dims[1:], np.full((len(new_time_steps),) + ds_weekly_allyears[var].shape[1:], np.nan)) for var in ds_weekly_allyears.data_vars}
    new_coords = {coord: ds_weekly_allyears.coords[coord] for coord in ds_weekly_allyears.coords if coord != 'time'}
    new_ds = xr.Dataset(
        data_vars=nan_data_vars,
        coords={**new_coords, 'time': new_time_steps}
    )

    # Concatenate the new dataset with ds_weekly
    ds_weekly_allyears = xr.concat([new_ds, ds_weekly_allyears], dim='time')

    return ds_weekly_allyears

# Prepare the dataset once
ds_weekly_allyears = prepare_dataset(outdir, stub)

In [None]:
ds_weekly_allyears

In [None]:
def calendar_plots(ds_weekly_allyears, pad_names, outdir, stub, pad_man):
    '''Generates 'calendar plots' (week by year RGB) paddocks specified in pad_names. Also includes option to use paddock ID instead of name.
    This can be used to visualize phenology across years for the same paddock.
    '''
    # List of paddock names to loop through
    print(pad_names)
    
    # Loop through each geometry in pad_man where 'name' is in pad_names
    for idx, row in pad_man[pad_man['name'].isin(pad_names)].iterrows():
        geometry = row['geometry']
        paddock_id = row['paddock']
        paddock_name = row['name']
        print(paddock_id, paddock_name, geometry)
    
        # Clip the dataset to the current geometry
        clipped_ds = ds_weekly_allyears.rio.clip([mapping(geometry)])
        
        # # Select only the red, green, and blue bands
        # clipped_ds_rgb = clipped_ds[['nbart_red', 'nbart_green', 'nbart_blue']]

        # Remove any time steps in 2024
        clipped_ds = clipped_ds.sel(time=clipped_ds['time'].dt.year != 2024)

        # Replace NaN values with 0
        clipped_ds = clipped_ds.fillna(0)
    
        # Create output filename
        out_name_RGB = outdir + stub + '_calendar_' + paddock_name + '_RGB.png'
        out_name_vegfrac = outdir + stub + '_calendar_' + paddock_name + '_vegfrac.png'
        
        # Run xr_animation
        rgb(clipped_ds, 
            bands=['nbart_red', 'nbart_green', 'nbart_blue'], 
            robust=True, 
            size=4,
            col="time", 
            col_wrap=52,  # weekly
            savefig_path=out_name_RGB)
        plt.close()
        
        rgb(clipped_ds, 
            bands=['bg', 'pv', 'npv'], 
            robust=True, 
            size=4,
            col="time", 
            col_wrap=52,  # weekly
            savefig_path=out_name_vegfrac)
        plt.close()
        print('Finished: ', out_name_vegfrac)

# To do: 
# 1 update plotting function to show text for year and crop alongside rows
# 2 would be good to also indicate which thumbnails are interpolated (grey box outline? colour the text above thumnail?)

# add a close plot function.. 

In [None]:
#pad_names = ['No 4', 'Pine Hill', 'Washpool', 'Rocky East']

pad_names = pad_man.loc[pad_man['type'] == 'named', 'name'].values
pad_names = ['Rocky', 'Rocky East',
             'Scramble Paddock', 'Washpool East', 'Bottom Range', 'Rocky South',
             'Washpool', 'Rubbish Paddock', 'Contour', 'Fingerboard', "Johnny's"]

calendar_plots(ds_weekly_allyears, pad_names, outdir, stub, pad_man)

### Create an interactive map of yearly crop type.
Next version, add other management info and yeild data, when available. 

In [None]:

import geopandas as gpd
import plotly.express as px
import json
from plotly.offline import plot

'''
The .html files don't load properly. Probably just make these as a normal plot... 
'''

# Ensure the GeoDataFrame has a valid geometry column
# (Assuming pad_manan is already a valid GeoDataFrame with a 'geometry' column)

# Reproject the GeoDataFrame to EPSG:4326
pad_manan = pad_manan.to_crs(epsg=4326)

# Convert the GeoDataFrame to GeoJSON format
gdf_json = pad_manan.to_json()

# Load GeoJSON data
geojson_data = json.loads(gdf_json)

# Create a list of hover text
pad_manan['hover_text'] = pad_manan.apply(lambda row: f'Paddock: {row["name"]}<br>Crop: {row[f"{year}_Crop"]}', axis=1)

# Create the interactive map with Plotly Express
fig = px.choropleth_mapbox(
    pad_manan,
    geojson=geojson_data,
    locations='paddock',
    featureidkey="properties.paddock",
    color=f'{year}_Crop',  # Use the crop type for coloring
    hover_name='hover_text',
    hover_data={f'{year}_Crop': False},  # Hide the crop type from hover data
    mapbox_style="carto-positron",
    center={"lat": pad_manan.geometry.centroid.y.mean(), "lon": pad_manan.geometry.centroid.x.mean()},
    zoom=10,
    opacity=0.5
)

# Update layout for better visualization
fig.update_layout(
    title_text=f'Crop Types for the Year {year}',
    title_x=0.5,
    margin={"r":0,"t":0,"l":0,"b":0}
)

# Display the map in the Jupyter Notebook
fig.show()

# Save the plot to a file and show it
output_file = f'{outdir}/{stub}_{year}_crop_map_.html'
plot(fig, filename=output_file, auto_open=False)
print(f"Interactive map saved as {output_file}")

In [None]:
pad_man['name']

In [None]:
### Static map of labelled paddocks using the Fourier Transform tif

# Load the Fourier Transform image
raster_path = outdir+stub+'.tif'

#pol = padman # which type = named
pol = pad_man[pad_man['type'] == 'named']

# read raster and convert to RGB
with rasterio.open(raster_path) as src:
    # Read the three bands
    band1 = src.read(1)  # Red
    band2 = src.read(2)  # Green
    band3 = src.read(3)  # Blue
    
    # Stack the bands into an RGB image
    rgb = np.dstack((band1, band2, band3))
    
    # Normalize to 0-1
    rgb = rgb.astype('float32')
    rgb /= rgb.max()
    
    # Ensure CRS matches the raster image (but don't want to change crs, right?) (removed for now)
    #pad_man = pad_man.to_crs(src.crs)

# Plotting
fig, ax = plt.subplots(figsize=(10, 10))

# Display the RGB image
ax.imshow(rgb, extent=(src.bounds.left, src.bounds.right, src.bounds.bottom, src.bounds.top))

# Overlay the paddock polygons
pol.plot(ax=ax, facecolor='none', edgecolor='red', linewidth=1)

# Add paddock labels
# for x, y, label in zip(pol.geometry.centroid.x, pol.geometry.centroid.y, pol['name']):
#     ax.text(x, y, label, fontsize=8, weight = 'bold', ha='center', va='center', color='yellow')
# Add paddock labels (new line for space)
for x, y, label in zip(pol.geometry.centroid.x, pol.geometry.centroid.y, pol['name']):
    label_with_newlines = label.replace(' ', '\n')
    ax.text(x, y, label_with_newlines, fontsize=8, ha='center', va='center', color='yellow', weight='bold')

# Save the figure with the correct size and resolution
plt.savefig(outdir+stub+'_map-padman.tif', dpi=300, bbox_inches='tight')

plt.axis('off')
plt.show()


### Yield x NPP relationships
Preliminary analysis of reported paddock yeilds as a function of total season NPP. Using NIRv as proxy for NPP. 

Requires:
- paddock-variable-time xarray data for each year to be created and saved. (e.g. MILG_b033_2017-24_ds_weekly_paddocks_2022.pkl)
- paddock annotations including crop and yield (paddock-year-yield.csv)
- some function for estimating key phenology transition dates from paddock-level satellite data (to be developed more later...)

Steps:
1. For each paddock-year, estimate total growth season NPP as (something like) the total NIRv during the crop growth season. Also other vars such as growth season length, etc.
2. Plot total estimated NPP by reported yield for each paddock-year. Use plotly to label points by year and crop

In [None]:
import statsmodels.api as sm
import plotly.express as px
import plotly.graph_objects as go



In [None]:
def clean_paddock_year_df(df):
    # Drop rows where 'Year' is NaN
    df_clean = df.dropna(subset=['Year'])
    
    # Keep the required columns and rename 'Yield_actuals' to 'Yield'
    df_clean = df_clean[['Year', 'Paddock', 'Crop', 'ha', 'Yield_actuals']].rename(columns={'Yield_actuals': 'Yield'})
    
    # Convert 'Year', 'Paddock', and 'Crop' to categorical types
    df_clean['Year'] = pd.to_datetime(df_clean['Year'], format='%Y').dt.year
    df_clean['Paddock'] = df_clean['Paddock'].astype('category')
    df_clean['Crop'] = df_clean['Crop'].astype('category')
    
    # Convert 'ha' to integer
    df_clean['ha'] = df_clean['ha'].astype('int')
    
    # Convert 'Yield' to numeric, force errors to NaN
    df_clean['Yield'] = pd.to_numeric(df_clean['Yield'], errors='coerce')

    df_clean['Yield'].replace(" ", pd.NA, inplace=True)
    
    # Drop rows where 'Yield' is NaN or 0
    #df_clean = df_clean.dropna(subset=['Yield'])
    #df_clean = df_clean[df_clean['Yield'] != 0]
    
    return df_clean


In [None]:
# read in a newer paddock annotation file that has yeild estimates for every named paddock each year
# strip back the comments and etc. to leave paddock name, year, crop, yeild, 
'''TO DO:
Replace the pad_an file at the start of this notebook with this paddock annotation file, as it has more info. 
'''
paddock_annotations2 = "/g/data/xe2/John/Data/PadSeg/paddock-year-yield.csv" # the latest version of paddock management annotation data (assumes format stays the same!)
pad_year = clean_paddock_year_df(pd.read_csv(paddock_annotations2))
print(pad_year)

In [None]:
df = pad_year.dropna()
print(len(pad_year))
print(len(df))

# First plot: Yield by Year
plt.figure(figsize=(8, 4))
sns.boxplot(x='Year', y='Yield', data=df)
plt.title('Yield by Year')
plt.show()

# Second plot: Yield by Paddock and Yield by Crop side-by-side with rotated axes
fig, axes = plt.subplots(1, 2, figsize=(8, 6))

# Yield by Paddock
sns.boxplot(ax=axes[0], y='Paddock', x='Yield', data=df)
axes[0].set_title('Yield by Paddock')

# Yield by Crop
sns.boxplot(ax=axes[1], y='Crop', x='Yield', data=df)
axes[1].set_title('Yield by Crop')

plt.tight_layout()
plt.show()

# Why is this still showing crop types for which there no correspoding row in the df?

In [None]:
# This function is very messy. It started out estimating NIRv median for each paddock year, then I added the entire time series (but it is staggerd becasue of years), then added the sum of NIRv during May-Dec.
# Next time:
# Keep the pad_year df of paddock, year, crop yield separate to the time series data. 
# Make functions that run through the time series and drop new predictor variables into pad_year 
# The function for generating new predictors should be structured like the function to get spectral indices. Write functions, then list vars to get, then add them. 

def paddock_year_ts(pad_year,variable_name, year):
    # Ensure the Paddock column is a string for matching with xarray 'name'
    pad_year['Paddock'] = pad_year['Paddock'].astype(str)
    
    # Filter pad_year for the specified year
    pad_year_filtered = pad_year[pad_year['Year'] == year]

    ##### Open ds_weekly for given year:
    with open(outdir+stub+'_ds_weekly_paddocks_'+str(year)+'.pkl', 'rb') as handle:
        ds_weekly_paddocks = pickle.load(handle)
    
    # Extract time coordinates from the xarray dataset
    time_coords = ds_weekly_paddocks.time.values
    
    # Create an empty list to store the results
    results = []

    for paddock in pad_year_filtered['Paddock'].unique():
        # Find the index in ds_weekly_paddocks where the name matches the paddock
        matching_paddock = ds_weekly_paddocks.sel(paddock=ds_weekly_paddocks.name == paddock)
        
        if matching_paddock.paddock.size > 0:
            # Extract the time series for the specified variable
            variable_data = matching_paddock.sel(variable=variable_name).pvt.values.flatten()
            
            # Ensure the variable data matches the length of the time index
            if len(variable_data) == len(time_coords):
                # Calculate the median value ignoring nan values
                median_value = np.nanmedian(variable_data)
                
                # Create a dictionary to hold the row data
                row_data = {
                    'Year': year,
                    'Paddock': paddock,
                    f'median_{variable_name}': median_value
                }
                
                # Convert the time series data to a pandas Series
                time_series = pd.Series(variable_data, index=time_coords)

                # Filter the time series to include only the months May-December
                may_december = time_series[time_series.index.month >= 5]
                may_december = may_december[may_december.index.month <= 12]
                
                # Calculate the sum of NIRv during May-December
                sum_nirv_may_december = may_december.sum()

                row_data = {
                    'Year': year,
                    'Paddock': paddock,
                    f'median_growth_season_sumNIRv': sum_nirv_may_december
                }
                
                # Append the time series data to the row data
                row_data.update(time_series.to_dict())
                
                results.append(row_data)

    # Convert the results list to a DataFrame
    results_df = pd.DataFrame(results)

    # Filter the final DataFrame to only include rows corresponding to the specified year
    final_df = pad_year_filtered.merge(results_df, on=['Year', 'Paddock'], how='left')
    
    return final_df

#paddock_year_ts(pad_year,variable_name, 2019)


In [None]:
variable_name = 'NIRv'

all_years_df = pd.DataFrame()
for year in range(2018, 2024):
    #print(year)
    df = paddock_year_ts(pad_year, variable_name, year)
    all_years_df = pd.concat([all_years_df, df], ignore_index=True)

all_years_df['Year'] = all_years_df['Year'].astype(str)


all_years_df

## NEXT: Once there is a ds_weekly_paddocks saved for every year, run this function through each year and concatenate the output. Make plotly scatterplot and save with group.
# This is a strange dataset because the combinations of year-paddock are never consistent across the years.. So it's a weird patchy df with blocks of nans. 

In [None]:

df = all_years_df[all_years_df['Yield']>0]
df = df.dropna(subset=['Yield', 'median_growth_season_sumNIRv'])

# Fit the linear regression model
X = sm.add_constant(df['median_growth_season_sumNIRv'])
y = df['Yield']
model = sm.OLS(y, X).fit()
intercept, slope = model.params

# Create the scatter plot with Plotly
fig = px.scatter(
    df,
    x='median_growth_season_sumNIRv',
    y='Yield',
    color='Year',
    hover_data=['Year', 'Paddock', 'Crop'],
    labels={'median_growth_season_sumNIRv': 'Median NIRv', 'Yield': 'Yield'},
    title='Scatter plot of Median NIRv vs Yield'
)

# Add the line of best fit
x_vals = np.array([df['median_growth_season_sumNIRv'].min(), df['median_growth_season_sumNIRv'].max()])
y_vals = intercept + slope * x_vals
fig.add_trace(go.Scatter(x=x_vals, y=y_vals, mode='lines', name='Best Fit Line'))

# Update the layout to classic styling
fig.update_layout(
    plot_bgcolor='white',
    paper_bgcolor='white',
    xaxis=dict(showgrid=False, linecolor='black'),
    yaxis=dict(showgrid=False, linecolor='black'),
    title=dict(font=dict(size=20, color='black')),
    legend=dict(font=dict(size=12)),
    margin=dict(l=40, r=40, t=40, b=40),
    autosize=True
)

# Show the plot
fig.show()

In [None]:
# linear model:

# Define the independent variable (X) and dependent variable (y)
X = df['median_growth_season_sumNIRv']
y = df['Yield']

# Add a constant to the independent variable matrix
X = sm.add_constant(X)

# Fit the linear regression model
model = sm.OLS(y, X).fit()

# Print the summary of the regression model
print(model.summary())

# Check the p-value for the coefficient of median_NIRv to test significance
p_value = model.pvalues['median_growth_season_sumNIRv']
print(f"P-value for median_NIRv coefficient: {p_value}")

# Interpret the results
if p_value < 0.05:
    print("The relationship between median_growth_season_sumNIRv and Yield is significantly positive.")
else:
    print("The relationship between median_growth_season_sumNIRv and Yield is not significantly positive.")

In [None]:
# Create the scatter plot with Plotly
fig = px.scatter(
    df,
    x='median_NIRv',
    y='Yield',
    color='Year',
    hover_data=['Year', 'Paddock', 'Crop'],
    labels={'median_NIRv': 'Median NIRv', 'Yield': 'Yield'},
    title='Scatter plot of Median NIRv vs Yield'
)

# Update the layout to classic styling
fig.update_layout(
    plot_bgcolor='white',
    paper_bgcolor='white',
    xaxis=dict(showgrid=False, linecolor='black'),
    yaxis=dict(showgrid=False, linecolor='black'),
    title=dict(font=dict(size=20, color='black')),
    legend=dict(font=dict(size=12)),
    margin=dict(l=40, r=40, t=40, b=40),
    autosize=True
)

# Save the plot as an HTML file
fig.write_html(outdir+stub+"_NIRv-yield_scatter.html")

# Show the plot
fig.show()

In [None]:
outdir


In [None]:
ds