In [1]:
import os
import sys
from pathlib import Path
import requests
import gzip
import pickle

from typing import Any, Callable, Sequence, Union

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
# Append python path - needed to import text_recognizer
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [3]:
def download_mnist(path):
    url = "https://github.com/pytorch/tutorials/raw/master/_static/"
    filename = "mnist.pkl.gz"

    if not (path / filename).exists():
        content = requests.get(url + filename).content
        (path / filename).open("wb").write(content)

    return path / filename


data_path = Path("data") if Path("data").exists() else Path("../data")
path = data_path / "downloaded" / "vector-mnist"
path.mkdir(parents=True, exist_ok=True)

datafile = download_mnist(path)

In [4]:
def read_mnist(path):
    with gzip.open(path, "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
    return x_train, y_train, x_valid, y_valid

x_train, y_train, x_valid, y_valid = read_mnist(datafile)

In [5]:
x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)

In [6]:
class BaseDataset(torch.utils.data.Dataset):
    """Base Dataset class that simply processes data and targets through optional transforms.

    More info: https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset

    Attributes:
        data: commonly these are torch tensors, numpy arrays, or PIL Images
        targets: commonly these are torch tensors or numpy arrays
        transform: function that takes a datum and returns the same
        target_transform: function that takes a target and returns the same
    """

    def __init__(
        self,
        data: Union[Sequence, torch.Tensor],
        targets: Union[Sequence, torch.Tensor],
        transform: Callable = None,
        target_transform: Callable = None,
    ) -> None:
        if len(data) != len(targets):
            raise ValueError("Data and targets must be of equal length")
        super().__init__()
        self.data = data
        self.targets = targets
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self) -> int:
        """Return length of the dataset."""
        return len(self.data)

    def __getitem__(self, index: int) -> tuple[Any, Any]:
        """Return a datum and its target, after processing by transforms.

        Args:
            index (int): _description_

        Returns:
            tuple[Any, Any]: (datum, target)
        """

        datum, target = self.data[index], self.targets[index]

        if self.transform is not None:
            datum = self.transform(datum)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return datum, target

In [7]:
train_ds = BaseDataset(x_train, y_train)

train_ds.data.shape

torch.Size([50000, 784])

### DataModule

In [8]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

def push_to_device(tensor):
    return tensor.to(device)      

    
class MNISTDataModule:
    url = "https://github.com/pytorch/tutorials/raw/master/_static/"
    filename = "mnist.pkl.gz"
    
    def __init__(self, dir, bs=32):
        self.dir = dir
        self.bs = bs
        self.path = self.dir / self.filename

    def prepare_data(self):
        if not (self.path).exists():
            content = requests.get(self.url + self.filename).content
            self.path.open("wb").write(content)

    def setup(self):
        with gzip.open(self.path, "rb") as f:
            ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

        x_train, y_train, x_valid, y_valid = map(
            torch.tensor, (x_train, y_train, x_valid, y_valid)
            )
        
        self.train_ds = BaseDataset(x_train, y_train, transform=push_to_device, target_transform=push_to_device)
        self.valid_ds = BaseDataset(x_valid, y_valid, transform=push_to_device, target_transform=push_to_device)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_ds, batch_size=self.bs, shuffle=True)
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.valid_ds, batch_size=2 * self.bs, shuffle=False)  

### Model - Basic Logistic

In [9]:
class MNISTLogistic(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(784, 10) 

    def forward(self, x):
        return self.lin(x)  

### Optimizer

In [10]:

def configure_optimizer(model: nn.Module) -> optim.Optimizer:
    return optim.Adam(model.parameters(), lr=3e-4)

### Loss function

In [11]:
def cross_entropy(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    return -output[range(target.shape[0]), target].mean()

loss_func = cross_entropy

### Fit

In [12]:
def fit(self: nn.Module, datamodule):
    datamodule.prepare_data()
    datamodule.setup()

    val_dataloader = datamodule.val_dataloader()
    
    self.eval()
    with torch.no_grad():
        valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)

    print("before start of training:", valid_loss / len(val_dataloader))

    opt = configure_optimizer(self)
    train_dataloader = datamodule.train_dataloader()
    for epoch in range(epochs):
        self.train()
        for xb, yb in train_dataloader:
            pred = self(xb)
            loss = loss_func(pred, yb) 

            loss.backward()
            opt.step()
            opt.zero_grad()

        self.eval()
        with torch.no_grad():
            valid_loss = sum(loss_func(self(xb), yb) for xb, yb in val_dataloader)

        # track stats
        print(f"{epoch:7d}/{epochs:7d}: {loss.item():.4f}")


### Logistic


In [13]:
MNISTLogistic.fit = fit

model = MNISTLogistic()
model.to(device)

datamodule = MNISTDataModule(dir=path, bs=32)

epochs = 2
model.fit(datamodule=datamodule)

before start of training: tensor(0.0022)
      0/      2: -41.6332
      1/      2: -63.2852


### MLP

In [14]:
digits_to_9 = list(range(10))
data_config = {"input_dims": (784,), "mapping": {digit: str(digit) for digit in digits_to_9}}
data_config

{'input_dims': (784,),
 'mapping': {0: '0',
  1: '1',
  2: '2',
  3: '3',
  4: '4',
  5: '5',
  6: '6',
  7: '7',
  8: '8',
  9: '9'}}

In [15]:
from text_recognizer.models.mlp import MLP

MLP.fit = fit

model = MLP(data_config)
model.to(device)

datamodule = MNISTDataModule(dir=path, bs=32)

epochs = 20
model.fit(datamodule=datamodule)

before start of training: tensor(0.0005)
      0/     20: -4569717.0000
      1/     20: -29427486.0000
      2/     20: -68096296.0000
      3/     20: -141538064.0000
      4/     20: -277902656.0000
      5/     20: -393276928.0000
      6/     20: -609995392.0000
      7/     20: -693799616.0000
      8/     20: -1103068544.0000
      9/     20: -1296659200.0000
     10/     20: -1626500608.0000
     11/     20: -2141411584.0000
     12/     20: -2263724032.0000
     13/     20: -3361170688.0000
     14/     20: -4291849728.0000
     15/     20: -5203449856.0000
     16/     20: -5750444544.0000
     17/     20: -5686204416.0000
     18/     20: -7228507136.0000
     19/     20: -9068530688.0000
