In [None]:
#This takes the output from the ConvLSTM model, adds a randomly assigned "track" to the data, and then recursively calculates the outout until there is a string of 30 days

In [None]:
#These are the versions used in Scott's paper
!pip3 install keras==2.6
!pip3 install TensorFlow==2.6.1
!pip3 install TensorFlow-Addons==0.14.0
!pip3 install numpy==1.21.6


In [None]:
#This randomly picks a track and adds it to the output
import numpy as np
def faketrack(datadir,output,n_train=2000,batch_size=25,n_t=30):
    tracks=[]
    for sample in range(output.shape[0]):
        batchno=np.random.randint(n_train)
        sampleno=np.random.randint(batch_size)
        day=np.random.randint(n_t)
        track=np.load(datadir+f'training/batch{batchno}_invar.npy')
        track=track[sampleno,day,:,:,0]!=0
        tracks.append(track)
    tracks=np.stack(tracks)
    return np.expand_dims(np.squeeze(output)*tracks,1)

In [None]:
from src.generators import *
from src.models import *
from src.losses import *
n_t=30
n_train=1000
experiment_name = 'convlstm_sla_future_'+ f'{n_t}days_{n_train}samples'
mean_ssh=np.load("gs_mean_ssh_future.npy")
std_ssh=np.load("gs_std_ssh_future.npy")
mean_sst=np.load("gs_mean_sst_future.npy")
std_sst=np.load("gs_std_sst_future.npy")
stats = (mean_ssh, std_ssh, mean_sst, std_sst)
datadir='/home/jovyan/pre-processed-future-fixed/'
model_weights_dir = '/home/jovyan/deep-learning-ssh-mapping-JAMES-paper/src/model_weights_future/'+experiment_name+'.h5'
#load the model and the weights from training

model=create_ConvLSTM_SLA(n_t,one_output=True)
model.load_weights(model_weights_dir)
val='validation/'

In [None]:
#This repeatedly runs the model on the data
n_val=10
invar=[]

for ID in range(n_val):
    filename=datadir+val+f'batch{ID}_invar.npy'
    invar.append(np.load(filename)[:,:,:,:,0])
invar=np.concatenate(invar,axis=0)
prediction=[]
for day in range(n_t):
    prediction.append(model.predict(invar))
    invar=np.concatenate([invar[:,1:,:],faketrack(datadir,prediction[-1],n_train=1000)],axis=1) #8/15 this rolls the input period to include the prediction
prediction=np.stack(prediction,axis=1)

In [None]:
#This loads the output data
batch_size=25
outvar=np.zeros((n_val,batch_size,n_t,10000,3))

for ID in range(n_val):
    filename=datadir+f'validation/batch{ID}_outvar.npy'
    out=np.load(filename)
    l=out.shape[2]
    outvar[ID,:,:,:l,:]=out
    
outvar=outvar.reshape((n_val*batch_size,n_t,10000,3))
track_array = np.zeros(outvar.shape)
for batch in range(outvar.shape[0]):
    for t in range(outvar.shape[1]):
        x = outvar[batch,t,:,0].copy()
        x[x!=0] = ((x[x!=0]+0.5*960e3)/960e3)*(128-1)
        y = outvar[batch,t,:,1].copy()
        y[y!=0] = ((-y[y!=0]+0.5*960e3)/960e3)*(128-1)
        outvar[batch,t,:,0]=x
        outvar[batch,t,:,1]=y
        

In [None]:
def bilinear_interp(x,y,x_grid,y_grid,z_grid):
    dx = x_grid[1]-x_grid[0]
    dy = y_grid[1]-y_grid[0]
    Nx = x_grid.shape[0]-1
    Ny = y_grid.shape[0]-1
    # print(Ny)
    # print(x_grid.shape)


    x0_idx = (((x-x_grid[0])/np.abs(x_grid[-1]-x_grid[0]))*Nx).astype('int')
    y0_idx = (((y-y_grid[0])/np.abs(y_grid[-1]-y_grid[0]))*Ny).astype('int')
    # print(y0_idx)

    x1_idx = x0_idx + 1
    y1_idx = y0_idx + 1

    x0 = x_grid[x0_idx]
    x1 = x_grid[x1_idx]
    y0 = y_grid[y0_idx]
    y1 = y_grid[y1_idx]

    x_n = (x-x0)/dx
    y_n = (y-y0)/dy

    z00 = z_grid[Ny-y0_idx,x0_idx]
    z10 = z_grid[Ny-y0_idx,x1_idx]
    z01 = z_grid[Ny-y1_idx,x0_idx]
    z11 = z_grid[Ny-y1_idx,x1_idx]

    z_interp = z00*(1-x_n)*(1-y_n)+z10*x_n*(1-y_n)+z01*y_n*(1-x_n)+z11*x_n*y_n
    return z_interp

In [None]:
# This calculates the MSE loss
per_pred = np.squeeze(prediction)*std_ssh+mean_ssh

tracks_persistence = outvar
x_grid = np.arange(128)
y_grid = np.arange(128)
loss = np.zeros(outvar.shape[:2])
for t in range(outvar.shape[0]):
    for l_t in range(outvar.shape[1]):
        ssh_true = tracks_persistence[t,l_t,:,-1]
        x = tracks_persistence[t,l_t,:,0]
        y = tracks_persistence[t,l_t,:,1]
        x = x[ssh_true!=0] #there were some SSH = NaN observations that I zero padded in the pre-processing
        y = 127-y[ssh_true!=0]
        ssh_true = ssh_true[ssh_true!=0]
        ssh_pred = bilinear_interp(x,y,x_grid,y_grid,per_pred[t,l_t,:,:])
        loss[t,l_t] = np.mean((ssh_true-ssh_pred)**2)
loss[loss==0] = np.nan

In [None]:
#plot
import matplotlib.pyplot as plt
dailyloss=np.nanmean(loss,axis=0)**0.5
plt.plot(dailyloss)
plt.xlabel("Lead time")
plt.ylabel("RMSE loss")
plt.title("Recursively calculated loss")
n_val

In [None]:
#This creates images of the results which can then be turned into a video using ffmpeg
import matplotlib.pyplot as plt
import matplotlib as mpl
import scipy


batchno=6 #which batch we will inspect
sampleno=0 #which sample we will inspect
tmax=30 #days per sample
n=128 #pixels per region
L_x = 960e3 # size of domain
L_y = 960e3  # size of domain


savedir='/home/jovyan/images/oneoutput/' #where you want to save the images to
    
index=0
mode2='Validation'
for batchno in range(1):
    output_data=np.load(datadir+val+"batch"+str(batchno)+"_outvar.npy") #output data from batch to compare to
    outputsample=output_data[sampleno,:,:,:] #specific sample
    for day in range(tmax):

        #make the axes and the colorbar
        fig, axs = plt.subplot_mosaic([['ax1', 'ax1','ax2','ax2'],
         ['ax1', 'ax1','ax2','ax2'],['colorbar', 'colorbar','colorbar','colorbar'],['BLANK', 'BLANK', 'BLANK','BLANK'],],empty_sentinel="BLANK",figsize=(10,10))
        cmap = mpl.cm.viridis
        norm = mpl.colors.Normalize(vmin=-1, vmax=1)
        fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
            cax=axs['colorbar'], orientation='horizontal', label='SSH')

        #plot the prediction
        axs['ax1'].imshow(prediction[batchno*25+sampleno,day,:,:]*std_ssh+mean_ssh,norm=norm,cmap=cmap)
        input_=np.load(datadir+val+"batch"+str(batchno)+"_invar.npy")

        #bin and plot the output data
        outputgrid, _,_,_ = scipy.stats.binned_statistic_2d(outputsample[day,:,0].flatten(), outputsample[day,:,1].flatten(), outputsample[day,:,2].flatten(), statistic = 'mean', bins=n, range = [[-L_x/2, L_x/2],[-L_y/2, L_y/2]])
        outputgrid = np.rot90(outputgrid)
        outputgrid[np.isnan(outputgrid)] = 0
        axs['ax2'].imshow(outputgrid,norm=norm,cmap=cmap)

        #figure and axis titles
        fig.suptitle("Mode:"+ mode2 + ", Batch number: " + str(batchno) + ", Sample number: " + str(sampleno) + ", Day:" + str(day), fontsize=16)
        axs['ax1'].set_title("Model prediction")
        axs['ax2'].set_title("Binned output data")

        #save the figure to the output data
        plt.savefig(savedir+(str(index).zfill(3))+'.jpg')
        index+=1