In [None]:
import torch
import torch.nn as nn

# Bulding an LSTM autoencoder 
class Encoder(nn.Module):
    def __init__(self, seq_len, n_features, batch_size, embedding_dim):
        super(Encoder, self).__init__()
        self.seq_len, self.n_features = seq_len, n_features
        self.embedding_dim = embedding_dim
        self.hidden_dim2 = 2 * embedding_dim
        self.hidden_dim1 = 4 * embedding_dim

        self.batch_size = batch_size
        
        self.rnn1 = nn.LSTM(
            input_size=self.n_features,
            hidden_size=self.hidden_dim1, # 128 cells
            num_layers=1,
            batch_first=True
        )
        
        self.rnn2 = nn.LSTM(
            input_size=self.hidden_dim1,  # 128 cells
            hidden_size=self.hidden_dim2,  # 64 cells
            num_layers=1,
            batch_first=True
        )
        
        self.rnn3 = nn.LSTM(
            input_size=self.hidden_dim2,  # 64
            hidden_size=self.embedding_dim,  # 32
            num_layers=1,
            batch_first=True
        )
      
    def forward(self, x):
        #print(x.shape)
        x = x.reshape((self.batch_size, self.seq_len, self.n_features))
        #print(x.shape)
        x, (_, _) = self.rnn1(x)
        x, (_, _) = self.rnn2(x)
        x, (hidden_n, _) = self.rnn3(x)
        #print(x.shape)
        #y = hidden_n.reshape((self.batch_size, self.embedding_dim))
        #print(y.shape)
        return hidden_n.reshape((self.batch_size, self.embedding_dim))

    
class Decoder(nn.Module):
    def __init__(self, seq_len, input_dim, batch_size, n_features):
        super(Decoder, self).__init__()
        self.seq_len = seq_len
        self.input_dim = input_dim
        self.n_features = n_features
        self.batch_size = batch_size
        self.hidden_dim1 = 2 * input_dim
        self.hidden_dim2 = 4 * input_dim
        
        self.rnn1 = nn.LSTM(
            input_size=input_dim, 
            hidden_size=input_dim,  # 34
            num_layers=1,
            batch_first=True
        )
        
        self.rnn2 = nn.LSTM(
            input_size=input_dim,  # 34
            hidden_size=self.hidden_dim2,  # 64
            num_layers=1,
            batch_first=True
        )
        self.rnn3 = nn.LSTM(
            input_size=self.hidden_dim2,  # 64
            hidden_size=self.hidden_dim1,  # 128
            num_layers=1,
            batch_first=True
        )
        self.output_layer = nn.Linear(self.hidden_dim1, n_features)
    def forward(self, x):

        x = x.repeat(self.seq_len, self.n_features)

        x = x.reshape((self.batch_size, self.seq_len, self.input_dim))


        x, (hidden_n, cell_n) = self.rnn1(x)
        x, (hidden_n, cell_n) = self.rnn2(x)
        x, (hidden_n, cell_n) = self.rnn3(x)

        x = x.reshape((self.batch_size, self.seq_len, self.hidden_dim1))

        return self.output_layer(x)

class RAE(nn.Module):
    def __init__(self,seq_len, n_features, embedding_dim, batch_size):
        super(RAE, self).__init__()
        
        self.seq_len, self.n_features = seq_len, n_features
        self.embedding_dim = embedding_dim
        
        self.encoder = Encoder (seq_len, n_features, batch_size, embedding_dim).to(device)
        self.decoder = Decoder (seq_len, embedding_dim, batch_size, n_features).to(device)
    def forward(self,x):
        x = self.encoder(x)
        x = self.decoder(x)
        
        return x


### TRAINING 
def train_model(model,train_dataset,val_dataset, n_epochs, optimizer, criterion):
    since = time.time()
    history = dict(epochs=[], train = [], val = [])
    
    for epoch in range(n_epochs):
        model = model.train()
        train_losses = []
        
        #for seq_true in train_dataset:
        for batch_idx, (data) in enumerate(train_dataset):
            data = data.to(device)
          
            batch_size = data.shape[0]
            
            optimizer.zero_grad()
            data_pred = model(data)
            loss = criterion(data_pred, data)
            
            loss.backward()
            optimizer.step()
            
            train_losses.append(loss.item())
            
        val_losses = []
        model = model.eval()
        with torch.no_grad():
            for batch_idx, (data) in enumerate (val_dataset):

                data = data.to(device)
                batch_size = data.shape[0]
                data_pred = model(data)
                loss = criterion(data_pred, data)
                val_losses.append(loss.item())
                
                
        train_loss = np.mean(train_losses)
        val_loss = np.mean(val_losses)
        history['epochs'].append(epoch)
        history['train'].append(train_loss)
        history['val'].append(val_loss)
        
        
        print(f'Epoch {epoch}: train loss {train_loss} val loss {val_loss}')
        
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    
    return model.eval(),history
        

## PREDICTIONS
def predict(model, dataset):
    predictions, losses = [], []
    #criterion = nn.L1Loss(reduction='sum').to(device)
    model = model.eval()
    criterion = nn.MSELoss(reduction='mean').to(device)   # nn.L1Loss sum

    with torch.no_grad():
        for batch_idx, (data) in enumerate (dataset):
            data = data.to(device)
            batch_size = data.shape[0]
            data_pred = model(data)
            loss = criterion(data_pred, data)     
            
            predictions.append(data_pred.cpu().numpy())  #.flatten()
            losses.append(loss.item())
    return predictions, losses

def plot_prediction(data,model,title,ax):
    predictions, pred_losses = predict(model,[data])
    
    ax.plot(data, label = 'true')
    ax.plot(predictions[0],label = 'predicted')
    ax.set_title(f'{title} (loss: {np.around(pred_losses[0],2)})')
    ax.legend()
    
