In [None]:
from torch import empty
import torch
torch.set_grad_enabled(False)

### The module

In [None]:
class Module(object):
    '''
    If a tuple is used it has to be starred:
    tup = empty(2, 2).normal_(), empty(1, 2).normal_()
    m = Module()
    m.forward(*tup) 
    m.forward(empty(2, 2).normal_(), empty(1, 2).normal_()) # is also ok
    '''
    
    def forward(self, *input):
        raise NotImplementedError
        
    def backward(self, *gradwrtoutput):
        raise NotImplementedError
        
    def param(self):
        return []


class Tanh(Module):
    
    def __init__(self):
        self.input = None
    
    def tanh(self, x):
        ex = x.exp()
        emx = (-x).exp()
        
        return (ex - emx)/(ex + emx)
    
    def d_tanh(self, x):
        ex = x.exp()
        emx = (-x).exp()
        
        return 4/(ex + emx).pow(2)
        
    def forward (self, *input):
        self.input = input

        return tuple([self.tanh(tensor) for tensor in input])
        
    def backward (self, *gradwrtoutput):     
        return tuple([gradwrtoutput[i] * self.d_tanh(self.input[i]) for i in range(len(self.input))])

### Small example

In [None]:
torch.manual_seed(0)

temp = empty(2, 2).fill_(1), empty(1, 2).fill_(1)
temp_error = empty(1).fill_(2e-4), empty(1).fill_(0)
m = Tanh()

x = m.forward(*temp)
print("forward: {}\n".format(x))

x = m.backward(*temp_error)
print("backward: {}\n".format(x))