In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import os
import pickle
import numpy as np
import xarray as xr
sns.set_theme(style="whitegrid")

fig, ax = plt.subplots(1,2,figsize=(15,5))
heatwave_cases = [n for n in os.listdir() if 'heatwave_cases' in n and 'pkl' in n]
for case_file in heatwave_cases:
    with open(case_file, 'rb') as f:
        case = pickle.load(f)
    # Initialize counters for this case file
    model_name = case_file.split('.')[0]
    results = {
        'total_cases': 0,
        'gridded_non_none': 0,
        'point_non_none': 0,
        'both_non_none': 0
    }
    
    # Count total cases and those with non-None values
    if 'heat_wave' in case:
        results['total_cases'] = len(case['heat_wave'])
        
        for case_id, case_data in case['heat_wave'].items():
            if 'surface_air_temperature' in case_data:
                temp_data = case_data['surface_air_temperature']
                
                # Check if gridded metrics are non-None
                has_gridded = any(metric is not None for metric in temp_data.get('gridded', {}).values())
                
                # Check if point metrics are non-None
                has_point = any(metric is not None for metric in temp_data.get('point', {}).values())
                
                if has_gridded:
                    results['gridded_non_none'] += 1
                
                if has_point:
                    results['point_non_none'] += 1
                
                if has_gridded and has_point:
                    results['both_non_none'] += 1  
    name = '_'.join(case_file.split('_')[2:5]).split('.')[0]
    if name == 'FOUR_v100_GFS':
        continue
    gridded_list=[]
    point_list=[]
    case_ids = np.arange(0,len(case['heat_wave']))
    if results['both_non_none'] == 0:
        continue
    for n in case['heat_wave']:
        if case['heat_wave'][n]['surface_air_temperature']['gridded']['MaximumMAE'] is not None:
            gridded_list.append(case['heat_wave'][n]['surface_air_temperature']['gridded']['MaximumMAE'])
        if case['heat_wave'][n]['surface_air_temperature']['point']['MaximumMAE'] is not None:
            point_list.append(case['heat_wave'][n]['surface_air_temperature']['point']['MaximumMAE'])
    gridded_da = xr.concat(gridded_list,np.arange(0,len(gridded_list)),)
    point_da = xr.concat(point_list,np.arange(0,len(point_list)),)

    if name == 'hres':
        gridded_da.mean('concat_dim').plot(ax=ax[0], label=f'{name}, n = {results['both_non_none']}', color='k', linestyle='--')
        point_da.mean('concat_dim').plot(ax=ax[1], label=f'{name}, n = {results['both_non_none']}', color='k', linestyle='--')
        continue  # Skip the default plotting below for hres    
    gridded_da.mean('concat_dim').plot(ax=ax[0],label=f'{name}, n = {results['both_non_none']}')
    point_da.mean('concat_dim').plot(ax=ax[1],label=f'{name}, n = {results['both_non_none']}')
    # Add special formatting for HRES model

for axis in ax:
    axis.set_ylim(0,10)
ax[0].set_ylabel('MAE, Maximum Temp of Event (C)')
ax[1].set_ylabel('')
ax[0].set_xlabel('Lead Time (hours)')
ax[1].set_xlabel('Lead Time (hours)')
ax[0].set_title('Heat Wave Cases, Gridded Obs', loc='left')
ax[1].set_title('Heat Wave Cases, Point Obs', loc='left')
ax[0].legend()
ax[1].legend()
plt.show()