In [95]:
# %load modules.py
#!/usr/bin/env python
# @Project      : DeepLearningProject2
# @Author       : Xiaoyu LIN
# @File         : modules.py
# @Description  :
import math
import torch
from torch import FloatTensor, LongTensor, Tensor

In [146]:
class Linear():
    """
    Fully connected layer
    """
    def __init__(self, input_units, hidden_units):
        """
        :param input_units: number of the input tensor
        :param hidden_units: number of hidden units in the layer
        
        """
        self.input_units = input_units
        self.hidden_units = hidden_units
 #       self.input = FloatTensor(input_size)
 #       self.output = FloatTensor(hidden_units)

        self.weights = FloatTensor(hidden_units, input_units).normal_(0, math.sqrt(2/input_units))
        self.biases = FloatTensor(hidden_units).zero_()

        self.grad_wrt_weights = [FloatTensor(hidden_units, input_units).zero_()]
        self.grad_wrt_biases = [FloatTensor(hidden_units).zero_()]
        
    def forward(self, input):
        """

        :param input:
        :return:
        """
        self.input = input
        self.output = input.matmul(self.weights.t()) + self.biases
#        raise NotImplementedError
        return self.output

    def backward(self, grad_wrt_output):
        """

        :param grad_wrt_output:
        :return:
        """
        self.grad_wrt_weights = grad_wrt_output.t().matmul(self.input)
        self.grad_wrt_biases =  grad_wrt_output.sum(0)
        self.grad_wrt_input = grad_wrt_output.mutmul(self.weights)
#        raise NotImplementedError
        return self.grad_wrt_input

    def gradient_descent(self, learning_rate):
        """
        Performs the weights and biases updates
        :param step_size: step size of the updates
        """
        # updating weights and biases
        self.weights -= learning_rate * self.grad_wrt_weights
        self.biases -= step_size * self.grad_wrt_biases
        
    def param(self):
        """

        :return:
        """
        parameters = {"weights": self.weights, "biases": self.biases}
        gradients = {"grad_wrt_weights": self.grad_wrt_weights, "grad_wrt_biases": self.grad_wrt_biases,"grad_wrt_input": self.grad_wrt_input}

        return parameters, gradients

class Tanh():
    """
    Tanh function
    """
    def __init__(self, input_units):
        pass

    def forward(self, input):
        self.input = input
        self.output = self.input.tanh_()
#        raise NotImplementedError

        return self.output

    def backward(self, grad_wrt_output):
#        raise NotImplementedError
        self.grad_wrt_input = grad_wrt_output * (1 - self.output * self.output)
    
        return self.grad_wrt_input
#    def param(self):
#        return []


class ReLU(object):
    """
    ReLu function
    """
    def __init__(self, input_units):
        pass
        
    def forward(self, input):
        self.input = input
        self.output = self.input.relu_()
        
        return self.output
    
    def backward(self,grad_wrt_output):
        derivative = self.input
        derivative[derivative > 0.0] = 1.0
        derivative[derivative <= 0.0] = 0.0
        self.grad_wrt_input = grad_wrt_output * derivative
        
        return self.grad_wrt_input
#    def param(self):
#        return []


class Sequential(object):

    def forward(self, *input):
        raise NotImplementedError

    def backward(self, *gradwrtoutput):
        raise NotImplementedError

    def param(self):
        return []


class LossMSE(object):
    """
    compute the MSE loss.
    """
    def __init__(self):
        pass
    
    def compute_loss(predictions, targets):
        self.predictions = predictions
        self.targets = targets
        self.samples_num = predictions.size(0)
        targets.reshape([-1,1])
        predictions.reshape([-1,1])
        # make sure shapes are the same
        assert predictions.shape == targets.shape
        
        return ((predictions-targets)**2).mean(0, True).item()/self.samples_num
    
    def compute_grad():
        self.grad_wrt_pred = ((predictions-targets)*2)/self.samples_num
        
        return self.grad_wrt_pred
#    def forward(self, input):
#        raise NotImplementedError

#    def backward(self, gradwrtoutput):

#        return ((predictions-targets)*2)/self.samples_num
#    def param(self):
#        return []