### Multi-Task Learning Model for Battery Health Prediction

This notebook implements a multi-task learning approach for battery health prediction,
including State of Health (SOH), Remaining Useful Life (RUL), and Discharge Capacity Change Rate (DCCR) 
prediction using gradient reversal and auxiliary networks.

The code is for hyperparameter optimization based on Bayesian optimization.

Author: Rong ZHU, Weiwen Peng*

Date: 2025/07/16

In [None]:
import torch
import numpy as np
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
import scipy.io
import matplotlib.pyplot as plt
import optuna
from torch.nn.utils import clip_grad_norm_

# Set device for computation
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# =============================================================================
# Data Loading and Preprocessing
# =============================================================================

# Define batch names for battery data
batch_list = ['batch01', 'batch02', 'batch03', 'batch04', 'batch05', 'batch06', 'batch07',
              'batch08', 'batch09', 'batch10', 'batch11', 'batch12', 'batch13', 'batch14',
              'batch15', 'batch16', 'batch17', 'batch18', 'batch19', 'batch20', 'batch21']

# Load battery data from MATLAB files
CM_data = scipy.io.loadmat('CM_Data1126.mat')
CM_data_aux = scipy.io.loadmat('CM_data_aux_label.mat')

In [None]:
# Initialize arrays for training data
train_feature_all = np.array([])
train_SOH_label_all = np.array([])
train_RUL_label_all = np.array([])
train_RE_label_all = np.array([])

# Extract features and labels from all batches
# Note: extract_feature function needs to be defined separately
for batch_num in range(21):
    for battery_num in range(4):
        # Extract features and labels for current batch and battery
        feature, SOH_label, RUL_label, RE_label = extract_feature(batch_num, battery_num)
        
        # Skip the first battery (battery_num == 0) for some reason
        if battery_num != 0:
            try:
                # Concatenate features and labels
                train_feature_all = np.concatenate((train_feature_all, feature), axis=0)
                train_SOH_label_all = np.concatenate((train_SOH_label_all, SOH_label), axis=0)
                # Note: There's a typo in original code - "train_RUL_labe_alll" should be "train_RUL_label_all"
                train_RUL_label_all = np.concatenate((train_RUL_label_all, RUL_label), axis=0)
                train_RE_label_all = np.concatenate((train_RE_label_all, RE_label), axis=0)
            except:
                # Initialize arrays if this is the first valid data
                train_feature_all = feature
                train_SOH_label_all = SOH_label
                train_RUL_label_all = RUL_label
                train_RE_label_all = RE_label


In [None]:
# Normalize features using min-max normalization
max_train_feature = np.max(train_feature_all, axis=0)
min_train_feature = np.min(train_feature_all, axis=0)
train_feature_all = (train_feature_all - min_train_feature) / (max_train_feature - min_train_feature)

# Split data into training and validation sets
train_feature, val_feature, train_SOH_label, val_SOH_label = train_test_split(
    train_feature_all, train_SOH_label_all, test_size=0.15, random_state=42)

# Split remaining labels (note: inconsistent splitting - should use same indices)
train_RUL_label, val_RUL_label = train_test_split(train_RUL_label_all, test_size=0.15, random_state=42)
train_RE_label, val_RE_label = train_test_split(train_RE_label_all, test_size=0.15, random_state=42)

# Convert numpy arrays to PyTorch tensors
train_x = torch.from_numpy(train_feature).float()
train_SOH = torch.from_numpy(train_SOH_label).float()
train_RUL = torch.from_numpy(train_RUL_label).float()
train_RE = torch.from_numpy(train_RE_label).float()

val_x = torch.from_numpy(val_feature).float()
val_SOH = torch.from_numpy(val_SOH_label).float()
val_RUL = torch.from_numpy(val_RUL_label).float()
val_RE = torch.from_numpy(val_RE_label).float()

# Create datasets and data loaders
train_dataset = TensorDataset(train_x, train_SOH, train_RUL, train_RE)
val_dataset = TensorDataset(val_x, val_SOH, val_RUL, val_RE)

train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=128, shuffle=False)


In [None]:
# =============================================================================
# Model Architecture
# =============================================================================

class GradReverse(torch.autograd.Function):
    """
    Gradient Reversal Layer for Domain Adaptation
    
    This layer acts as an identity function during forward pass but reverses
    the gradient during backward pass, multiplied by a lambda parameter.
    """
    @staticmethod
    def forward(ctx, x, lambd=1.0):
        ctx.lambd = lambd
        return x.view_as(x)
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg().mul(ctx.lambd), None

def grad_reverse(x, lambd=1.0):
    """Convenience function for gradient reversal"""
    return GradReverse.apply(x, lambd)


class MLP(nn.Module):
    """
    Multi-Layer Perceptron with configurable architecture
    
    Args:
        input_dim: Input feature dimension
        output_dim: Output dimension
        layers_num: Number of layers (must be >= 2)
        hidden_dim: Hidden layer dimension
        dropout: Dropout probability
    """
    def __init__(self, input_dim=59, output_dim=1, layers_num=4, hidden_dim=50, dropout=0.01):
        super(MLP, self).__init__()
        
        assert layers_num >= 2, "Number of layers must be greater than or equal to 2"
        
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.layers_num = layers_num
        self.hidden_dim = hidden_dim
        
        # Build network layers
        self.layers = []
        for i in range(layers_num):
            if i == 0:
                # Input layer
                self.layers.append(nn.Linear(input_dim, hidden_dim))
                self.layers.append(nn.ReLU())
            elif i == layers_num - 1:
                # Output layer
                self.layers.append(nn.Linear(hidden_dim, output_dim))
            else:
                # Hidden layers
                self.layers.append(nn.Linear(hidden_dim, hidden_dim))
                self.layers.append(nn.ReLU())
                self.layers.append(nn.Dropout(p=dropout))
        
        self.net = nn.Sequential(*self.layers)
        self._init_weights()
    
    def _init_weights(self):
        """Initialize network weights using Xavier normal initialization"""
        for layer in self.net:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight)
    
    def forward(self, x):
        return self.net(x)


class Solution_u(nn.Module):
    """
    Multi-Task Learning Model for Battery Health Prediction
    
    This model consists of:
    - A shared feature extractor (encoder)
    - Three task-specific predictors for SOH, RUL, and RE
    
    Args:
        feature_extractor_layer: Number of layers in feature extractor
        feature_extractor_hidden_num: Hidden units in feature extractor
        SOH_predictor_layer: Number of layers in SOH predictor
        SOH_predictor_hidden_num: Hidden units in SOH predictor
        RUL_predictor_layer: Number of layers in RUL predictor
        RUL_predictor_hidden_num: Hidden units in RUL predictor
        RE_predictor_layer: Number of layers in RE predictor
        RE_predictor_hidden_num: Hidden units in RE predictor
    """
    def __init__(self, 
                 feature_extractor_layer=3,
                 feature_extractor_hidden_num=60,
                 SOH_predictor_layer=1,
                 SOH_predictor_hidden_num=60,
                 RUL_predictor_layer=1,
                 RUL_predictor_hidden_num=60,
                 RE_predictor_layer=1,
                 RE_predictor_hidden_num=60):
        super(Solution_u, self).__init__()
        
        # Shared feature extractor
        # Note: Need to fix feature.shape[1] - should be passed as parameter
        self.encoder = MLP(
            input_dim=59,  # This should be feature.shape[1]
            output_dim=32,
            layers_num=feature_extractor_layer,
            hidden_dim=feature_extractor_hidden_num,
            dropout=0.01
        )
        
        # Task-specific predictors
        self.SOH_predictor = MLP(
            input_dim=32,
            output_dim=1,
            layers_num=SOH_predictor_layer,
            hidden_dim=SOH_predictor_hidden_num,
            dropout=0.01
        )
        
        self.RUL_predictor = MLP(
            input_dim=32,
            output_dim=1,
            layers_num=RUL_predictor_layer,
            hidden_dim=RUL_predictor_hidden_num,
            dropout=0.01
        )
        
        self.RE_predictor = MLP(
            input_dim=32,
            output_dim=1,
            layers_num=RE_predictor_layer,
            hidden_dim=RE_predictor_hidden_num,
            dropout=0.01
        )
        
        self._init_weights()
    
    def forward(self, x):
        """
        Forward pass through the multi-task model
        
        Args:
            x: Input features
            
        Returns:
            tuple: (SOH_prediction, RUL_prediction, RE_prediction)
        """
        # Extract shared features
        features = self.encoder(x)
        
        # Make task-specific predictions
        SOH = self.SOH_predictor(features)
        RUL = self.RUL_predictor(features)
        RE = self.RE_predictor(features)
        
        return SOH, RUL, RE
    
    def _init_weights(self):
        """Initialize all network weights"""
        for layer in self.modules():
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight)
                nn.init.constant_(layer.bias, 0)
            elif isinstance(layer, nn.Conv1d):
                nn.init.xavier_normal_(layer.weight)
                nn.init.constant_(layer.bias, 0)


class MAPELoss(nn.Module):
    """
    Mean Absolute Percentage Error Loss
    
    MAPE = mean(|y_true - y_pred| / |y_true|)
    """
    def __init__(self):
        super(MAPELoss, self).__init__()
    
    def forward(self, y_pred, y_true):
        # Add small epsilon to avoid division by zero
        epsilon = 1e-8
        loss = torch.abs((y_true - y_pred) / (y_true + epsilon))
        return torch.mean(loss)


class AuxiliaryNet(nn.Module):
    """
    Auxiliary Network for Gradient-based Task Identification
    
    This network takes gradients from different decoders and tries to
    identify which decoder the gradient came from. Used for regularization
    in multi-task learning.
    
    Args:
        grad_dim: Dimension of gradient vectors
        num_decoders: Number of decoders (tasks)
    """
    def __init__(self, grad_dim, num_decoders):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(grad_dim, 64),
            nn.ReLU(),
            nn.Linear(64, num_decoders)
        )
    
    def forward(self, grad_vec):
        return self.net(grad_vec)


def get_decoder_flatdim(decoder):
    """
    Calculate the total number of parameters in a decoder
    
    Args:
        decoder: PyTorch module
        
    Returns:
        int: Total number of parameters
    """
    total_params = 0
    for p in decoder.parameters():
        total_params += p.numel()
    return total_params


def accuracy_cal(y_pred, y_true):
    """
    Calculate various accuracy metrics
    
    Args:
        y_pred: Predicted values
        y_true: True values
        
    Returns:
        tuple: (RMSE, MAE, MAPE, R2)
    """
    # This function needs to be implemented
    # Placeholder implementation
    mse = torch.mean((y_pred - y_true) ** 2)
    rmse = torch.sqrt(mse)
    mae = torch.mean(torch.abs(y_pred - y_true))
    
    # MAPE calculation
    epsilon = 1e-8
    mape = torch.mean(torch.abs((y_true - y_pred) / (y_true + epsilon)))
    
    # R2 calculation
    ss_res = torch.sum((y_true - y_pred) ** 2)
    ss_tot = torch.sum((y_true - torch.mean(y_true)) ** 2)
    r2 = 1 - ss_res / ss_tot
    
    return rmse.item(), mae.item(), mape.item(), r2.item()


In [None]:

# =============================================================================
# Hyperparameter Optimization with Optuna
# =============================================================================

def objective(trial):
    """
    Objective function for Optuna hyperparameter optimization
    
    Args:
        trial: Optuna trial object
        
    Returns:
        float: Objective value to maximize (average R2 score)
    """
    # Suggest hyperparameters
    feature_extractor_layer = trial.suggest_int('feature_extractor_layer', 3, 10)
    feature_extractor_hidden_num = trial.suggest_int('feature_extractor_hidden_num', 10, 100)
    
    SOH_predictor_layer = trial.suggest_int('SOH_predictor_layer', 2, 5)
    SOH_predictor_hidden_num = trial.suggest_int('SOH_predictor_hidden_num', 10, 100)
    
    RUL_predictor_layer = trial.suggest_int('RUL_predictor_layer', 2, 5)
    RUL_predictor_hidden_num = trial.suggest_int('RUL_predictor_hidden_num', 10, 100)
    
    RE_predictor_layer = trial.suggest_int('RE_predictor_layer', 2, 5)
    RE_predictor_hidden_num = trial.suggest_int('RE_predictor_hidden_num', 10, 100)
    
    epochs = trial.suggest_int('epochs', 100, 1800)
    miu = trial.suggest_float('miu', 1e-3, 1, log=True)
    gamma = trial.suggest_float('gamma', 1e-3, 1, log=True)
    lambda_aux = trial.suggest_float('lambda_aux', 1e-3, 1, log=True)
    
    # Create model with suggested hyperparameters
    model = Solution_u(
        feature_extractor_layer=feature_extractor_layer,
        feature_extractor_hidden_num=feature_extractor_hidden_num,
        SOH_predictor_layer=SOH_predictor_layer,
        SOH_predictor_hidden_num=SOH_predictor_hidden_num,
        RUL_predictor_layer=RUL_predictor_layer,
        RUL_predictor_hidden_num=RUL_predictor_hidden_num,
        RE_predictor_layer=RE_predictor_layer,
        RE_predictor_hidden_num=RE_predictor_hidden_num
    )
    model.to(device)
    
    # Calculate dimensions for auxiliary network
    decoder_flat_dims = [
        get_decoder_flatdim(model.SOH_predictor),
        get_decoder_flatdim(model.RUL_predictor),
        get_decoder_flatdim(model.RE_predictor)
    ]
    max_dim = max(decoder_flat_dims)
    
    # Create auxiliary network
    auxnet = AuxiliaryNet(max_dim, 3)
    auxnet.to(device)
    
    # Define loss functions and optimizers
    lr = 0.001
    loss_data = MAPELoss()
    loss_fn = nn.MSELoss()
    aux_criterion = nn.CrossEntropyLoss()
    
    # Separate optimizers for different components
    optimizer_shared = torch.optim.Adam(model.encoder.parameters(), lr=lr)
    optimizer_SOH = torch.optim.Adam(model.SOH_predictor.parameters(), lr=lr)
    optimizer_RUL = torch.optim.Adam(model.RUL_predictor.parameters(), lr=lr)
    optimizer_RE = torch.optim.Adam(model.RE_predictor.parameters(), lr=lr)
    optimizer_aux = torch.optim.Adam(auxnet.parameters(), lr=lr)
    
    # Training loop
    for epoch in range(epochs):
        model.train()
        
        for batch_idx, (x, y, rul, re) in enumerate(train_loader):
            x, y, rul, re = x.to(device), y.to(device), rul.to(device), re.to(device)
            
            # Clear gradients
            optimizer_shared.zero_grad()
            optimizer_SOH.zero_grad()
            optimizer_RUL.zero_grad()
            optimizer_RE.zero_grad()
            optimizer_aux.zero_grad()
            
            # Forward pass
            u, RUL, RE = model(x)
            
            # Calculate task-specific losses
            loss_soh = loss_data(u, y)
            loss_rul = miu * loss_data(RUL, rul)
            loss_re = gamma * loss_fn(RE, re)
            
            # Backward pass for SOH predictor
            loss_soh.backward(retain_graph=True)
            grad_soh_list = []
            for p in model.SOH_predictor.parameters():
                if p.grad is not None:
                    grad_soh_list.append(p.grad.detach().view(-1))
            optimizer_SOH.zero_grad()
            
            # Backward pass for RUL predictor
            loss_rul.backward(retain_graph=True)
            grad_rul_list = []
            for p in model.RUL_predictor.parameters():
                if p.grad is not None:
                    grad_rul_list.append(p.grad.detach().view(-1))
            optimizer_RUL.zero_grad()
            
            # Backward pass for RE predictor
            loss_re.backward(retain_graph=True)
            grad_re_list = []
            for p in model.RE_predictor.parameters():
                if p.grad is not None:
                    grad_re_list.append(p.grad.detach().view(-1))
            optimizer_RE.zero_grad()
            
            # Update main model parameters
            optimizer_shared.step()
            optimizer_SOH.step()
            optimizer_RUL.step()
            optimizer_RE.step()
            
            # Clear gradients for auxiliary network training
            optimizer_shared.zero_grad()
            optimizer_SOH.zero_grad()
            optimizer_RUL.zero_grad()
            optimizer_RE.zero_grad()
            optimizer_aux.zero_grad()
            
            # Helper function to pad gradients to same dimension
            def pad_zero(vec, target_dim):
                cur_dim = vec.shape[1]
                if cur_dim < target_dim:
                    pad_size = target_dim - cur_dim
                    zero_pad = torch.zeros(vec.shape[0], pad_size, device=vec.device)
                    return torch.cat([vec, zero_pad], dim=1)
                else:
                    return vec
            
            # Process gradients for auxiliary network
            def process_decoder_grad(grad_list, decoder_id):
                if len(grad_list) == 0:
                    return None
                
                # Concatenate gradients and pad to max dimension
                grad_vec = torch.cat(grad_list, dim=0).unsqueeze(0).to(device)
                grad_vec = pad_zero(grad_vec, max_dim)
                
                # Apply gradient reversal
                rev_feat = grad_reverse(grad_vec, lambda_aux)
                
                # Forward through auxiliary network
                aux_out = auxnet(rev_feat)
                label_t = torch.tensor([decoder_id], dtype=torch.long, device=device)
                aux_loss_i = aux_criterion(aux_out, label_t)
                aux_loss_i.backward()
                
                return aux_loss_i.item()
            
            # Process gradients from each decoder
            process_decoder_grad(grad_soh_list, 0)  # SOH decoder
            process_decoder_grad(grad_rul_list, 1)  # RUL decoder
            process_decoder_grad(grad_re_list, 2)   # RE decoder
            
            # Update auxiliary network
            optimizer_aux.step()
    
    # Evaluate model on validation set
    model.eval()
    with torch.no_grad():
        val_soh_pred, val_rul_pred, _ = model(val_x.to(device))
        
        # Calculate metrics for SOH and RUL
        val_rmse, val_mae, val_mape, val_r2 = accuracy_cal(val_soh_pred, val_SOH.to(device))
        val_rmse1, val_mae1, val_mape1, val_r21 = accuracy_cal(val_rul_pred, val_RUL.to(device))
    
    # Return average R2 score as objective
    return (val_r2 + val_r21) / 2


def print_progress(study, trial):
    """Callback function to print optimization progress"""
    result = (f"Trial {trial.number} finished with value: {trial.value:.4f} "
              f"and parameters: {trial.params}. "
              f"Best is trial {study.best_trial.number} with value: {study.best_value:.4f}.")
    print(result)

In [None]:
# =============================================================================
# Main Optimization Loop
# =============================================================================

if __name__ == "__main__":
    # Create Optuna study and optimize
    study = optuna.create_study(direction='maximize')
    study.optimize(objective, n_trials=200, callbacks=[print_progress])
    
    # Print best hyperparameters
    print('\nBest trial:')
    trial = study.best_trial
    print(f'  Value: {trial.value:.4f}')
    print('  Params: ')
    for key, value in trial.params.items():
        print(f'    {key}: {value}')