In [1]:
import copy
import random
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt

from DataPreprocess import create_dataloader
from InitializingModule import InitModel
from TrainingAlgorithm import LTS_module, multiclass_weight_tuning, reorganize_module, find_cram_index, evaluate

In [17]:
def learning_mechanism(datapath:str,
                       num_data:int,
                       hidden_size:int,
                       criterion,
                       n:float,
                       epochs:int,
                       loss_threshold:float,
                       eta_threshold:float,
                       l_reg_params:dict,
                       s_reg_params:dict,
                       l_weight_params:dict,
                       s_weight_params:dict):
    '''
    ### Args:
        datapath: Where the data is.
        num_data: The number of different ways separating training and validation data.
        hidden_size: Init hidden layer size.
        criterion: A pytorch loss function.
        epochs: Epochs for weight-tuning module.
        loss_threshold: loss_threshold for weight-tuning module.
        eta_threshold: eta_threshold for weight-tuning module.
        l_reg_params / s_reg_params: Parameters for regularizing module in reorganizing module.
        l_weight_params / s_weight_params: Parameters for weight-tuning module in reorganizing module.
    '''
    
    total_cram = []
    total_param_bigger_than_n = []
    total_weight_tuning = []
    total_n_hidden_node = []
    val_accuracy = []
    train_accuracy = []
    val_loss = []
    train_loss = []
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    for _ in range(num_data):
        
        cramming_times = 0
        param_bigger_than_n_times = 0
        weight_tuning_times = 0
        
        # Separate training set and validation set with random_state
        train_loader, val_loader = create_dataloader(datapath, random_state=random.randint(0, 1440))
        
        N_data = train_loader.dataset.X.shape[0]
        input_size = train_loader.dataset.X.shape[1]
        
        # initialize model
        init_model = InitModel(input_size, hidden_size, 1, device)
        model = init_model.init_module_multi_ReLU_AE(train_loader)
        
        # setting n for LTS module
        n, _ = evaluate(train_loader, model, criterion, device)
        n *= 0.7
        
        while N_data > n_data:
            n_train_loader, n_data = LTS_module(train_loader=train_loader, model=model, criterion=criterion, n=n, device=device)
            
            param_num = sum(p.numel() for p in model.parameters())
            
            if n_data < param_num:
                # Add reorganizing model (longer)
                param_bigger_than_n_times += 1
                model = reorganize_module(model, n_train_loader, val_loader, criterion, l_reg_params, l_weight_params)
                continue
            
            saved_model = copy.deepcopy(model)
            situation, model = multiclass_weight_tuning(n_train_loader, val_loader, epochs, model, criterion, loss_threshold, eta_threshold)
            # model, situation = weight-tune
            
            if situation == 'Acceptable':
                # Add reorganizing model (longer)
                weight_tuning_times += 1
                model = reorganize_module(model, n_train_loader, val_loader, criterion, l_reg_params, l_weight_params)
                continue
            
            model = saved_model
            
            # cramming
            cram_index = find_cram_index(model, n_train_loader, criterion, device)
            model.add_neuron(n_train_loader, cram_index)
            cramming_times += 1
            
            # Add reorganizing model (shorter)
            model = reorganize_module(model, n_train_loader, val_loader, criterion, s_reg_params, s_weight_params)
        
        t_loss, t_accs = evaluate(train_loader, model, criterion, device)
        v_loss, v_accs = evaluate(val_loader, model, criterion, device)
        
        total_cram.append(cramming_times)
        total_param_bigger_than_n.append(param_bigger_than_n_times)
        total_weight_tuning.append(weight_tuning_times)
        total_n_hidden_node.append(model.layer_out.weight.numel())
        train_accuracy.append(t_accs)
        train_loss.append(t_loss)
        val_accuracy.append(v_accs)
        val_loss.append(v_loss)
    
    df = pd.DataFrame([total_cram, total_param_bigger_than_n, total_weight_tuning, total_n_hidden_node, train_accuracy, train_loss, val_accuracy, val_loss],
                      columns=[f'{i}_data' for i in range(num_data)],
                      index=['cram times', 'parameter bigger than n', 'weight tune', 'hidden node', 'train acc', 'train loss', 'val acc', 'val_loss'])
    
    return df