# Code in pytorch:

In [None]:
# Basics
from matplotlib import pyplot as plt
import matplotlib.path as mpath
import os
import sys
from os import listdir
from os.path import isfile, join
import numpy as np
import pandas as pd
from datetime import datetime
from datetime import timedelta
from tqdm import tqdm
from re import search
from math import cos,sin,pi
import random as rn

# xarray and cartopy plots
import xarray as xr
#import cf_units

import cartopy
import cartopy.crs as ccrs
import pyproj
from pyproj import Transformer
#import rasterio

# Google file system
#import gcsfs
#from google.cloud import storage

# ML
from scipy import ndimage
import torch
import tensorflow as tf 

"""
import keras
import tensorflow as tf 
from keras import backend as K

from tensorflow.python.keras.backend import set_session

from keras.models import load_model, Model
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, UpSampling2D, Conv2DTranspose, Reshape, concatenate, BatchNormalization, Activation
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from keras.models import Sequential
from netCDF4 import Dataset"""
#from netCDF4 import Dataset
#import h5netcdf
#import netCDF4

from sklearn.model_selection import train_test_split
import setGPU

# Import custom scripts
sys.path.append('../')
#from process_pangeo import *
#from GC_scripts import *
#from processRCM import *
#from reprojectionFunctions import *
#from MakeInputFunctions import *
#from model import *

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
print(f"Is CUDA supported by this system? {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
  
# Storing ID of current CUDA device
cuda_id = torch.cuda.current_device()
print(f"ID of current CUDA device:{torch.cuda.current_device()}")
        
print(f"Name of current CUDA device:{torch.cuda.get_device_name(cuda_id)}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Load configuration file:
from config import *
from helperFunctions import *

%load_ext autoreload
%autoreload 2

## Load data:

In [None]:
if downloadFromGC:
    downloadFileFromGC(pathGC, '', fileGCMLike)
    GCMLike = xr.open_dataset(fileGCMLike)
    os.remove(fileGCMLike)
else:
    GCMLike = xr.open_dataset(pathCluster+fileGCMLike)
print(GCMLike.dims)
GCMLike

In [None]:
fig = plt.figure(figsize=(20, 10))
ax = plt.subplot(2, 1, 1, projection=ccrs.SouthPolarStereo())
GCMLike.SMB.isel(time=0).plot(
    x="x", ax=ax, transform=ccrs.SouthPolarStereo(), add_colorbar=False
)

ax.coastlines("10m", color="black")
ax.gridlines(draw_labels=True)


ax = plt.subplot(2, 1, 2, projection=ccrs.PlateCarree())
GCMLike.SMB.isel(time=0).plot(
    x="x", ax=ax, transform=ccrs.SouthPolarStereo(), add_colorbar=False
)
ax.coastlines("10m", color="black")
ax.gridlines(draw_labels=True)

In [None]:
if downloadFromGC:
    downloadFileFromGC(pathGC, '', fileTarget)
    target_dataset = xr.open_dataset(fileTarget)
    os.remove(fileTarget)
else:
    target_dataset = xr.open_dataset(pathCluster+fileTarget)
print(target_dataset.dims)

# Cut a small part of on the right that is not too important
N = 160
max_x = (N / 2) * 35 * 1000
max_y = (N / 2) * 35 * 1000

target_dataset = cutBoundaries(target_dataset, max_x, max_y)
print("New target dimensions:", target_dataset.dims)

## Create input and target:

**Z**:
- (ignore for now) External forcing also given to RCM → total concentration of greenhouse gases and solar and ozone forcings
- Cosinus, sinus vector to encode information about day of year
- Daily spatial means and standard deviations time series for each $X_{i,j,m}$ (because normalising 2D variables removes temporal information)

**X**: 
SHAPE `[nbmonths, x, y, nb_vars]`

For pytorch need to put into `[nbmonths, nb_vars, x, y]`

In [None]:
def input_maker(
    GCMLike,
    size_input_domain=16,  # size of domain, format: 8,16,32, must be defined in advance
    stand=True,  # standardization
    seas=True,  # put a cos, sin vector to control the season, format : bool
    means=True,  # add the mean of the variables raw or stdz, format : bool
    stds=True,  # add the std of the variables raw or stdz, format : bool
    resize_input=True,  # resize input to size_input_domain
    region="Larsen",  # region of interest
    print_=True
):

    if region != "whole antarctica":
        DATASET = createLowerInput(GCMLike, region = region, Nx=48, Ny=25, print_=False)
    else:
        DATASET = GCMLike
    """
    MAKE THE 2D INPUT ARRAY
    SHAPE [nbmonths, x, y, nb_vars]
    """

    # Remove target variable from DATASET:
    DATASET = DATASET.drop(["SMB"])

    nbmonths = DATASET.dims["time"]
    x = DATASET.dims["x"]
    y = DATASET.dims["y"]
    nb_vars = len(list(DATASET.data_vars))
    VAR_LIST = list(DATASET.data_vars)

    INPUT_2D_bf = np.transpose(
        np.asarray([DATASET[i].values for i in VAR_LIST]), [1, 2, 3, 0]
    )

    # if no size is given, take smallest power of 2
    if size_input_domain == None:
        size_input_domain = np.max(
            [
                highestPowerof2(INPUT_2D_bf.shape[1]),
                highestPowerof2(INPUT_2D_bf.shape[2]),
            ]
        )

    if resize_input:
        # resize to size_input_domain
        INPUT_2D = resize(INPUT_2D_bf, size_input_domain, size_input_domain, print_)
    else:
        INPUT_2D = INPUT_2D_bf

    if stand:
        # Standardize:
        INPUT_2D_SDTZ = standardize(INPUT_2D)
        # in their code with aerosols extra stuff but ignore
        INPUT_2D_ARRAY = INPUT_2D_SDTZ
    else:
        INPUT_2D_ARRAY = INPUT_2D

    if print_:
        print("Parameters:\n -------------------")
        print("Size of input domain:", size_input_domain)
        print("Region:", region)
        print("\nCreating 2D input X:\n -------------------")
        print(f"Number of variables: {nb_vars}")
        print(f"Variables: {VAR_LIST}")
        print(f"INPUT_2D shape: {INPUT_2D_ARRAY.shape}")
        print("\nCreating 1D input Z:\n -------------------")
        
    """
    MAKE THE 1D INPUT ARRAY
    CONTAINS MEANS, STD SEASON IF ASKED
    """
    
    INPUT_1D = []
    if means and stds:
        vect_std = INPUT_2D.std(axis=(1, 2))
        vect_means = INPUT_2D.mean(axis=(1, 2))
        SpatialMean = vect_means.reshape(INPUT_2D.shape[0], 1, 1, INPUT_2D.shape[3])
        SpatialSTD = vect_std.reshape(INPUT_2D.shape[0], 1, 1, INPUT_2D.shape[3])

        INPUT_1D.append(SpatialMean)
        INPUT_1D.append(SpatialSTD)
        if print_:
            print(f"SpatialMean/std shape: {SpatialMean.shape}")

    if seas:
        months = 12
        cosvect = np.tile(
            [cos(2 * i * pi / months) for i in range(months)],
            int(INPUT_2D.shape[0] / months),
        )
        sinvect = np.tile(
            [sin(2 * i * pi / months) for i in range(months)],
            int(INPUT_2D.shape[0] / months),
        )
        cosvect = cosvect.reshape(INPUT_2D.shape[0], 1, 1, 1)
        sinvect = sinvect.reshape(INPUT_2D.shape[0], 1, 1, 1)

        INPUT_1D.append(cosvect)
        INPUT_1D.append(sinvect)
        if print_:
            print(f"Cos/sin encoding shape: {cosvect.shape}")

    INPUT_1D_ARRAY = np.concatenate(INPUT_1D, axis=3)
    if print_:
        print(f"INPUT_1D shape: {INPUT_1D_ARRAY.shape}")
        
    DATASET.close()
    return INPUT_2D_ARRAY, INPUT_1D_ARRAY, VAR_LIST

In [None]:
# Download GMC like RCM input data:
downloadFromGC = False
fileGCMLike = "MAR(ACCESS1-3)-stereographic_monthly_GCM_like.nc"
if downloadFromGC:
    downloadFileFromGC(pathGC, "", fileGCMLike)
    GCMLike = xr.open_dataset(fileGCMLike)
    os.remove(fileGCMLike)
else:
    GCMLike = xr.open_dataset(pathCluster+fileGCMLike)

region = "Larsen"
size_input_domain = 32

# Make input
i2D, i1D, VAR_LIST = input_maker(
    GCMLike,
    size_input_domain,
    stand=True,  # standardization
    seas=True,  # put a cos,sin vector to control the season, format : bool
    means=True,  # add the mean of the variables raw or stdz, format : r,s,n
    stds=True,
    resize_input=True,
    region=region,
)

inputs_2D = []
inputs_1D = []
inputs_1D.append(i1D)
inputs_2D.append(i2D)

# Make a non standardised version for plots:
i2D_ns, i1D_ns, var_list = input_maker(
    GCMLike,
    size_input_domain,
    stand=False,  # standardization
    seas=True,  # put a cos,sin vector to control the season, format : bool
    means=True,  # add the mean of the variables raw or stdz, format : r,s,n
    stds=True,
    resize_input=False,
    region=region,
    print_=False,
)

inputs_2D_ns = []
inputs_1D_ns = []
inputs_1D_ns.append(i1D_ns)
inputs_2D_ns.append(i2D_ns)

In [None]:
def target_maker(
    target_dataset, 
    region="Larsen", # region of interest
    resize=True, # resize to target_size
    target_size=None # if none provided and resize true, set to min highest power of 2
):
    target_times = []
    targets = []

    if region != "Whole antarctica":
        lowerTarget = createLowerTarget(target_dataset, region = region, Nx=64, Ny=64, print_=False)
        targetArray = lowerTarget.SMB.values
    else:
        targetArray = target_dataset.SMB.values

    targetArray = targetArray.reshape(
        targetArray.shape[0], targetArray.shape[1], targetArray.shape[2], 1
    )

    if target_size == None:
        # resize to highest power of 2:
        target_size = np.min(
            [
                highestPowerof2(targetArray.shape[1]),
                highestPowerof2(targetArray.shape[2]),
            ]
        )

    if resize:
        target_SMB = resize(targetArray, target_size, target_size)
    else:
        target_SMB = targetArray

    targets.append(target_SMB)
    target_times.append(target_dataset.time.values)

    full_target = np.concatenate(targets, axis=0)

    return full_target, target_times

In [None]:
full_target, target_times = target_maker(
    target_dataset, region="Larsen", resize=False
)
# Full target to model
target_time = np.concatenate(target_times, axis=0)
target_lon = target_dataset["x"]
target_lat = target_dataset["y"]

In [None]:
# Full input to model
full_input=[np.concatenate(inputs_2D,axis=0),np.concatenate(inputs_1D,axis=0)]
full_input_ns=[np.concatenate(inputs_2D_ns,axis=0),np.concatenate(inputs_1D_ns,axis=0)]

print("Shapes of targets and inputs:\n---------------------------")
print("Target:", full_target.shape)
print("Input 2D:", full_input[0].shape)
print("Input 1D:", full_input[1].shape)

### Change to tensors:

In [None]:
from torch.utils.data import Dataset
from torch.utils.data import TensorDataset

from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader, random_split

In [None]:
X = torch.tensor(full_input[0].transpose(0, 3,1,2))
Z = torch.tensor(full_input[1].transpose(0, 3,1,2))
Y = torch.tensor(full_target.transpose(0, 3,1,2))
X.shape, Z.shape, Y.shape

In [None]:
seed = 123

dataset = TensorDataset(X, Z, Y)
loader = DataLoader(
    dataset,
    batch_size=32
)

# Example:
(x, z, y) = next(iter(loader))
print(x.shape, z.shape, y.shape)
    
# 2. Split into train / validation partitions
val_percent = 0.2
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(seed))

In [None]:
# 3. Create data loaders
batch_size = 32
loader_args = dict(batch_size=batch_size)
train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size)
val_loader = DataLoader(val_set, shuffle=False, batch_size=batch_size)

# Display image and label.
train_X, train_Y, train_Z = next(iter(train_loader))
print(f"2D Feature batch shape: {train_X.size()}")
print(f"1D Feature batch shape: {train_Z.size()}")
print(f"Labels batch shape: {train_Y.size()}")

In [None]:
sample2dtrain_ = train_X.cpu().detach().numpy()
sampletarget_ = train_Y.cpu().detach().numpy()

dt = pd.to_datetime([GCMLike.time.isel(time=randTime).values])
time = str(dt.date[0])

if REGION == "Larsen":
        sample2dtrain_ = resize(sample2dtrain_, 25, 48, print_=False)
else:
        sample2dtrain_ = resize(sample2dtrain_, 25, 90, print_=False)

    sampletarget_ = sampletarget_.reshape(
        sampletarget_.shape[1], sampletarget_.shape[0], 1
    )
    M = 2
    vmin = np.min(sampletarget_)
    vmax = np.max(sampletarget_)

In [None]:
ax = plt.subplot(1, 1, 1, projection=ccrs.SouthPolarStereo())
plotTrain(GCMLike, sample2dtrain_, 4, ax, time, VAR_LIST, region='Larsen')

In [None]:
ax = plt.subplot(1, 1, 1, projection=ccrs.SouthPolarStereo())
plotTarget(target_dataset, sampletarget_, ax, vmin, vmax, region='Larsen')

In [None]:
class CustomDataset(Dataset):
    def __init__(self, X, Z, Y):
        self.target = Y
        self.X = X
        self.Z = Z
    def __len__(self):
        return len(self.target)
    
    def __getitem__(self, idx):
        X2D = self.X[idx,:,:,:]
        Z1D = self.Z[idx,:,:,:]
        target = self.target[idx,:,:,:]
        
        return {
            'X': X2D,
            'Z': Z1D,
            'target': target
        }
    
dataset = CustomDataset(X,Z,Y)

In [None]:
# 3. Create data loaders
batch_size = 32
loader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True)
train_loader = DataLoader(train_set, shuffle=True, **loader_args)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

In [None]:
# Display image and label.
train_X, train_Z, train_Y = next(iter(train_loader))
print(f"Feature batch shape: {train_X.size()}")
print(f"Labels batch shape: {train_Z.size()}")
print(f"Labels batch shape: {train_Y.size()}")