In [None]:
import xarray as xr
from pkg_resources import resource_filename
from sklearn.decomposition import PCA
from DoWnGAN.losses import content_MSELoss, content_loss, SSIM_Loss
from DoWnGAN.losses import content_loss, content_MSELoss, SSIM_Loss

from DoWnGAN.dataloader import xr_standardize_field
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error
from DoWnGAN.prep_gan import xr_standardize_field, dt_index, filter_times
from torch.multiprocessing import Pool, Process, set_start_method
import torch.multiprocessing as mp

import multiprocessing

import pandas as pd
import glob
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress
from matplotlib import colors
import os
import torch.nn as nn


import mlflow
import pandas as pd
import torch
torch.cuda.empty_cache()

from tqdm.notebook import tqdm_notebook

plt.style.use(['science','ieee','no-latex'])
experiment_number = 0
torch.cuda.empty_cache()
device = torch.device('cuda:0')

from dask.distributed import Client
print("Cores:", multiprocessing.cpu_count())
client = Client(n_workers = multiprocessing.cpu_count(), memory_limit="6GB")

assert torch.cuda.is_available()

runs = {"florida":
        {
#             "hash_code_CNN_L1": 'e25c6b40324643c3afc1cf42981b11b5',
#             "hash_code_CNN_MSSSIM": 'd7d73e34e60c46ff96c8a49e5e9e973b',
            "hash_code_13x13": 'e1d15a0615ca489aa6a17ec60247d0af',
            "hash_code_9x9": '3f48868c52404eb0a833897aa4642871',
            "hash_code_5x5": '1824682ae27c48669665cf042052d584',
            "hash_code_nf": 'feda42500d2b45549be96f1bf62b0b03'
        },
       "central":
        {
#             "hash_code_CNN_L1": 'fbe44b0423204805bc6af4d7d6ac562e',
            "hash_code_13x13": 'bcf7e7cfa8ab4c4196ad6a2bb18e8601',
            "hash_code_9x9": '079a94c41ad3482996cc2b9f95adba8d',
            "hash_code_5x5": '202ea9f8a73b401fa22e62c24d9ab2d0',
            "hash_code_nf": '0c5ee480663f4f9eb7200f8879aa1244'
        },
        "west":
        {
#             "hash_code_CNN_L1": 'f76c0170818244629de4544805f93a59',
            "hash_code_13x13": 'c4ec13e65fe74b399fc9e325a9966fef',
            "hash_code_9x9": '6abe7a9940c04b47819689070100e5e6',
            "hash_code_5x5": '70f5be887eff42e8a216780752644b2f',
            "hash_code_nf": 'db9f0fae83c949eaad5d1176a43dae47'
        },
        
}
filters = {
    "hash_code_13x13": nn.AvgPool2d(13, stride=1, padding=0),
    "hash_code_9x9": nn.AvgPool2d(9, stride=1, padding=0),
    "hash_code_5x5": nn.AvgPool2d(5, stride=1, padding=0),
    "hash_code_nf": lambda x: x
}

padding = {
    "hash_code_13x13": nn.ReplicationPad2d(6),
    "hash_code_9x9": nn.ReplicationPad2d(4),
    "hash_code_5x5": nn.ReplicationPad2d(2),
    "hash_code_nf": lambda x: x
}


region = "central"
set = "train"

def standardize_field(x, m, s):
    return (x-m)/s

with torch.no_grad():
    for set in ["validation", "train"]:
        for region in runs.keys():
            print(region)
            sf = 8
            region_area = {
                "florida": (4, 20, 70, 86),
                "central": (30, 46, 50, 66),
                "west": (30, 46, 15, 31)
            }

            hash_list = [runs[region][key] for key in runs[region].keys()]

            low, up, l, r = region_area[region]
            fine_paths = {
                "U": resource_filename("DoWnGAN", "data/wrf/U10_regrid_16/regrid_16_6hrly_wrf2d_d01_ctrl_U10*.nc"),
                "V": resource_filename("DoWnGAN", "data/wrf/V10_regrid_16/regrid_16_6hrly_wrf2d_d01_ctrl_V10*.nc"),
            }

            u10 = xr.open_mfdataset(glob.glob(fine_paths["U"])).U10
            v10 = xr.open_mfdataset(glob.glob(fine_paths["V"])).V10


            # Extract times in datetime format
            times = dt_index(u10.Times)

            # Apply filter to times for months you'd like
            time_mask = filter_times(times, mask_years=[2000, 2006, 2010])

            if set == "validation":
                time_mask = ~time_mask.copy()
    #             time_mask = ~time_mask
            time_mask[0] = False


            path = f"ground_truth/{region}_{set}_fine_written.nc"
            path_train = f"ground_truth/{region}_train_fine_written.nc"
            if os.path.exists(path) and os.path.exists(path_train):
                dsgt = xr.open_dataset(f"ground_truth/{region}_{set}_fine_written.nc")
                u10_patch = dsgt["u10"]
                v10_patch = dsgt["v10"]
                dsgt_train = xr.open_dataset(f"ground_truth/{region}_train_fine_written.nc")
            else:
                raise ValueError("No processed ground truth data found!")

            mu = float(dsgt_train.u10.mean())
            mv = float(dsgt_train.v10.mean())
            su = float(dsgt_train.u10.std())
            sv = float(dsgt_train.v10.std())

            print("U10 Mean, std", mu, su)
            print("v10 Mean, std", mv, sv)


            dsgt = xr.Dataset()
            u10_patch = standardize_field(u10_patch, mu, su)
            v10_patch = standardize_field(v10_patch, mv, sv)
            dsgt["u10"] = u10_patch
            dsgt["v10"] = v10_patch

            print("U10 Mean, std", dsgt["u10"].mean(), dsgt["u10"].std())
            print("v10 Mean, std", dsgt["v10"].mean(), dsgt["v10"].std())

            coarse_paths = {
                "UV": resource_filename("DoWnGAN", "./data/interim_2000-10-01_to_2013-09-30.nc")
            }

            # Load ERA Interim
            coarse = xr.open_dataset(coarse_paths["UV"], engine="scipy").astype("float")
            # Organize lats in increasing order:
            coarse = coarse.sortby("latitude", ascending=True).rename({"longitude":"lon", "latitude":"lat"})

            if set == "validation":
                coarse_u10 = coarse.u10[time_mask, low:up, l:r]
                coarse_v10 = coarse.v10[time_mask, low:up, l:r]
                coarse_u10_train = coarse.u10[~time_mask, low:up, l:r]
                coarse_v10_train = coarse.v10[~time_mask, low:up, l:r]
                cmu, cmv = coarse_u10_train.mean(), coarse_v10_train.mean()
                csu, csv = coarse_u10_train.std(), coarse_v10_train.std()
                coarse_u10 = standardize_field(coarse_u10, cmu, csu)
                coarse_v10 = standardize_field(coarse_v10, cmv, csv)
            else:  
                coarse_u10 = coarse.u10[time_mask, low:up, l:r]
                coarse_v10 = coarse.v10[time_mask, low:up, l:r]

                coarse_u10 = xr_standardize_field(coarse_u10)#.chunk({"time": 250})
                coarse_v10 = xr_standardize_field(coarse_v10)#.chunk({"time": 250})

            randoms = np.random.randint(0, dsgt.Times.shape[0], dsgt.Times.shape[0])

            fine_t = torch.stack([
                torch.from_numpy(dsgt.u10[randoms, ...].values).to("cpu"),
                torch.from_numpy(dsgt.v10[randoms, ...].values).to("cpu")
            ], dim=1)


            train_coarse_sp = xr.open_dataset(f"/home/nannau/msc/netcdf_regions_gt/organized/covariates/train/{region}_surface_pressure.nc")
            train_coarse_mask = xr.open_dataset(f"/home/nannau/msc/netcdf_regions_gt/organized/covariates/train/{region}_land_sea_mask.nc")
            train_coarse_sf = xr.open_dataset(f"/home/nannau/msc/netcdf_regions_gt/organized/covariates/train/{region}_surface_friction.nc")
            train_coarse_cape = xr.open_dataset(f"/home/nannau/msc/netcdf_regions_gt/organized/covariates/train/{region}_cape.nc")
            train_coarse_geo = xr.open_dataset(f"/home/nannau/msc/netcdf_regions_gt/organized/covariates/train/{region}_geopotential_height.nc")
            train_cov_list = [coarse_u10, coarse_v10, train_coarse_sp.to_array()[0, ...], train_coarse_sf.to_array()[0, ...], train_coarse_geo.to_array()[0, ...], train_coarse_cape.to_array()[0, ...]]

            means = [x.mean() for x in train_cov_list]
            std = [x.std() for x in train_cov_list]

            # THE PROBLEM HAS TO DO WITH THE NORMALIZATION OF THESE COVARIATES!!!!!!!!!!!!!!!!! :O :O :O :O :O 
            coarse_sp = xr.open_dataset(f"/home/nannau/msc/netcdf_regions_gt/organized/covariates/{set}/{region}_surface_pressure.nc")
            coarse_mask = xr.open_dataset(f"/home/nannau/msc/netcdf_regions_gt/organized/covariates/{set}/{region}_land_sea_mask.nc")
            coarse_sf = xr.open_dataset(f"/home/nannau/msc/netcdf_regions_gt/organized/covariates/{set}/{region}_surface_friction.nc")
            coarse_cape = xr.open_dataset(f"/home/nannau/msc/netcdf_regions_gt/organized/covariates/{set}/{region}_cape.nc")
            coarse_geo = xr.open_dataset(f"/home/nannau/msc/netcdf_regions_gt/organized/covariates/{set}/{region}_geopotential_height.nc")

            cov_list = [coarse_u10, coarse_v10, coarse_sp.to_array()[0, ...], coarse_sf.to_array()[0, ...], coarse_geo.to_array()[0, ...], coarse_cape.to_array()[0, ...]]

            # Do not add in this extra normalization step
            cov_list = [standardize_field(x, means[i], std[i]) for i, x in enumerate(cov_list)]
            torch_list = [torch.from_numpy(ds.values) for ds in cov_list]

            torch_list.insert(2, torch.from_numpy(coarse_mask.to_array()[0, ...].values))
            coarse_t_loaded = torch.stack(torch_list, dim=1).float()[randoms, ...]

            aranged = torch.split(torch.arange(0, coarse_t_loaded.size(0)), int(coarse_t_loaded.size(0)/(64)), dim=0)
            print("Chunk size: ", aranged[0].size())


            print("Coarse: ", coarse_t_loaded.size())
            print("Fine: ", fine_t.size())

            complete_hash_dict = {}
    #         for hc in hash_list:
            for f in runs[region].keys():
                hc = runs[region][f]
                long_mae = []
                long_mse = []
                long_wass =[]
                long_msssim = []
                hash_dict = {}
                for i in tqdm_notebook(range(1000)):
                    if i % 1 == 0:
                        logged_model_g = f'/home/nannau/msc/Fall_2021/DoWnGAN/DoWnGAN/mlflow_experiments/{experiment_number}/{hc}/artifacts/Generator/Generator_{i}/'
                        G = mlflow.pytorch.load_model(logged_model_g)
                        state_dict = mlflow.pytorch.load_state_dict(logged_model_g)
                        G.load_state_dict(state_dict)

                        logged_model_c = f'/home/nannau/msc/Fall_2021/DoWnGAN/DoWnGAN/mlflow_experiments/{experiment_number}/{hc}/artifacts/Critic/Critic_{i}/'
                        C = mlflow.pytorch.load_model(logged_model_c)
                        state_dict = mlflow.pytorch.load_state_dict(logged_model_c)
                        C.load_state_dict(state_dict)

                        
                        lmae = []
                        lmse = []
                        lmsssim = []
                        lwass = []
            #             for fchunk, cchunk in zip(torch.split(fine_t, N_chunks, dim=0), torch.split(coarse_t_loaded, N_chunks, dim=0)):
                        for r in aranged:
            #             for _ in range(100):
            #                 r = np.random.randint(0, coarse_t_loaded.size(0), bsize)
            #                 Y = G(cchunk.to(device).float())
                            Y = G(coarse_t_loaded[r, ...].to(device).float())
                            X = fine_t[r, ...].to(device)

                            if f != "hash_code_nf":
                                C_real = torch.mean(C(X - filters[f](padding[f](X)).to(device)))
                                C_fake = torch.mean(C(Y - filters[f](padding[f](Y)).to(device)))
                                wass = C_real - C_fake

                            else:
                                C_real = torch.mean(C(X))
                                C_fake = torch.mean(C(Y))
                                wass = C_real - C_fake
                            lwass.append(wass.detach().cpu())

                            Y[:, 0, ...] = Y[:, 0, ...]*su + mu
                            Y[:, 1, ...] = Y[:, 1, ...]*sv + mv

                            X[:, 0, ...] = X[:, 0, ...]*su + mu
                            X[:, 1, ...] = X[:, 1, ...]*sv + mv

            #                 assert Y.size() == fchunk.size()
                            mae = content_loss(
                                X.to(device),
            #                     fchunk.to(device),
                                Y.to(device),
                                device
                            ).item()
                            lmae.append(mae)

                            mse = content_MSELoss(
                                X.to(device),
            #                     fchunk.to(device),
                                Y.to(device),
                                device
                            ).item()
                            lmse.append(mse)

                            msssim = SSIM_Loss(
                                X.to(device),
            #                     fchunk.to(device),
                                Y.to(device),
                                device
                            ).item()
                            lmsssim.append(msssim)
                        del G
                        del state_dict
                        long_mae.append((np.min(lmae), np.mean(lmae), np.max(lmae)))
                        long_mse.append((np.min(lmse), np.mean(lmse), np.max(lmse)))
                        long_msssim.append((np.min(lmsssim), np.mean(lmsssim), np.max(lmsssim)))
                        long_wass.append((np.min(lwass), np.mean(lwass), np.max(lwass)))


                hash_dict["mae"] = long_mae
                hash_dict["mse"] = long_mse
                hash_dict["msssim"] = long_msssim
                hash_dict["wasserstein"] = long_wass

                complete_hash_dict[hc] = hash_dict

            df = pd.DataFrame()
            metric = ["mae", "mse", "msssim", "wasserstein"]
            for m in metric:
                for key in complete_hash_dict.keys():
                    df[f"{key}_{m}_min"] = [t[0] for t in complete_hash_dict[key][m]]
                    df[f"{key}_{m}_mean"] = [t[1] for t in complete_hash_dict[key][m]]
                    df[f"{key}_{m}_max"] = [t[2] for t in complete_hash_dict[key][m]]

            # for key in complete_hash_dict.keys():
            #     print(key)
            df.to_csv(f"{region}_{set}_offline_metrics.csv")
            del fine_t
            del coarse_t_loaded
            torch.cuda.empty_cache()


Cores: 16
florida
Masking these years: [2000, 2006, 2010]
U10 Mean, std -1.7083028554916382 4.247805118560791
v10 Mean, std -0.440981388092041 4.205788612365723
U10 Mean, std <xarray.DataArray 'u10' ()>
array(0.08260679, dtype=float32) <xarray.DataArray 'u10' ()>
array(1.0378276, dtype=float32)
v10 Mean, std <xarray.DataArray 'v10' ()>
array(-0.05707378, dtype=float32) <xarray.DataArray 'v10' ()>
array(1.0098454, dtype=float32)
Chunk size:  torch.Size([51])
Coarse:  torch.Size([3287, 7, 16, 16])
Fine:  torch.Size([3287, 2, 128, 128])


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

central
Masking these years: [2000, 2006, 2010]
U10 Mean, std 0.505122721195221 2.814021110534668
v10 Mean, std 0.07797596603631973 3.2136151790618896
U10 Mean, std <xarray.DataArray 'u10' ()>
array(-0.01178362, dtype=float32) <xarray.DataArray 'u10' ()>
array(1.009194, dtype=float32)
v10 Mean, std <xarray.DataArray 'v10' ()>
array(-0.0005574, dtype=float32) <xarray.DataArray 'v10' ()>
array(0.9753815, dtype=float32)
Chunk size:  torch.Size([51])
Coarse:  torch.Size([3287, 7, 16, 16])
Fine:  torch.Size([3287, 2, 128, 128])


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

west
Masking these years: [2000, 2006, 2010]
U10 Mean, std 0.9174613952636719 3.011970281600952
v10 Mean, std -0.1184549406170845 3.958988666534424
U10 Mean, std <xarray.DataArray 'u10' ()>
array(-0.0471669, dtype=float32) <xarray.DataArray 'u10' ()>
array(1.029438, dtype=float32)
v10 Mean, std <xarray.DataArray 'v10' ()>
array(0.06888004, dtype=float32) <xarray.DataArray 'v10' ()>
array(1.0693603, dtype=float32)
Chunk size:  torch.Size([51])
Coarse:  torch.Size([3287, 7, 16, 16])
Fine:  torch.Size([3287, 2, 128, 128])


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

florida
Masking these years: [2000, 2006, 2010]
U10 Mean, std -1.7083028554916382 4.247805118560791
v10 Mean, std -0.440981388092041 4.205788612365723
U10 Mean, std <xarray.DataArray 'u10' ()>
array(7.249134e-08, dtype=float32) <xarray.DataArray 'u10' ()>
array(1.0000006, dtype=float32)
v10 Mean, std <xarray.DataArray 'v10' ()>
array(-2.3660812e-07, dtype=float32) <xarray.DataArray 'v10' ()>
array(1.0000015, dtype=float32)
Chunk size:  torch.Size([245])
Coarse:  torch.Size([15704, 7, 16, 16])
Fine:  torch.Size([15704, 2, 128, 128])


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

central
Masking these years: [2000, 2006, 2010]
U10 Mean, std 0.505122721195221 2.814021110534668
v10 Mean, std 0.07797596603631973 3.2136151790618896
U10 Mean, std <xarray.DataArray 'u10' ()>
array(1.2905484e-07, dtype=float32) <xarray.DataArray 'u10' ()>
array(0.99999875, dtype=float32)
v10 Mean, std <xarray.DataArray 'v10' ()>
array(1.6700231e-09, dtype=float32) <xarray.DataArray 'v10' ()>
array(0.99999726, dtype=float32)
Chunk size:  torch.Size([245])
Coarse:  torch.Size([15704, 7, 16, 16])
Fine:  torch.Size([15704, 2, 128, 128])


  0%|          | 0/1000 [00:00<?, ?it/s]