In [None]:
'''hRnnPtLitIni_v1b3.ipynb [++++] A minimalist definition of RNNs using PyTorch and Lightning.
''';
# AUTHOR: Hendrik Mandelkow

# Imports

In [None]:
import numpy as np
import torch
import lightning as pl
import matplotlib.pyplot as plt


In [None]:
def hbarg(X,W=0.9):
    X = np.r_['0,2',X]
    N = len(X)
    h = [ plt.bar(np.arange(x.size)+n*W/N-W/2+1,x,W/N,align='edge') for n,x in enumerate(X) ]
    plt.xticks(np.arange(1,X[0].size+1))
    return h


# DataSet

# PyTorch Models RNN + RNN-AE
1. declare Mdl a subclass of nn.Module
2. declare all required layer objects as attribures Mdl.Layer1 = ...
3. declare graph linking objects by defining forward(x)

In [None]:
class hRnn(torch.nn.Module):
    def __init__(self, MdlPar, Base=None):
        super().__init__()
        self.MdlPar = MdlPar.copy()
        #< self.Nt = Nt
        self.layers = torch.nn.ModuleList()
        self.lossfun = torch.nn.MSELoss() # default
        
        LayerCl = torch.nn.LSTM
        LayerAct = torch.nn.Identity()
        Bidir = False
        
        LayerNin = MdlPar[0]
        for L in MdlPar[1:]:
            if L in ['lstm']:
                LayerCl = torch.nn.LSTM
            elif L in ['gru']:
                LayerCl = torch.nn.GRU
            elif L in ['bidir']:
                Bidir=True
            elif L in ['linear']:
                LayerCl = torch.nn.Linear
            elif L in ['bilinear']:
                LayerCl = torch.nn.Bilinear
            elif L in ['relu']: 
                LayerAct = torch.nn.ReLU()
            elif L in ['lelu']: 
                LayerAct = torch.nn.LeakyReLU(0.01)
            elif L in ['tanh']: 
                LayerAct = torch.nn.Tanh()
            elif L in ['ident','noact']: 
                LayerAct = torch.nn.Identity()
            elif isinstance(L,int):
                LayerNout = L if Base is None else Base**L
                #< self.layers.append( LayerCl( input_size=LayerNin, hidden_size=LayerNout, num_layers=1, batch_first=True) )
                # layer = LayerCl( input_size=LayerNin, hidden_size=LayerNout, num_layers=1, batch_first=True)
                if LayerCl is torch.nn.Linear:
                    layer = LayerCl( LayerNin, LayerNout )
                    Bidir = False
                elif LayerCl is torch.nn.Bilinear:
                    layer = LayerCl( LayerNin, LayerNin, LayerNout )
                    Bidir = False
                else:
                    # layer = LayerCl( input_size=LayerNin, hidden_size=LayerNout, num_layers=1, batch_first=True)
                    layer = LayerCl( LayerNin, LayerNout, num_layers=1, batch_first=True, bidirectional=Bidir)
                    #< print('Layer: ',layer)
                    
                layer.activation = LayerAct
                self.layers.append( layer )
                LayerNin = 2*LayerNout if Bidir else LayerNout
                
            elif re.match(r'\d+',L):
                LayerNout = int( re.match(r'\d+',L)[0] ) # NB: match[0] = full match
                Proj = int( re.search(r'\dp(\d+)',L)[1] ) if re.search(r'\dp(\d+)',L) else 0
                Rep = int( re.search(r'\dr(\d+)',L)[1] ) if re.search(r'\dr(\d+)',L) else 1
                layer = LayerCl( LayerNin, LayerNout, proj_size=Proj, num_layers=Rep, batch_first=True, bidirectional=Bidir)
                layer.activation = LayerAct
                self.layers.append( layer )
                LayerNin = 2*LayerNout if Bidir else LayerNout
                                
            else: 
                raise ValueError('Illegal value in MdlPar.')

            
    def forward( self, x, Nt=None):
        if Nt is None: # encoder!
            # Could use np.r_['0,3,1',...]?!
            if x.ndim < 2: x = x.unsqueeze(-1)
            if x.ndim < 3: x = x.unsqueeze(0)
        else: # decoder!
            x = x.reshape(-1,1,x.shape[-1])
            x = x.repeat( 1, Nt, 1)
            
        xhc = (x,)
        for layer in self.layers:
            # NOTE: Cannot use Sequential() model because of this funky output format!?:
            # x, (h_n,c_n) = layer(x) # x(out)=[h_1,h_2,...,h_n]
            #< x = layer(x) # x(out)=[h_1,h_2,...,h_n]
            if isinstance( layer, torch.nn.Bilinear ):
                xhc = layer( xhc[0], xhc[0])
            else:
                xhc = layer( xhc[0]) # x(out)=[h_1,h_2,...,h_n]
            #< if isinstance(x,tuple): x, (h_n,c_n) = x
            # if isinstance(x,tuple): x, hc = x
            if isinstance(xhc,tuple): # +++ Make output uniform: a tuple of non-tuples
                xhc = sum([ n if isinstance(n,tuple) else (n,) for n in xhc],())
            else: 
                xhc = (xhc,)
            xhc = ( layer.activation(xhc[0]), ) + xhc[1:]
            #< x = layer.activation(x)
            # try: x = layer.activation(x)
            # except: pass
            
        # if Nt is None: return h_n
        # else: return x
        #< return x, h_n, c_n
        #< return x, hc
        return xhc
    
    
    @staticmethod
    def mdlpar2id(MdlPar):
        '''[+++] Create model ID (str) for logging.'''
        # MdlId = '-'.join( [str(n) for n in MdlPar[1:-3]]) # e.g. 16-8-4-8-16
        ## More concise only 2^n units / layer
        MdlId = ''.join( [ n[:2].capitalize() if isinstance(n,str) else str(int(np.log2(n))) for n in MdlPar[1:-3]])
        return MdlId


In [None]:
hMdlPar2Id = lambda MdlPar: '-'.join( [str(n) for n in MdlPar[1:-3]])
hMdlPar2Id = lambda MdlPar: ''.join( [ n[:2].capitalize() if isinstance(n,str) else str(int(np.log2(n))) for n in MdlPar[1:-3]])


In [None]:
class hRnnAE(torch.nn.Module):
    '''hRnnAE [1a1]
    Automatically split input MdlPar for encoder and decoder RNN:
    Mdl = hRnnAE( MdlPar = [1,'lstm','relu',32,16,32,'linear','ident',1] ) ->
    -> Mdl.encoder = hRnn( MdlPar = [1,'lstm','relu',32,16] )
    -> Mdl.encoder = hRnn( MdlPar = [16,'lstm','relu',32,'linear','ident',1] )
    ''';
    def __init__( self, MdlPar ):
        super().__init__()
        self.MdlPar = MdlPar # save for good measure
        self.MdlId = self.mdlpar2id(self.MdlPar)
        n = MdlPar.index( min( filter( lambda x: isinstance(x,int), MdlPar[1:-1]) ))
        
        self.encoder = hRnn(MdlPar[:n+1])
        StrPar = list( filter( lambda x: isinstance(x,str), MdlPar[:n]) )
        self.decoder = hRnn( MdlPar[n:n+1] + StrPar + MdlPar[n+1:])
        
        self.lossfun = torch.nn.MSELoss() # reduction="mean" (default)
        # self.lossfun = torch.nn.L1Loss() # reduction="mean" (default)
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        # self.lrscheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer )
        self.logger = pl.pytorch.loggers.CSVLogger('LitLog',self.MdlId,None,'',10) # root, name, version, prefix, flush

        
    def forward(self, x):
        xhc = self.encoder(x)
        Nt = xhc[0].shape[-2]
        xhc = self.decoder(xhc[1], Nt)
        return xhc[0]
    
    
    def fit( self, train_set, valid_set=None, lossfun=None, optimizer=None, lrscheduler=None, device=None):
        # if optimizer is None: 
        # self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        # lossfun = MSELoss(reduction="sum")
        # if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # +++
        # model.to(device)
        return None

    
    @staticmethod
    def mdlpar2id(MdlPar):
        MdlId = '-'.join( [str(n) for n in MdlPar[1:-3]])
        MdlId = ''.join( [ n[:2].capitalize() if isinstance(n,str) else str(int(np.log2(n))) for n in MdlPar[1:-3]])
        return MdlId


In [None]:
class hRnnAE2(torch.nn.Module):
    '''hRnnAE [2a1]
    Use separate inputs for encoder and decoder RNN:
    ''';
    def __init__( self, EncPar, DecPar ):
        super().__init__()
        assert EncPar[-1] == DecPar[0], 'Oops! Encoder output must equal decoder input.'
        self.EncPar = EncPar # save for good measure
        self.DecPar = DecPar # save for good measure        
        self.encoder = hRnn( EncPar )
        self.decoder = hRnn( DecPar )
        
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        self.lossfun = torch.nn.MSELoss() # reduction="mean" (default)

        
    def forward(self, x):
        xhc = self.encoder(x)
        Nt = xhc[0].shape[-2]
        xhc = self.decoder(xhc[1], Nt)
        return xhc[0]
    
    
    def fit( self, train_set, valid_set=None, optimizer=None, lossfun=None, device=None):
        # if optimizer is None: 
        # self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        # lossfun = MSELoss(reduction="sum")
        # if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # +++
        # model.to(device)
        return None


# Lightning wrapper

In [None]:
## Simpler: Just wrap hRnnAE
class hMdlLit(pl.LightningModule):
    def __init__( self, Mdl, **props ):
        super().__init__()
        self.model = Mdl
        self.__dict__.update( **props )
        self.save_hyperparameters() # saves attributes of self
        # self.save_hyperparameters(ignore=['Mdl']) # why Mdl and not "model"?!?

        
    def forward(self, x):
        return self.model(x)

    
    def training_step(self, batch, batch_idx):
        '''training_step returns the loss for one batch.
        The rest of the training loop is handled by Lightning implicitly.
        '''
        if isinstance( batch, (tuple,list)):
            x, y = batch
        else:
            x = batch
            y = x # autoencoder
            
        yh = self(x)
        if isinstance( yh, (tuple,list)): yh = yh[0]
        
        loss = self.model.lossfun(yh, y)
        # Logging to TensorBoard (if installed) by default
        self.log("Tloss", loss, prog_bar=True)
        # self.log("Tloss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    
    def configure_optimizers(self):
        '''Return optimizer obj(s) and (optionally) a LR scheduler(s) obj as list or dict.'''
        #< return torch.optim.Adam( self.parameters(), lr=1e-3)
        #< scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, patience=20, min_lr=5e-5)
        optimizer = getattr(self,'optimizer',None)
        if optimizer is None: optimizer = getattr(self.model,'optimizer',None)
        assert optimizer, 'Oops! No opimizer found in self.optimizer or self.model.optimizer?!?'
        scheduler = getattr(self,'scheduler',None)
        if scheduler is None: scheduler = getattr(self.model,'scheduler',None)
        if scheduler is None:
            return optimizer
        else:
            return optimizer, scheduler

    
    def validation_step( self, batch, batch_idx):
        '''Compute any validation metrics or results and return, log or store them.
        ''';
        #x, y = batch
        x = batch
        xh = self(x)
        loss = self.model.lossfun(xh, x)
        # Logging to TensorBoard (if installed) by default
        self.log("Vloss", loss, prog_bar=True)
        # self.log("Tloss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss
    
    
#     def test_step(self,...): # model.test() == model.test_step()
    
#     def predict_step(self,...): # if undefined == forward()
    
    

## Test with random data