In [None]:
import torch
from sklearn.model_selection import train_test_split
from sklearn.calibration import CalibratedClassifierCV
import torch.optim as optim
import torch.nn as nn

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

# Get Data and Base Model

In [None]:
dataset_name = 'CLE

## Platt Scaling

In [None]:
platt_calibrated_model = CalibratedClassifierCV(base_model, method='sigmoid', cv='prefit')
platt_calibrated_model.fit(X_cal, y_cal)  # Fit calibration on validation set

## Isotonic Regression

In [None]:
isotonic_calibrated_model = CalibratedClassifierCV(base_model, method='isotonic', cv='prefit') 
isotonic_calibrated_model.fit(X_cal, y_cal)

## Temperature Scaling

In [None]:
class TemperatureScaling(nn.Module):
    """A simple module for temperature scaling."""
    def __init__(self, base_model):
        super(TemperatureScaling, self).__init__()
        self.temperature = nn.Parameter(torch.ones(1) * 1.0)  # Initial temperature is set to 1.0
        self.base_model = base_model

    def forward(self, x, return_logits=True):
        x = x.to(device)
        logits = self.base_model(x, return_logits=True)
        scaled_logits = logits / self.temperature
        if return_logits:
            return scaled_logits
        else:
            return F.softmax(calibrated_logits, dim=1)

def train_temperature_scaling(base_model, X_cal, y_cal):
    """Train temperature scaling using negative log-likelihood. """
    
    # Initialize the temperature scaling model and move it to the appropriate device
    temperature_model = TemperatureScaling(base_model).to(device)
    optimizer = optim.LBFGS([temperature_model.temperature], max_iter=50, line_search_fn="strong_wolfe")
    
    # Move inputs and labels to the specified device
    X_cal = torch.FloatTensor(X_cal).to(device)
    y_cal = torch.LongTensor(y_cal).to(device)
    
    # Use cross-entropy as the loss function
    criterion = nn.CrossEntropyLoss()
    
    # Closure function for the optimizer
    def closure():
        optimizer.zero_grad()  # Clear gradients
        scaled_logits = temperature_model(X_cal, return_logits=True)  # Scale the logits using the current temperature
        loss = criterion(scaled_logits, y_cal)  # Calculate cross-entropy loss
        loss.backward()  # Backpropagate
        return loss
    
    # Perform optimization step
    optimizer.step(closure)
    
    # Print the optimal temperature value
    print(f"Optimal temperature: {temperature_model.temperature.item():.4f}")
    
    return temperature_model


temperature_calibrated_model = train_temperature_scaling(base_model, X_cal, y_cal)