Skip to content

Commit

Permalink
Adds type hints.
Browse files Browse the repository at this point in the history
  • Loading branch information
dantreiman committed Jun 22, 2022
1 parent 5872afe commit aaea9d8
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions ludwig/utils/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dataclasses import dataclass
from typing import List, Type, Union

import numpy as np
import torch
import torch.nn as nn

Expand Down Expand Up @@ -86,7 +87,7 @@ def __init__(self, n_bins: int = 15):
self.bin_lowers = bin_boundaries[:-1]
self.bin_uppers = bin_boundaries[1:]

def forward(self, logits, one_hot_labels):
def forward(self, logits: torch.Tensor, one_hot_labels: torch.Tensor) -> torch.Tensor:
softmaxes = nn.functional.softmax(logits, dim=1)
confidences, predictions = torch.max(softmaxes, 1)
labels = torch.argmax(one_hot_labels, 1)
Expand All @@ -103,7 +104,7 @@ def forward(self, logits, one_hot_labels):
return ece


@dataclass()
@dataclass
class CalibrationResult:
"""Tracks results of probability calibration."""

Expand All @@ -115,7 +116,9 @@ class CalibrationResult:

class CalibrationModule(nn.Module, ABC):
@abstractmethod
def train_calibration(self, logits, labels) -> CalibrationResult:
def train_calibration(
self, logits: Union[torch.Tensor, np.ndarray], labels: Union[torch.Tensor, np.ndarray]
) -> CalibrationResult:
"""Calibrate output probabilities using logits and labels from validation set."""
return NotImplementedError()

Expand All @@ -129,7 +132,7 @@ class TemperatureScaling(CalibrationModule):
Implementation inspired by https://github.com/gpleiss/temperature_scaling
Args:
Args:
num_classes: The number of classes. Must be 2 if binary is True.
binary: If binary is true, logits is expected to be a 1-dimensional array. If false, logits is a 2-dimensional
array of shape (num_examples, num_classes).
Expand All @@ -142,7 +145,9 @@ def __init__(self, num_classes: int = 2, binary: bool = False):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.temperature = nn.Parameter(torch.ones(1), requires_grad=False).to(self.device)

def train_calibration(self, logits, labels) -> CalibrationResult:
def train_calibration(
self, logits: Union[torch.Tensor, np.ndarray], labels: Union[torch.Tensor, np.ndarray]
) -> CalibrationResult:
logits = torch.as_tensor(logits, dtype=torch.float32, device=self.device)
labels = torch.as_tensor(labels, dtype=torch.int64, device=self.device)
one_hot_labels = nn.functional.one_hot(labels, self.num_classes).float()
Expand Down Expand Up @@ -197,10 +202,10 @@ def eval():
before_calibration_nll, before_calibration_ece, after_calibration_nll, after_calibration_ece
)

def scale_logits(self, logits: torch.Tensor):
def scale_logits(self, logits: torch.Tensor) -> torch.Tensor:
return torch.div(logits, self.temperature)

def forward(self, logits: torch.Tensor):
def forward(self, logits: torch.Tensor) -> torch.Tensor:
"""Converts logits to probabilities."""
scaled_logits = self.scale_logits(logits)
if self.binary:
Expand Down Expand Up @@ -235,12 +240,9 @@ def __init__(self, num_classes: int = 2, off_diagonal_l2: float = 0.01, mu: floa
self.off_diagonal_l2 = off_diagonal_l2
self.mu = off_diagonal_l2 if mu is None else mu

def _safe_ln(self, x):
"""Safe natural log of x (tensor)"""
eps = torch.finfo(torch.float32).eps # ~ 1e-16
return torch.log(torch.clamp(x, eps, 1 - eps))

def train_calibration(self, logits, labels) -> CalibrationResult:
def train_calibration(
self, logits: Union[torch.Tensor, np.ndarray], labels: Union[torch.Tensor, np.ndarray]
) -> CalibrationResult:
logits = torch.as_tensor(logits, dtype=torch.float32, device=self.device)
labels = torch.as_tensor(labels, dtype=torch.int64, device=self.device)
one_hot_labels = nn.functional.one_hot(labels, self.num_classes).float()
Expand Down Expand Up @@ -288,7 +290,7 @@ def eval():
before_calibration_nll, before_calibration_ece, after_calibration_nll, after_calibration_ece
)

def regularization_terms(self):
def regularization_terms(self) -> torch.Tensor:
"""Off-Diagonal and Intercept Regularisation (ODIR).
Described in "Beyond temperature scaling: Obtaining well-calibrated multiclass probabilities with Dirichlet
Expand All @@ -300,9 +302,9 @@ def regularization_terms(self):
bias_vector_loss = self.mu * torch.linalg.vector_norm(self.b, 2)
return bias_vector_loss + weight_matrix_loss

def scale_logits(self, logits: torch.Tensor):
def scale_logits(self, logits: torch.Tensor) -> torch.Tensor:
return torch.matmul(self.w, logits.T).T + self.b

def forward(self, logits: torch.Tensor):
def forward(self, logits: torch.Tensor) -> torch.Tensor:
"""Converts logits to probabilities."""
return torch.softmax(self.scale_logits(logits), -1)

0 comments on commit aaea9d8

Please sign in to comment.