In [None]:
import time
import os
import importlib
from tqdm.auto import tqdm #from functools import cache

import numpy as np
import xarray as xr
import dask
from dask.distributed import Client, LocalCluster
from dask_jobqueue import SLURMCluster
import pickle
import json
import datetime

import pandas as pd
import geopandas as gpd
from shapely.geometry import Polygon

import cartopy.crs as ccrs
import cartopy.feature as cf
import matplotlib.pyplot as plt
from matplotlib import ticker, cm

import matplotlib.colors as colors
import matplotlib.animation as animation
from matplotlib.ticker import LogFormatter, MaxNLocator

from IPython.display import display

direc = os.getcwd()
os.chdir("/home/b/b381737/python_scripts/master/open_data")
import open_data_utils #import the module here, so that it can be reloaded.
importlib.reload(open_data_utils)
from open_data_utils import xr_to_gdf, add_country_names, country_intersections, select_extent, FlexDataset, FlexDataCollection, load_nc_partposit, FlexDataset2, calc_enhancement
os.chdir(direc)



In [None]:
cluster = SLURMCluster(name='dask-cluster',
                        cores=8,
                        processes=8,
                        n_workers=8,
                        memory='10GB',
                        interface='ib0',
                        queue='prepost',
                        project='bb1170',
                        walltime='12:00:00',
                        asynchronous=0)
client = Client(cluster)
client

In [None]:
client.shutdown()

In [None]:
dir_name = "/mnt/lustre02/work/bb1170/static/CT2019/Conc3hour_3x2/"
file_dummy = "CT2019B.molefrac_glb3x2_"
ct_file = "/work/bb1170/static/CT2019/Flux3hour_1x1/CT2019B.flux1x1."

In [None]:
file = "/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI20091123-20091207/sensitivity_part_nums/Darwin/wu200k/grid_time_20091207210000.nc"
fd = FlexDataset2(file, extent=[0, 180,-80,30], ct_dir=dir_name, ct_name_dummy=file_dummy, chunks=dict(time=20, pointspec=4))
fp = fd.footprint
fig, ax = fp.plot()
_ = fd.add_map(ax)

In [None]:
fd.enhancement(ct_file, allow_read=False)

In [None]:
file = "/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/test_traj0/grid_time_20091127230000.nc"
fd = FlexDataset2(file, extent=[0, 180,-80,30], ct_dir=dir_name, ct_name_dummy=file_dummy, chunks=dict(time=20, pointspec=4))
tr = fd.trajectories
tr.ct_endpoints()
tr.co2_from_endpoints()
tr.load_endpoints()

In [None]:
fd.background(allow_read=False)

In [None]:
dir_path = "/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI20091123-20091207"
setups = ["wu_setup", "pressure_setup", "pressure_wu_setup", "pressure_wu_fewer_setup", "test_setup", "wu_setup_2", "wu_fewer_setup", "wu_more_setup", "wu_setup_CT-grid", "wu_setup_CT-grid_mixing_ratio"]

wu_file = os.path.join(dir_path, setups[0], "grid_time_20091207210000.nc")
pr_file = os.path.join(dir_path, setups[1], "grid_time_20091207210000.nc")
pw_file = os.path.join(dir_path, setups[2], "grid_time_20091207210000.nc")
pwf_file = os.path.join(dir_path, setups[3], "grid_time_20091207210000.nc")
test_file = os.path.join(dir_path, setups[4], "grid_time_20091207210000.nc")
wu_33_file = "/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI_33-60_20091123-20091207/wu/grid_time_20091207210000.nc"
wu2_file = os.path.join(dir_path, setups[5], "grid_time_20091207210000.nc")
wu_few_file = os.path.join(dir_path, setups[6], "grid_time_20091207210000.nc")
wu_more_file = os.path.join(dir_path, setups[7], "grid_time_20091207210000.nc")
wu_ct_file = os.path.join(dir_path, setups[8], "grid_time_20091207210000.nc")
wu_ct_mr_file = os.path.join(dir_path, setups[9], "grid_time_20091207210000.nc")
wu_100_file = os.path.join("/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI20091123-20091207/sensitivity_part_nums/Wollongong/wu100k/grid_time_20091207210000.nc")
pr_100_file = os.path.join("/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI20091123-20091207/sensitivity_part_nums/Wollongong/pressure100k/grid_time_20091207210000.nc")
unit_w_file = os.path.join("/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI20091123-20091207/sensitivity_rel_height/unit_wollongong/grid_time_20091207210000.nc")
unit_d_file = os.path.join("/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI20091123-20091207/sensitivity_rel_height/unit_darwin/grid_time_20091207210000.nc")
unit_d_long_file = os.path.join("/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI20091116-20091207/sensitivity_rel_time/wu40k_darwin/grid_time_20091207210000.nc")
unit_w_long_file = os.path.join("/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI20091116-20091207/sensitivity_rel_time/wu40k_wollongong/grid_time_20091207210000.nc")
unit_d_long_shift_file = os.path.join("/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI20091116-20091207/sensitivity_rel_time/wu40k_darwin_shift/grid_time_20091207210000.nc")
unit_w_long_shift_file = os.path.join("/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI20091116-20091207/sensitivity_rel_time/wu40k_wollongong_shift/grid_time_20091207210000.nc")
unit_d_file = os.path.join("/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI20091116-20091207/trajectories/unit_darwin040k/grid_time_20091207210000.nc")
unit_w_file = os.path.join("/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI20091116-20091207/trajectories/unit_wollongong040k/grid_time_20091207210000.nc")
ea_file = "/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/test/grid_time_20091127230000.nc"
ea_single_file = "/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EA_test_box/grid_time_20091127230000.nc"

#fd = FlexDataset(dir_path + file_path, extent=[100,180,-60,0], chunks=dict(time=15, pointspec=4))
extent = [30,180,-80,0]
chunks = dict(time=15, pointspec=4)
#chunks=None
#fd_wu = FlexDataset(wu_file, extent, chunks=chunks, name="Wu Setup", cmaps="Reds", norm=colors.LogNorm(), station_names=["Darwin", "Wollongong"])
#fd_pr = FlexDataset(pr_file, extent, chunks=chunks, name="Pressure Setup", cmaps="Blues", norm=colors.LogNorm(), station_names=["Darwin", "Wollongong"])
#fd_pw = FlexDataset(pw_file, extent, chunks=chunks, name="Pressure Wu Setup", cmaps="Greens", norm=colors.LogNorm(), station_names=["Darwin", "Wollongong"])
#fd_pwf = FlexDataset(pwf_file, extent, chunks=chunks, name="Pressure Wu (fewer) Setup", cmaps="Purples", norm=colors.LogNorm(), station_names=["Darwin", "Wollongong"])
#fd_test = FlexDataset(test_file, extent, chunks=chunks, name="Test Setup", cmaps="Greens", norm=colors.LogNorm(), station_names=["Darwin", "Wollongong"])
#fd_wu_33 = FlexDataset(wu_33_file, extent, chunks=chunks, name="Wu Setup 33-60", cmaps="Greens", norm=colors.LogNorm(), station_names=["Darwin", "Wollongong"])
#fd_wu2 = FlexDataset(wu2_file, extent, chunks=chunks, name="Wu Setup 2", cmaps="Greens", norm=colors.LogNorm(), station_names=["Darwin", "Wollongong"])
#fd_wu_few = FlexDataset(wu_few_file, extent, chunks=chunks, name="Wu (less particles)", cmaps="Reds", norm=colors.LogNorm(), station_names=["Darwin", "Wollongong"])
#fd_wu_more = FlexDataset(wu_more_file, extent, chunks=chunks, name="Wu (more particles)", cmaps="Greens", norm=colors.LogNorm(), station_names=["Darwin", "Wollongong"])
#fd_wu_ct = FlexDataset(wu_ct_file, extent, chunks=chunks, name="Wu (Carbon Tracker grid)", cmaps="Greens", norm=colors.LogNorm(), station_names=["Darwin", "Wollongong"])
#fd_wu_ct_mr = FlexDataset(wu_ct_mr_file, extent, chunks=chunks, name="Wu (Carbon Tracker grid, mr)", cmaps="Greens", norm=colors.LogNorm(), station_names=["Darwin", "Wollongong"])
#fd_wu_100 = FlexDataset(wu_100_file, extent, chunks=chunks, name="Wu (100k parts)", cmaps="Greens", norm=colors.LogNorm(), station_names=["Wollongong"])
#fd_pr_100 = FlexDataset(pr_100_file, extent, chunks=chunks, name="Pressure (100k parts)", cmaps="Greens", norm=colors.LogNorm(), station_names=["Wollongong"])
#fd_unit_w = FlexDataset(unit_w_file, extent, chunks=chunks, name="Unit (100k parts)", cmaps="Greens", norm=colors.LogNorm(), station_names=["Wollongong"])
#fd_unit_d = FlexDataset(unit_d_file, extent, chunks=chunks, name="Unit (100k parts)", cmaps="Greens", norm=colors.LogNorm(), station_names=["Darwin"])
#fd_unit_d_long = FlexDataset(unit_d_long_file, chunks=chunks, name="Unit (40k parts)", cmaps="jet", norm=colors.LogNorm(), station_names=["Darwin"])
#fd_unit_w_long = FlexDataset(unit_w_long_file, chunks=chunks, name="Unit (40k parts)", cmaps="jet", norm=colors.LogNorm(), station_names=["Wollongong"])
#fd_unit_d_long_shift = FlexDataset(unit_d_long_shift_file, chunks=chunks, name="Unit (40k parts, shifted)", cmaps="jet", norm=colors.LogNorm(), station_names=["Darwin"])
#fd_unit_w_long_shift = FlexDataset(unit_w_long_shift_file, chunks=chunks, name="Unit (40k parts, shifted)", cmaps="jet", norm=colors.LogNorm(), station_names=["Wollongong"])
#fd_unit_d = FlexDataset(unit_d_file, chunks=chunks, name="Unit Darwin (40k parts)", cmaps="jet", norm=colors.LogNorm(), station_names=["Darwin"])
#fd_unit_w = FlexDataset(unit_w_file, chunks=chunks, name="Unit Wollongong (40k parts)", cmaps="jet", norm=colors.LogNorm(), station_names=["Wollongong"])
fd_ea = FlexDataset(ea_file, chunks=chunks, name="ERA 5 data", cmaps="jet", norm=colors.LogNorm(), station_names=["Wollongong"])
fd_single_ea = FlexDataset(ea_single_file, chunks=dict(time=10, pointspec=400), name="ERA 5 data", cmaps="jet", norm=colors.LogNorm(), station_names=["Wollongong"])

In [None]:
extent = [30,180,-80,0]
chunks = dict(time=15, pointspec=4)
dir_path = "/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI20091123-20091207"
set_path = ["05grid","10grid","20grid"]
setups = ["wu_setup", "pressure_setup", "pressure_wu_setup"]

fd_wu05_file = os.path.join(dir_path, set_path[0], setups[0], "grid_time_20091207210000.nc")
fd_pr05_file = os.path.join(dir_path, set_path[0], setups[1], "grid_time_20091207210000.nc")
fd_pw05_file = os.path.join(dir_path, set_path[0], setups[2], "grid_time_20091207210000.nc")
fd_wu10_file = os.path.join(dir_path, set_path[1], setups[0], "grid_time_20091207210000.nc")
fd_pr10_file = os.path.join(dir_path, set_path[1], setups[1], "grid_time_20091207210000.nc")
fd_pw10_file = os.path.join(dir_path, set_path[1], setups[2], "grid_time_20091207210000.nc")
fd_wu20_file = os.path.join(dir_path, set_path[2], setups[0], "grid_time_20091207210000.nc")
fd_pr20_file = os.path.join(dir_path, set_path[2], setups[1], "grid_time_20091207210000.nc")
fd_pw20_file = os.path.join(dir_path, set_path[2], setups[2], "grid_time_20091207210000.nc")


#fd_wu05 = FlexDataset(fd_wu05_file, extent, chunks=chunks, name="Wu 0.5°", cmaps="Greens", norm=colors.LogNorm(), station_names=["Darwin", "Wollongong"])
#fd_pr05 = FlexDataset(fd_pr05_file, extent, chunks=chunks, name="Pressure 0.5°", cmaps="Blues", norm=colors.LogNorm(), station_names=["Darwin", "Wollongong"])
#fd_pw05 = FlexDataset(fd_pw05_file, extent, chunks=chunks, name="Pressure Wu 0.5°", cmaps="Reds", norm=colors.LogNorm(), station_names=["Darwin", "Wollongong"])
#fd_wu10 = FlexDataset(fd_wu10_file, extent, chunks=chunks, name="Wu 1°", cmaps="Greens", norm=colors.LogNorm(), station_names=["Darwin", "Wollongong"])
#fd_pr10 = FlexDataset(fd_pr10_file, extent, chunks=chunks, name="Pressure 1°", cmaps="Blues", norm=colors.LogNorm(), station_names=["Darwin", "Wollongong"])
#fd_pw10 = FlexDataset(fd_pw10_file, extent, chunks=chunks, name="Pressure Wu 1°", cmaps="Reds", norm=colors.LogNorm(), station_names=["Darwin", "Wollongong"])
#fd_wu20 = FlexDataset(fd_wu20_file, extent, chunks=chunks, name="Wu 2°", cmaps="Greens", norm=colors.LogNorm(), station_names=["Darwin", "Wollongong"])
#fd_pr20 = FlexDataset(fd_pr20_file, extent, chunks=chunks, name="Pressure 2°", cmaps="Blues", norm=colors.LogNorm(), station_names=["Darwin", "Wollongong"])
#fd_pw20 = FlexDataset(fd_pw20_file, extent, chunks=chunks, name="Pressure Wu 2°", cmaps="Reds", norm=colors.LogNorm(), station_names=["Darwin", "Wollongong"])

In [None]:
# Plot one footprint 
fd = fd_ea
fd.extent = [60,180,-60,0]
for i in [0]:
    station = i
    fig, ax = fd.subplots(figsize=(12,4))
    fd.plot_footprint(ax, station=station, plot_station=True, cmap="jet",station_kwargs=dict(color="black", label=fd.station_names[station]), cbar_kwargs=dict(label="s*m^3/kg"))
    fd.add_map(ax)
    ax.set_title(f"Footprint with {fd.name}")
    plt.legend(loc="upper right")
    plt.show()

In [None]:
# Plot one footprint 
dtype="datetime64[D]"
for day in range(5,10):
    fd = fd_unit_w
    fd.extent = [-180,180,-80,40]
    for i in [0]:
        station = i
        fig, ax = fd.subplots(figsize=(12,4))
        fd.plot(ax, station=station, time=np.arange(-(day+1)*8,-day*8), pointspec=np.arange(len(fd.DataSet.pointspec)), plot_station=True, cmap="jet",station_kwargs=dict(color="black", label=fd.station_names[station]), cbar_kwargs=dict(label="s*m^3/kg"))
        fd.add_map(ax)
        ax.set_title(f"Footprint {fd.DataSet.time[-(day+1)*8].values.astype(dtype)} - {fd.DataSet.time[-day*8].values.astype(dtype)}")
        plt.legend(loc="upper right")
        plt.show()

In [None]:
# Plot differences
fd1 = fd_unit_w_long
fd2 = fd_unit_w_long_shift
diff = fd2.Footprints[0]-fd1.Footprints[0]
fig, ax = fd1.subplots(figsize=(10,4.5))
diff.plot(ax=ax,cmap="seismic", cbar_kwargs=dict(label="Footprint[s]"))
fd1.add_map(ax=ax, )
ax.set_title(f"{fd2.name} and {fd1.name}")
#plt.savefig("figures/australia/column_setups/diff_wu_pr.png", dpi=300)

In [None]:
fds = [fd_wu, fd_wu_33]

vmin = 1
vmax = 0
for i,fd in enumerate(fds):
    vmin = fd.vmin() if fd.vmin()<vmin else vmin
    vmax = fd.vmax() if fd.vmax()>vmax else vmax

station = 1
setup = ["Wu", "Pressure Wu"]

for i,fd in enumerate(fds):
    fig, ax = plt.subplots(figsize=(10,10), subplot_kw=dict(projection=ccrs.PlateCarree())) if i == 0 else (fig,ax)
    fd.plot_footprint(ax, station, plot_station=True, vmin=vmin, vmax=vmax, cbar_kwargs=dict(label=f"Footprint [s] ({fd.name})", orientation="horizontal", pad=0.035))
    fd.add_map(ax) if i == 0 else None
    ax.set_title("Comparison of Footprints")
plt.savefig("figures/australia/column_setups/33_vs_normal.png", dpi=300)
plt.show()


fds.reverse()
for i,fd in enumerate(fds):
    fig, ax = plt.subplots(figsize=(10,10), subplot_kw=dict(projection=ccrs.PlateCarree())) if i == 0 else (fig,ax)
    fd.plot_footprint(ax, station, plot_station=True, vmin=vmin, vmax=vmax, cbar_kwargs=dict(label=f"Footprint [s] ({fd.name})", orientation="horizontal", pad=0.035))
    fd.add_map(ax) if i == 0 else None
    ax.set_title("Comparison of Footprints")
plt.show()

In [None]:
fds = [fd_wu05, fd_pr05, fd_pw05]
station = 0
rel = []
for ind in range(len(fds)):
    inds = np.delete(np.arange(len(fds)), ind)

    columns =["Setup","Footprint sum", f"Total difference to {fds[ind].name}", f"Relative difference to {fds[ind].name}"]
    df = pd.DataFrame(columns=columns)
    comp_fd = fds[ind]
    comp_vals = comp_fd.Footprints[station].values
    for j, i in enumerate(inds):
        fd = fds[i]
        vals = fd.Footprints[station].values
        line = [fd.name, vals.sum(), np.sum(np.abs(comp_vals-vals)), np.sum(np.abs(comp_vals-vals))/comp_vals.sum()]
        df.loc[len(df)] = line
        rel.append(np.sum(np.abs(comp_vals-vals))/comp_vals.sum())
    display(df)
    
fds = [fd_wu10, fd_pr10, fd_pw10]
station = 0
for ind in range(len(fds)):
    inds = np.delete(np.arange(len(fds)), ind)

    columns =["Setup","Footprint sum", f"Total difference to {fds[ind].name}", f"Relative difference to {fds[ind].name}"]
    df = pd.DataFrame(columns=columns)
    comp_fd = fds[ind]
    comp_vals = comp_fd.Footprints[station].values
    for j, i in enumerate(inds):
        fd = fds[i]
        vals = fd.Footprints[station].values
        line = [fd.name, vals.sum(), np.sum(np.abs(comp_vals-vals)), np.sum(np.abs(comp_vals-vals))/comp_vals.sum()]
        df.loc[len(df)] = line
        rel.append(np.sum(np.abs(comp_vals-vals))/comp_vals.sum())
    display(df)

fds = [fd_wu20, fd_pr20, fd_pw20]
station = 0
for ind in range(len(fds)):
    inds = np.delete(np.arange(len(fds)), ind)

    columns =["Setup","Footprint sum", f"Total difference to {fds[ind].name}", f"Relative difference to {fds[ind].name}"]
    df = pd.DataFrame(columns=columns)
    comp_fd = fds[ind]
    comp_vals = comp_fd.Footprints[station].values
    for j, i in enumerate(inds):
        fd = fds[i]
        vals = fd.Footprints[station].values
        line = [fd.name, vals.sum(), np.sum(np.abs(comp_vals-vals)), np.sum(np.abs(comp_vals-vals))/comp_vals.sum()]
        df.loc[len(df)] = line
        rel.append(np.sum(np.abs(comp_vals-vals))/comp_vals.sum())
    display(df)

In [None]:
rel = np.array(rel).reshape(3,6)

plt.figure(figsize=(10,5))
plt.plot([0.5,1,2], rel.mean(axis=1))
plt.xlabel("grid size [°]")
plt.ylabel("average relative error")
plt.title("(setup2 - setup1)/setup2")
plt.grid()

In [None]:
setups = ["Wu Setup", "Pressure Setup", "Pressure Wu Setup"]
ppm1 = [.312, .283, .318]
ppm2 = [.74, .85, .748]
ppm = [ppm1, ppm2]

df = pd.DataFrame(columns=["Setup", "station", "Enhancement (ppm)", "Relative difference to Wu Setup"])
for i, ppms in enumerate(ppm):
    for j, p in enumerate(ppms):
        line = [setups[j], i, p, (ppms[0]-p)/ppms[0]]
        df.loc[len(df)] = line
df
    

In [None]:
plt.plot(times,rel_diffs)
plt.title(f"({fd2.name} - {fd1.name})/sum of {fd2.name}")
plt.xlabel("timestep")
plt.ylabel("relative difference")

# Calculate emissions

In [None]:
# Calculate emissions from multiple files
end = '2009-12-07' 
start = '2009-11-23'

site = "Darwin"

gosat_file = '/work/bb1170/static/REMOTECv240_L2_CO2_GOSAT/REMOTEC_L2_CO2_GOSAT_'
tccon_file = "/work/bb1170/static/TCCON/data/gosat/netcdf/wg*.public.nc"
ct_file = "/work/bb1170/static/CT2019/Flux3hour_1x1/CT2019B.flux1x1."

enh_dict = dict()
for name in ["wu","pressure"]:
    enh_dict[name] = dict(values=[], number=[])
    #for num in ["01","02","03","04","05","06","07","08","09","10","20","30","40","50","60","70","80","90","100"]:
    #for num in ["001","002", "003","004","005","006","007","008","009","010","011","012","013","014","015","016","017","018","019", "020","030","040","050","060","070","080","090","100"]:
    #for num in ["001","002", "003","004","005","006","007","008","009","010","011","012","013","015","016","017","018","019", "020","030","040","050","060","080","090","100"]:
    for num in ["200","300","400", "500", "600", "700", "800", "900", "1000"]:
        fp_file = f"/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI20091123-20091207/sensitivity_part_nums/{site}/"+name+num+"k/grid_time_20091207210000.nc"
        if not os.path.exists(fp_file):
            print("skipped")
            continue
        #try:
        ds = xr.open_dataset(fp_file, chunks=dict(time=20, pointspec=4))
        enh= calc_enhancement(ds, ct_file,1013, start, end)
        enh_dict[name]["values"].append(enh)
        enh_dict[name]["number"].append(int(num))
        #except:
        #    print("exception")
        #if not os.path.exists("/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI20091123-20091207/sensitivity_part_nums/Darwin/"+name+num+"k/grid_time_20091207210000.nc"):
        print(name+num)
        #print(os.path.exists(f"/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI20091123-20091207/sensitivity_part_nums/{site}/"+name+num+"k/grid_time_20091207210000.nc"))

In [None]:
print(enh_dict)

In [None]:
#with open("data/enhancements_darwin.json", "w") as f:
#    json.dump(data, f)

In [None]:
# plot emission results with error bands
station = "wollongong"

xlim = (-5, 1005)
percentages = [0.2, 0.1]

with open(f'data/enhancements_{station}.json', 'r') as fp:
    data = json.load(fp)
    
for name in ["wu","pressure"]:
    plt.figure(figsize=(10,5))
    plt.grid()
    plt.xlabel(r"Released particles [$10^3$ particles]")
    plt.ylabel("Enhancement")
    plt.xlim(*xlim)
    plt.title(f"Effect of particle number ({station})")
    vals = np.array(data[name]["values"])
    nums = np.array(data[name]["number"])
    val_diff = vals[-1] - vals
    plt.errorbar(nums, vals, fmt="o", label=name, color="black")
    plt.hlines(vals[-1],*xlim ,color="grey",alpha=0.7, linestyle="dashed")
    plt.fill_between([*xlim], (1+percentages[0])*vals[-1], (1+percentages[1])*vals[-1] ,color="orange",alpha=0.3, linestyle="dashed", label=f"{percentages[1]*1e2}% deviation")
    plt.fill_between([*xlim], (1-percentages[1])*vals[-1], (1-percentages[0])*vals[-1] ,color="orange",alpha=0.3, linestyle="dashed")
    plt.fill_between([*xlim], (1+percentages[1])*vals[-1], (1-percentages[1])*vals[-1] ,color="red",alpha=0.3, linestyle="dashed", label=f"{percentages[0]*1e2}% deviation")
    plt.legend()
    #plt.savefig(f"/mnt/lustre01/pf/b/b381737/python_scripts/figures/sensitivity/part_num_{station}_{name}.png", dpi=300)
    plt.show()

In [None]:
# plot multiple results
station = "wollongong"
with open(f'data/enhancements_{station}.json', 'r') as fp:
    data = json.load(fp)
    
plt.figure(figsize=(10,5))
plt.grid()
plt.xlabel(r"Released particles [$10^3$ particles]")
plt.ylabel("Enhancement")
plt.title(f"Effect of particle number ({station})")
for name in ["wu","pressure"]:
    vals = np.array(data[name]["values"])
    nums = np.array(data[name]["number"])
    val_diff = vals[-1] - vals
    plt.errorbar(nums, vals, fmt="o", label=name, )
    plt.legend()
#plt.savefig(f"/mnt/lustre01/pf/b/b381737/python_scripts/figures/sensitivity/part_num_{station}_all.png", dpi=300)
plt.show()

In [None]:
# plot difference to last value
plt.figure(figsize=(10,5))
plt.grid()
plt.xlabel(r"Released particles [$10^3$ particles]")
plt.ylabel(r"Relative difference to $10^6$ measurement")
plt.title(f"Effect of particle number")

marker = ["o", "x"]
colors = ["b","r"]#["tab:blue", "tab:orange", "tab:green"]

for i, station in enumerate(["darwin", "wollongong"]):
    with open(f"data/enhancements_{station}.json", 'r') as fp:
        data = json.load(fp)
    for j, name in enumerate(["wu","pressure"]):
        vals = np.array(data[name]["values"])
        nums = np.array(data[name]["number"])
        val_diff = vals[-1] - vals
        rel_diff = val_diff/vals[-1]
        abs_diff = abs(rel_diff)
        plt.errorbar(nums, rel_diff, fmt=marker[j], color=colors[i], label=f"{name} ({station})")
plt.legend()
plt.savefig(f"/mnt/lustre01/pf/b/b381737/python_scripts/figures/sensitivity/part_num_total_1m.png", dpi=300)
plt.show()

# Effect of release height

In [None]:
# effect on sum of footprints
fd = fd_unit_w

fp_sum = []
fp = fd.Footprints[0]
arr = fd.DataArrays[0]
arr = arr.sum(dim=["time", "longitude", "latitude"]).compute()
for i in range(len(fd.DataSet.pointspec.values)):
    if i==0: continue
    x = arr.isel(pointspec=slice(0,i))
    x = x.sum(dim="pointspec").compute()
    fp_sum.append(x)
fp_arr = np.array(fp_sum)[:,0,0]
plt.figure(figsize=(10,5))
plt.plot(fd.DataSet.RELZZ2.values[1:], fp_arr)
plt.title(f"Effect of maximum height in {fd.station_names[0]} (footprint sum)")
plt.xlabel("Maximum Release height [m]")
plt.ylabel("Footprint sum [s m^3/kg]")
plt.grid()
plt.savefig(f"/mnt/lustre01/pf/b/b381737/python_scripts/figures/sensitivity/rel_height_{fd.station_names[0]}_fp.png")

In [None]:
# Calculation of enhancements with different HEIGHTS included
fd = fd_unit_d

end = '2009-12-07' 
start = '2009-11-23'

gosat_file = '/work/bb1170/static/REMOTECv240_L2_CO2_GOSAT/REMOTEC_L2_CO2_GOSAT_'
tccon_file = "/work/bb1170/static/TCCON/data/gosat/netcdf/wg*.public.nc"
ct_file = "/work/bb1170/static/CT2019/Flux3hour_1x1/CT2019B.flux1x1."
#fp_file = "/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI20091123-20091207/sensitivity_rel_height/unit/grid_time_20091207210000.nc"
fp_data = fd.DataSet

enh_list = []
height_list = []
for i in range(len(fd.DataSet.pointspec.values)):
    if i==0: continue
    if (i%3) != 0: continue
    fp = fp_data.isel(pointspec=slice(0,i), numpoint=slice(0,i))
    enh, _ = calc_emission(fp, tccon_file, ct_file, gosat_file, start, end)
    enh_list.append(enh)
    height_list.append(fp.RELZZ2.values[-1])
enh_dict = {}
enh_dict["values"] = list(np.array(enh_list).astype(float))
enh_dict["heights"] = list(np.array(height_list).astype(float))
with open(f"{fd.directory}/enhancements_heights.json", "w") as f:
    json.dump(enh_dict, f)

In [None]:
fd = fd_unit_w
with open(f'{fd.directory}/enhancements_heights.json', 'r') as fp:
    enh_dict = json.load(fp)
plt.figure(figsize=(10,5))
plt.title(f"Effect of maximum height in {fd.station_names[0]} (enhancement)")
plt.ylabel("Enhancement")
plt.xlabel("Maximal height [m]")
plt.plot(enh_dict["heights"], enh_dict["values"])
plt.grid()
plt.savefig(f"/mnt/lustre01/pf/b/b381737/python_scripts/figures/sensitivity/rel_height_{fd.station_names[0]}_enh.png")

# Effect of run time

In [None]:
# Calculation of enhancements with different TIMES included
fd = fd_unit_d_long_shift

end = '2009-12-07' 
start = '2009-11-16'

gosat_file = '/work/bb1170/static/REMOTECv240_L2_CO2_GOSAT/REMOTEC_L2_CO2_GOSAT_'
tccon_file = "/work/bb1170/static/TCCON/data/gosat/netcdf/wg*.public.nc"
ct_file = "/work/bb1170/static/CT2019/Flux3hour_1x1/CT2019B.flux1x1."
#fp_file = "/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI20091123-20091207/sensitivity_rel_height/unit/grid_time_20091207210000.nc"
fp_data = fd.DataSet

enh_list = []
time_list = []
for i in range(len(fd.DataSet.time.values)):
    if i==0: continue
    if (i%8) != 0: continue
    fp = fp_data.isel(time=slice(0, i))
    enh, _ = calc_emission(fp, tccon_file, ct_file, gosat_file, start, end)
    enh_list.append(enh)
    time_list.append(fp.time.values[-1])
enh_dict = {}
enh_dict["values"] = list(np.array(enh_list).astype(float))
enh_dict["times"] = list(np.array(time_list))
with open(f"{fd.directory}/enhancements_times.json", "w") as f:
    json.dump(enh_dict, f, default=str)

In [None]:
enh_dict

In [None]:
for fd in [fd_unit_w_long_shift, fd_unit_w_long]:
    fd = fd
    with open(f'{fd.directory}/enhancements_times.json', 'r') as fp:
        enh_dict = json.load(fp)
        enh_dict["times"] = np.array(enh_dict["times"]).astype("datetime64")
    plt.figure(figsize=(10,5))
    plt.title(f"Effect of run duration in {fd.station_names[0]} (enhancement)")
    plt.ylabel("Enhancement")
    plt.xlabel("Run time")
    plt.gca().invert_xaxis()
    plt.plot(enh_dict["times"], enh_dict["values"])
    plt.grid()
    #plt.savefig(f"/mnt/lustre01/pf/b/b381737/python_scripts/figures/sensitivity/run_time_{fd.station_names[0]}.png")

In [None]:
fig, ax = plt.subplots(1,1,figsize=(7,4))
plt.title(f"Enhamcement comparison 1 hour shift (Darwin)")
plt.ylabel("Enhancement")
plt.xlabel("Run time")
plt.gca().invert_xaxis()

plt.grid()

for fd in [fd_unit_d_long_shift, fd_unit_d_long]:
    fd = fd
    with open(f'{fd.directory}/enhancements_times.json', 'r') as fp:
        enh_dict = json.load(fp)
        enh_dict["times"] = np.array(enh_dict["times"]).astype("datetime64")
        plt.plot(enh_dict["times"], enh_dict["values"], label=fd.name)
        ax.xaxis.set_major_locator(MaxNLocator(6))
plt.legend()
plt.savefig(f"/mnt/lustre01/pf/b/b381737/python_scripts/figures/sensitivity/rel_time_40k_{fd.station_names[0]}.png", dpi=500)

In [None]:
#.set_major_locator(MaxNLocator(6))


# Carbon Tracker plot

In [None]:
end = '2009-12-07' 
start = '2009-12-01'

enddate = datetime.date(int(end[0:4]),int(end[5:7]),int(end[8:10]))
startdate = datetime.date(int(start[0:4]),int(start[5:7]),int(start[8:10]))

ct_file = "/work/bb1170/static/CT2019/Flux3hour_1x1/CT2019B.flux1x1."
first = True
for date in pd.date_range(startdate, enddate):
    fileCT = ct_file+str(date.year)+str(date.month).zfill(2)+str(date.day).zfill(2)+'.nc'
    DSCTfluxday = xr.open_mfdataset(fileCT, combine='by_coords',drop_variables = 'time_components', chunks="auto")
    if first:
        first = False
        DSCTflux = DSCTfluxday
    else:
        DSCTflux = xr.concat([DSCTflux,DSCTfluxday],dim = 'time')

#calculate Satelite CO2 enhancement

#sum flux components: 
DSCT_totalflux = DSCTflux.bio_flux_opt+DSCTflux.ocn_flux_opt+DSCTflux.fossil_flux_imp+DSCTflux.fire_flux_imp
DSCT_totalflux.name = 'total_flux'

#can be deleted when footprint has -179.5 coordinate
DSCT_totalflux = DSCT_totalflux[:,:,1:]
DSCT_totalflux = DSCT_totalflux.where(DSCT_totalflux!=0)
print(DSCT_totalflux.time.values[[0,-1]])
flux = DSCT_totalflux.sum(dim="time").compute()
fd.extent = [-180,180,-80,40]
fig, ax = fd.subplots(figsize=(12,4))

fd_unit_w.add_map(ax)
#ax.set_title(f"Footprint with {fd.name}")
#plt.legend(loc="upper right")
flux.plot(ax=ax, vmin=-0.00005,vmax=0.00005, cmap="bwr")#, cbar_kwargs=dict(orientation="horizontal"))
plt.title("Coarbon tracker fluxes")
plt.show()
# Plot one footprint 
fd = fd_unit_w
fd.extent = [-180,180,-80,40]

inds = [0, 8*6+5]
for i in [0]:
    station = i
    fig, ax = fd.subplots(figsize=(12,4))
    fd.plot(ax,time=np.arange(*inds), pointspec=np.arange(len(fd.DataSet.pointspec)), station=station, plot_station=True, cmap="jet",station_kwargs=dict(color="black", label=fd.station_names[station]), cbar_kwargs=dict(label="s*m^3/kg"))
    fd.add_map(ax)
    ax.set_title(f"Footprint with {fd.name}")
    plt.legend(loc="upper right")
    plt.show()

In [None]:
end = '2009-12-01' 
start = '2009-11-24'

enddate = datetime.date(int(end[0:4]),int(end[5:7]),int(end[8:10]))
startdate = datetime.date(int(start[0:4]),int(start[5:7]),int(start[8:10]))

ct_file = "/work/bb1170/static/CT2019/Flux3hour_1x1/CT2019B.flux1x1."
first = True
for date in pd.date_range(startdate, enddate):
    fileCT = ct_file+str(date.year)+str(date.month).zfill(2)+str(date.day).zfill(2)+'.nc'
    DSCTfluxday = xr.open_mfdataset(fileCT, combine='by_coords',drop_variables = 'time_components', chunks="auto")
    if first:
        first = False
        DSCTflux = DSCTfluxday
    else:
        DSCTflux = xr.concat([DSCTflux,DSCTfluxday],dim = 'time')

#calculate Satelite CO2 enhancement

#sum flux components: 
DSCT_totalflux = DSCTflux.bio_flux_opt+DSCTflux.ocn_flux_opt+DSCTflux.fossil_flux_imp+DSCTflux.fire_flux_imp
DSCT_totalflux.name = 'total_flux'

#can be deleted when footprint has -179.5 coordinate
DSCT_totalflux = DSCT_totalflux[:,:,1:]
DSCT_totalflux = DSCT_totalflux.where(DSCT_totalflux!=0)
print(DSCT_totalflux.time.values[[0,-1]])
flux = DSCT_totalflux.sum(dim="time").compute()
fd.extent = [-180,180,-80,40]
fig, ax = fd.subplots(figsize=(12,4))

fd_unit_w.add_map(ax)
#ax.set_title(f"Footprint with {fd.name}")
#plt.legend(loc="upper right")
flux.plot(ax=ax, vmin=-0.00005,vmax=0.00005, cmap="bwr")#, cbar_kwargs=dict(orientation="horizontal"))
plt.title("Coarbon tracker fluxes")

# Plot one footprint 
fd = fd_unit_w
fd.extent = [-180,180,-80,40]

inds = [8*6-1, 9*13-8]
for i in [0]:
    station = i
    fig, ax = fd.subplots(figsize=(12,4))
    fd.plot(ax,time=np.arange(*inds), pointspec=np.arange(len(fd.DataSet.pointspec)), station=station, plot_station=True, cmap="jet",station_kwargs=dict(color="black", label=fd.station_names[station]), cbar_kwargs=dict(label="s*m^3/kg"))
    print(fd.DataSet.time.values[inds][::-1])
    fd.add_map(ax)
    ax.set_title(f"Footprint with {fd.name}")
    plt.legend(loc="upper right")
    plt.show()

In [None]:
# Plot one footprint 
fd = fd_unit_w
fd.extent = [-180,180,-80,40]

inds = [8*6+1, 9*13-10]
for i in [0]:
    station = i
    fig, ax = fd.subplots(figsize=(12,4))
    fd.plot(ax,time=np.arange(*inds), pointspec=np.arange(len(fd.DataSet.pointspec)), station=station, plot_station=True, cmap="jet",station_kwargs=dict(color="black", label=fd.station_names[station]), cbar_kwargs=dict(label="s*m^3/kg"))
    print(fd.DataSet.time.values[inds])
    fd.add_map(ax)
    ax.set_title(f"Footprint with {fd.name}")
    plt.legend(loc="upper right")
    plt.show()

# Animation time series

In [None]:
class Animation():
    def __init__(self, fd):
        self.fig = plt.figure(figsize=(20,5))
        self.ax0 = self.fig.add_subplot(121, projection=ccrs.PlateCarree())
        self.ax1 = self.fig.add_subplot(122)
        print(self.fig.axes)
        self.fd = fd
        self.fd.extent = [-120,180,-80,40]
        with open(f'{fd.directory}/enhancements_times.json', 'r') as fp:
            enh_dict = json.load(fp)
            enh_dict["times"] = np.array(enh_dict["times"]).astype("datetime64")
        self.enh_dict = enh_dict
        #self.line, = self.fig.axes[1].plot(enh_dict["times"], enh_dict["values"], color="black")
        
    def index_wrapper(self, index):
        inds = [index*8, (index+1)*8]
        times = [self.fd.DataSet.time[inds[0]], self.fd.DataSet.time[inds[1]]]
        return inds, times
    
    def init(self):
        index = 0
        station = 0
        dtype = "datetime64[h]"
        vmax=1e2
        vmin=1e-3
        # get new values for frame
        inds, times = self.index_wrapper(index)
        # set new values in the plot
        ax0 = self.fig.axes[0]
        ax1 = self.fig.axes[1]
        ax0.clear()
        ax1.clear()
        fd.plot(ax0, station=station, time=np.arange(*inds), 
                pointspec=np.arange(len(self.fd.DataSet.pointspec)), 
                plot_station=True, cmap="jet", vmin=vmin, vmax=vmax, 
                station_kwargs=dict(color="black", label=self.fd.station_names[station]), 
                cbar_kwargs=dict(label="s*m^3/kg", orientation="horizontal"))
        fd.add_map(ax0)
        ax0.set_title(f"Footprint {self.fd.DataSet.time[inds[0]].values.astype(dtype)} - {self.fd.DataSet.time[inds[1]].values.astype(dtype)}")
        
        ax1.set_title(f"Effect of run duration in {self.fd.station_names[0]} (enhancement)")
        ax1.set_ylabel("Enhancement")
        ax1.set_xlabel("Maximal height [m]")
        ax1.plot(self.enh_dict["times"], self.enh_dict["values"], color="black")
        ax1.set_xlim(ax1.get_xlim()[::-1])
        ax1.fill_betweenx([min(self.enh_dict["values"]), max(self.enh_dict["values"])], times[0], times[1], color="r", alpha=0.3)
        ax1.grid()
        return
    
    def animate(self, index):
        station = 0
        dtype = "datetime64[h]"
        vmax=1e2
        vmin=1e-3
        # get new values for frame
        inds, times = self.index_wrapper(index)
        # set new values in the plot
        ax0 = self.fig.axes[0]
        ax1 = self.fig.axes[1]
        ax0.clear()
        ax1.clear()
        fd.plot(ax0, station=station, time=np.arange(*inds), 
                pointspec=np.arange(len(self.fd.DataSet.pointspec)), 
                plot_station=True, cmap="jet", vmin=vmin, vmax=vmax, 
                station_kwargs=dict(color="black", label=self.fd.station_names[station]), 
                add_colorbar=False)
        fd.add_map(ax0)
        ax0.set_title(f"Footprint {self.fd.DataSet.time[inds[0]].values.astype(dtype)} - {self.fd.DataSet.time[inds[1]].values.astype(dtype)}")
        
        ax1.set_title(f"Effect of run duration in {self.fd.station_names[0]} (enhancement)")
        ax1.set_ylabel("Enhancement")
        ax1.set_xlabel("Time")
        ax1.plot(self.enh_dict["times"], self.enh_dict["values"], color="black")
        ax1.set_xlim(ax1.get_xlim()[::-1])
        ax1.fill_betweenx([min(self.enh_dict["values"]), max(self.enh_dict["values"])], times[0], times[1], color="r", alpha=0.3)
        ax1.grid()
        return
    
    def run(self):
        self.ani = animation.FuncAnimation(self.fig, self.animate, 13,  init_func=self.init)
        # ffmpeg is better but not available by default on the dkrz
    def save(self):
        writergif = animation.PillowWriter(fps=2) #can also set dpi
        self.ani.save("/mnt/lustre01/pf/b/b381737/python_scripts/animations/test.gif", writer=writergif, dpi=300)


In [None]:
ani = Animation(fd_unit_w)
ani.run()
ani.save()

# Assign Country to data and choose

In [None]:
mass = data.spec001_mr[dict(pointspec=[0,1,2])].sum(dim=["time", "pointspec"])[0]
gdf = xr_to_gdf(mass, "spec001_mr")
gdf = add_country_names(gdf)
#gdf.to_pickle("/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI20121123-20121207/gdf.pkl")
ci = country_inersections(gdf, "Australia")
#intersections = country_inersections(mass, "Australia")

In [None]:
gdf.to_pickle("/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI20121123-20121207/gdf_woll_12.pkl")

In [None]:
columns=["Station","Year","Total [s]", "Australia [%]", "Other countries [%]", "Ocean [%]"]
df = pd.DataFrame(columns=columns)
direc="/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI20121123-20121207/"
l = []
for y in ["09", "12"]:
    for stat in ["woll", "dar"]:
    
        l.append("Wollongong" if stat == "woll" else "Darwin")
        l.append("20"+y)
        path = f"gdf_{stat}_{y}.pkl"
        print(path)
        gdf = pd.read_pickle(direc+path)
        ret = country_intersections(gdf, "Australia")
        tot = gdf.spec001_mr.sum()
        l.append(round(tot, 0))
        for key in ret.keys():
            if key == "rest":
                continue
            l.append(round(ret[key].spec001_mr.sum()/tot*100, 2))
        df = df.append(dict(zip(columns,l)), ignore_index=True)
        l = []

# Animation

In [None]:
#for vmin
mass = data.spec001_mr.where(data.spec001_mr!=0)
#height index, vmin, vmax for cbar
h_ind = 0
vmin = np.min(mass.isel(pointspec=[0])).values
vmax = np.max(mass.isel(pointspec=[0])).values
print(vmin)
print(vmax)

#box to be shown TODO: prevent error if not standard projection
extent=[0,180,-85,-0]

cmap = "jet"
norm = colors.LogNorm()

fig, ax = plt.subplots(1, 1, figsize=(10,5), subplot_kw=dict(projection=ccrs.PlateCarree()))

#total number of frames
frames=len(data.time.values)
#milliseconds between frames
interval=100

def index_mod(index):
    return index

def add_to_ax(ax):
    ax.scatter(114.1742, 22.3025,s=200, label="hko", color="yellow", edgecolor="black", marker="*")
    #ax.scatter(16.1469, 49.0845, s=100, label="Power plant", color="red", marker="^")
    return ax

def animate(index):
    ind = index_mod(index)
    ax.clear()
    mass = data.spec001_mr.isel(time=[ind], height=[h_ind], pointspec=[0])
    mass = mass.where(mass!=0)
    if len(fig.axes)==1:
        mass.plot(ax=ax, cmap=cmap, vmin=vmin, vmax=vmax, 
                  norm=norm, 
                  transform=ccrs.PlateCarree(), 
                  cbar_kwargs=dict(label="Footprint [s]"))
    else:
        mass.plot(ax=ax, cmap=cmap, vmin=vmin, vmax=vmax, 
                  norm=norm, 
                  transform=ccrs.PlateCarree(),
                  add_colorbar=False)
    ax.set_title(data.time.values[ind].astype('datetime64[h]'))
    add_map(ax, extent=extent)
    add_to_ax(ax)

ani = animation.FuncAnimation(fig, animate, frames, interval=interval, blit=False)
ani.save("/mnt/lustre01/pf/b/b381737/python_scripts/master/open_data/animations/australia/wu.gif")

# Compare two datasets

In [None]:
#for vmin
file_path = "/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/OD20180905/grid_conc_20180905000000.nc"
data1 = xr.open_dataset(file_path)
mass1 = data1.spec001_mr
file_path = "/work/bb1170/RUN/b381737/software/flexpart_v10.4_3d7eebf/output/EI20180905/grid_conc_20180905000000.nc"
data2 = xr.open_dataset(file_path)
mass2 = data2.spec001_mr
mass2 = select_extent(mass2, 5, 25, 40, 55)

mass1 = mass1.where(mass1!=0)
mass2 = mass2.where(mass2!=0)



#height index, vmin, vmax for cbar
h_ind = 0
vmin = min([np.min(mass1.isel(pointspec=[0])).values, np.min(mass2.isel(pointspec=[0])).values])
vmax = max([np.max(mass1.isel(pointspec=[0])).values, np.max(mass2.isel(pointspec=[0])).values])
print(vmin)
print(vmax)

#box to be shown TODO: prevent error if not standard projection
extent=[15, 19, 46, 50]

cmaps = ["Blues", "Reds"]
labels = ["Footprint [s] (IFS)", "Footprint [s] (Interim)"]#["Footprint [s] (Interim)", "Footprint [s] (IFS)"]
norm = colors.LogNorm()

fig, ax = plt.subplots(1, 1, figsize=(10,5), subplot_kw=dict(projection=ccrs.PlateCarree()))

#total number of frames
frames=len(data1.time.values)
#milliseconds between frames
interval=100

def index_mod(index):
    return index

def add_to_ax(ax):
    #ax.scatter(114.1742, 22.3025,s=200, label="hko", color="yellow", edgecolor="black", marker="*")
    #ax.scatter(16.1469, 49.0845, s=100, label="Power plant", color="red", marker="^")
    return ax

def animate(index):
    ind = index_mod(index)
    ax.clear()
    for i, mass in enumerate([mass1,mass2]):
        mass = mass.isel(dict(time=ind, height=0))
        mass = mass.where(mass!=0)
        if len(fig.axes)==(i+1):
            mass.plot(ax=ax, cmap=cmaps[i], vmin=vmin, vmax=vmax, 
                      norm=norm, 
                      transform=ccrs.PlateCarree(), 
                      cbar_kwargs=dict(label=labels[i]))
        else:
            mass.plot(ax=ax, cmap=cmaps[i], vmin=vmin, vmax=vmax, 
                      norm=norm, 
                      transform=ccrs.PlateCarree(),
                      add_colorbar=False)
    ax.set_title(data1.time.values[ind].astype('datetime64[h]'))
    add_map(ax, extent=extent)
    add_to_ax(ax)

ani = animation.FuncAnimation(fig, animate, frames, interval=interval, blit=False)
ani.save("/mnt/lustre01/pf/b/b381737/python_scripts/master/open_data/animations/tutorial/vienna_comparision2.gif")