In [None]:
import xarray as xr
import numpy as np
import importlib
from matplotlib import pyplot as plt
import DA_core as DA
from glob import glob
import torch.utils.data as Data
from torch import optim
import torch
from torchsummary import summary
import ML_core as ML
from numpy.random import default_rng
import os

rng = default_rng()
DA.read_data_dir='./data'
DA.save_data_dir='./data'

data_dir='./data'
# data_dir='/net2/fnl/PyQG/data'
B_ens_kws={'cmap':'bwr','levels':np.linspace(-2.5E-11,2.5E-11,26),'extend':'both'}
B_ens_kws1={'cmap':'bwr','levels':np.linspace(-0.25E-11,0.25E-11,26),'extend':'both'}
q_kws={'cmap':'bwr','levels':np.linspace(-3.4E-5,3.4E-5,18),'extend':'both'}

In [None]:
importlib.reload(ML)

In [None]:
os.system('nvidia-smi')

In [None]:
torch.cuda.is_available()

In [None]:
# Select experiment for training data
DA_paras={'nens':80,
          'DA_method':'EnKF',
          'Nx_DA':32,
          'Nx_truth':128,
          'obs_freq':10,
          'obs_err':[1,-5,5,-7],
          'DA_freq':10,
          'save_B':False,
          'nobs':[50,50],
          'R_W':100,
          'inflate':[1,0.5]}
DA_exp=DA.DA_exp(**DA_paras)
print(DA_exp.file_name())
# obs_ds=DA_exp.read_obs()
in_ch=[0,1]
out_ch=[0,1,2]

In [None]:
# Read ensemble-mean q and ensemble covariance B for the selected DA experiment
mean_ds=DA_exp.read_mean().load()
print(mean_ds.q.shape)
# if DA_exp.nens>1:  
#     std_ds=DA_exp.read_std()

# Read the full covariance matrix for all time steps and model gridpoints
B_ens_ds=xr.open_dataset('{}/training/{}/B_sample.nc'.format(data_dir,DA_exp.file_name()))
print(B_ens_ds.B_ens.shape)

ml_std_ds=xr.open_dataset('./data/std_{0}.nc'.format(DA_exp.file_name()))
print(ml_std_ds)

B_R=int((len(B_ens_ds.x_d)-1)/2)
B_size=16
B_start=0

DA_days=slice(9,365,DA_exp.DA_freq)
DA_it=slice(int((DA_days.start-DA_exp.DA_freq+1)/DA_exp.DA_freq),int((DA_days.stop-DA_exp.DA_freq+1)/DA_exp.DA_freq)+1)
print(DA_days,DA_it)
i_x=np.arange(0,DA_exp.Nx_DA)
i_y=np.arange(0,DA_exp.Nx_DA)

In [None]:
B_shape=B_ens_ds.B_ens.isel(time=DA_it,y=i_y,x=i_x).shape
print(B_shape)
print(mean_ds.q.isel(time=DA_days).shape)
B_nt=B_shape[0]
B_ny=B_shape[2]
B_nx=B_shape[3]
B_total=B_nt*B_ny*B_nx
print(B_total)
n_train=int(B_total*0.8)
# rngs=rng.permutation(B_total)
rngs=np.arange(B_total)
partition={'train':rngs[0:n_train],'valid':rngs[n_train:]}
        
train_ds=ML.Dataset(mean_ds.q.isel(time=DA_days),DA_exp.Nx_DA,
                    B_ens_ds.B_ens.isel(time=DA_it,y=i_y,x=i_x),i_y,i_x,
                    partition['train'],ml_std_ds.q_std.data,ml_std_ds.B_std.data,in_ch,out_ch,B_size=B_size,B_start=B_start)
valid_ds=ML.Dataset(mean_ds.q.isel(time=DA_days),DA_exp.Nx_DA,
                    B_ens_ds.B_ens.isel(time=DA_it,y=i_y,x=i_x),i_y,i_x,
                    partition['valid'],ml_std_ds.q_std.data,ml_std_ds.B_std.data,in_ch,out_ch,B_size=B_size,B_start=B_start)

params = {'batch_size':16,'num_workers':1,'shuffle':True}
training_generator = torch.utils.data.DataLoader(train_ds, **params)
validation_generator = torch.utils.data.DataLoader(valid_ds, **params)

In [None]:
# Calculate level-wise standard deviations for normalization
B_std=np.empty((2,2))
B_std[0,0]=np.std(B_ens_ds.B_ens.isel(time=DA_it,lev=0,lev_d=0))
B_std[0,1]=np.std(B_ens_ds.B_ens.isel(time=DA_it,lev=0,lev_d=1))
B_std[1,0]=B_std[0,1]
B_std[1,1]=np.std(B_ens_ds.B_ens.isel(time=DA_it,lev=1,lev_d=1))
print(B_std)

q_std=np.zeros((2,1))
q_std[0]=np.std(mean_ds.q.isel(time=DA_days,lev=0))
q_std[1]=np.std(mean_ds.q.isel(time=DA_days,lev=1))
print(q_std)

B_da=xr.DataArray(B_std,coords=[mean_ds.lev,mean_ds.lev])
q_da=xr.DataArray(q_std.squeeze(),coords=[mean_ds.lev])
std_ds=xr.Dataset({'B_std':B_da,'q_std':q_da})
std_ds.to_netcdf('./ML/{0}/std_{0}.nc'.format(DA_exp.file_name()))

In [None]:
# device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model=ML.Unet(in_ch=len(in_ch),out_ch=len(out_ch))
model=model.to(device)
print(device)

# check keras-like model summary using torchsummary
summary(model, input_size=(len(in_ch),train_ds.B_size,train_ds.B_size))

In [None]:
criterion = torch.nn.MSELoss() # MSE loss function
optimizer = optim.Adam(model.parameters(), lr=0.002)

In [None]:
model=model.double()
n_epochs = 30 #Number of epocs
validation_loss = list()
train_loss = list()
start_epoch=0
if start_epoch>0:
    model_file='./ML/unet_epoch{}_in{}_out{}_B{}_{}.pt'.format(start_epoch,''.join(map(str,in_ch)),''.join(map(str,out_ch)),B_size,DA_exp.file_name())
    print(model_file)
    model.load_state_dict(torch.load(model_file,map_location=torch.device('cpu')))
# time0 = time()  
for epoch in range(start_epoch+1, n_epochs + 1):
    train_loss.append(ML.train_model(model,criterion,training_generator,optimizer,device))
    # validation_loss.append(ML.test_model(model,criterion,validation_generator,optimizer,device))
    torch.save(model.state_dict(), './ML/{}/unet_epoch{}_in{}_out{}_B{}_{}.pt'.\
        format(DA_exp.file_name(),epoch,''.join(map(str,in_ch)),''.join(map(str,out_ch)),B_size,DA_exp.file_name()))

In [None]:
plt.plot(train_loss,'b', label='training loss');
plt.plot(validation_loss,'r', label='validation loss');

plt.legend();