In [5]:
import pandas as pd
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt

def get_day_array(date_string, days):
    start_date = pd.to_datetime(date_string)
    day_array = [(start_date + pd.Timedelta(days=i)).strftime('%Y-%m-%d') for i in range(days)]
    return day_array

# Set path to data and coordinate files
data_path = 'processed_airs_sftp/40_km_grid/'
lat_file = '/Users/joaojesus/Desktop/final_year_proj/coordinates_40km_grid/Airs_nh_lat_40_grid.csv'
lon_file = '/Users/joaojesus/Desktop/final_year_proj/coordinates_40km_grid/Airs_nh_lon_40_grid.csv'

In [36]:
import pandas as pd

# Load latitude and longitude arrays
lat = np.genfromtxt(lat_file, delimiter=',')
lon = np.genfromtxt(lon_file, delimiter=',')

# Flatten latitude and longitude arrays
lat_flat = lat.flatten()
lon_flat = lon.flatten()

# Create an empty xarray dataset
ds = xr.Dataset()

# Loop over altitude levels and load data files
for alt in [30, 36]:
    data = []
    dates = get_day_array(date_string="2009-01-01", days=30)
    for date in dates:
        try:
            file_name = f'{date}_mfx.npz'
            file_path = data_path + f'{alt}km/mfx/{file_name}'
            mfx_data = np.load(file_path)['arr_0']
            # Flatten mfx_data
            mfx_data_flat = mfx_data.flatten()
            data.append(mfx_data_flat)
        except:
            print(f'File {file_name} not found in {file_path}')
    # Stack data along a new time dimension and add to dataset
    ds[f'mfx_{alt}km'] = xr.DataArray(np.stack(data), dims=('time', 'points'))

# Set coordinates and attributes
ds = ds.assign_coords(lat=('points', lat_flat), lon=('points', lon_flat))



In [37]:
ds

In [None]:
# Group dataset by unique latitude and longitude pairs and take the mean
ds_grouped = ds.groupby(['lat', 'lon']).mean()

# Add time dimension and attributes
ds_grouped['time'] = xr.DataArray(dates, dims=('time'))
ds_grouped.attrs['units'] = 'm^2/s^2'
ds_grouped.attrs['description'] = 'Momentum Flux'

# Plot contours for specific latitude and longitude point
fig, ax = plt.subplots(figsize=(10, 6))
ds_grouped.sel(lat=60, lon=60, method='nearest')[['mfx_30km', 'mfx_36km']].plot.line(ax=ax)
ax.set_xlabel('Time')
ax.set_ylabel('Momentum Flux [m$^2$/s$^2$]')
plt.show()
