# Metrics
It is common practice in Lightening to inherent from
 torchmetrics for computation efficiency and competitability with Lightening (https://lightning.ai/docs/torchmetrics/stable/pages/implement.html#implement)

 TODO: Add metrics for other tasks

In [1]:
from typing import Any, List

import torch
import torch.nn.functional as F
from pytorch_lightning import LightningModule
from torchmetrics import Metric
from torchmetrics import MaxMetric, MinMetric
from torchmetrics.classification.accuracy import Accuracy
from torchmetrics.classification.calibration_error import CalibrationError

from typing import Any, List, Literal, Optional, Dict, Callable


# Metrics
class ShannonEntropyError(Metric):
    def __init__(self, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.add_state("entropy_total", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("count", default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, logits: torch.Tensor, is_dist=False):
        p = logits if is_dist else F.softmax(logits, dim=-1)

        self.entropy_total += torch.sum(- p * torch.log(p))
        self.count += logits.shape[0]

    def compute(self):
        return self.entropy_total.float() / self.count.float()

class ClassificationKernelCalibrationError(Metric):
    def __init__(self, dist_sync_on_step=False, **kwargs):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.add_state("kcal_total", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("count", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.kcal_func = ClassificationKernelLoss(**kwargs)

    def update(self, preds: torch.Tensor, target: torch.Tensor, inputs: torch.Tensor, verbose=False):
        kcal = self.kcal_func(inputs, target, preds, verbose=verbose)

        self.kcal_total += kcal
        self.count += 1

    def compute(self):
        return self.kcal_total.float() / self.count

# Loss

TODO: add loss for other tasks and modify ClassificationMixedLoss to adopt DM losses to CE loss with initializing weights

In [2]:
# Kernels Utils

def rbf_kernel(u: torch.Tensor, v: torch.Tensor, bandwidth=1):
    diff_norm_mat = torch.norm(u.unsqueeze(1) - v, dim=2).square()
    return torch.exp(- diff_norm_mat / bandwidth)

def quadrant_partition_kernel(u: torch.Tensor, v: torch.Tensor):
    raise NotImplementedError()


def norm_partition_kernel(u: torch.Tensor, v: torch.Tensor):
    raise NotImplementedError()

def tanh_kernel(u: torch.Tensor, v: torch.Tensor, bandwidth=1):
    out = torch.tanh(v) * torch.tanh(u).unsqueeze(1) # N x N x 1 x num_samples
    return out.squeeze(2)
kernel_funs = {"rbf": rbf_kernel,
               "partition_quadrant": quadrant_partition_kernel,
               "partition_norm": norm_partition_kernel,
               "tanh": tanh_kernel}

VALID_OPERANDS = ['x', 'y', 'p', 'coords']

def mean_no_diag(A):
    assert A.dim() == 2 and A.shape[0] == A.shape[1]
    n = A.shape[0]
    A = A - torch.eye(n).to(A.device) * A.diag()
    return A.sum() / (n * (n - 1))


# Loss
class ClassificationKernelLoss:
    """
        MMD loss function for classification tasks.
        Allows for distribution matching by specifying operands and kernel functions.
        `scalers` and `bandwidths` are the parameters of the kernel functions.
    """
    def __init__(self,
                 operands: Dict[str, str] = {'x': "rbf", 'y': "rbf"},
                 scalers: Optional[Dict] = None,
                 bandwidths: Optional[Dict] = {'x': 0.01, 'y': 1.0}):

        assert all([op in VALID_OPERANDS for op in operands.keys()])

        if scalers is None:
            scalers = {op: 1. for op in operands.keys()}
        else:
            assert all(op in scalers for op in operands.keys())

        self.kernel_fun = {op: kernel_funs[kernel] for op, kernel in operands.items()}
        self.operands = list(operands.keys())
        self.scalers = scalers
        self.bandwidths = bandwidths

    def __call__(self, x, y, logits, verbose=False):
        kernel_out = None
        loss_mats = [None for i in range(3)]

        for op in self.operands:
            scaler = self.scalers[op]
            bandwidth = self.bandwidths[op]
            if op == 'x':
                # This is only true for tabular data. For example, multi-channel images will have 4D batches for x.
                assert x.dim() == 2
                loss_mat = loss_mat2 = loss_mat3 = scaler * self.kernel_fun[op](x, x, bandwidth)
            elif op == 'y':
                # Computes MMD loss for classification (See Section 4.1 of paper)
                num_classes = logits.shape[-1]
                y_all = torch.eye(num_classes).to(logits.device)
                k_yy = self.kernel_fun[op](y_all, y_all, bandwidth)
                q_y = F.softmax(logits, dim=-1)
                q_yy = torch.einsum('ic,jd->ijcd', q_y, q_y)
                total_yy = q_yy * k_yy.unsqueeze(0)

                k_yj = k_yy[:,y].T
                total_yj = torch.einsum('ic,jc->ijc', q_y, k_yj)
                y_one_hot = F.one_hot(y, num_classes=num_classes).float()

                loss_mat = scaler * total_yy.sum(dim=(2,3))
                loss_mat2 = scaler * total_yj.sum(-1)
                loss_mat3 = scaler * self.kernel_fun[op](y_one_hot, y_one_hot, bandwidth)
            else:
                assert False, f"When running classification, operands must be x and y. Got operand {op} instead."

            for i, value in enumerate([loss_mat, loss_mat2, loss_mat3]):
                if loss_mats[i] is None:
                    loss_mats[i] = value
                else:
                    loss_mats[i] =  loss_mats[i] * value

        kernel_out = mean_no_diag(loss_mats[0]) - 2 * mean_no_diag(loss_mats[1]) + mean_no_diag(loss_mats[2])

        return kernel_out

class ClassificationCELoss:
    """
        Cross-entropy loss for classification.
    """
    def __init__(self, **kwargs):
        self.loss = torch.nn.CrossEntropyLoss(**kwargs)

    def __call__(self, x, y, logits):
        return self.loss(logits, y)

class ClassificationMixedLoss:
    """
        Mixed loss function (MMD + NLL) for classification.
        `loss_scalers` determines the mixture weight between MMD and NLL.
    """
    def __init__(self, loss_scalers: Optional[Dict] = None, **kwargs):
        if loss_scalers is None:
            loss_scalers = {"nll": .01, "mmd": 1}
        else:
            assert set(loss_scalers.keys()) == {"nll", "mmd"}
        self.loss_scalers = loss_scalers
        self.nll = torch.nn.CrossEntropyLoss()
        self.mmd = ClassificationKernelLoss(**kwargs)

    def __call__(self, x, y, logits):
        return self.loss_scalers["nll"] * self.nll(logits, y) + self.loss_scalers["mmd"] * self.mmd(x, y, logits)



# Model
This model's input size and output size will be determined by dataset which can be obtained from datamodule


In [3]:
from torch import nn
# Model
class SimpleDenseNet(nn.Module):
    """
        Neural network model for classificaiton.
        Forward() call returns logits.
    """
    def __init__(
        self,
        input_size: int = 784,
        lin1_size: int = 256,
        lin2_size: int = 256,
        lin3_size: int = 256,
        output_size: int = 10,
        use_batchnorm: bool = True,
    ):
        super().__init__()

        if use_batchnorm:
            self.model = nn.Sequential(
                nn.Linear(input_size, lin1_size),
                nn.BatchNorm1d(lin1_size),
                nn.ReLU(),
                nn.Linear(lin1_size, lin2_size),
                nn.BatchNorm1d(lin2_size),
                nn.ReLU(),
                nn.Linear(lin2_size, lin3_size),
                nn.BatchNorm1d(lin3_size),
                nn.ReLU(),
                nn.Linear(lin3_size, output_size),
            )
        else:
            self.model = nn.Sequential(
                nn.Linear(input_size, lin1_size),
                nn.ReLU(),
                nn.Linear(lin1_size, lin2_size),
                nn.ReLU(),
                nn.Linear(lin2_size, lin3_size),
                nn.ReLU(),
                nn.Linear(lin3_size, output_size),
            )
        self.output_size = output_size

    def forward(self, x):
        return self.model(x)


# Lightening Model

TODO:add metrics on line 44

In [4]:
from typing import Any, List

import torch
import torch.nn.functional as F
from pytorch_lightning import LightningModule
from torchmetrics import MaxMetric, MinMetric
from torchmetrics.classification.accuracy import Accuracy
from torchmetrics.classification.calibration_error import CalibrationError
import pdb
from typing import Any, List, Literal, Optional, Dict, Callable
# Core NN Module
class ClassificationLitModule(LightningModule):
    """ LightningModule for Classification tasks.

    A LightningModule organizes your PyTorch code into 5 sections:
        - Computations (init).
        - Train loop (training_step)
        - Validation loop (validation_step)
        - Test loop (test_step)
        - Optimizers (configure_optimizers)

    Read the docs:
        https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
    """

    def __init__(
        self,
        net: torch.nn.Module,
        criterion: Callable,
        calibrator: Callable = None,
        lr: float = 0.001,
        weight_decay: float = 0.0005,
        kcal_kwargs = None
    ):
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # it also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False, ignore=['net'])  # this is needed for efficiency
        self.net = net

        # loss function
        # self.criterion = torch.nn.CrossEntropyLoss()
        self.criterion = criterion
        self.calibrator = calibrator
        task = 'binary' if net.output_size == 2 else 'multiclass'

        # use separate metric instance for train, val and test step
        # to ensure a proper reduction over the epoch
        self.train_acc = Accuracy(task= task)
        self.val_acc = Accuracy(task= task)
        self.test_acc = Accuracy(task= task)

        # Initialize metrics
        assert net.output_size >= 2, f"Must have >=2 classes for classification task. Model only has {net.output_size} classes."
        ece_kwargs = {"task": 'multiclass', "n_bins": 20, "norm": 'l1', "num_classes": net.output_size} # We are always using Multiclass ECE since we are considering binary case as multiclass by 0 as class1 and 1 as class 2
        self.train_ece = CalibrationError(**ece_kwargs)
        self.val_ece = CalibrationError(**ece_kwargs)
        self.test_ece = CalibrationError(**ece_kwargs)

        self.train_entropy = ShannonEntropyError()
        self.val_entropy = ShannonEntropyError()
        self.test_entropy = ShannonEntropyError()

        if kcal_kwargs is None:
            kcal_kwargs = {}
        self.test_kcal = ClassificationKernelCalibrationError(**kcal_kwargs)

        # For logging best validation metrics
        self.val_acc_best = MaxMetric()
        self.val_ece_best = MinMetric()
        self.val_entropy_best = MinMetric()

        # Additional metrics for post-hoc calibration
        if self.calibrator:
            self.test_calibrated_acc = Accuracy(task= task)
            self.test_calibrated_ece = CalibrationError(**ece_kwargs)
            self.test_calibrated_entropy = ShannonEntropyError()
            self.test_calibrated_kcal = ClassificationKernelCalibrationError(**kcal_kwargs)


    def forward(self, x: torch.Tensor):
        return self.net(x)

    def on_train_start(self):
        # by default lightning executes validation step sanity checks before training starts,
        # so we need to make sure val_acc_best doesn't store accuracy from these checks
        self.val_acc_best.reset()
        self.val_ece_best.reset()

    def step(self, batch: Any):
        x, y = batch
        y = y.squeeze(-1)
        logits = self.forward(x)
        loss = self.criterion(x, y ,logits)
        pdb.set_trace()
        # Question should we disable taking gradient for preds and probs since it is only being used on computing metrics?
        preds = torch.argmax(logits, dim=-1)
        probs = F.softmax(logits, dim=-1)

        return loss, preds, logits, y, probs

    def training_step(self, batch: Any, batch_idx: int):
        loss, preds, logits, targets, probs = self.step(batch)

        # log train metrics

        acc = self.train_acc(preds, targets)
        ece = self.train_ece(probs, targets)
        entropy = self.train_entropy(logits)

        self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train/ece", ece, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train/entropy", entropy, on_step=False, on_epoch=True, prog_bar=True)

        # we can return here dict with any tensors
        # and then read it in some callback or in `training_epoch_end()` below
        # remember to always return loss from `training_step()` or else backpropagation will fail!
        return {"loss": loss, "preds": preds, "targets": targets}

    def on_train_epoch_end(self):
        # We need to compute and log metrics in this phase. If we compute metrics during training step, every epoch which is computationally ineficient
        # instead can compute the metrics in the phase to just evaluate the metrics at the end of epoch
        pass

    def validation_step(self, batch: Any, batch_idx: int):

        loss, preds, logits, targets, probs  = self.step(batch)

        # log val metrics
        acc = self.val_acc(preds, targets)
        ece = self.val_ece(probs, targets)
        entropy = self.val_entropy(logits)
        self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/ece", ece, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/entropy", entropy, on_step=False, on_epoch=True, prog_bar=True)

        return {"loss": loss, "preds": preds, "targets": targets}

    def on_validation_epoch_end(self):
        acc = self.val_acc.compute()  # get val accuracy from current epoch
        self.val_acc_best(acc)

        ece = self.val_ece.compute()  # get val accuracy from current epoch
        self.val_ece_best(ece)

        entropy = self.val_entropy.compute()  # get val accuracy from current epoch
        self.val_entropy_best(entropy)

        # log `*_best` metrics as a value through `.compute()` method, instead of as a metric object
        # otherwise metric would be reset by lightning after each epoch
        self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True)
        self.log("val/ece_best", self.val_ece_best.compute(), sync_dist=True, prog_bar=True)
        self.log("val/entropy_best", self.val_entropy_best.compute(), sync_dist=True, prog_bar=True)

    def on_test_epoch_start(self):
        if self.calibrator is None:
            return

        val_x, val_y = self.trainer.datamodule.data_val[:]
        val_x, val_y = val_x.to(self.device), val_y.squeeze(-1).to(self.device)

        with torch.no_grad():
            logits = self.forward(val_x)
            val_pred = F.softmax(logits, dim=-1)

        with torch.enable_grad():
            self.calibrator.train(val_pred, val_y)

    def test_step(self, batch: Any, batch_idx: int):
        loss, preds, logits, targets, probs = self.step(batch)

        # log test metrics
        acc = self.test_acc(preds, targets)
        ece = self.test_ece(probs, targets)
        entropy = self.test_entropy(logits)

        x, _ = batch
        kcal = self.test_kcal(logits, targets, x)

        self.log("test/loss", loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test/ece", ece, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test/entropy", entropy, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test/kcal", kcal, on_step=False, on_epoch=True, prog_bar=True)

        # If post-hoc calibration method is chosen, apply it to model predictions
        if self.calibrator:
            pred_dist = F.softmax(logits, dim=-1)
            with torch.no_grad():
                calibrated_dists = self.calibrator(pred_dist)
                calibrated_preds = torch.argmax(calibrated_dists, dim=-1)

            calibrated_acc = self.test_calibrated_acc(calibrated_preds, targets)
            calibrated_ece = self.test_calibrated_ece(calibrated_dists, targets)
            calibrated_entropy = self.test_calibrated_entropy(calibrated_dists, is_dist=True)
            calibrated_kcal = self.test_kcal(calibrated_dists, targets, x)

            # log post-hoc calibrated test metrics
            self.log(f"test/calibrated_acc", calibrated_acc, on_step=False, on_epoch=True, prog_bar=True)
            self.log(f"test/calibrated_ece ", calibrated_ece , on_step=False, on_epoch=True, prog_bar=True)
            self.log(f"test/calibrated_entropy", calibrated_entropy, on_step=False, on_epoch=True, prog_bar=True)
            self.log(f"test/calibrated_kcal", calibrated_kcal, on_step=False, on_epoch=True, prog_bar=True)

        return {"loss": loss, "preds": preds, "targets": targets}

    def on_test_epoch_end(self):
        pass

    def on_epoch_end(self):
        # Reset metrics at the end of every epoch
        self.train_acc.reset()
        self.test_acc.reset()
        self.val_acc.reset()

        self.train_ece.reset()
        self.test_ece.reset()
        self.val_ece.reset()

        self.train_entropy.reset()
        self.test_entropy.reset()
        self.val_entropy.reset()

        self.test_kcal.reset()

        if self.calibrator:
            self.test_calibrated_acc.reset()
            self.test_calibrated_ece.reset()
            self.test_calibrated_entropy.reset()
            self.test_calibrated_kcal.reset()

    def configure_optimizers(self):
        """Choose what optimizers and learning-rate schedulers to use in your optimization.
        Normally you'd need one. But in the case of GANs or similar you might have multiple.

        See examples here:
            https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
        """
        return torch.optim.Adam(
            params=self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay
        )

# Calibration

In [5]:
# Calibration
def _get_prediction_device(predictions):
    """ Get the device of a prediction

    Args:
        predictions: a prediction of any type.

    Returns:
        device: the torch device that prediction is on.
    """
    if issubclass(type(predictions), torch.distributions.distribution.Distribution):
        with torch.no_grad():
            device = predictions.sample().device    # Trick to get the device of a torch Distribution class because there is no interface for this
    elif issubclass(type(predictions), dict):
        assert len(predictions.keys()) != 0, "Must have at least one element in the ensemble"
        device = _get_prediction_device(predictions[next(iter(predictions))])   # Return the device of the first element in the dictionary
    else:
        device = predictions.device
    return device

class Calibrator:
    """ The abstract base class for all calibrator classes.

    Args:
        input_type (str): the input prediction type.
            If input_type is 'auto' then it is automatically induced when Calibrator.train() or update() is called, it cannot be changed after the first call to train() or update().
            Not all sub-classes support 'auto' input_type, so it is strongly recommended to explicitly specify the prediction type.
    """
    def __init__(self, input_type='auto'):
        self.input_type = input_type
        self.device = None

    def _change_device(self, predictions):
        """ Move everything into the same device as predictions, do nothing if they are already on the same device """
        # print("_change_device is deprecated ")
        device = _get_prediction_device(predictions)
        # device = self.get_device(predictions)
        self.to(device)
        self.device = device
        return device


    def to(self, device):
        """ Move this class and all the tensors it owns to a specified device.

        Args:
            device (torch.device): the device to move this class to.
        """
        assert False, "Calibrator.to has not been implemented"


    def train(self, predictions, labels, *args, **kwargs):
        """ The train abstract class. Learn the recalibration map based on labeled data.

        This function uses the training data to learn any parameters that is necessary to transform a low quality (e.g. uncalibrated) prediction into a higher quality (e.g. calibrated) prediction.
        It takes as input a set of predictions and the corresponding labels.
        In addition, a few recalibration algorithms --- such as group calibration or multicalibration --- can take as input additional side features, and the transformation depends on the side feature.

        Args:
            predictions (object): a batched prediction object, must match the input_type argument when calling __init__.
            labels (tensor): the labels with shape [batch_size]
            side_feature (tensor): some calibrator instantiations can use additional side feature, when used it should be a tensor of shape [batch_size, n_features]

        Returns:
            object: an optional log object that contains information about training history.
        """
        assert False, "Calibrator.train has not been implemented"

    #
    # If half_life is not None, then it is the number of calls to this function where the sample is discounted to 1/2 weight
    # Not all calibration functions support half_life
    def update(self, predictions, labels, *args, **kwargs):
        """ Same as Calibrator.train, but updates the calibrator online with the new data (while train erases any existing data in the calibrator and learns it from scratch)

        Args:
            predictions (object): a batched prediction object, must match the input_type argument when calling __init__.
            labels (tensor): the labels with shape [batch_size]
            side_feature (tensor): some calibrator instantiations can use additional side feature, when used it should be a tensor of shape [batch_size, n_features]

        Returns:
            object: an optional log object that contains information about training history.
        """
        assert False, "Calibrator.update has not been implemented"

    # Input an array of shape [batch_size, num_classes], output the recalibrated array
    # predictions should be in the same pytorch device
    # If side_feature is not None when calling train, it shouldn't be None here either.
    def __call__(self, predictions, *args, **kwargs):
        """ Use the learned calibrator to transform new data.

        Args:
            predictions (prediction object): a batched prediction object, must match the input_type argument when calling __init__.
            labels (tensor): the labels with shape [batch_size]
            side_feature (tensor): some calibrator instantiations can use additional side feature, when used it should be a tensor of shape [batch_size, n_features]

        Returns:
            prediction object: the transformed predictions
        """
        assert False, "Calibrator.__call__ has not been implemented"

    def check_type(self, predictions):
        """ Checks that the prediction has the correct shape specified by input_type.

        Args:
            predictions (prediction object): a batched prediction object, must match the input_type argument when calling __init__.
        """
        if self.input_type == 'point':
            assert len(predictions.shape) == 1, "Point prediction should have shape [batch_size]"
        elif self.input_type == 'interval':
            assert len(predictions.shape) == 2 and predictions.shape[1] == 2, "interval predictions should have shape [batch_size, 2]"
        elif self.input_type == 'quantile':
            assert len(predictions.shape) == 2 or (len(predictions.shape) == 3 and predictions.shape[2] == 2), "quantile predictions should have shape [batch_size, num_quantile] or [batch_size, num_quantile, 2]"
        elif self.input_type == 'distribution':
            # assert hasattr(predictions, 'cdf') and hasattr(predictions, 'icdf'), "Distribution predictions should have a cdf and icdf method"
            assert hasattr(predictions, 'cdf') , "Distribution predictions should have a cdf method"

    def assert_type(self, input_type, valid_types):
        msg = "Input data type not supported, input data type is %s, supported types are %s" % (input_type, " ".join(valid_types))
        assert input_type in valid_types, msg




class TemperatureScaling(Calibrator):
    """ The class to recalibrate a categorical prediction with temperature scaling

    Temeprature scaling is often the algorithm of choice when calibrating predictions from deep neural networks.
    The only learnable parameter --- the temperature parameter $T$ --- is tuned to maximize the log-likelihood of the labels.
    Temperature scaling requires very few samples to train because it only learns a single parameter $T$, yet despite the simplcity,
    empirical results show that temperature scaling achieves low calibration error when applied to deep network predictions.

    Args:
        verbose (bool): if verbose=True print detailed messsages
    """
    def __init__(self, verbose=False):
        super(TemperatureScaling, self).__init__(input_type='categorical')
        self.verbose = verbose
        self.temperature = None

    def train(self, predictions, labels, *args, **kwargs):
        """ Find the optimal temperature with gradient descent.

        Args:
            predictions (tensor): a batch of categorical predictions with shape [batch_size, num_classes]
            labels (tensor): a batch of labels with shape [batch_size]
        """
        # Use gradient descent to find the optimal temperature
        # Can add bisection option in the future, since it should be considerably faster
        self.to(predictions)

        self.temperature = torch.ones(1, 1, requires_grad=True, device=self.device)
        optim = torch.optim.Adam([self.temperature], lr=1e-3)
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode='min', patience=3, threshold=1e-6, factor=0.5)

        log_prediction = torch.log(predictions + 1e-10).detach()

        # Iterate at most 100k iterations, but expect to stop early
        for iteration in range(100000):
            optim.zero_grad()
            adjusted_predictions = log_prediction / self.temperature
            loss = F.cross_entropy(adjusted_predictions, labels)
            loss.backward()
            optim.step()
            lr_scheduler.step(loss)

            # Hitchhike the lr scheduler to terminate if no progress
            if optim.param_groups[0]['lr'] < 1e-6:
                break
            if self.verbose and iteration % 100 == 0:
                print("Iteration %d, lr=%.5f, NLL=%.3f" % (iteration, optim.param_groups[0]['lr'], loss.cpu().item()))

    def __call__(self, predictions, *args, **kwargs):
        """ Use the learned temperature to calibrate the predictions.

        Only use this after calling TemperatureScaling.train.

        Args:
            predictions (tensor): a batch of categorical predictions with shape [batch_size, num_classes]

        Returns:
            tensor: the calibrated categorical prediction, it should have the same shape as the input predictions
        """
        if self.temperature is None:
            print("Error: need to first train before calling this function")
        self.to(predictions)
        log_prediction = torch.log(predictions + 1e-10)
        return torch.softmax(log_prediction / self.temperature, dim=1)

    def to(self, device):
        """ Move all assets of this class to a torch device.

        Args:
            device (device): the torch device (such as torch.device('cpu'))
        """
        device = _get_prediction_device(device)
        if self.temperature is not None:
            self.temperature.to(device)
        self.device = device
        return self

# Data Module

We only use single datamodule and extends it to other dataset defining new dataset loading function to preprocess it

TODO: define new load function and add the function name to classification_load_funs and add the dimension to classification_shapes

In [6]:
# Data Module
from typing import Optional, Tuple
import os
import pandas as pd
import numpy as np
import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split, TensorDataset
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from functools import partial
class ClassificationDataModule(LightningDataModule):
    """Datamodule for classification datasets.

    A DataModule implements 5 key methods:
        - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode)
        - setup (things to do on every accelerator in distributed mode)
        - train_dataloader (the training dataloader)
        - val_dataloader (the validation dataloader(s))
        - test_dataloader (the test dataloader(s))

    This allows you to share a full dataset without explaining how to download,
    split, transform and process the data.

    Read the docs:
        https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html
    """

    def __init__(
        self,
        dataset_name: Optional[str] = None,
        data_dir: str = "data/",
        random_seed: Optional[int] = 1,
        train_val_test_split: Tuple[float, float, float] = (0.7, 0.1, 0.2),
        batch_size: int = 64,
        test_batch_size: int = 128,
        normalize: bool = True,
        num_workers: int = 0,
        pin_memory: bool = False
    ):
        super().__init__()

        assert np.isclose(sum(train_val_test_split), 1), f"Train_val_test_split must sum to 1. Got {train_val_test_split} with sum {sum(train_val_test_split):0.5f}."

        # this line allows to access init params with 'self.hparams' attribute
        self.save_hyperparameters(logger=False)

        self.data_train: Optional[Dataset] = None
        self.data_val: Optional[Dataset] = None
        self.data_test: Optional[Dataset] = None

        self.input_size: int = classification_shapes[dataset_name][0]
        self.label_size: int = classification_shapes[dataset_name][1]

        self.setup()

    @property
    def dataset_type(self) -> str:
        return "regression"

    def prepare_data(self):
        """Download data if needed.

        This method is called only from a single GPU.
        Do not use it to assign state (self.x = y).
        """
        pass

    def setup(self, stage: Optional[str] = None):
        """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.

        This method is called by lightning when doing `trainer.fit()` and `trainer.test()`,
        so be careful not to execute the random split twice! The `stage` can be used to
        differentiate whether it's called before trainer.fit()` or `trainer.test()`.
        """

        # load datasets only if they're not loaded already
        if not self.data_train and not self.data_val and not self.data_test:
            loader_fun = classification_load_funs[self.hparams.dataset_name]
            X, y = loader_fun(self.hparams.data_dir)
            if self.hparams.normalize:
                std = X.std(axis=0)
                zeros = np.isclose(std, 0.)
                X[:, ~zeros] = (X[:, ~zeros] - X[:, ~zeros].mean(axis=0)) / X[:, ~zeros].std(axis=0)
                X[:, zeros] = 0.
            if y.ndim == 1:
                y = y.reshape(-1, 1)
            # Split based on initialized ratio
            dataset = TensorDataset(torch.Tensor(X), torch.Tensor(y).long())
            lengths = [int(len(X) * p) for p in self.hparams.train_val_test_split]
            lengths[-1] += len(X) - sum(lengths)  # fix any rounding errors
            self.data_train, self.data_val, self.data_test = random_split(
                dataset=dataset,
                lengths=lengths,
                generator=torch.Generator().manual_seed(self.hparams.random_seed),
            )

    def train_dataloader(self):
        return DataLoader(
            dataset=self.data_train,
            batch_size=self.hparams.batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=True,
            drop_last=True
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.data_val,
            batch_size=self.hparams.test_batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=False,
            drop_last=True
        )

    def test_dataloader(self):
        return DataLoader(
            dataset=self.data_test,
            batch_size=self.hparams.test_batch_size,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=False,
            drop_last=True
        )

def _load_adult(data_dir):
    """
    Attribute Information:
    The dataset contains 16 columns
    Target filed: Income
    -- The income is divide into two classes: <=50K and >50K
    Number of attributes: 14
    -- These are the demographics and other features to describe a person
    """
    data_file = os.path.join(data_dir, 'classification/adult/adult.data')
    colnames = ["age","workclass","fnlwgt","education","educational-num","marital-status","occupation","relationship","race","gender","capital-gain","capital-loss","hours-per-week","native-country","income"]
    data = pd.read_csv(data_file, header=None, names=colnames, skipinitialspace=True)
    data = data.replace("?", np.nan).dropna()
    category_col =['workclass', 'education','marital-status', 'occupation',
                  'relationship', 'race', 'gender', 'native-country']
    b, c = np.unique(data['income'], return_inverse=True)
    data['income'] = c # turn into binary [0,1]

    def encode_and_bind(original_dataframe, feature_to_encode):
      dummies = pd.get_dummies(original_dataframe[[feature_to_encode]])
      res = pd.concat([original_dataframe, dummies], axis=1)
      res = res.drop([feature_to_encode], axis=1)
      return res

    for feature in category_col:
        data = encode_and_bind(data, feature)

    y = data['income'].to_numpy()
    data = data.drop('income', axis=1)
    X = data.to_numpy().astype(float)
    return X, y


classification_load_funs = {
    "adult": _load_adult}

classification_shapes = {
    "wdbc": (30, 2),
    "adult": (104, 2),
    "heart-disease": (23, 5),
    "online-shoppers": (28, 2),
    "dry-bean": (16, 7)
}


# Utils

In [7]:
import logging
import warnings
from typing import List, Sequence

import pytorch_lightning as pl
import rich.syntax
import rich.tree
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.utilities import rank_zero_only
import wandb
from pytorch_lightning.loggers.logger import Logger ## Used for storing multiple loggers


def get_logger(name=__name__) -> logging.Logger:
    """Initializes multi-GPU-friendly python command line logger."""

    logger = logging.getLogger(name)

    # this ensures all logging levels get marked with the rank zero decorator
    # otherwise logs would get multiplied for each GPU process in multi-GPU setup
    for level in (
        "debug",
        "info",
        "warning",
        "error",
        "exception",
        "fatal",
        "critical",
    ):
        setattr(logger, level, rank_zero_only(getattr(logger, level)))

    return logger


log = get_logger(__name__)


def extras(config: DictConfig) -> None:
    """Applies optional utilities, controlled by config flags.

    Utilities:
    - Ignoring python warnings
    - Rich config printing
    """

    # disable python warnings if <config.ignore_warnings=True>
    if config.get("ignore_warnings"):
        log.info("Disabling python warnings! <config.ignore_warnings=True>")
        warnings.filterwarnings("ignore")

    # pretty print config tree using Rich library if <config.print_config=True>
    if config.get("print_config"):
        log.info("Printing config tree with Rich! <config.print_config=True>")
        print_config(config, resolve=True)


@rank_zero_only
def print_config(
    config: DictConfig,
    print_order: Sequence[str] = (
        "datamodule",
        "model",
        "callbacks",
        "logger",
        "trainer",
    ),
    resolve: bool = True,
) -> None:
    """Prints content of DictConfig using Rich library and its tree structure.

    Args:
        config (DictConfig): Configuration composed by Hydra.
        print_order (Sequence[str], optional): Determines in what order config components are printed.
        resolve (bool, optional): Whether to resolve reference fields of DictConfig.
    """

    style = "dim"
    tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)

    quee = []

    for field in print_order:
        quee.append(field) if field in config else log.info(f"Field '{field}' not found in config")

    for field in config:
        if field not in quee:
            quee.append(field)

    for field in quee:
        branch = tree.add(field, style=style, guide_style=style)

        config_group = config[field]
        if isinstance(config_group, DictConfig):
            branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
        else:
            branch_content = str(config_group)

        branch.add(rich.syntax.Syntax(branch_content, "yaml"))

    rich.print(tree)

    with open("config_tree.log", "w") as file:
        rich.print(tree, file=file)


@rank_zero_only
def log_hyperparameters(
    config: DictConfig,
    model: pl.LightningModule,
    datamodule: pl.LightningDataModule,
    trainer: pl.Trainer,
    callbacks: List[pl.Callback],
    logger: List[Logger],
) -> None:
    """Controls which config parts are saved by Lightning loggers.

    Additionaly saves:
    - number of model parameters
    """

    if not trainer.logger:
        return

    hparams = {}

    # choose which parts of hydra config will be saved to loggers
    hparams["model"] = config["model"]

    # save number of model parameters
    hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
    hparams["model/params/trainable"] = sum(
        p.numel() for p in model.parameters() if p.requires_grad
    )
    hparams["model/params/non_trainable"] = sum(
        p.numel() for p in model.parameters() if not p.requires_grad
    )

    hparams["datamodule"] = config["datamodule"]
    hparams["trainer"] = config["trainer"]

    if "seed" in config:
        hparams["seed"] = config["seed"]
    if "callbacks" in config:
        hparams["callbacks"] = config["callbacks"]

    # send hparams to all loggers
    trainer.logger.log_hyperparams(hparams)


def finish(
    config: DictConfig,
    model: pl.LightningModule,
    datamodule: pl.LightningDataModule,
    trainer: pl.Trainer,
    callbacks: List[pl.Callback],
    logger: List[Logger],
) -> None:
    """Makes sure everything closed properly."""

    # without this sweeps with wandb logger might crash!
    for lg in logger:
        if isinstance(lg, pl.loggers.wandb.WandbLogger):
            import wandb
            wandb.finish()


# Train

In [8]:
# Train
import os
from typing import List, Optional

import hydra
from omegaconf import DictConfig
from pytorch_lightning import (
    Callback,
    LightningDataModule,
    LightningModule,
    Trainer,
    seed_everything,
)
from pytorch_lightning.loggers.logger import Logger ## Used for storing multiple loggers
from pytorch_lightning.utilities import rank_zero_only
import logging

log = get_logger(__name__)


def train(config: DictConfig) -> Optional[float]:
    """Contains the training pipeline. Can additionally evaluate model on a testset, using best
    weights achieved during training.

    Args:
        config (DictConfig): Configuration composed by Hydra.

    Returns:
        Optional[float]: Metric score for hyperparameter optimization.
    """
    # Set seed for random number generators in pytorch, numpy and python.random
    if config.get("seed"):
        seed_everything(config.seed, workers=True)

    # Convert relative ckpt path to absolute path if necessary
    ckpt_path = config.trainer.get("resume_from_checkpoint")
    if ckpt_path and not os.path.isabs(ckpt_path):
        config.trainer.resume_from_checkpoint = os.path.join(
            hydra.utils.get_original_cwd(), ckpt_path
        )
    # Init lightning datamodule
    log.info(f"Instantiating datamodule <{config.datamodule._target_}>")
    datamodule = ClassificationDataModule(dataset_name = config.datamodule.dataset_name, data_dir = config.datamodule.data_dir)


    # Init lightning model
    log.info(f"Instantiating model <{config.model._target_}>")
    config.model.net.input_size = datamodule.input_size
    config.model.net.output_size = datamodule.label_size
    model = ClassificationLitModule(net = SimpleDenseNet(input_size =config.model.net.input_size, output_size = config.model.net.output_size),
                                                         criterion = globals()[config.model.criterion._target_](), calibrator = globals()[config.model.calibrator._target_]())


    # Init lightning callbacks
    callbacks: List[Callback] = []
    if "callbacks" in config:
        for _, cb_conf in config.callbacks.items():
            if "_target_" in cb_conf:
                log.info(f"Instantiating callback <{cb_conf._target_}>")
                callbacks.append(hydra.utils.instantiate(cb_conf))

    # Init lightning loggers
    logger: List[Logger] = []
    if "logger" in config:
        for _, lg_conf in config.logger.items():
            if "_target_" in lg_conf:
                log.info(f"Instantiating logger <{lg_conf._target_}>")
                logger.append(hydra.utils.instantiate(lg_conf))

    # Init lightning trainer
    log.info(f"Instantiating trainer <{config.trainer._target_}>")
    trainer: Trainer = hydra.utils.instantiate(
        config.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
    )

    # Send some parameters from config to all lightning loggers
    log.info("Logging hyperparameters!")
    log_hyperparameters(
        config=config,
        model=model,
        datamodule=datamodule,
        trainer=trainer,
        callbacks=callbacks,
        logger=logger,
    )

    # Train the model
    if config.get("train"):
        log.info("Starting training!")
        trainer.fit(model=model, datamodule=datamodule)

    # Get metric score for hyperparameter optimization
    optimized_metric = config.get("optimized_metric")
    if optimized_metric and optimized_metric not in trainer.callback_metrics:
        raise Exception(
            "Metric for hyperparameter optimization not found! "
            "Make sure the `optimized_metric` in `hparams_search` config is correct!"
        )
    score = trainer.callback_metrics.get(optimized_metric)

    # Test the model
    if config.get("test"):
        ckpt_path = "best"
        if not config.get("train") or config.trainer.get("fast_dev_run"):
            ckpt_path = None
        log.info("Starting testing!")
        trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)

    # Make sure everything closed properly
    log.info("Finalizing!")
    finish(
        config=config,
        model=model,
        datamodule=datamodule,
        trainer=trainer,
        callbacks=callbacks,
        logger=logger,
    )

    # Print path to best checkpoint
    if not config.trainer.get("fast_dev_run") and config.get("train"):
        log.info(f"Best model ckpt at {trainer.checkpoint_callback.best_model_path}")

    # Return metric score for hyperparameter optimization
    return score



In [9]:
%pdb off

Automatic pdb calling has been turned OFF


In [10]:
config = {
    'task_name': 'train',
    'tags': ['dev'],
    'train': True,
    'test': True,
    'ckpt_path': None,
    'seed': 12345,
    'datamodule': {
        '_target_': 'src.datamodules.classification_datamodule.ClassificationDataModule',
        'dataset_name': 'adult',
        'data_dir': '${paths.data_dir}',
        'batch_size': 64,
        'test_batch_size': 8,
        'train_val_test_split': [0.7, 0.1, 0.2],
        'num_workers': 0,
        'pin_memory': False
    },
    'model': {
        '_target_': 'src.models.classification_module.ClassificationLitModule',
        'lr': 0.001,
        'weight_decay': 0.0005,
        'net': {
            '_target_': 'src.models.components.simple_dense_net.SimpleDenseNet',
            'input_size': 1,
            'lin1_size': 256,
            'lin2_size': 256,
            'lin3_size': 256,
            'output_size': 1,
            'use_batchnorm': True
        },
        'criterion': {
            '_target_': 'src.metrics.losses.ClassificationMixedLoss',
            'loss_scalers': {'nll': 1, 'mmd': 0.2, 'sink': 6e-05},
            'operands': {'x': 'rbf', 'y': 'rbf'},
            'scalers': {'x': 1.0, 'y': 1.0},
            'bandwidths': {'x': 10.0, 'y': 0.01}
        },
        'calibrator': {'_target_': 'torchuq.transform.calibrate.HistogramBinning'},
        'kcal_kwargs': {
            'operands': {'x': 'rbf', 'y': 'rbf'},
            'scalers': {'x': 1.0, 'y': 1.0},
            'bandwidths': {'x': 10.0, 'y': 0.01}
        }
    },
    'callbacks': {
        'model_checkpoint': {
            '_target_': 'pytorch_lightning.callbacks.ModelCheckpoint',
            'monitor': 'val/loss',
            'mode': 'min',
            'save_top_k': 1,
            'save_last': True,
            'verbose': False,
            'dirpath': 'checkpoints/',
            'filename': 'epoch_{epoch:03d}',
            'auto_insert_metric_name': False
        },
        'early_stopping': {
            '_target_': 'pytorch_lightning.callbacks.EarlyStopping',
            'monitor': 'val/loss',
            'mode': 'min',
            'patience': 30,
            'min_delta': 0
        },
        'model_summary': {
            '_target_': 'pytorch_lightning.callbacks.RichModelSummary',
            'max_depth': -1
        },
        'rich_progress_bar': {
            '_target_': 'pytorch_lightning.callbacks.RichProgressBar'
        }
    },
    'logger': {
        'wandb': {
            '_target_': 'pytorch_lightning.loggers.wandb.WandbLogger',
            'project': 'mmd_4_11',
            'name': '${name}',
            'save_dir': '.',
            'offline': False,
            'id': None,
            'log_model': False,
            'prefix': '',
            'job_type': 'train',
            'group': '',
            'tags': ['classification', '${name}', 'hparam', '${datamodule.dataset_name}']
        }
    },
    'trainer': {
        '_target_': 'pytorch_lightning.Trainer',
        'min_epochs': 10,
        'max_epochs': 200,
        'gradient_clip_val': 0.5,
        'accelerator': 'cpu'

    },
    'paths': {
        'root_dir': '${oc.env:PROJECT_ROOT}',
        'data_dir': '${paths.root_dir}/data/',
        'log_dir': '${paths.root_dir}/logs/',
        'output_dir': '${hydra:runtime.output_dir}',
        'work_dir': '${hydra:runtime.cwd}'
    },
    'extras': {
        'ignore_warnings': False,
        'enforce_tags': True,
        'print_config': True
    },
    'name': 'classification_mixed',
    'hparams_search': 'classification_mixed_optuna.yaml'
}

# Convert the dictionary to a DictConfig
config = OmegaConf.create(config)

# Modifying the configuration
config.datamodule._target_ = 'ClassificationDataModule'
config.datamodule.dataset_name = 'adult'
config.datamodule.data_dir = 'data'
config.model._target_ = 'ClassificationLitModule'
config.model.net._target_ = 'SimpleDenseNet'
config.model.criterion._target_ = 'ClassificationMixedLoss'
config.model.calibrator._target_ = 'TemperatureScaling'
config.logger.wandb.project = 'DM_Benchmark'

# Assuming train is a function that accepts this config
train(config)


Seed set to 12345
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/local/scratch/a/ko120/miniconda3/envs/dm/lib/python3.12/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /local/scratch/a/ko120/.netrc


Output()

RecursionError: maximum recursion depth exceeded