In [1]:
import fastbook

fastbook.setup_book()


In [2]:
import torch

gpu = torch.device("mps")
cpu = torch.device("cpu")
device = gpu


## Gather MNIST data

### Download dataset

In [None]:
from pathlib import Path

from fastai.data.external import URLs, untar_data

path = untar_data(URLs.MNIST)
print(f"MNIST data downloaded to {path}")
# Path.BASE_PATH = path


### Define how to access images from training and testing datasets

In [4]:
from enum import Enum
from typing import List, Literal

Digit = Literal[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


class DataSplit(Enum):
    TRAINING = "training"
    TESTING = "testing"


def get_digit_file_paths(digit: Digit, datasplit: DataSplit) -> List[Path]:
    return list((path / datasplit.value / str(digit)).ls().sorted())


#### Peek into the images before turning them into tensors

In [None]:
from PIL import Image

im4 = Image.open(get_digit_file_paths(4, DataSplit.TRAINING)[0])
print(f"Image shape is {im4.shape}")
im4


#### Images as tensors

In [6]:
from functools import cache

from fastai.torch_core import tensor


@cache
def get_digit_tensors(digit: Digit, datasplit: DataSplit) -> List[tensor]:
    """Gets all tensor images for the given digit in the specified DataSet."""
    return [
        tensor(Image.open(path), device=device)
        for path in get_digit_file_paths(digit, datasplit)
    ]


In [None]:
from fastai.torch_core import show_image

show_image(get_digit_tensors(3, DataSplit.TRAINING)[1])


## Utilities

### Ensuring tensor shape

In [8]:
from typing import Final, Tuple

from torch import Tensor

IMAGE_SHAPE: Final[Tuple[int, int]] = (28, 28)


class EmptyInputError(Exception):
    pass


class TensorShapeError(Exception):
    def __init__(self, message: str, offending_tensor: Tensor):
        self.offending_tensor = offending_tensor
        super().__init__(message)


def ensure_shape(tensor: Tensor, expected: Tuple[int, ...]) -> tensor:
    """Validates tensor shape matches expected dimensions. Use -1 for 'any'."""
    if len(tensor.shape) != len(expected):
        raise TensorShapeError(
            "Tensor shape and expected shape must have the "
            "same number of dimensions. "
            f"Got {len(tensor.shape)} and {len(expected)}.",
            tensor,
        )
    for actual, expected in zip(tensor.shape, expected):
        if expected == -1:
            continue
        if actual != expected:
            raise TensorShapeError(
                f"Expected shape {expected}, got {tensor.shape}", tensor
            )
    return tensor


### Adjusting pixel values to 0...1 floats

In [9]:
def normalize_pixel_data(input: Tensor) -> Tensor:
    """Probably not a good name. Takes a tensor of image data with pixel values
    between 0 and 255 and returns a tensor with float values between 0 and 1"""
    return input.float() / 255


### Stacking tensor images

In [10]:
from typing import Sequence


def stack_image_tensors(image_tensors: Sequence[tensor]) -> Tensor:
    """Takes a sequence of MNIST digit image tensors and returns a single tensor
    whose first dimension refers to each of the images."""

    if len(image_tensors) == 0:
        raise EmptyInputError(
            "The input sequence of tensors must have at least one element."
        )

    for t in image_tensors:
        ensure_shape(t, IMAGE_SHAPE)

    # It may be silly to wrap this in a function, but doing so adds some semantics
    # and checks that the input has the expected shape.
    stacked = torch.stack(image_tensors)

    # This should never raise, so it just documents what I'm expecting.
    ensure_shape(stacked, (-1,) + IMAGE_SHAPE)
    return stacked


@cache
def get_stacked_preprocessed_digits(digit: Digit, datasplit: DataSplit) -> Tensor:
    """Gets all images for a given digit in the specified datasplit, stacked in a
    single tensor and normalized."""
    print(
        f"get_stacked_preprocessed_digits is running for digit: {digit}, split: {datasplit}"
    )
    return normalize_pixel_data(
        stack_image_tensors(get_digit_tensors(digit, datasplit))
    )


In [11]:
# | test
def test_stacked_threes():
    stacked_training_threes = stack_image_tensors(
        get_digit_tensors(3, DataSplit.TRAINING)
    )
    # A stacked tensor of threes is rank-3 with ~6k images, each of which is 28 x 28 pixels.
    assert stacked_training_threes.shape == (6131,) + IMAGE_SHAPE


In [None]:
test_stacked_threes()
show_image(stack_image_tensors(get_digit_tensors(3, DataSplit.TRAINING))[1])


## Baseline: Pixel Similarity

### Calculating the mean digit images

In [13]:
def mean_image(stacked_images: Tensor) -> Tensor:
    """Calculates the "ideal" digit image, composed of the mean pixel values from
    each of the images in the first (0) dimension in the `stacked_images` argument."""
    ensure_shape(stacked_images, (-1,) + IMAGE_SHAPE)
    mean = torch.mean(
        stacked_images, 0, dtype=torch.float32
    )  # why can't pytorch infer the dtype?
    ensure_shape(mean, IMAGE_SHAPE)
    return mean


@cache
def mean_digit_image(digit: Digit, datasplit: DataSplit) -> Tensor:
    """Gets an averaged image for a given digit in a specified DataSet."""
    return mean_image(get_stacked_preprocessed_digits(digit, datasplit))


def mean_digit_images(datasplit: DataSplit) -> Tensor:
    """Returns the averaged image for each digit for the specified DataSet."""
    return torch.stack([mean_digit_image(digit, datasplit) for digit in range(0, 10)])


In [None]:
for digit in range(0, 10):
    show_image(mean_digit_images(DataSplit.TRAINING)[digit])


### Calculating how close a given digit image is from the "ideal", or mean image for that digit

* Mean absolute difference, or L1 norm, is the mean of the absolute value of the differences between pixels.
* Root mean square error, RMSE, or L2 norm, takes the mean of the squares of the differences and then the square root of that mean.
  * this one penalizes larger mistakes more and smaller mistakes less

In [15]:
def l1_norm(candidate: Tensor, ideal: Tensor) -> Tensor:
    """Calculates the mean of the absolute value of the differences between
    pixels of the candidate and the "ideal" image tensors.
    Returns a scalar wrapped in a rank-0 tensor."""
    ensure_shape(candidate, IMAGE_SHAPE)
    ensure_shape(ideal, IMAGE_SHAPE)
    result = (candidate - ideal).abs().mean()
    ensure_shape(result, ())
    return result


def l2_norm(candidate: Tensor, ideal: Tensor) -> Tensor:
    """Calculates the root of the mean of the squares of the differences
    (that is diff_tensor -> square -> mean_scalar -> sqrt)
    between pixels of the candidate and the "ideal" image tensors.
    Returns a scalar wrapped in a rank-0 tensor."""
    ensure_shape(candidate, IMAGE_SHAPE)
    ensure_shape(ideal, IMAGE_SHAPE)
    result = ((candidate - ideal) ** 2).mean().sqrt()
    ensure_shape(result, ())
    return result


In [None]:
[
    norm_function(
        get_stacked_preprocessed_digits(3, DataSplit.TRAINING)[0],
        mean_digit_image(3, DataSplit.TRAINING),
    )
    for norm_function in [l1_norm, l2_norm]
]


In [17]:
# TODO: Consider removing check_shape
def mnist_distance(
    candidate_s: Tensor, ideal: Tensor, check_shape: bool = True
) -> Tensor:
    """Calculates the "distance(s)" between a candidate image tensor OR a tensor of
    candidate image tensors and an "ideal" image tensor.
    If passed a single candidate it will return a scalar wrapped in a rank-0 tensor,
    and if passed a tensor of candidates it will return a rank-1 tensor with the
    corresponding distances (by broadcasting the "ideal" tensor.)
    The "distance(s)" is/are calculated by taking the absolute difference of the pixel
    values of candidate-ideal image pairs and calculating the mean of those pixel
    differences, resulting in a scalar for each of the candidates."""
    if check_shape:
        if candidate_s.ndim == 2:  # single image
            ensure_shape(candidate_s, IMAGE_SHAPE)
        else:  # many images
            ensure_shape(candidate_s, (-1,) + IMAGE_SHAPE)

    distance_s = (candidate_s - ideal).abs().mean((-1, -2))
    # "The tuple (-1,-2) represents a range of axes. In Python, -1 refers to the
    # last element, and -2 refers to the second-to-last. So in this case, this
    # tells PyTorch that we want to take the mean ranging over the values indexed
    # by the last two axes of the tensor. The last two axes are the horizontal
    # and vertical dimensions of an image."

    if check_shape:
        if candidate_s.ndim == 2:
            # The output is a rank-0 tensor wrapping a single distance value for
            # one image.
            ensure_shape(distance_s, ())
        else:
            # The output vector has as many distance values as there are images in
            # the input tensor.
            ensure_shape(distance_s, (candidate_s.shape[0],))
    return distance_s


#### Distance between a three and the ideal three

In [None]:
_distance = mnist_distance(
    get_stacked_preprocessed_digits(3, DataSplit.TRAINING)[0],
    mean_digit_image(3, DataSplit.TRAINING),
)
print(f"shape: {_distance.shape}")
_distance


#### Distances between each of the threes and the ideal three

In [None]:
_distances = mnist_distance(
    get_stacked_preprocessed_digits(3, DataSplit.TRAINING),
    mean_digit_image(3, DataSplit.TRAINING),
)
print(f"shape: {_distances.shape}")
_distances


#### Distance between a concrete three and each of the ideal digits

In [None]:
_distances = mnist_distance(
    get_stacked_preprocessed_digits(3, DataSplit.TRAINING)[0],
    mean_digit_images(DataSplit.TRAINING),
    check_shape=False,
)
print(f"shape: {_distances.shape}")
_distances


#### Distance between each of the threes and each of the ideal digits

In [21]:
# I'm going to use this later so I'll wrap it in a function
def calculate_distances(
    digit: Digit, from_datasplit: DataSplit, to_ideal_digits_datasplit: DataSplit
) -> Tensor:
    expected_digit_count = len(get_digit_file_paths(digit, from_datasplit))
    shaped_concrete_digits = get_stacked_preprocessed_digits(
        digit, from_datasplit
    ).unsqueeze(1)
    ensure_shape(
        shaped_concrete_digits, (expected_digit_count, 1) + IMAGE_SHAPE
    )  # [~6K, 1, 28, 28]
    shaped_means = mean_digit_images(DataSplit.TRAINING).unsqueeze(0)
    ensure_shape(shaped_means, (1, 10, 28, 28))
    distances = mnist_distance(shaped_concrete_digits, shaped_means, check_shape=False)
    ensure_shape(distances, (expected_digit_count, 10))  # [~6K, 10]
    return distances


In [None]:
_distances = calculate_distances(
    digit=3,
    from_datasplit=DataSplit.TRAINING,
    to_ideal_digits_datasplit=DataSplit.TRAINING,
)
print(f"shape: {_distances.shape}")
_distances


#### Distance between each concrete digit and each of the ideal digits

In [None]:
# Will use a loop instead of broadcasting because the training sets of different
# digits have different lengths. Alternatively I could cap all of them to the
# same length.
[
    calculate_distances(
        digit,
        from_datasplit=DataSplit.TRAINING,
        to_ideal_digits_datasplit=DataSplit.TRAINING,
    )
    for digit in range(0, 10)
]


In [24]:
# given a concrete digit, which ideal digit is it closer to?
def match(image: Tensor, ideals: Tensor) -> Digit:
    ensure_shape(image, IMAGE_SHAPE)
    ensure_shape(ideals, (10,) + IMAGE_SHAPE)
    distances = mnist_distance(image, ideals, check_shape=False)
    ensure_shape(distances, (10,))
    # if there's more than one min this will produce the index of the first one
    min_indexes = torch.argmin(distances)
    min_index = min_indexes.item()
    assert min_index in range(0, 10)
    return min_index


In [None]:
match(
    image=get_stacked_preprocessed_digits(8, DataSplit.TRAINING)[3],
    ideals=mean_digit_images(DataSplit.TRAINING),
)


In [26]:
# given a tensor of stacked digit images, which ideal digit is each closer to?
def matches_for_images(images, ideals: Tensor) -> Tensor:
    ensure_shape(images, (-1,) + IMAGE_SHAPE)
    image_count = images.shape[0]
    shaped_images = images.unsqueeze(1)
    ensure_shape(shaped_images, (image_count, 1) + IMAGE_SHAPE)
    ensure_shape(ideals, (10,) + IMAGE_SHAPE)
    shaped_ideals = ideals.unsqueeze(0)
    ensure_shape(shaped_ideals, (1, 10) + IMAGE_SHAPE)
    distances = mnist_distance(shaped_images, shaped_ideals, check_shape=False)
    ensure_shape(distances, (image_count, 10))
    digits = torch.argmin(distances, dim=1)
    ensure_shape(digits, (image_count,))
    return digits


def matches_for_digit(
    digit: Digit,
    from_datasplit: DataSplit,
    ideals_datasplit: DataSplit,
) -> Tensor:
    return matches_for_images(
        images=get_stacked_preprocessed_digits(digit, from_datasplit),
        ideals=mean_digit_images(ideals_datasplit),
    )


matches_for_digit(
    3,
    from_datasplit=DataSet.TRAINING,
    ideals_datasplit=DataSet.TRAINING,
)

In [None]:
_matches_3s = matches_for_images(
    images=get_stacked_preprocessed_digits(3, DataSplit.TRAINING),
    ideals=mean_digit_images(DataSplit.TRAINING),
)
_correct_3s = _matches_3s == 3
_correct_3s


In [None]:
_correct_3s.float().mean()


In [None]:
[
    matches_for_digit(digit, DataSplit.TRAINING, DataSplit.TRAINING)
    for digit in range(0, 10)
]


In [30]:
def baseline_accuracy(
    digit: Digit,
    from_datasplit: DataSplit,
    ideals_datasplit: DataSplit,
) -> Tensor:
    matches = matches_for_digit(digit, from_datasplit, ideals_datasplit)
    corrects = matches == digit
    accuracy = corrects.float().mean()
    ensure_shape(accuracy, ())  # returns a rank-0 tensor
    return accuracy


def baseline_accuracies(
    from_datasplit: DataSplit,
    ideals_datasplit: DataSplit,
) -> List[float]:
    # digits have different counts of samples
    # so this is iterated instead of broadcasted
    return [
        baseline_accuracy(digit, from_datasplit, ideals_datasplit).item()
        for digit in range(0, 10)
    ]


In [None]:
for i, calculate_accuracy_plain_labels in enumerate(
    baseline_accuracies(DataSplit.TRAINING, DataSplit.TRAINING)
):
    print(f"{i}: {calculate_accuracy_plain_labels:.3f}")


## As a linear model

In [None]:
from fastbook import gv

gv("""
init->predict->loss->gradient->step->stop
step->predict[label=repeat]
""")


In [None]:
from functools import reduce


def labeled_data(datasplit: DataSplit) -> Tuple[Tensor, Tensor]:
    _all_digits = range(0, 10)
    _stacked_digits = [
        get_stacked_preprocessed_digits(digit, datasplit) for digit in _all_digits
    ]
    _lengths = [len(get_digit_file_paths(digit, datasplit)) for digit in _all_digits]

    train_x = torch.cat(_stacked_digits).view(-1, IMAGE_SHAPE[0] * IMAGE_SHAPE[1])
    train_y = tensor(
        reduce(
            lambda a, b: a + b, [[digit] * _lengths[digit] for digit in _all_digits]
        ),
        dtype=torch.int64,
        device=device,
    )
    return (train_x, train_y)


train_x, train_y = labeled_data(DataSplit.TRAINING)
valid_x, valid_y = labeled_data(DataSplit.TESTING)
print(f"Training: {train_x.shape, train_y.shape}")
print(f"Validation - testing: {valid_x.shape, valid_y.shape}")


A Dataset in PyTorch is required to return a tuple of (x,y) when indexed. 
It looks like this: $[(x_1,y_1), (x_2,y_2), ... (x_n,y_n)]$

In [34]:
training_dset = list(zip(train_x, train_y))


In [None]:
batch_x, batch_y = training_dset[0]
batch_x.shape, batch_y.shape


### Functions to initialize parameters

In [36]:
from typing import Union

from torch import SymInt


def init_params(shape: Union[int, SymInt], std: float = 1.0) -> Tensor:
    return (torch.randn(shape, device=device) * std).requires_grad_()


# TODO: I could add a different kind of param initializer here such as Kaiming


In the book the model has to differentiate between 2 digits, so the model has a single output. 0.0 is used as some kind of threshold: any values greater than 0 represent a prediction for one digit, the others represent the other digit. For this 2-digit model, `train_y` is a vector of 1s for one digit and 0s for the other digit. (And then unsqueezed to a second dimension of size 1).

0 is at the center of the model outputs distribution, then they switch it to 0.5 (that is, all values between 0 and 1) using **sigmoid**.

I want to switch the model from having one input $y_1$ (probability of input being digit a, and $y_2=1-y_1$) to 10 outputs (originally I was thinking 9 so I would calculate the 10th as $1-p(others)$, but it makes more sense to use Softmax for 10 outputs, as the probs of the 10 should add to 1).

In [37]:
count_outputs = 10


In [38]:
weights = init_params((IMAGE_SHAPE[0] * IMAGE_SHAPE[1], count_outputs))
biases = init_params(count_outputs)


In [None]:
print(f"Shape of the full training data: {train_x.shape}")
print(f"Shape of the weights: {weights.shape}")


### Run the model "manually"

So I better understand the forward computation.

How I think the `batch @ weights + bias`, equation operates in a linear model with several outputs: it has 9 sets of parameters and 9 biases. So I'm imagining it's just that equation once per output probability, with the same input data, independently and in parallel.

#### Run it once for a single image

In [None]:
first_image = train_x[0]
print(f"Shape of the first image: {first_image.shape}")

# the T is transpose, flipping weights from [784, count_outputs] to [count_outputs, 784]
transposed_weights = weights.T
print(f"Shape of transposed weights: {transposed_weights.shape}")

element_wise_product = first_image * transposed_weights
print(f"Shape of element-wise multiplication: {element_wise_product.shape}")

model_result = element_wise_product.sum() + biases
print(f"Shape of model result: {model_result.shape}")
print(f"Model result: {model_result}")


#### Run it once for a batch data using broadcasting (no loops!)

In [None]:
input = train_x  # the whole dataset -- it still feels instantaneous

print(f"Shape of batch: {input.shape}")
print(f"Shape of weights: {weights.shape}")
print(f"Shape of bias: {biases.shape}")


# Here is one of the two magic equations (the other one is the activation
# function), equivalent to doing the run above but for each image in the batch
# (in this case the full training set).
model_result = input @ weights + biases

print(f"Shape of results: {model_result.shape}")


### Model wrapped in a class

In [42]:
# Wrapped in a class, together with its weights and biases, like models
# in the pytorch nn package:

from abc import ABC, abstractmethod
from itertools import chain
from typing import Callable


class Module(ABC):  # e.g. a model or layer
    @abstractmethod
    def run(self, input: Tensor) -> Tensor:
        pass

    @abstractmethod
    def params(self) -> Sequence[Tensor]:
        pass

    def __call__(self, *args, **kwargs):
        return self.run(*args, **kwargs)

    def verify_input(self, input: Tensor, in_features: int) -> Tensor:
        return ensure_shape(input, (-1, in_features))

    def verify_output(self, output: Tensor, out_features: int) -> Tensor:
        return ensure_shape(output, (-1, out_features))


class Linear(Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        init_params_function: Callable[
            [Union[int, SymInt]], Tensor
        ],  # takes shape, returns tensor
    ):
        self.in_features = in_features
        self.out_features = out_features
        self.weights = init_params_function((in_features, out_features))
        self.biases = init_params_function(out_features)

    def run(self, input: Tensor) -> Tensor:
        self.verify_input(input, self.in_features)
        # Apply linear transformation: input (e.g. batch) @ weights + biases
        linear = input @ self.weights + self.biases
        return self.verify_output(
            linear, self.out_features
        )  # return unnormalized logits

    def params(self) -> Sequence[Tensor]:
        return [self.weights, self.biases]


class ReLU(Module):
    def run(self, input: Tensor) -> Tensor:
        return input.max(tensor(0.0, device=device))

    def params(self) -> Sequence[Tensor]:
        return []


class Sequential(Module):
    def __init__(self, submodules: Sequence[Module]):
        self.submodules = submodules

    def run(self, input: Tensor) -> Tensor:
        # Call all the submodules in order, calling the first with the input,
        # passing the output of each to the input of the next one, and returning
        # the output of the last one
        assert len(self.submodules) > 0
        i = input
        o: Tensor = None
        for submodule in self.submodules:
            if o is not None:
                i = o
            o = submodule(i)
        return o

    def params(self) -> Sequence[Tensor]:
        return list(chain.from_iterable([sub.params() for sub in self.submodules]))


This is a linear model to test some functions below:

In [43]:
_linear_model = Linear(
    in_features=IMAGE_SHAPE[0] * IMAGE_SHAPE[1],
    out_features=count_outputs,
    init_params_function=init_params,
)


In [None]:
_linear_model(valid_x) is None


### Normalizers of logits to predictions

In [45]:
def normalize_softmax(logits: Tensor) -> Tensor:
    """Normalizes logits applying softmax.

    Returns:
        Tensor: A tensor of predictions with the same dimensions as the input.
        Predictions sum 1, and equal to the probability of each class match.
    """
    return torch.softmax(logits, dim=1)


def normalize_log_softmax(logits: Tensor) -> Tensor:
    """Normalizes logits applying log softmax.

    Returns:
        Tensor: A tensor of predictions with the same dimensions as the input.
        Predictions sum 1, and equal to the probability of each class match.
    """
    return torch.log_softmax(logits, dim=1)


### Functions to turn labels into targets

In [46]:
import torch.nn.functional as F


def encode_one_hot_targets(labels: Tensor, num_classes=10) -> Tensor:
    return F.one_hot(labels, num_classes).float()


def pass_through_labels(labels: Tensor) -> Tensor:
    return labels


In [None]:
print(valid_y[0])
print(encode_one_hot_targets(valid_y)[0])


### Functions to validate the model

In [48]:
def calculate_accuracy_plain_labels(preds: Tensor, labels: Tensor) -> float:
    """Calculates the accuracy of a matrix of predictions given a vector of labels.

    Args:
        preds (Tensor): A 2D matrix of image index to digit match likelihoods.
        These can be logits or normalized. Just the maximum value of each row is used.
        labels (Tensor): A 1D vector of labels corresponding to each image

    Returns:
        float: A scalar indicating the average accuracy
    """
    assert preds.ndim == 2
    ensure_shape(labels, (preds.shape[0],))
    preds_as_digits = preds.argmax(dim=1)
    correct = preds_as_digits == labels
    return correct.float().mean().item()


def validate_model_plain_labels(
    model: Module, valid_x: Tensor, valid_y: Tensor
) -> float:
    return calculate_accuracy_plain_labels(model(valid_x), valid_y)


In [None]:
calculate_accuracy_plain_labels(_linear_model(valid_x), valid_y)


### Functions to calculate the loss

Even though the normalization step (such as Softmax) seems to conceptually fit better with the model, I see it being used in the loss function. I've extracted it into a normalizer.

In [50]:
def calculate_loss_plain_labels(preds: Tensor, labels: Tensor) -> Tensor:
    # preds must be normalized
    ensure_shape(labels, (preds.shape[0],))
    probs_right_guesses = torch.gather(
        preds,
        dim=1,
        index=labels.unsqueeze(
            dim=1
        ),  # unsqueeze because "Index tensor must have the same number of dimensions as input tensor"
    )
    losses = 1 - probs_right_guesses
    return ensure_shape(losses.sum(), ())


def calculate_loss_one_host_targets(preds: Tensor, targets: Tensor) -> Tensor:
    # preds must be normalized
    assert preds.shape == targets.shape

    # TODO is there a more efficient way to do this?
    probability_right_guesses = torch.where(
        targets == 1.0, preds, torch.zeros_like(preds)
    )
    losses = 1 - probability_right_guesses.sum(dim=1)
    return losses.sum()


These two actually return the same value currently:

In [None]:
calculate_loss_plain_labels(normalize_softmax(_linear_model(train_x)), train_y)


In [None]:
calculate_loss_one_host_targets(
    normalize_softmax(_linear_model(train_x)),
    encode_one_hot_targets(train_y),
)


### Datasets and DataLoaders

A collection that contains tuples of independent and dependent variables is a dataset

In [53]:
train_dataset = list(zip(train_x, train_y))
valid_dataset = list(zip(valid_x, valid_y))


Take a look somewhere in a dataset:

In [None]:
_index = 40000
show_image(train_dataset[_index][0].view((IMAGE_SHAPE)))
train_dataset[_index][1]


When we pass a dataset to a DataLoader we will get back many batches that are themselves tuples of tensors representing batches of independent and dependent variables.

In [55]:
from fastai.data.load import DataLoader

# A dataloader with a tiny batch size for playing around
data = DataLoader(train_dataset, batch_size=2, shuffle=True)


In [56]:
from fastai.basics import first

first_training_batch = first(data)


Each training batch is a tuple of two tensors.

In [57]:
assert isinstance(first_training_batch, tuple), "first_training_batch should be a tuple"
assert len(first_training_batch) == 2, "first_training_batch should have length 2"


The first element is a tensor of images and the second is a tensor of labels. Both have the same length, which corresponds to the `batch_size` passed to the `DataLoader` constructor above.

In [58]:
_images, _labels = first_training_batch


In [None]:
_images.shape


In [None]:
_labels.shape


The first image and label in the batch:

Take a look into the first dataloader batch -> first image and label:

In [None]:
show_image(first(_images).view(IMAGE_SHAPE))
first(_labels)


### Training

#### Forward pass + calc grads

In [62]:
# Returns average loss
def forward_pass_calc_grad(
    batch_x: Tensor,
    batch_y: Tensor,
    model: Module,
    normalize: Callable[[Tensor], Tensor],
    loss_function: Callable[[Tensor, Tensor], Tensor],
) -> float:
    logits = model(batch_x)
    preds = normalize(logits)
    loss = loss_function(preds, batch_y)
    loss.backward()
    return loss.item()


#### Optimizer

In [63]:
class Optimizer:
    def __init__(self, learning_rate: float, params: Sequence[Tensor]):
        self.learning_rate = learning_rate
        self.params = params

    def step(self):
        for p in self.params:
            p.data -= p.grad.data * self.learning_rate

    def zero_grad(self):
        # for p in self.params: p.grad = None
        for p in self.params:
            p.grad.zero_()


#### Cycle of training

In [64]:
from typing import Generator

import matplotlib.pyplot as plt

# TODO: untied end: training data with labels or targets


def train_model(
    data: Union[DataLoader, Tuple[Tensor, Tensor]],
    model,
    *,
    normalizer,
    loss_function,
    optimizer,
    epochs: int,
) -> Generator[Tuple[int, float], None, None]:
    def train_epoch(x: Tensor, y: Tensor):
        forward_pass_calc_grad(
            x,
            y,
            model,
            normalizer,
            loss_function,
        )
        optimizer.step()
        optimizer.zero_grad()

    for epoch in range(epochs):
        if isinstance(data, DataLoader):
            for batch_x, batch_y in data:
                train_epoch(batch_x, batch_y)
        else:
            train_epoch(data[0], data[1])

        yield epoch, validate_model_plain_labels(model, valid_x, valid_y)
        # TODO: will this have affected the parameter grads? Significantly?


def plot_accuracies(accuracies: List[float]) -> None:
    plt.plot(accuracies)
    plt.xlabel("Batch")
    plt.ylabel("Accuracy")
    plt.ylim(0, 1)
    plt.title("Training Accuracy Over Batches")
    plt.show()


In [65]:
intermediate_features = 20
model = Sequential(
    [
        Linear(
            in_features=IMAGE_SHAPE[0] * IMAGE_SHAPE[1],
            out_features=intermediate_features,
            init_params_function=init_params,
        ),
        ReLU(),
        Linear(
            in_features=intermediate_features,
            out_features=count_outputs,
            init_params_function=init_params,
        ),
    ]
)


In [None]:
import time

start_time = time.time()

data = (
    train_x,
    train_y,
)  # DataLoader(train_dataset, batch_size=60000, shuffle=True, device=device)

accuracies = list()

for epoch, acc in train_model(
    data,
    model,
    normalizer=normalize_softmax,
    loss_function=calculate_loss_plain_labels,
    optimizer=Optimizer(
        learning_rate=1e-4,
        params=model.params(),
    ),
    epochs=4000,
):
    elapsed = time.time() - start_time
    accuracies.append(acc)
    if epoch % 100 == 0:
        print(f"Time={elapsed:.2f}s, epoch: {epoch}, accuracy: {acc}")


In [None]:
plot_accuracies(accuracies)


In [68]:
import os
import platform


def notify_complete(success=True):
    """Notify when cell execution completes"""
    if platform.system() == "Darwin":  # macOS
        os.system("afplay /System/Library/Sounds/Glass.aiff")
    else:  # Other platforms
        print("\a")  # ASCII bell


notify_complete()


Questions after all of this:
- How to choose the number of intermediate features?
- How to play with LRs? Should the LR vary within a training run? With which shape?
- How to saturate the GPU? How to profile bottlenecks?
- Speaking of which, what's the deal with `DataLoader`? Why so slow?
- How does batching help with learning (not only with speed)?