In [1]:
import sys
sys.path.append('..')

In [42]:
# Utility Functions for Inference in a Jupyter Notebook

from typing import Literal, Tuple
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_probability as tfp
import numpy as np
import uncertainty_baselines as ub
from tqdm import tqdm
from src.cifar.wide_resnet_factors import wide_resnet
from src.cifar.label_corrupted_dataset import make_label_corrupted_dataset

tfb = tfp.bijectors

In [31]:
def load_model_checkpoint(model, checkpoint_dir):
    """
    Restores the model weights from the specified checkpoint directory.

    Args:
        model (tf.keras.Model): The model to restore.
        checkpoint_dir (str): Directory containing the saved checkpoint.
    """
    checkpoint = tf.train.Checkpoint(model=model)
    latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
    if latest_checkpoint:
        checkpoint.restore(latest_checkpoint).expect_partial()
        print(f"Model weights restored from {latest_checkpoint}")
    else:
        print(f"No checkpoint found in {checkpoint_dir}")
   
# Load and Preprocess CIFAR-10/100 Dataset
def load_dataset(
    dataset_name: Literal['cifar10', 'cifar100'],
    batch_size=32, 
    data_dir=None,
):
    """
    Loads and preprocesses CIFAR-10/100 datasets.
    
    Args:
        dataset_name (str): The dataset to load, e.g., 'cifar10' or 'cifar100'.
        split (str): Dataset split to load, e.g., 'train', 'test', or 'validation'.
        batch_size (int): Batch size for loading the data.
        data_dir (str): Directory for dataset storage.
        corruption_type (str): For corrupted datasets, specify the corruption type.
        severity (int): For corrupted datasets, specify the corruption severity (1-5).
    
    Returns:
        tf.data.Dataset: A preprocessed dataset.
    """
    
    clean_test_builder = ub.datasets.get(
        dataset_name,
        data_dir=data_dir,
        split=tfds.Split.TEST
    )
    dataset = clean_test_builder.load(batch_size=batch_size)

    return dataset

# Build the Wide ResNet Model
def build_wide_resnet(input_shape=(32, 32, 3), num_classes=10, depth=28, width_multiplier=2, l2=0.0):
    """
    Builds a Wide ResNet model.
    
    Args:
        input_shape (tuple): Input shape of the data.
        num_classes (int): Number of output classes.
        depth (int): Depth of the ResNet.
        width_multiplier (int): Width multiplier for the ResNet.
        l2 (float): L2 regularization parameter.
    
    Returns:
        tf.keras.Model: A compiled Wide ResNet model.
    """
    model = wide_resnet(
        input_shape=input_shape,
        depth=depth,
        width_multiplier=width_multiplier,
        num_classes=num_classes,
        l2=l2,
        version=2,
        num_factors=1,
        no_scale=False
    )
    return model

def build_ub_wide_resnet(input_shape=(32, 32, 3), 
                         depth=28, 
                         width_multiplier=2, 
                         num_classes=10, 
                         l2=0.0, 
                         hps=None, 
                         seed=None):
    """
    Build a Wide ResNet model using `ub.models.wide_resnet`.

    Args:
        input_shape (tuple): Shape of the input images.
        depth (int): Depth of the Wide ResNet.
        width_multiplier (int): Width multiplier for the ResNet.
        num_classes (int): Number of output classes.
        l2 (float): L2 regularization factor.
        hps (dict): Additional hyperparameters.
        seed (int): Random seed for model initialization.

    Returns:
        tf.keras.Model: A Wide ResNet model instance.
    """
    model = ub.models.wide_resnet(
        input_shape=input_shape,
        depth=depth,
        width_multiplier=width_multiplier,
        num_classes=num_classes,
        l2=l2,
        hps=hps,
        seed=seed
    )
    return model

def perform_inference(model, dataset, deterministic=False):
    """
    Performs inference on the dataset using the given model.
    
    Args:
        model (tf.keras.Model): The trained model.
        dataset (tf.data.Dataset): The dataset to perform inference on.
    
    Returns:
        List[Dict]: A list of predictions and true labels.
    """

    if not deterministic:
        locs, scales, all_labels = [], [], []
        for batch in tqdm(dataset):
            images = batch['features']
            labels = batch['labels']
            # Perform a forward pass
            loc, scale = model(images, training=False)  # Adjust based on your model's output
            locs.append(loc.numpy())
            scales.append(scale.numpy())
            all_labels.append(labels.numpy())

        # Combine all batches into single arrays
        locs = tf.concat(locs, axis=0).numpy()
        scales = tf.concat(scales, axis=0).numpy()
        all_labels = tf.concat(all_labels, axis=0).numpy()

        return locs, scales, all_labels

    locs, all_labels = [], []
    for batch in tqdm(dataset):
        images = batch['features']
        labels = batch['labels']
        # Perform a forward pass
        loc = model(images, training=False)  # Adjust based on your model's output
        locs.append(loc.numpy())
        all_labels.append(labels.numpy())

    # Combine all batches into single arrays
    locs = tf.concat(locs, axis=0).numpy()
    all_labels = tf.concat(all_labels, axis=0).numpy()

    return locs, all_labels


# Utility to Display Results
def display_results(predictions, class_names=None):
    """
    Displays predictions and their true labels.

    Args:
        predictions (List[Dict]): A list of predictions and true labels.
        class_names (List[str]): Optional list of class names for better readability.
    """
    for i, result in enumerate(predictions[:5]):  # Display first 5 examples
        pred = result['predicted']
        true = result['true']
        if class_names:
            print(f"Example {i+1}: Predicted={class_names[pred]}, True={class_names[true]}")
        else:
            print(f"Example {i+1}: Predicted={pred}, True={true}")

In [43]:
def evaluate_sgn(
    dataset_name: str,
    checkpoint_dir: str,
    num_classes: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    batch_size = 32

    dataset = load_dataset(dataset_name=dataset_name, batch_size=batch_size)

    model = build_wide_resnet(num_classes=num_classes-1)
    load_model_checkpoint(model, checkpoint_dir)

    locs, scales, labels = perform_inference(model, dataset, deterministic=False)
    return locs, scales, labels

def evaluate_ls(
    dataset_name: str,
    checkpoint_dir: str,
    num_classes: int,
) -> Tuple[np.ndarray, np.ndarray]:
    # Main Execution
    batch_size = 32

    dataset = load_dataset(dataset_name=dataset_name, batch_size=batch_size)

    model = build_ub_wide_resnet(num_classes=num_classes, seed=42)
    load_model_checkpoint(model, checkpoint_dir)

    locs, labels = perform_inference(model, dataset, deterministic=True)
    return locs, labels

In [None]:
locs_ls, labels_ls = evaluate_ls(
    dataset_name='cifar10', 
    checkpoint_dir='/home/baumana1/work/data/sgn_results_wrong/cifar10ls/no_noise',
    num_classes=10,
)

  0%|          | 0/313 [00:00<?, ?it/s]

Model weights restored from /home/baumana1/work/data/sgn_results_wrong/cifar10ls/no_noise/checkpoint-7


 36%|███▋      | 114/313 [00:41<01:21,  2.45it/s]

In [None]:
locs_sgn, scales_sgn, labels_sgn = evaluate_sgn(
    dataset_name='cifar10', 
    checkpoint_dir='/home/baumana1/work/data/sgn_results_wrong/cifar10sgn/no_noise',
    num_classes=10,
)

  0%|          | 0/313 [00:00<?, ?it/s]

Model weights restored from /home/baumana1/work/data/sgn_results_wrong/cifar10ls/no_noise/checkpoint-7


100%|██████████| 313/313 [01:32<00:00,  3.86it/s]


In [38]:
locs.shape, labels.shape

((10000, 10), (10000,))

In [39]:
preds = np.argmax(locs, axis=-1)

In [40]:
acc = (preds == labels).mean()

In [41]:
acc

0.9124

In [12]:
import tensorflow as tf

# CLR inverse function
def clr_inv(p):
    z = tf.math.log(p)
    return z - tf.reduce_mean(z, axis=1, keepdims=True)

# CLR forward function
def clr_forward(z, axis=1):
    return tf.nn.softmax(z, axis=axis)

def helmert_tf(n):
  tensor = tf.ones((n, n))
  H = tf.linalg.set_diag(tf.linalg.band_part(tensor, -1, 0), 1-tf.range(1, n+1, dtype=tf.float32))
  d = tf.range(0, n, dtype=tf.float32) * tf.range(1, n+1, dtype=tf.float32)
  H_full = H / tf.math.sqrt(d)[:, tf.newaxis]
  return H_full[1:]


def ilr_forward(z, axis=-1):
    H = helmert_tf(tf.shape(z)[-1] + 1)
    return clr_forward(z @ H, axis=axis)


def ilr_inv(p):
    z = clr_inv(p)
    H = helmert_tf(tf.shape(p)[-1])
    return z @ tf.linalg.matrix_transpose(H)

# Test the CLR functions
test_input = tf.constant([[0.2, 0.3, 0.5], [0.1, 0.1, 0.8]], dtype=tf.float32)

# Apply forward and inverse transformations
gaussian_result = ilr_inv(test_input)
backtr_result = ilr_forward(gaussian_result)

print("gaussian_result:", gaussian_result.numpy())
print("backtr_result:", backtr_result.numpy())

gaussian_result: [[-2.8670713e-01 -5.8261782e-01]
 [ 1.1100184e-08 -1.6978569e+00]]
backtr_result: [[0.2 0.3 0.5]
 [0.1 0.1 0.8]]
