In [36]:
import pandas as pd
import torch
import torch.nn.functional as F

# -----------------------------
# 1️⃣ Load LightGBM predictions
# -----------------------------
df_lgb = pd.read_csv("lightgbm_predictions.csv")  # columns: record_id, y_pred, y_proba

# -----------------------------
# 2️⃣ Get DistilBERT predictions
# -----------------------------
# Use your trained Trainer and val_dataset
predictions = trainer.predict(val_dataset)

# Convert logits to probabilities (binary example)
# If multi-class, use softmax
logits = torch.tensor(predictions.predictions)
probas = F.softmax(logits, dim=1).numpy()  # shape = (num_samples, num_classes)

# Assuming you are doing binary classification
# For multi-class, we'll keep the probability of the predicted class
distilbert_proba = probas[:, 1]  # probability of class 1

# Create DataFrame
val_record_ids = val_dataset.labels.index  # original indices
df_distilbert = pd.DataFrame({
    'record_id': val_record_ids,
    'y_pred': (distilbert_proba > 0.5).astype(int),
    'y_proba': distilbert_proba
})

# -----------------------------
# 3️⃣ Merge predictions
# -----------------------------
df_merged = df_lgb.merge(df_distilbert, on='record_id', suffixes=('_lgb', '_distilbert'))

# -----------------------------
# 4️⃣ Weighted late fusion
# -----------------------------
# Since DistilBERT is weak, give more weight to LightGBM
weight_lgb = 0.7
weight_distilbert = 0.3

df_merged['y_proba_fused'] = (
    df_merged['y_proba_lgb'] * weight_lgb +
    df_merged['y_proba_distilbert'] * weight_distilbert
)

# Final class prediction using 0.5 threshold
df_merged['y_pred_fused'] = (df_merged['y_proba_fused'] > 0.5).astype(int)

# -----------------------------
# 5️⃣ Save fused predictions
# -----------------------------
df_merged.to_csv("fused_predictions.csv", index=False)

print("Late fusion completed. Sample output:")
print(df_merged.head())




Late fusion completed. Sample output:
   record_id  y_pred_lgb  y_proba_lgb  y_pred_distilbert  y_proba_distilbert  \
0        255           1     0.593284                  0            0.010886   
1        267           1     0.562147                  0            0.063890   
2        268           1     0.552658                  0            0.081270   
3        222           0     0.416508                  0            0.097016   
4        246           1     0.568745                  0            0.094964   

   y_proba_fused  y_pred_fused  
0       0.418564             0  
1       0.412670             0  
2       0.411242             0  
3       0.320661             0  
4       0.426610             0  
