In [45]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import RobustScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import matplotlib.pyplot as plt

In [46]:
# Load dataset
dataset = pd.read_csv('battery_feature_extracted.csv')
X = dataset.drop(columns=['average_voltage'])
y = dataset['average_voltage']

In [47]:
# Train/val/test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)


In [48]:
# Scaling
scaler = RobustScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_val_scaled = scaler.transform(X_val)
X_test_scaled = scaler.transform(X_test)


In [49]:
# Tensors
X_train_tensor = torch.FloatTensor(X_train_scaled)
y_train_tensor = torch.FloatTensor(y_train.values).unsqueeze(1)
X_val_tensor = torch.FloatTensor(X_val_scaled)
y_val_tensor = torch.FloatTensor(y_val.values).unsqueeze(1)
X_test_tensor = torch.FloatTensor(X_test_scaled)
y_test_tensor = torch.FloatTensor(y_test.values).unsqueeze(1)


In [50]:
# Sample weights
ion_columns = [col for col in X.columns if col.startswith("working_ion_")]
ion_counts = X_train[ion_columns].sum()
ion_weights = 1.0 / ion_counts
ion_weights /= ion_weights.sum()
train_weights = X_train[ion_columns].dot(ion_weights.astype(np.float32))
train_weights_tensor = torch.tensor(train_weights.values.astype(np.float32)).unsqueeze(1)


In [51]:
# Model
class GRUNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1, dropout=0.5):
        super().__init__()
        self.gru = nn.GRU(input_size, hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        _, h_n = self.gru(x)
        return self.fc(h_n[-1])

class TabTransformerWithGRU(nn.Module):
    def __init__(self, num_features, output_size=1, dim_embedding=128, num_heads=2, num_layers=2, gru_hidden_size=128, gru_num_layers=1, gru_dropout=0.5):
        super().__init__()
        self.embedding = nn.Linear(num_features, dim_embedding)
        encoder_layer = nn.TransformerEncoderLayer(d_model=dim_embedding, nhead=num_heads, dropout=0.2, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.gru_network = GRUNetwork(dim_embedding, gru_hidden_size, output_size, gru_num_layers, gru_dropout)

    def forward(self, x):
        x = self.embedding(x).unsqueeze(1)
        x = self.transformer(x)
        return self.gru_network(x)


In [52]:
# Loss
class WeightedCompositeLoss(nn.Module):
    def forward(self, outputs, targets, weights):
        mse = (weights * (outputs - targets) ** 2).mean()
        mae = (weights * torch.abs(outputs - targets)).mean()
        return mse + 0.6 * mae


In [53]:
# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TabTransformerWithGRU(num_features=X_train_tensor.shape[1]).to(device)
criterion = WeightedCompositeLoss()
optimizer = optim.Adam(model.parameters(), lr=0.00075)




In [54]:
# Move tensors
X_train_tensor = X_train_tensor.to(device)
y_train_tensor = y_train_tensor.to(device)
train_weights_tensor = train_weights_tensor.to(device)
X_val_tensor = X_val_tensor.to(device)
y_val_tensor = y_val_tensor.to(device)
X_test_tensor = X_test_tensor.to(device)
y_test_tensor = y_test_tensor.to(device)


In [55]:
# Training loop
training_losses, validation_losses = [], []
for epoch in range(2000):
    model.train()
    optimizer.zero_grad()
    outputs = model(X_train_tensor)
    loss = criterion(outputs, y_train_tensor, train_weights_tensor)
    loss.backward()
    optimizer.step()

    model.eval()
    with torch.no_grad():
        val_outputs = model(X_val_tensor)
        val_loss = criterion(val_outputs, y_val_tensor, torch.ones_like(y_val_tensor))

    training_losses.append(loss.item())
    validation_losses.append(val_loss.item())
    if epoch % 200 == 0:
        print(f"Epoch {epoch} | Train Loss: {loss.item():.4f} | Val Loss: {val_loss.item():.4f}")


Epoch 0 | Train Loss: 0.2747 | Val Loss: 8.0351
Epoch 200 | Train Loss: 0.0116 | Val Loss: 1.0323
Epoch 400 | Train Loss: 0.0088 | Val Loss: 0.8077
Epoch 600 | Train Loss: 0.0078 | Val Loss: 0.7261
Epoch 800 | Train Loss: 0.0060 | Val Loss: 0.6960
Epoch 1000 | Train Loss: 0.0053 | Val Loss: 0.6796
Epoch 1200 | Train Loss: 0.0047 | Val Loss: 0.6560
Epoch 1400 | Train Loss: 0.0042 | Val Loss: 0.6597
Epoch 1600 | Train Loss: 0.0036 | Val Loss: 0.6556
Epoch 1800 | Train Loss: 0.0035 | Val Loss: 0.6387


In [56]:
# Evaluation
model.eval()
with torch.no_grad():
    preds = model(X_test_tensor)
    test_mse = mean_squared_error(y_test_tensor.cpu(), preds.cpu())
    test_mae = mean_absolute_error(y_test_tensor.cpu(), preds.cpu())
    ss_res = torch.sum((y_test_tensor - preds) ** 2)
    ss_tot = torch.sum((y_test_tensor - torch.mean(y_test_tensor)) ** 2)
    r2 = 1 - ss_res / ss_tot

print(f"\nTest MSE: {test_mse:.4f}")
print(f"Test MAE: {test_mae:.4f}")
print(f"Test R²: {r2.item():.4f}")


Test MSE: 0.5510
Test MAE: 0.3703
Test R²: 0.7897


In [57]:
# Add per-ion metrics
X_test_df = X_test.reset_index(drop=True).copy()
X_test_df['true'] = y_test.values
X_test_df['pred'] = preds.cpu().numpy().flatten()

print("\nPer-ion metrics on test set (weighted sample):")
for ion in ion_columns:
    subset = X_test_df[X_test_df[ion] == 1]
    if not subset.empty:
        y_true = subset['true'].values
        y_pred = subset['pred'].values
        mae = mean_absolute_error(y_true, y_pred)
        mse = mean_squared_error(y_true, y_pred)
        r2 = r2_score(y_true, y_pred)
        print(f"{ion.replace('working_ion_', '')}: MAE = {mae:.4f}, MSE = {mse:.4f}, R² = {r2:.4f}")



Per-ion metrics on test set (weighted sample):
Al: MAE = 0.3120, MSE = 0.1532, R² = 0.8668
Ca: MAE = 0.2706, MSE = 0.1560, R² = 0.8737
Cs: MAE = 0.7058, MSE = 0.7185, R² = -0.8962
K: MAE = 0.2122, MSE = 0.0643, R² = 0.9808
Li: MAE = 0.3874, MSE = 0.6845, R² = 0.6776
Mg: MAE = 0.5012, MSE = 1.1019, R² = 0.6874
Na: MAE = 0.3072, MSE = 0.1899, R² = 0.9052
Rb: MAE = 0.2933, MSE = 0.1159, R² = 0.9282
Y: MAE = 0.1726, MSE = 0.0338, R² = 0.9280
Zn: MAE = 0.3614, MSE = 0.2760, R² = 0.6912
