In [None]:
import os

new_directory = '/home/franciscoperez/Documents/GitHub/CNN-PELSVAE2/cnn-pels-vae/'
os.chdir(new_directory)

import torch
import torch.nn as nn
import torch.nn.init as init
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F


from src.cnn.focalloss import  FocalLossMultiClass as focal_loss
import src.utils as utils 
from src.cnn.training_cnn import initialize_masks, train_one_epoch_alternative, create_dataloader, setup_torch_environment, initialize_optimizers
from src.sampler.getbatch import SyntheticDataBatcher

from src.utils import get_data
import yaml 
import numpy as np
from typing import Union, Tuple, Optional, Any, Dict, List

from sklearn.utils.class_weight import compute_class_weight
import torch.optim as optim

with open('src/configuration/regressor.yaml', 'r') as file:
    config_file: Dict[str, Any] = yaml.safe_load(file)
vae_model: str =   config_file['model_parameters']['ID']  
data_sufix: str =   config_file['model_parameters']['sufix_path']  

with open('src/configuration/nn_config.yaml', 'r') as file:
    nn_config = yaml.safe_load(file)
    
PP = utils.load_pp_list(vae_model)
prior = False
create_samples = True
wandb_active = False
N_LAYERS = 3
opt_method= "oneloss"

In [None]:
class CNN(nn.Module):
    """
    Convolutional Neural Network (CNN) for processing light curves.
    
    Attributes:
        layers (int): Number of convolutional layers in the network.
        conv1, conv2, ..., conv4 (nn.Conv1d): Convolutional layers of the network.
        bn1, bn2, ..., bn4 (nn.BatchNorm1d): Batch normalization layers.
        pool1, pool2, ..., pool4 (nn.MaxPool1d): Pooling layers to reduce spatial dimensions.
        fc1 (nn.Linear): Fully connected layer to map features to intermediate representation.
        fc2 (nn.Linear): Final fully connected layer to map intermediate representation to class scores.
    
    Parameters:
        num_classes (int): Number of classes in the output prediction. Default is 2.
        layers (int): Number of convolutional layers to use (2 to 4). Default is 2.
        kernel_size (int): Size of the convolutional kernel. Default is 6.
        stride (int): Stride of the convolution operation. Default is 1.
    
    Methods:
        forward(x): Defines the forward pass of the CNN.
    """

    def __init__(self, num_classes: int = 2, layers = 2, 
                 kernel_size = 6, stride = 1, loss_function='focalLoss') -> None:
        """
        Initialize the CNN model with the given parameters.
        """
        super(CNN, self).__init__()

        self.layers = layers
        self.loss_function = loss_function
        self.conv1 = nn.Conv1d(in_channels=2, out_channels=16, kernel_size=kernel_size, 
                               stride=stride, padding=int(kernel_size/2), 
                               padding_mode='replicate', groups=2)

        init.xavier_uniform_(self.conv1.weight)  

        self.bn1 = nn.BatchNorm1d(16)
        self.pool1 = nn.MaxPool1d(3)
        
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=kernel_size,
                               stride=stride, padding=int(kernel_size/2), 
                               padding_mode='replicate', groups=2)

        init.xavier_uniform_(self.conv2.weight)  

        self.bn2 = nn.BatchNorm1d(32)
        self.pool2 = nn.MaxPool1d(3)

        if self.layers > 2: 
            self.conv3 = nn.Conv1d(in_channels=32, out_channels=64, 
                                   kernel_size=kernel_size,
                                   stride=stride, padding=int(kernel_size/2), 
                                   padding_mode='replicate', groups=2)

            init.xavier_uniform_(self.conv3.weight)  
            self.bn3 = nn.BatchNorm1d(64)
            self.pool3 = nn.MaxPool1d(3)  
        
        if self.layers > 3: 
            self.conv4 = nn.Conv1d(in_channels=64, out_channels=128, 
                                   kernel_size=kernel_size, 
                                   stride=stride, padding=int(kernel_size/2), 
                                   padding_mode='replicate', groups=2)

            init.xavier_uniform_(self.conv4.weight)  
            self.bn4 = nn.BatchNorm1d(128)
            self.pool4 = nn.MaxPool1d(3)  

        if self.layers == 2:
            self.fc1 = nn.Linear(1056, 200)
            init.xavier_uniform_(self.fc1.weight)  
        elif self.layers == 3:
            self.fc1 = nn.Linear(704, 200)
            init.xavier_uniform_(self.fc1.weight) 
        elif self.layers == 4:
            self.fc1 = nn.Linear(512, 200)
            init.xavier_uniform_(self.fc1.weight) 
        
        self.fc2 = nn.Linear(200, num_classes)
        init.xavier_uniform_(self.fc2.weight)  


    def forward(self, x):
        """
        Forward pass of the CNN.

        Parameters:
            x (Tensor): The input data tensor with shape (batch_size, channels, length).

        Returns:
            Tensor: The output tensor with shape (batch_size, num_classes).
        """
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.pool1(x)        

        
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.pool2(x)

        if self.layers == 3: 
            x = self.conv3(x)
            x = self.bn3(x)
            x = F.relu(x)
            x = self.pool3(x)
            
        if self.layers == 4: 
            x = self.conv3(x)
            x = self.bn3(x)
            x = F.relu(x)
            x = self.pool3(x)
            
            x = self.conv4(x)
            x = self.bn4(x)
            x = F.relu(x)
            x = self.pool4(x)
        
        x = x.view(x.size(0), -1)  
        x = self.fc1(x)
        x = F.relu(x)
        
        x = self.fc2(x)
        
        if (self.loss_function=='NLLLoss') or (self.loss_function=='focalLoss'):
            return F.log_softmax(x, dim=1)  
        else: 
            return x
        
def setup_model(num_classes: int, show_architecture: bool = True) -> nn.Module:
    """
    Setup and initialize the CNN model with the specified number of output classes and 
    configuration.

    Parameters:
        num_classes (int): Number of classes for the final output layer of the CNN.
        device (torch.device): The device (CPU or GPU) where the model 
        should be allocated. show_architecture (bool): If True, print the 
        architecture of the model. Default is True.

    Returns:
        nn.Module: The initialized CNN model, potentially wrapped in a nn.DataParallel 
        module if multiple GPUs are available.

    This function loads configuration from a YAML file, initializes a CNN model accordin
    to this configuration, and moves the model to the specified device. If multiple GPUs 
    are available, it wraps the model in a nn.DataParallel module to enable parallel 
    processing.
    """
    # Load neural network configuration from YAML files
    nn_config = load_yaml_files(nn_config=True, regressor=False)

    print('----- model setup --------')
    # Initialize the CNN model with parameters from the configuration file
    model = CNN(num_classes=num_classes, layers=nn_config['training']['layers'], loss_function=nn_config['training']['loss'])

    # Move the model to the specified device (CPU or GPU)
    if torch.cuda.is_available():
        model = model.to(device)

    # If more than one GPU is available, use DataParallel for parallel processing
    if torch.cuda.device_count() > 1:
        print("Using", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)
    
    # Optionally print the model architecture
    if show_architecture:
        print("Model Architecture:")
        print(model)

    return model

def load_yaml_files(nn_config: bool = True, regressor: bool = True):
    """
    Load configuration data from YAML files based on the specified options.

    Parameters:
        nn_config (bool): Flag indicating whether to load the neural network 
                          configuration file. Default is True.
        regressor (bool): Flag indicating whether to load the regressor configuration 
                          file. This flag is only considered if `nn_config` is also True. 
                          Default is True.

    Returns:
        A tuple containing the loaded configurations as dictionaries.
        - If both `nn_config` and `regressor` are True, returns a tuple 
          with both configurations.
        - If only `nn_config` is True, returns a single-element tuple with the neural 
          network configuration.
        - Returns None if `nn_config` is False.

    This function reads configuration settings from 'src/configuration/nn_config.yaml' and optionally from
    'src/configuration/regressor.yaml'. The returned configurations are used to set up and 
    customize the behavior of neural network models and training processes.
    """
    if nn_config and regressor:
        with open('src/configuration/nn_config.yaml', 'r') as file:
            nn_config_dict = yaml.safe_load(file)

        with open('src/configuration/regressor.yaml', 'r') as file:
            regressor_dict = yaml.safe_load(file)
        
        print('------ Data loading -------------------')
        print('mode: ', nn_config_dict['data']['mode_running'], nn_config_dict['data']['sample_size'])

        return nn_config_dict, regressor_dict

    elif nn_config:
        with open('src/configuration/nn_config.yaml', 'r') as file:
            nn_config_dict = yaml.safe_load(file)

        return nn_config_dict
    else: 
        raise Exception("Files were not loaded, please check function arguments")

In [None]:
nn_config, config_file = load_yaml_files(nn_config=True, regressor=True)

vae_model: str = config_file['model_parameters']['ID']

x_train, x_test, y_train, y_test, x_val, y_val, \
label_encoder, y_train_labeled, y_test_labeled = utils.get_data(nn_config['data']['sample_size'], 
                                                          nn_config['data']['mode_running'])

class_weights, num_classes, _  = get_counts_and_weights_by_class(y_train_labeled, 
                                                    y_test_labeled, x_train)

model = setup_model(num_classes, device)
wset.setup_gradients(wandb_active, model)

training_data = utils.move_data_to_device((x_train, y_train), device)
val_data = utils.move_data_to_device((x_val, y_val), device)
testing_data = utils.move_data_to_device((x_test, y_test), device)

best_val = np.iinfo(np.int64).max
harder_samples = True
no_improvement_count, counter, weight_f1_score_hyperparameter_search  = 0, 0, 0
train_loss_values, val_loss_values, train_accuracy_values, \
                                    val_accuracy_values  = [], [], [], []

nn_config, config_file = wset.cnn_hyperparameters(wandb_active, hyperparam_opt, 
                                                  nn_config, config_file)

train_dataloader = create_dataloader(training_data, nn_config['training']['batch_size'])
val_dataloader = create_dataloader(val_data, nn_config['training']['batch_size'])
test_dataloader = create_dataloader(testing_data, nn_config['training']['batch_size'])

criterion, criterion_synthetic_samples = get_criterion(nn_config, class_weights)

beta_actual = nn_config['training']['beta_initial']

optimizer1, optimizer2, locked_masks, \
            locked_masks2 = initialize_optimizers(model, nn_config_dict = nn_config)

batcher = SyntheticDataBatcher(pp = pp, vae_model=vae_model, 
                              n_samples=nn_config['training']['synthetic_samples_by_class'],
                            seq_length = x_train.size(-1), prior=prior)

for epoch in range(nn_config['training']['epochs']):
    print(nn_config['training']['opt_method'], create_samples, harder_samples, 
        counter, nn_config['training']['ranking_method'])

    if (nn_config['training']['opt_method']=='twolosses' 
        and create_samples and harder_samples):
        dict_priorization = {}

        if (nn_config['training']['ranking_method']=='no_priority') or (epoch < 2):
            synthetic_data_loader = batcher.create_synthetic_batch(b=beta_actual, 
                                            wandb_active=wandb_active, 
                                            n_oversampling=nn_config['training']['n_oversampling'])

        elif nn_config['training']['ranking_method']=='proportion':
            ranking, proportions = get_dict_class_priorization(model, 
                                                    train_dataloader, 
                                                    ranking_method = 
                                                    nn_config['training']['ranking_method'])


            proportions = ((proportions - np.min(proportions))/
                          (np.max(proportions) - np.min(proportions))*16 + 8)

            counter2 = 0

            for o in ranking:
                dict_priorization[label_encoder[o]] =  int(proportions[counter2])
                counter2 = counter2 + 1

            synthetic_data_loader = batcher.create_synthetic_batch(b=beta_actual, 
                                            wandb_active=wandb_active, 
                                            samples_dict = dict_priorization, 
                                            n_oversampling=nn_config['training']['n_oversampling'])            
        else:
            ranking, _ = get_dict_class_priorization(model, train_dataloader, 
                                                    ranking_method = 
                                                    nn_config['training']['ranking_method'])

            ranking_penalization = 1.25
            for o in ranking:
                objects = nn_config['training']['synthetic_samples_by_class']*ranking_penalization
                dict_priorization[label_encoder[o]] =  int(objects)
                if ranking_penalization>0.5:
                    ranking_penalization = ranking_penalization/1.25

            synthetic_data_loader = batcher.create_synthetic_batch(b=beta_actual, 
                                            wandb_active=wandb_active, 
                                            samples_dict = dict_priorization,
                                            n_oversampling = nn_config['training']['n_oversampling'])

        beta_actual = 0.85 + 0.15 * np.exp(-0.1 * epoch)
        harder_samples = False

    elif  nn_config['training']['opt_method']=='twolosses' and create_samples: 
        print("Using available synthetic data")
    else:
        print("Skipping synthetic sample creation")
        synthetic_data_loader = None

    running_loss, model, val_loss = train_one_epoch_alternative(model, criterion, 
                                    optimizer1, train_dataloader, val_dataloader, device,
                                    mode = nn_config['training']['opt_method'], 
                                    criterion_2= criterion_synthetic_samples, 
                                    dataloader_2 = synthetic_data_loader,
                                    optimizer_2 = optimizer2, locked_masks2 = locked_masks2,
                                    locked_masks = locked_masks, 
                                    repetitions = nn_config['training']['repetitions'])