In [None]:
from pathlib import Path
import requests
import gzip
import pickle

import numpy
import torch
import torch.nn as nn
import torch.optim as optim

: 

In [None]:
def download_mnist(path):
    url = "https://github.com/pytorch/tutorials/raw/master/_static/"
    filename = "mnist.pkl.gz"

    if not (path / filename).exists():
        content = requests.get(url + filename).content
        (path / filename).open("wb").write(content)

    return path / filename


data_path = Path("data") if Path("data").exists() else Path("../data")
path = data_path / "downloaded" / "vector-mnist"
path.mkdir(parents=True, exist_ok=True)

datafile = download_mnist(path)

: 

In [None]:
def read_mnist(path):
    with gzip.open(path, "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
    return x_train, y_train, x_valid, y_valid

x_train, y_train, x_valid, y_valid = read_mnist(datafile)

: 

In [None]:
x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)

: 

In [None]:
"""Base Dataset class."""
from typing import Any, Callable, Dict, Sequence, Tuple, Union

from PIL import Image
import torch


SequenceOrTensor = Union[Sequence, torch.Tensor]


class BaseDataset(torch.utils.data.Dataset):
    """Base Dataset class that simply processes data and targets through optional transforms.

    Read more: https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset

    Parameters
    ----------
    data
        commonly these are torch tensors, numpy arrays, or PIL Images
    targets
        commonly these are torch tensors or numpy arrays
    transform
        function that takes a datum and returns the same
    target_transform
        function that takes a target and returns the same
    """

    def __init__(
        self,
        data: SequenceOrTensor,
        targets: SequenceOrTensor,
        transform: Callable = None,
        target_transform: Callable = None,
    ) -> None:
        if len(data) != len(targets):
            raise ValueError("Data and targets must be of equal length")
        super().__init__()
        self.data = data
        self.targets = targets
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self) -> int:
        """Return length of the dataset."""
        return len(self.data)

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Return a datum and its target, after processing by transforms.

        Parameters
        ----------
        index

        Returns
        -------
        (datum, target)
        """
        datum, target = self.data[index], self.targets[index]

        if self.transform is not None:
            datum = self.transform(datum)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return datum, target

: 

In [None]:
train_ds = BaseDataset(x_train, y_train)

train_ds.data.shape

: 

### DataModule

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

class MNISTDataModule:
    url = "https://github.com/pytorch/tutorials/raw/master/_static/"
    filename = "mnist.pkl.gz"
    
    def __init__(self, dir, bs=32):
        self.dir = dir
        self.bs = bs
        self.path = self.dir / self.filename

    def prepare_data(self):
        if not (self.path).exists():
            content = requests.get(self.url + self.filename).content
            self.path.open("wb").write(content)

    def setup(self):
        with gzip.open(self.path, "rb") as f:
            ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

        x_train, y_train, x_valid, y_valid = map(
            torch.tensor, (x_train, y_train, x_valid, y_valid)
            )
        
        self.train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)
        self.valid_ds = BaseDataset(x_valid, y_valid, transform=push_to_device, target_transform=push_to_device)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_ds, batch_size=self.bs, shuffle=True)
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.valid_ds, batch_size=2 * self.bs, shuffle=False)

def push_to_device(tensor):
    print(device)
    return tensor.to(device)        

: 

### Model

In [None]:
class MNISTLogistic(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(784, 10) 

    def forward(self, x):
        return self.lin(x)  

: 

### Loss Function

In [None]:
def cross_entropy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    return -output[range(target.shape[0]), target].mean()

loss_func = cross_entropy

: 

### Optimizer

In [None]:

def configure_optimizer(model: nn.Module) -> optim.Optimizer:
    return optim.Adam(model.parameters(), lr=3e-4)

: 

### Fit

In [None]:
def fit(self: nn.Module, datamodule):
    datamodule.prepare_data()
    datamodule.setup()

    val_dataloader = datamodule.val_dataloader()
    
    self.eval()
    with torch.no_grad():
        valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)

    print("before start of training:", valid_loss / len(val_dataloader))

    opt = configure_optimizer(self)
    train_dataloader = datamodule.train_dataloader()
    for epoch in range(epochs):
        self.train()
        for xb, yb in train_dataloader:
            pred = self(xb)
            loss = loss_func(pred, yb)

            loss.backward()
            opt.step()
            opt.zero_grad()

        self.eval()
        with torch.no_grad():
            valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)

        print(epoch, valid_loss / len(val_dataloader))

: 

In [None]:

MNISTLogistic.fit = fit

: 

In [None]:
digits_to_9 = list(range(10))
data_config = {"input_dims": (784,), "mapping": {digit: str(digit) for digit in digits_to_9}}
data_config

: 

In [None]:
model = MNISTLogistic()
model.to(device)

datamodule = MNISTDataModule(dir=path, bs=32)

epochs = 2
model.fit(datamodule=datamodule)

: 

: 