# Classe MLP

## Importando os módulos

In [9]:
import torch
import torch.nn as nn

## Treinamento

In [11]:
class MLP(nn.Module):

    def __init__(self, n_epochs, n_features, n_hidden, n_classes, activation, learning_rate, patience, seed=None):        # iniciando a rede neural e definindo configurações iniciais
        super(MLP, self).__init__()

        if seed is not None:
            torch.manual_seed(seed)

        self.n_epochs = n_epochs
        self.lr = learning_rate
        self.n_classes = n_classes
        self.patience = patience

        self.best_loss = float('inf')
        self.no_improvement_count = 0
        self.stop_training = False

        # Camada de entrada
        self.input_layer = nn.Linear(n_features, n_hidden[0])

        # Camadas escondidas
        self.hidden_layers = nn.ModuleList()

        for layer in range(len(n_hidden) - 1):
            self.hidden_layers.append(nn.Linear(n_hidden[layer], n_hidden[layer + 1]))

        # Camada de saída
        self.output_layer = nn.Linear(n_hidden[-1], n_classes)
        
        # Função de ativação das camadas escondidas
        self.activation_name = activation

        if activation == "relu":
            self.activation = nn.ReLU()
        elif activation == "tanh":
            self.activation = nn.Tanh()
        else:
            raise ValueError("Função de ativação não reconhecida. Escolha entre 'relu' ou 'tanh'.")
        
    
    def forward(self, x):                                                   # definindo como os dados passam pela rede neural
        x = self.activation(self.input_layer(x))
        
        for layer in self.hidden_layers:
            x = self.activation(layer(x))

        x = self.output_layer(x)

        return x
    

    def check_early_stop(self, val_loss):

        if val_loss < self.best_loss:
            self.best_loss = val_loss
            self.no_improvement_count = 0
        else:
            self.no_improvement_count += 1
            if self.no_improvement_count >= self.patience:
                self.stop_training = True
                print("Parando treinamento por early stopping.")

    
    def compute_accuracy(self, output, y_true, n_classes):
        if n_classes == 1:
            preds = (output > 0).float()
        else:
            preds = torch.argmax(output, dim=1)
            if y_true.ndim > 1:
                y_true = torch.argmax(y_true, dim=1)

        correct = (preds == y_true).float().sum()
        total = y_true.shape[0]

        return (correct / total).item()


    def train_model(self, X_train, X_val, y_train, y_val, optimizer):                                       # definindo a função de treinamento

        # Definindo o otimizador
        if optimizer == "sgd":
            self.optimizer = torch.optim.SGD(self.parameters(), lr = self.lr)
        elif optimizer == "adam":
            self.optimizer = torch.optim.Adam(self.parameters(), lr = self.lr)
        else:
            raise ValueError("Otimizador não reconhecido. Escolha entre 'sgd' ou 'adam'.")
        
        criterion = nn.BCEWithLogitsLoss() if self.n_classes == 1 else nn.CrossEntropyLoss()        # função de perda diferente para classificação binária ou multiclasse 

        loss_train_list = []
        loss_valid_list = []
        accuracy_train_list = []
        accuracy_valid_list = []

        for epoch in range(self.n_epochs):
            self.train()

            # Forward
            output = self.forward(X_train)

            # Loss
            loss = criterion(output, y_train)
            loss_train_list.append(loss.item())

            # Backpropagation
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Validação
            self.eval()
            with torch.no_grad():
                val_output = self.forward(X_val)
                val_loss = criterion(val_output, y_val) 
                loss_valid_list.append(val_loss.item())

            # Acurácia
            accuracy_train = self.compute_accuracy(output, y_train, self.n_classes)
            accuracy_valid = self.compute_accuracy(val_output, y_val, self.n_classes)
            accuracy_train_list.append(accuracy_train)
            accuracy_valid_list.append(accuracy_valid)

            print(f"Epoch {epoch+1}/{self.n_epochs} - Train Loss: {loss.item():.4f} - Val Loss: {val_loss:.4f} - Train Accuracy: {accuracy_train:.4f} - Val Accuracy: {accuracy_valid:.4f}")

            # Early Stopping
            self.check_early_stop(val_loss.item())

            if self.stop_training:
                break


        return loss_train_list, loss_valid_list, accuracy_train_list, accuracy_valid_list
