In [None]:
from typing import Union
import numpy as np
from time import time


class ShinyActivations:
    class Activation:
        delta = None

        def __repr__(self):
            return f"{self.__class__.__name__} Activation"

    class Relu(Activation):
        def __call__(self, *args, **kwargs):
            return (args[0] > 0) * args[0]

    relu = Relu


class ShinyDeactivations:
    class Deactivation:
        def __repr__(self):
            return f"{self.__class__.__name__} Deactivation"

    class Relu(Deactivation):
        def __call__(self, *args, **kwargs):
            return args[0] > 0

    relu = Relu


class ShinyLayers:
    default_learning_rate = 1e-4

    class CLayer:
        ...

    class Layer(CLayer):
        def __init__(self, output_size, learning_rate=None):
            self.output_size = output_size
            self.learning_rate = ShinyLayers.default_learning_rate if learning_rate is None else learning_rate
            self.weights: np.array = None

        def get_output_size(self):
            return self.output_size

        def __repr__(self):
            return f"{self.__class__.__name__} Layer: output size = {self.output_size}"

    class Linear(Layer):
        def __call__(self, *args, **kwargs):
            return np.dot(args[0], self.weights)

    class Input(CLayer):
        def __init__(self, size):
            self.size = size

        def __repr__(self):
            return f"Input Layer: size = {self.size}"


class ShinyErrorMeasure:
    class ErrorMeasure:
        def __call__(self, *args, **kwargs):
            return self.error_measurement(args[0], args[1])

        def error_measurement(self, x, y):
            ...

    class SquaredError(ErrorMeasure):
        def __init__(self):
            self.error_measurement = lambda x, y: sum((x - y) ** 2) ** (1 / 2)


class ShinyException:
    class SetData(Exception):
        ...

    class DataShape(Exception):
        ...


class ShinyInformation:
    class Information:
        def __init__(self):
            self.error = None
            self.weights = {}
            self.delta = []
            self.iteration = None
            self.epoch = None
            self.time = None


class ShinyRepresentation:
    class Representation:
        @staticmethod
        def display():
            raise NotImplementedError

        @staticmethod
        def stop_condition(information: ShinyInformation.Information):
            return False and information.epoch

    class StandardRepresentation(Representation):
        @staticmethod
        def display(information: ShinyInformation.Information):
            print(
                f"Time:{information.time}\n epoch: {information.epoch} iteration: {information.iteration} error: {information.error}, delta: {information.delta}")

        @staticmethod
        def stop_condition(information: ShinyInformation.Information):
            return information.error < 1


class ShinyLayerOutput:
    def __init__(self):
        self.value = None
        self.delta = None
        self.sending_delta = None

    def __call__(self, *args, **kwargs):
        return self.value


class Shiny:
    @staticmethod
    def imports():
        ...

    def __init__(self,
                 error_measurement=ShinyErrorMeasure.SquaredError,
                 representation=ShinyRepresentation.StandardRepresentation):
        self.layers = []
        self.train_X: np.array = None
        self.train_y: np.array = None
        self.error_measurement = error_measurement()
        self.information = ShinyInformation.Information()
        self.representation = representation

    def add(self, *layers: Union[ShinyLayers, ShinyActivations]):
        for layer in layers:
            self.layers.append(layer)
        return self

    def set_data(self, train_x: np.array, train_y: np.array):
        if len(train_X_shape := train_x.shape) != 2:
            raise ShinyException.DataShape(
                f"your train data is {train_X_shape}D! it must be 2D bro! please be careful with your training data, second dimension must be FEATURES!!!")
        if len(train_y_shape := train_y.shape) != 1:
            raise ShinyException.DataShape(f"your goal data is {train_y_shape}D! it must be 1D bro!")
        self.train_X = train_x
        self.train_y = train_y
        return self

    def represent(self):
        self.representation.display(self.information)

    def train(self, epoch_num=1):
        if self.train_X is None or self.train_y is None:
            raise ShinyException.SetData("you did not set train data! please do it by set_data method")
        input_size = self.train_X.shape[1]
        self.layers.insert(0, ShinyLayers.Input(input_size))
        for layer in self.layers:
            if isinstance(layer, ShinyLayers.Layer):
                layer.weights = 2 * np.random.random((input_size, (input_size := layer.output_size))) - 1
        layer_outputs = [ShinyLayerOutput() for _ in range(len(self.layers))]
        start_time = time()
        for self.information.epoch in range(epoch_num):
            for self.information.iteration in range(len(self.train_X)):
                features, goal = self.train_X[self.information.iteration].reshape(1, -1), self.train_y[
                    self.information.iteration]
                layer_outputs[0].value = features
                for index, layer in enumerate(self.layers[1:], 1):
                    layer_outputs[index].value = layer(layer_outputs[index - 1].value)
                self.information.error = self.error_measurement(layer_outputs[-1].value, goal)
                self.information.delta = layer_outputs[-1].delta = goal - layer_outputs[-1].value
                layer_outputs[-1].sending_delta = layer_outputs[-1].delta.dot(self.layers[-1].weights.T)
                for layer_output_index in range(len(layer_outputs) - 2, 0, -1):
                    if isinstance((layer := self.layers[layer_output_index]), ShinyActivations.Activation):
                        deactivation = getattr(ShinyDeactivations, layer.__class__.__name__.lstrip("_"))()
                        layer_outputs[layer_output_index].sending_delta = layer_outputs[
                                                                              layer_output_index + 1].sending_delta * deactivation(
                            layer_outputs[layer_output_index].value)
                    elif isinstance(self.layers[layer_output_index], ShinyLayers.Layer):
                        layer_outputs[layer_output_index].delta = layer_outputs[layer_output_index + 1].sending_delta
                        layer_outputs[layer_output_index].sending_delta = layer_outputs[layer_output_index].delta.dot(
                            self.layers[layer_output_index].weights.T)
                c_layers = list(filter(lambda l: isinstance(l[0], ShinyLayers.CLayer), zip(self.layers, layer_outputs)))
                for c_layers_index in range(1, len(c_layers)):
                    c_layers[c_layers_index][0].weights += c_layers[c_layers_index][0].learning_rate * c_layers[
                        c_layers_index - 1][1].value.T.dot(c_layers[c_layers_index][1].delta)
                self.information.time = time() - start_time
                self.represent()
                if self.representation.stop_condition(self.information):
                    break
        return self

    def __predict(self, value):
        for layer in self.layers:
            if not isinstance(layer, ShinyLayers.Input):
                value = layer(value)
        return value

    def predict(self, test_x):
        return np.array([self.__predict(data) for data in test_x]).reshape(len(test_x, ))

    def info(self):
        print(self.layers)


train_X = np.array([[1, 0, 1],
                    [0, 1, 1],
                    [0, 0, 1],
                    [1, 1, 1]])

train_Y = np.array([1, 1, 0, 0])
model = Shiny().set_data(train_X, train_Y).add(
    ShinyLayers.Linear(4, learning_rate=0.2),
    ShinyActivations.Relu(),
    ShinyLayers.Linear(1, learning_rate=0.2)
).train(epoch_num=60)
# predict_y = model.predict(test_X)