In [None]:
import numpy as np

class Regularisation:
    def __init__(self, network, weight=0):
        self.network = network
        self.weight = weight

class L2_regularisation(Regularisation):
    def Apply_L2(self):
        """Returns L2 regularization loss for the given network."""
        L = len(self.network)
        res = 0
        for j in range(L):
            if np.isnan(self.network[j].weight).any():
                print(f"Warning: NaN detected in network weights at layer {j}")
                return 0
            res += 0.5 * np.sum(self.network[j].weight ** 2)
        return res

    def Apply_L2_grad(self, weight):
        """Returns L2 regularization gradient for the given weight matrix/tensor."""
        return 2 * weight

class L1_regularisation(Regularisation):
    def Apply_L1(self):
        """Returns L1 regularization loss for the given network."""
        L = len(self.network)
        res = 0
        for j in range(L):
            if np.isnan(self.network[j].weight).any():
                print(f"Warning: NaN detected in network weights at layer {j}")
                return 0
            res += (1 / 2) * np.sum(np.abs(self.network[j].weight))
        return res

    def Apply_L1_grad(self, weight):
        """Returns L1 regularization gradient for the given weight matrix/tensor."""
        return np.sign(weight)

class ApplyReg(Regularisation):
    def __init__(self, reg_function, network, weight=0):
        self.reg_function = reg_function
        super().__init__(network, weight)

    def do_reg(self):
        if self.reg_function == 'L2':
            return L2_regularisation(self.network).Apply_L2()
        if self.reg_function == 'L1':
            return L1_regularisation(self.network).Apply_L1()
        if self.reg_function == 'L2_d':
            return L2_regularisation(self.network).Apply_L2_grad(self.weight)
        if self.reg_function == 'L1_d':
            return L1_regularisation(self.network).Apply_L1_grad(self.weight)


class CalculateAllLoss:
  def __init__(self, X_train, y_predicted,network, y_train, primary_loss, weight_decay=0, regularisation_fn=None):
    self.y_predicted = y_predicted
    self.y_true = y_train
    self.network = network
    self.X_train = X_train
    self.loss_value = primary_loss
    self.weight_decay = weight_decay
    self.regularisation_fn= regularisation_fn
    self.calc_accuracy_loss()


  def overall_loss(self):
    """
    Calculates the total loss of the network.
    - Total loss value.
    """

    total_loss = self.loss_value

    if self.weight_decay > 0 and self.regularisation_fn:
        regularized_val = ApplyReg(self.regularisation_fn, self.network).do_reg()
        print(f"Reg value: {regularized_val}")
        total_loss += self.weight_decay * regularized_val
    return total_loss




  def calc_accuracy_loss(self):
    """
    Computes the accuracy and loss for a given neural network.
    """

    total_loss = self.loss_value

    if self.weight_decay > 0 and self.regularisation_fn:
        regularized_val = ApplyReg(self.regularisation_fn, self.network).do_reg()
        print(f"Reg value: {regularized_val}")
        total_loss += self.weight_decay * regularized_val


    assert self.X_train.shape[1] == self.y_true.shape[1], "Mismatch in batch size between inputs and labels"


    batch_size = self.X_train.shape[1]
    correct_predictions = np.sum(np.argmax(self.y_predicted, axis=0) == np.argmax(self.y_true, axis=0))

    accuracy = correct_predictions / batch_size

    return accuracy , total_loss

