In [None]:
# Import libraries
import sys
import os

import numpy as np
import pandas as pd
import scipy.io as sio
import h5py

import matplotlib.pyplot as plt
import seaborn as sns

import math
import random
from itertools import combinations_with_replacement

import torch
import torch.nn as nn
import torch.optim as optim 
from torch.utils.data import DataLoader, TensorDataset, random_split

from pathlib import Path

from tqdm import tqdm

In [None]:
#---- INITIALISE PATHS ----
class pathsBib:
    data_path = 'data/'
    model_path = 'model/'
    res_path = 'res/'


def init_env():
    """
    Returns:
        datafile    :   (str) file name
    
    """
    is_init_path = init_path()

    datafile = None 

    if is_init_path:
        datafile = pathsBib.data_path + 'DeltasOmegasAmpl_Re280.mat'
        print(f"Data file: {datafile}")
    else:
        print("ERROR: failed to initialize path!")
        sys.exit()

    return datafile
     

def init_path():
    """
    Returns:
        is_init_path()  :   (bool) if initialisation is successful

    """

    is_init_path = False
    try:
        print("#"*30)
        print(f"Start initialization of paths")
        path_list = [i for _,i in pathsBib.__dict__.items() if type(i)==str and "/" in i]
        print(path_list)
        for pth in path_list:
            Path(pth).mkdir(exist_ok=True)
            print(f"INIT:\t{pth}\tDONE")
        print("#"*30)
        is_init_path = True
    except:
        print(f"ERROR: failed to initialise path. Please, check setup for your path!")
        sys.exit()

    return is_init_path



#---- CREATE LIBRARY OF DMD MODES ----
class DMD:
    def __init__(self, datafile, n_total, n_test, f_modes, order=1):
        """
        Args:
            datafile    :   (str) path to the data
            n_total     :   (int) number of samples
            n_test      :   (int) number of test samples
            f_modes     :   (arr) frequencies of robust DMD modes
            order       :   (int) largest order of non-linearity 

        """
        self.datafile = datafile

        self.n_total = n_total
        self.n_test = n_test
        self.n_train = self.n_total - self.n_test # number of training samples

        self.f_modes = f_modes
        self.n_modes = len(self.f_modes) 

        self.order = order

        self.casename = f'DMD_ntest{self.n_test}_nmodes{self.n_modes}'
        self.filename = pathsBib.res_path + self.casename

        print(f"DMD file name:\n {self.filename}")


    def load_data(self):
        f = sio.loadmat(self.datafile)
        list(f.keys())

        data = f['DeltasOmegasAmpl']
        self.deltas = data[:, 0] 
        self.omegas = data[:, 1]


    def search_robust_modes(self, eps=1e-2):
        """
        Search robust frequencies in the data

        Args:
            eps     :   (float) tolerance for frequency matching
        
        """
        n_ind = []
        for i in self.f_modes:
            ind = np.where((np.abs(self.omegas) < i + eps) & (abs(self.omegas) > i - eps))[0]
            if len(ind) == 2: # must have a conjugate pair
                n_ind.append(ind) 
            else:
                print(f'WARNING: frequency {i} was not found, check value and/or tolerance and/or data!')
        
        n_ind = np.concatenate(n_ind) 

        # Truncate 'deltas' and 'omegas' 
        self.n_deltas = self.deltas[n_ind]
        self.n_omegas = self.omegas[n_ind]
        print(f'Found {len(self.n_omegas)} robust modes with frequencies: {self.n_omegas}')


    def reconst_exp(self, delta, omega, t, t0=0, delta_null=True):
        """
        Reconstruct exponential function for each DMD mode

        Args:
            delta       :   (float) growth rate
            omega       :   (float) frequency
            t           :   (float) instantaneous time
            t0          :   (float) initial time 

        """
        if delta_null:
            delta = 0

        return np.exp((delta + 1j * omega) * (t - t0))


    def make_library(self, dt, delta_null=True):
        """
        Create library of robust modes 

        Args:
            dt          :   (float) time step 
            delta_null  :   (bool) if True, set delta to zero 
        
        """
        self.tt = np.arange(0, self.n_total * dt, dt) 

        self.library = np.zeros((len(self.n_omegas) // 2, len(self.tt)), dtype='complex') 
        for k in range(len(self.tt)):
            for m in range(len(self.n_omegas) // 2):
                self.library[m, k] = DMD.reconst_exp(self.n_deltas[2 * m], self.n_omegas[2 * m], self.tt[k], t0=0, delta_null=delta_null) + \
                                     DMD.reconst_exp(self.n_deltas[2 * m + 1], self.n_omegas[2 * m + 1], self.tt[k], t0=0, delta_null=delta_null) # add the complex conjugate

        self.library = self.library.real # take only the real part
        

    def compute_nonlinear(self, remove_duplicates=True):
        """
        Add non-linear combinations of modes to the library (if applicable)

        Args:
            remove_duplicates   :   (bool) if True, remove duplicated modes

        Returns:
            n_lib               :   (int) number of unique modes         
        
        """
        print('Largest order of non-linear interactions:', self.order)
        self.order_list = list(range(2, self.order + 1))

        nl_combs = 0
        for i in self.order_list:
            nl_combs += math.factorial(self.n_modes + i - 1) // (math.factorial(i) * math.factorial(self.n_modes - 1)) # number of distinct combinations
        print('Number of non-linear modes:', nl_combs)

        nl_library = np.zeros((nl_combs, self.n_total)) 
        count = 0
        for what_order in self.order_list:
            print(f'Computing non-linear modes of order {what_order}')
            for i, comb in enumerate(combinations_with_replacement(range(self.n_modes), what_order)):
                print(f'Mode combination {comb}')
                nl_library[count, :] = np.prod(self.library[list(comb), :], axis=0)
                count += 1

        # Concatenate with linear modes
        self.library = np.concatenate((self.library, nl_library), axis=0)
        print(f'Library shape: {self.library.shape}')

        self.index_list = np.arange(self.library.shape[0]) # list of unique indices
        # Remove duplicates
        if remove_duplicates:
            count = 0
            for i in self.index_list[:-1]: 
                count += 1
                for j in self.index_list[count:]: 
                    # Scale modes (min-max)
                    M = (self.library[i, :] - np.min(self.library[i, :])) / (np.max(self.library[i, :]) - np.min(self.library[i, :]))
                    N = (self.library[j, :] - np.min(self.library[j, :])) / (np.max(self.library[j, :]) - np.min(self.library[j, :]))
                    # Compute mean absolute error (MAE)
                    MAE = np.mean(np.abs(M - N))
                    print(f'Comparing modes M{i+1} and M{j+1}, MAE = {MAE:.4f}')
                    if MAE < 1e-1:             
                        print(f'Duplicate found: M{i+1} and M{j+1}, removing M{j+1}')
                        self.index_list = np.delete(self.index_list, j)

        self.library = self.library[self.index_list, :] 
        n_lib = self.library.shape[0] # number of unique modes
        print(f'Unique library shape: {self.library.shape}')

        # Split training and testing data
        train_data = self.library[:, :self.n_train]
        test_data = self.library[:, self.n_train:]
        print(f'Train data shape: {train_data.shape}')
        print(f'Test data shape: {test_data.shape}')
        
        # Save data
        f = h5py.File(pathsBib.data_path + 'DMD_library_Re280.h5py', 'w')
        f.create_dataset('train', data=train_data)
        f.create_dataset('test', data=test_data)

        return n_lib
            
    
    def plot_library(self):
        # Create labels for non-linear modes
        labels = []
        for what_order in self.order_list:
            for i, comb in enumerate(combinations_with_replacement(range(self.n_modes), what_order)):
                # Adapt to the order of modes
                label = 'M' + 'M'.join([str(x + 1) for x in sorted(comb)])
                labels.append(label)
        
        # Add linear modes at the beginning of the list
        for i in range(self.n_modes):
            labels.insert(i, f'M{i + 1}')
        labels = [labels[i] for i in self.index_list] # reorder labels
        print(f'Labels: {labels}')

        fig, axs = plt.subplots(self.library.shape[0], 1, figsize=[20, self.library.shape[0]], sharex=True)
        for i, ax in enumerate(axs):
            ax.plot(self.tt, self.library[i, :], color='black', linestyle='-', linewidth=1)
            ax.set_xlim([0, np.max(self.tt)])
            ax.tick_params(axis='both', which='major', labelsize=10)
            ax.set_ylabel(labels[i], fontsize=10)

        axs[-1].set_xlabel(rf'$t$', fontsize=12)
        plt.tight_layout()
        


#---- MAKE DATALOADER ----
class dataclass:
    def __init__(self, input_len, output_len, batch_size, train_split, scaling):
        """
        Args:
            input_len       :   (int) length of input sequence
            output_len      :   (int) length of output sequence
            batch_size      :   (int) batch size
            train_split     :   (float) ratio of train and validation split (if 1, no validation)
            scaling         :   (str) type of scaling (minmax, standard)

        """
        self.input_len = input_len
        self.output_len = output_len
        self.batch_size = batch_size
        self.train_split = train_split
        self.scaling = scaling


    def get_data(self):
        """
        Create Dataloader for training and validation (if applicable)

        """
        try:
            f = h5py.File(pathsBib.data_path + 'DMD_library_Re280.h5py', 'r')
            data = np.array(f['train'])
            f.close()
            if self.scaling:
                data_norm = dataclass.normalise(data, self.scaling, 'enc')     
            target = np.loadtxt(pathsBib.data_path + 'TKE_Re280.txt')
            target = target[:data.shape[1]] # length should match training data!
            if self.scaling:
                target_norm = dataclass.normalise(target, self.scaling, 'dec')
        except:
            print(f"ERROR: failed to find data. Please, check path or file!")
            sys.exit()
        
        X, Y = dataclass.make_Sequence(self, data=data_norm, target=target_norm)
        self.train_dl, self.val_dl = dataclass.make_Dataloader(torch.from_numpy(X), torch.from_numpy(Y),
                                                    batch_size=self.batch_size,
                                                    drop_last=False,
                                                    train_split=self.train_split)
        print(f"INFO: DataLoader has been generated!")
        del data, data_norm, target, target_norm, X, Y
        return self.train_dl, self.val_dl


    def make_Sequence(self, data, target):
        """
        Generate time-delay sequence data

        Returns:
            X   :   (arr) Encoder data
            Y   :   (arr) Decoder (labeled) data

        """
        if len(data.shape) <=2:
            data = np.expand_dims(data,0)
        if len(target.shape) <=1:
            target = np.expand_dims(target,0)
        nSamples = data.shape[-1] - self.input_len - self.output_len + 1
        # Initialise return arrays
        X = np.empty([nSamples, self.input_len, data.shape[1]])
        Y = np.empty([nSamples, self.output_len + 1, target.shape[0]]) # add initialization decoder
        k = 0
        for i in tqdm(np.arange(data.shape[0])):
            for j in np.arange(data.shape[-1] - self.input_len - self.output_len):
                X[k] = np.transpose(data[i, :, j       :j+self.input_len]) # put sequence first for LSTM
                Y[k] = np.expand_dims(np.transpose(target[i, j+self.input_len-1:j+self.input_len+self.output_len]), axis=-1)
                k    = k + 1

        print(f"The training data has been generated with shape of {X.shape, Y.shape}")

        return X, Y


    def make_Dataloader(X, y, batch_size, drop_last=False, train_split=1):
        """
        Args:
            drop_last           :   (bool) if True, drop the last batch if it does not have same number of samples

        Return:
            train_dl, val_dl    :   train and validation DataLoader

        """
        dataset = TensorDataset(X, y)

        len_d = len(dataset)
        train_size = int(train_split * len_d)
        valid_size = len_d - train_size

        train_d, val_d = random_split(dataset, [train_size, valid_size])

        train_dl = DataLoader(train_d, batch_size=batch_size, drop_last=drop_last, shuffle=True)
        if valid_size > 0:
            val_dl = DataLoader(val_d, batch_size=batch_size, drop_last=drop_last, shuffle=True)
        else:
            val_dl = None

        return train_dl, val_dl


    def normalise(data, scaling, encdec):
        """
        Args:
            encdec      :   (str) 'enc' for encoder data, 'dec' for decoder data

        Returns:
            data_norm   :   (arr) normalised data

        """
        if scaling == "minmax":
            minval = np.min(data)
            maxval = np.max(data)
            np.save(pathsBib.data_path + f'minmax-scaling-{encdec}.npy', [minval, maxval])
            data_norm = (data - minval) / (maxval - minval)
        elif scaling == "standard":
            meanval = np.mean(data)
            stdval = np.std(data)
            np.save(pathsBib.data_path + f'standard-scaling-{encdec}.npy', [meanval, stdval])
            data_norm = (data - meanval) / stdval
        else:
            print(f"ERROR: failed to normalise data. Please, check scaling type!")
            sys.exit()

        return data_norm


    def reverse_normalise(data, scaling, encdec):
        """
        Args:
            encdec  :   (str) 'enc' for encoder data, 'dec' for decoder data

        Returns:
            data    :   (arr) non-normalised data

        """
        if scaling == "minmax":
            minval, maxval = np.load(pathsBib.data_path + f'minmax-scaling-{encdec}.npy')
            data = data * (maxval - minval) + minval
        elif scaling == "standard":
            meanval, stdval = np.load(pathsBib.data_path + f'standard-scaling-{encdec}.npy')
            data = data * stdval + meanval
        else:
            print(f"ERROR: failed to reverse normalise data. Please, check scaling type!")
            sys.exit()

        return data



# ---- NETWORKS ----
class EncoderLSTM(nn.Module):
    def __init__(self, input_dim, lstm_dim, num_layer=1, dropout=0.0):
        """ 
        Encoder Long-Short Term Memory (LSTM) network

        Args:
            input_dim   :   (int) input dimension of model 
            lstm_dim    :   (int) hidden dimension of LSTM
            num_layer   :   (int) number of LSTM layers
            dropout     :   (float) dropout rate between LSTM layers (if num_layer > 1)

        """
        super(EncoderLSTM, self).__init__()
        self.lstm_dim = lstm_dim
        self.num_layer = num_layer

        self.lstm = nn.LSTM(input_size=input_dim, hidden_size=self.lstm_dim, num_layers=self.num_layer, dropout=dropout, batch_first=True)


    def init_hidden(self, batch_size, device):
        hidden = torch.zeros(self.num_layer,
                            batch_size,
                            self.lstm_dim).to(device)

        cell  =  torch.zeros(self.num_layer,
                            batch_size,
                            self.lstm_dim).to(device)
        return hidden, cell
        

    def forward(self, input_tensor):
        """
        Returns:
            output      :   (tensor) output from last layer (shape [batch_size, seq_len, hidden_dim])
            hidden      :   (tensor) last hidden state (shape [num_layers, batch_size, hidden_dim])
            cell        :   (tensor) last cell state (shape [num_layers, batch_size, hidden_dim])
         
        """
        hidden, cell = self.init_hidden(input_tensor.shape[0], device=input_tensor.device)

        output, (hidden, cell) = self.lstm(input_tensor, (hidden.detach(), cell.detach())) 

        return output, hidden, cell 


class Attention(nn.Module):
    def __init__(self, method, hidden_dim):
        """
        Compute attention weights 

        Args:
            method          :   (str) type of attention ('dot', 'general', 'concat')
            hidden_dim      :   (int) hidden dimension of LSTM

        """
        super(Attention, self).__init__()
        self.method = method
        self.hidden_dim = hidden_dim

        if self.method == 'general':
            self.attn = nn.Linear(hidden_dim, hidden_dim)
        elif self.method == 'concat':
            self.attn = nn.Linear(hidden_dim * 2, hidden_dim)
            self.other = nn.Parameter(torch.empty(1, hidden_dim))

        self.init_weights()


    def init_weights(self):
        if self.method == 'concat':
            nn.init.uniform_(self.other, -0.1, 0.1)


    def forward(self, hidden, encoder_output):
        """ 
        Args:
            hidden          :   (tensor) decoder hidden state, shape [batch_size, 1, hidden_dim]
            encoder_output  :   (tensor) encoder state, shape [batch_size, seq_len, hidden_dim]

        Returns:
            attn            :   (tensor) attention weights, shape [batch_size, 1, seq_len]
        
        """
        if self.method =='dot':
            score = torch.bmm(hidden, encoder_output.transpose(1, 2))
        elif self.method == 'general':
            score = self.attn(encoder_output) 
            score = torch.bmm(hidden, score.transpose(1, 2))
        elif self.method == 'concat':
            score = torch.tanh(self.attn(torch.cat((hidden.expand(-1, encoder_output.size(1), -1), encoder_output), dim=2))) 
            other_exp = self.other.expand(score.size(0), -1, -1)
            score = torch.bmm(other_exp, score.transpose(1, 2))

        # Normalise attention scores to weights in range 0 to 1
        attn = nn.functional.softmax(score, dim=2)
        
        return attn
            
            
class DecoderAttentionLSTM(nn.Module):
    def __init__(self, output_dim, lstm_dim, mlp_dim, num_layer=1, dropout=0.0, attn_model=None, output_len=1):
        """ 
        Decoder LSTM network with attention (if applicable)

        Args:
            output_dim      :   (int) output dimension of model
            lstm_dim        :   (int) hidden dimension of LSTM
            mlp_dim         :   (int) hidden dimension of MLP
            num_layer       :   (int) number of LSTM layers
            dropout         :   (float) dropout rate between LSTM layers (if num_layer > 1)
            attn_model      :   (str) type of attention model ('BAH', 'L-DOT', 'L-GEN', 'L-CON')
            output_len      :   (int) length of output sequence
        
        """
        super(DecoderAttentionLSTM, self).__init__()
        self.lstm_hidden = lstm_dim
        self.num_layer = num_layer
        self.attn_model = attn_model
        self.output_len = output_len

        if attn_model is None:
            self.lstm = nn.LSTM(input_size=output_dim, hidden_size=lstm_dim, num_layers=num_layer, dropout=dropout, batch_first=True)
        else:
            self.proj = nn.Linear(output_dim, lstm_dim)
            self.lstm = nn.LSTM(input_size=lstm_dim * 2, hidden_size=lstm_dim, num_layers=num_layer, dropout=dropout, batch_first=True)
            # Bahdanau attention
            if attn_model == 'BAH': 
                self.attn = Attention(method='concat', hidden_dim=lstm_dim) 
            # Luong attention
            elif attn_model == 'L-DOT' or attn_model == 'L-GEN' or attn_model == 'L-CON':
                if attn_model == 'L-DOT':
                    self.attn = Attention(method='dot', hidden_dim=lstm_dim)
                elif attn_model == 'L-GEN':
                    self.attn = Attention(method='general', hidden_dim=lstm_dim)
                elif attn_model == 'L-CON':
                    self.attn = Attention(method='concat', hidden_dim=lstm_dim)

                self.ffn_attn = nn.Linear(lstm_dim * 2, lstm_dim) 
            else:
                print(f"ERROR: attention model {attn_model} is not recognized!")
                sys.exit()

        self.mlp = nn.Sequential(
            nn.Linear(lstm_dim, mlp_dim),
            nn.ReLU(),
            nn.Linear(mlp_dim, lstm_dim),
        )
        self.out = nn.Linear(lstm_dim, output_dim)


    def forward(self, encoder_hidden, encoder_cell, encoder_outputs, true_tensor, device, eps):
        """ 
        Args:
            true_tensor     :   (tensor) true values, shape [batch_size, output_len + 1, output_dim] 
            encoder_hidden  :   (tensor) encoder hidden state, shape [num_layers, batch_size, hidden_dim]
            encoder_cell    :   (tensor) encoder cell state, shape [num_layers, batch_size, hidden_dim]
            encoder_outputs :   (tensor) encoder outputs, shape [batch_size, input_len, hidden_dim]
            device          :   device to run the model on
            eps             :   (float) probability of teacher forcing (1.0 = use true data, 0.0 = use predicted data)

        Returns:
            outputs         :   (tensor) model output, shape [batch_size, output_len, output_dim]
            attentions      :   (tensor) attention weights, shape [batch_size, output_len, input_len] 
        
        """
        device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Initialise decoder
        decoder_input = true_tensor[:, 0, :].unsqueeze(1) # initialise with last labeled data from encoder
        decoder_hidden = encoder_hidden[-self.num_layer:, :, :] # correct if num_layer mismatch
        decoder_cell = encoder_cell[-self.num_layer:, :, :]

        # Initialise attentional hidden state
        if self.attn_model == 'L-DOT' or self.attn_model == 'L-GEN' or self.attn_model == 'L-CON':
            hidden_tilde = torch.zeros(encoder_outputs.size(0), 1, self.lstm_hidden).to(device) 

        outputs = []
        attentions = []
        for i in range(self.output_len):
            if self.attn_model is None:
                output, decoder_hidden, decoder_cell = self.forward_step(
                        decoder_input, decoder_hidden, decoder_cell
                    )
            else:
                if self.attn_model == 'BAH':
                    output, decoder_hidden, decoder_cell, attn_weights = self.forward_step_bahdanau(
                        decoder_input, decoder_hidden, decoder_cell, encoder_outputs
                    )
                elif self.attn_model == 'L-DOT' or self.attn_model == 'L-GEN' or self.attn_model == 'L-CON':
                    output, decoder_hidden, decoder_cell, hidden_tilde, attn_weights = self.forward_step_luong(
                        decoder_input, decoder_hidden, decoder_cell, encoder_outputs, hidden_tilde
                    )
                attentions.append(attn_weights)
            outputs.append(output)

            if random.random() < eps: 
                # Teacher forcing
                decoder_input = true_tensor[:, i+1, :].unsqueeze(1)
            else:
                # Scheduled sampling
                decoder_input = output.detach()

        outputs = torch.cat(outputs, dim=1) 
        if self.attn_model:
            attentions = torch.cat(attentions, dim=1) 

        return outputs, attentions

    
    def forward_step(self, input_tensor, hidden_lstm, cell_lstm):
        out_lstm, (hidden_lstm, cell_lstm) = self.lstm(input_tensor, (hidden_lstm, cell_lstm))

        # Final output layer
        out_ffn = self.mlp(out_lstm)
        output = self.out(out_ffn)
        return output, hidden_lstm, cell_lstm


    def forward_step_bahdanau(self, input_tensor, hidden_lstm, cell_lstm, encoder_outputs):
        input_tensor_proj = self.proj(input_tensor)

        # Calculate attention weights from previous LSTM state and all encoder outputs
        attn_weights = self.attn(hidden_lstm[-1].unsqueeze(1), encoder_outputs)
        context = torch.bmm(attn_weights, encoder_outputs) # context vector

        # Combine LSTM input and context vector
        in_lstm = torch.cat((input_tensor_proj, context), dim=2)
        out_lstm, (hidden_lstm, cell_lstm) = self.lstm(in_lstm, (hidden_lstm, cell_lstm))

        # Final output layer
        out_ffn = self.mlp(out_lstm)
        output = self.out(out_ffn)  
        return output, hidden_lstm, cell_lstm, attn_weights


    def forward_step_luong(self, input_tensor, hidden_lstm, cell_lstm, encoder_outputs, hidden_tilde):
        input_tensor_proj = self.proj(input_tensor)

        # Combine LSTM input and last attentional hidden state (input-feeding)
        in_lstm = torch.cat((input_tensor_proj, hidden_tilde), dim=2) 
        out_lstm, (hidden_lstm, cell_lstm) = self.lstm(in_lstm, (hidden_lstm, cell_lstm))

        # Calculate attention weights from current LSTM state and all encoder outputs
        attn_weights = self.attn(out_lstm, encoder_outputs)
        context = torch.bmm(attn_weights, encoder_outputs) # context vector

        # Compute the attentional hidden state
        in_htilde = torch.cat((out_lstm, context), dim=2)
        hidden_tilde = torch.tanh(self.ffn_attn(in_htilde))

        # Final output layer
        out_ffn = self.mlp(hidden_tilde)
        output = self.out(out_ffn)
        return output, hidden_lstm, cell_lstm, hidden_tilde, attn_weights
    

class Seq2Seq(nn.Module):
    def __init__(self, encoder_params, decoder_params):
        """
        Sequence-to-sequence model

        Args:
            encoder_params  :   (dict) Parameters for the encoder LSTM
            decoder_params  :   (dict) Parameters for the decoder LSTM

        """
        super(Seq2Seq, self).__init__()

        self.encoder = EncoderLSTM(input_dim=encoder_params['input_dim'], 
                                   lstm_dim=encoder_params['lstm_dim'], 
                                   num_layer=encoder_params['num_layer'], 
                                   dropout=encoder_params['dropout']
                                   )
        self.decoder = DecoderAttentionLSTM(output_dim=decoder_params['output_dim'], 
                                            lstm_dim=decoder_params['lstm_dim'], 
                                            mlp_dim=decoder_params['mlp_dim'],
                                            num_layer=decoder_params['num_layer'], 
                                            dropout=decoder_params['dropout'],
                                            attn_model=decoder_params['attn_model'], 
                                            output_len=decoder_params['output_len']
                                            )
        

    def forward(self, input_tensor, true_tensor, device=None, eps=0.0):
        encoder_outputs, encoder_hidden, encoder_cell = self.encoder(input_tensor)
        outputs, attentions = self.decoder(encoder_hidden, encoder_cell, encoder_outputs, true_tensor, device, eps)

        return outputs, attentions
    


# ---- TRAINING LOOP ----
def train(device, model, train_dl, loss_fn, optimizer, scheduler=None, num_epoch=100, eps_huber=0.0, schsam=None, val_dl=None):
    """
    Args: 
        device      :   the device for training (it) should match the model's device!)
        model       :   the model to be trained
        train_dl    :   dataloader for training
        loss_fn     :   loss function     
        optimizer   :   optimizer function
        scheduler   :   scheduler function
        num_epoch   :   (int) number of epochs
        eps_huber   :   (float) threshold for Huber loss (if 0, use MSE loss)
        schsam      :   (str) type of scheduled sampling 
        val_dl      :   Dataloader for validation (if applicable)

    Returns:
        history     :   (dict) contains training and validation (if applicable) losses

    """
    history = {}
    history["train_loss"] = []

    if val_dl:
        history["val_loss"] = []

    model.to(device)

    for epoch in range(num_epoch):
        model.train() # change to training mode

        loss_val = 0; num_batch = 0
        eps_schsam = scheduled_sampling(schsam) # get probability for teacher forcing
        for batch in tqdm(train_dl):
            x, y = batch
            x = x.to(device).float(); y = y.to(device).float()
            optimizer.zero_grad()

            pred, _ = model(x, y, device=device, eps=eps_schsam)

            if eps_huber != 0: 
                if num_batch != 0 and num_batch % (len(train_dl)-1) == 0 and epoch % 10 == 0:
                    with torch.no_grad():
                        e = (pred - y[:, 1:, :]).abs()
                        frac_mae = (e > eps_huber).float().mean().item()
                    print(f"INFO: fraction of samples in MAE regime: {frac_mae:.3f}")

            loss = loss_fn(pred, y[:, 1:, :])
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0) # gradient clipping

            optimizer.step()

            loss_val += loss.item()/x.shape[0]
            num_batch += 1
        
        if scheduler is not None:
            scheduler.step()

        history["train_loss"].append(loss_val/num_batch)
    
       
        if val_dl:
            model.eval() # change to evaluation mode

            loss_val = 0; num_batch = 0
            for batch in (val_dl):
                x, y = batch
                x = x.to(device).float(); y = y.to(device).float()

                pred, _ = model(x, y, device=device, eps=0.0) # no teacher 
                loss = loss_fn(pred, y[:, 1:, :]) 

                loss_val += loss.item()/x.shape[0]
                num_batch += 1

            history["val_loss"].append(loss_val/num_batch)

        train_loss = history["train_loss"][-1]
        if val_dl:
            val_loss = history["val_loss"][-1]
            print(f"At Epoch    = {epoch+1},\n"
                  f"Train_loss  = {train_loss},\n"
                  f"Val_loss    = {val_loss},\n" 
                  f"eps_schsam  = {eps_schsam:.4f}"       
            )
            if scheduler is not None:
                print(f"LR      = {scheduler.get_last_lr()[0]:.4e}")
        else:
            print(f"At Epoch    = {epoch+1},\n"
                  f"Train_loss  = {train_loss},\n"
                  f"eps_schsam  = {eps_schsam:.4f}"   
            )
            if scheduler is not None:
                print(f"LR      = {scheduler.get_last_lr()[0]:.4e}")

    return history


def scheduled_sampling(schsam):
    """ 
    Args:
        schsam      :   (str) type of scheduled sampling ('no_teacher', 'scheduled_sampling', else teacher forcing)

    Returns:
        eps         :   (float) probability for teacher forcing 
        
    """
    if schsam == 'no_teacher':
        eps = 0.0
    elif schsam == 'scheduled_sampling':
        eps = 0.5
    else: # teacher forcing 
        eps = 1.0 
    
    return eps



# ---- TEST ----
class testclass:
    def __init__(self, device, input_len, n_test, input_dim, output_dim, scaling):
        self.device = device
        self.input_len = input_len
        self.n_test = n_test
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.scaling = scaling
        

    def load_pretrain_model(self, encoder_params, decoder_params):
        """
        Load state_dict and history of pre-trained model
    
        Returns:
            stat_dict   :   learnable parameters 
            history     :   history of training

        """
        model_path = pathsBib.model_path + f"ADAN_statedict-history_Att-{config.decoder_params['attn_model']}" + ".pt"

        try:
            ckpoint = torch.load(model_path, weights_only=False, map_location=self.device)
        except:
            print("ERROR: model not found!")
            sys.exit()

        stat_dict = ckpoint['model']

        decoder_params['output_len'] = self.n_test
        self.model = Seq2Seq(encoder_params, decoder_params)
        self.model.load_state_dict(state_dict=stat_dict)
        self.history = ckpoint['history']

        print(f'INFO: the state dict has been loaded!')
        print(self.model.eval)

        return self.model


    def test(self):
        try:
            f = h5py.File(pathsBib.data_path + 'DMD_library_Re280.h5py', 'r')
            data = np.transpose(np.array(f['train']))
            f.close()
            nsamples = data.shape[0] # number of training samples
            if self.scaling:
                data_norm = dataclass.normalise(data, self.scaling, 'enc')
            data_norm = data_norm[-self.input_len:, :] # get last input_len samples to initialise encoder
            target = np.loadtxt(pathsBib.data_path + 'TKE_Re280.txt')
            if self.scaling:
                target_norm = dataclass.normalise(target[:nsamples], self.scaling, 'dec') # match length of training data
            target_norm = target_norm[-1] # initialisation decoder 
        except:
            print(f"ERROR: failed to find data. Please, check path or file!")
            sys.exit()

        self.model.eval()
        self.model.to(self.device)

        print(f"INFO: Testing model")

        input_tensor = torch.from_numpy(data_norm[None, :, :]).float().to(self.device) # [1, input_len, input_dim]
        true_tensor = torch.from_numpy(np.expand_dims(target_norm, axis=(0, 1, 2))).float().to(self.device) # expand dims to [1, 1, output_dim]

        print(f"INFO: starting autoregressive rollout for {self.n_test} steps")
        if config.decoder_params['attn_model'] is None:
            pred, _ = self.model(input_tensor, true_tensor, device=self.device, eps=0.0)
            pred = pred.cpu().detach().numpy()
            print(f"INFO: Testing completed, size of predictions = {pred.shape}")
        else:
            pred, attn = self.model(input_tensor, true_tensor, device=self.device, eps=0.0)
            pred = pred.cpu().detach().numpy()
            self.attn = attn.cpu().detach().numpy()
            print(f"INFO: Testing completed, size of predictions = {pred.shape}, size of attention matrix = {self.attn.shape}")

        if self.scaling:
            self.output = dataclass.reverse_normalise(pred, self.scaling, 'dec')

        # Save results
        np.savez_compressed(
        file = pathsBib.res_path + 'ADAN_Preds.npz',
        out = self.output
        )


    def plot_loss(self):
        """
        Plot training and validation (if applicable) losses

        """
        fig, axs = plt.subplots(1, 1, figsize=(10,4))

        axs.plot(self.history["train_loss"], color='blue', linestyle='-', linewidth=1.5, marker='o', markersize=5)
        if len(self.history["val_loss"]) != 0:
            axs.plot(self.history["val_loss"], color='red', linestyle='-', linewidth=1.5, marker='^', markersize=5)
        axs.set_yscale('log')
        axs.grid(True, which='both')
        axs.set_xlim([-1, len(self.history["train_loss"])])
        axs.tick_params(axis='both', which='major', labelsize=14)
        axs.set_xlabel("Training epoch", fontsize=14)
        axs.set_ylabel("Loss", fontsize=14)

        axs.legend(["Train", "Validation"], fontsize=14)


    def plot_pred(self):
        """
        Plot prediction

        """
        output_toplot = np.squeeze(self.output) 

        try:
            target = np.loadtxt(pathsBib.data_path + 'TKE_Re280.txt')
            target = target[-output_toplot.shape[0]:] 
        except:
            print(f"ERROR: failed to find data. Please, check path or file!")
            sys.exit()

        fig, ax = plt.subplots(figsize=(20, 2), sharex=True)

        ax.plot(target, color='black', linestyle='-', linewidth=2, label='True')
        ax.plot(output_toplot, color='blue', linestyle='-', linewidth=2, label='Prediction')
        ax.tick_params(axis='both', which='major', labelsize=12)
        ax.set_xlabel(r"Prediction step",fontsize=14)
        ax.set_ylabel(r"TKE", fontsize=14)

        plt.legend(ncol=2, loc='upper center', fontsize=14)


    def plot_attn(self):
        """ 
        Plot attention matrix (if applicable)
        
        """
        # Turn into DataFrame
        labelx = [f'T={i+1}' for i in range(self.attn.shape[1])]
        labely = [f't={i+1}' for i in range(self.attn.shape[2])]
        attn_df = pd.DataFrame(self.attn[0, :, :].T, index=labely, columns=labelx)

        fig, axs = plt.subplots(1, 1, figsize=(12, 4))
        sns.heatmap(attn_df, cmap='Reds', vmin=0.0, vmax=np.max(attn_df), square=False, cbar_kws={"shrink": .5, 'label': 'Attention weights'})
        axs.tick_params(axis='both', which='major', labelsize=8)
        

In [None]:
# ---- CONFIGURATION ----
class config:
    n_test = 128 

    input_len = 64
    batch_size = 32
    train_split = 0.8
    scaling = "standard" # "minmax", "standard"

    encoder_params = {
    'lstm_dim'  :   128,    
    'mlp_dim'   :   512, # dimension MLP in ResNet encoder
    'num_layer' :   3, # number of LSTM layers or ResLSTM blocks
    'dropout'   :   0.2, # dropout between LSTM layers (if num_layer > 1)
    }

    decoder_params = {
    'lstm_dim'  :   128,
    'mlp_dim'   :   512,
    'num_layer' :   3, # must be <= encoder num_layer!
    'dropout'   :   0.2,    
    'attn_model':   None, # 'BAH', 'L-DOT', 'L-GEN', 'L-CON', None (no attention)
    'output_len':   64,
    }

    eps_huber = 5e-2 # threshold for Huber loss; if 0, use MSE loss
    lr = 1e-3
    schlr = True # if True, use learning rate scheduler (exponential decay)
    num_epoch = 1000
    schsam = 'scheduled_sampling' # 'no_teacher', 'scheduled_sampling', else teacher forcing


In [None]:
# Create environment
datafile = init_env()

In [None]:
# Create library of DMD modes
target_tensor = np.loadtxt(pathsBib.data_path + 'TKE_Re280.txt')
if len(target_tensor.shape) <=1:
    target_tensor = np.expand_dims(target_tensor,0)
    output_dim, n_total = target_tensor.shape

DMD = DMD(datafile, n_total, config.n_test, np.array([1.33, 2.29, 2.66]), order=1)

DMD.load_data() 

DMD.search_robust_modes(eps=1e-2) 

DMD.make_library(dt=1, delta_null=True)

n_lib = DMD.compute_nonlinear(remove_duplicates=True) 

DMD.plot_library()

In [None]:
DL = dataclass(config.input_len, config.decoder_params["output_len"], config.batch_size, config.train_split, config.scaling)
train_dl, val_dl = DL.get_data()
print(f"INFO: number of training batches: {len(train_dl)}")
if val_dl:
    print(f"INFO: number of validation batches: {len(val_dl)}")

In [None]:
config.encoder_params['input_dim'] = n_lib
config.decoder_params['output_dim'] = output_dim
if config.decoder_params['num_layer'] > config.encoder_params['num_layer']:
    print(f"WARNING: decoder num_layer must be smaller or equal than encoder num_layer!")

model = Seq2Seq(config.encoder_params, config.decoder_params)
print(model)
NumPara = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"INFO: the model has been generated, the number of parameter is {NumPara}")

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"INFO: the device has been assigned to {device} ")

In [None]:
loss_fn = nn.HuberLoss(delta=config.eps_huber) if config.eps_huber != 0 else nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=config.lr)
if config.schlr == True:
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
else:
    scheduler = None

In [None]:
# Check if pre-trained model exists (avoid training)
model_path = pathsBib.model_path + f"ADAN_statedict-history_Att-{config.decoder_params['attn_model']}" + ".pt"
if not os.path.isfile(model_path):
    print(f"INFO: start training!")
    history = train(device, model, train_dl, loss_fn, optimizer, scheduler, num_epoch=config.num_epoch, eps_huber=config.eps_huber, schsam=config.schsam, val_dl=val_dl)
    print(f"INFO: training finished!")

    check_point = {"model":model.state_dict(),
                   "history":history,
                   }
    
    torch.save(check_point, model_path)
    print(f"INFO: the checkpoint has been saved!")

In [None]:
print(f"INFO: start testing!")
TT = testclass(device, config.input_len, config.n_test, config.encoder_params['input_dim'] , config.decoder_params['output_dim'], config.scaling)

model = TT.load_pretrain_model(config.encoder_params, config.decoder_params)
print(f"INFO: the model has been loaded, the number of parameter is {NumPara}")

TT.test()

In [None]:
TT.plot_loss()

In [None]:
TT.plot_pred()

In [None]:
if config.decoder_params['attn_model'] is not None:
    TT.plot_attn()