In [2]:
import numpy as np 
import torch
import matplotlib.pyplot as plt

In [4]:
class Layer(torch.nn.Module):
    def __init__(self, size_in, size_out, activation_func = torch.sigmoid):
        super(Layer, self).__init__()
        
        self.w = torch.nn.Parameter(
            torch.randn(size_in, size_out, requires_grad = True))
    
        self.b =torch.nn.Parameter(
            torch.randn(1, size_out, requires_grad = True))
        
        self.activation_func = activation_func
        
    def Forward(self, x):
        #print(self.activation_func)
        return self.activation_func(x @ self.w + self.b)
        

In [None]:
class LSTM(torch.nn.Module):
    def __init__(self, size_in, size_out, size_mem_lt, size_mem_st ):
        super(LSTM, self).__init__()
       
        self.forget_gate     = Layer(size_in + size_mem_st, size_mem_lt, activation_func = torch.sigmoid)   
        self.memory_gate     = Layer(size_in + size_mem_st, size_mem_lt, activation_func = torch.sigmoid)
        self.memory_layer    = Layer(size_in + size_mem_st, size_mem_lt, activation_func = torch.tanh   )
        self.lt_recall_layer = Layer(size_mem_lt,size_mem_st, activation_func = torch.tanh              )
        self.st_recall_layer = Layer(size_in + size_mem_st, size_mem_st, activation_func = torch.sigmoid)
        self.Layer_out       = Layer(size_mem_st, size_out)
        self.size_mem_lt     = size_mem_lt
        self.size_mem_st     = size_mem_st
        
    
    def Forward(self, x):
        mem_lt = torch.zeros(1, self.size_mem_lt)
        mem_st = torch.zeros(1, self.size_mem_st)
        out = []
        
        
        for i in range(x.shape[0]):
            z = torch.cat([x[[i], :], mem_st], dim =1)
            forget_gate     = self.forget_gate.Forward(z) 
            memory_gate     = self.memory_gate.Forward(z)
            memory_layer    = self.memory_layer.Forward(z)
            lt_recall_layer = self.lt_recall_layer.Forward(z)
            st_recall_layer = self.st_recall_layer.Forward(z)

            mem = self.layer_0.Forward(z)
            out.append(self.layer_out.Forward(mem))  
        out = torch.cat(out, dim = 0)
        return out
        
        
    def Generate(self, start, iterations):
        mem = torch.randn((1, self.size_mem))
        out = [start]
        for i in range(x.shape[0]):
            z = torch.cat([out[i], mem], dim =1)
            mem = self.layer_0.Forward(z)
            out.append(self.layer_out.Forward(mem))  
        out = torch.cat(out, dim = 0)
        return out
        
        
            