In [None]:
!pip install datasets



In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm
import numpy as np
import torchvision
from einops import einsum
from datasets import load_dataset
from accelerate.test_utils.testing import get_backend

DEVICE, _, _ = get_backend()
DEVICE

'cuda'

In [None]:
ds = load_dataset("ylecun/mnist")

In [None]:
def image_to_tensor(image):
    return torchvision.transforms.functional.to_tensor(image).to(DEVICE)

In [None]:
im = image_to_tensor(ds["train"][0]["image"])
im.shape

torch.Size([1, 28, 28])

In [None]:
class LinearFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, W, b):
        ctx.save_for_backward(X, W)
        out_features = einsum(X, W, "b h_in, h_in h_out -> b h_out")
        return out_features + b.unsqueeze(0)

    @staticmethod
    def backward(ctx, grad_output):
        (X, W) = ctx.saved_tensors
        grad_input = einsum(W, grad_output, "h_in h_out, b h_out -> b h_in")
        grad_W = einsum(X, grad_output, "b h_in, b h_out -> h_in h_out")
        grad_b = einsum(grad_output, "b h_out -> h_out")
        return grad_input, grad_W, grad_b

class LinearModule(torch.nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.W = torch.nn.Parameter(torch.empty((in_features, out_features)))
        self.b = torch.nn.Parameter(torch.empty(out_features))

        torch.nn.init.normal_(self.W)
        torch.nn.init.normal_(self.b)

    def forward(self, X):
        return LinearFunction.apply(X, self.W, self.b)

In [None]:
model = nn.Sequential(
    nn.Conv2d(1, 32, 5),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(32, 64, 5),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Flatten(),
    # nn.Linear(1024, 512),
    LinearModule(1024, 512),
    nn.ReLU(),
    nn.Dropout(0.25),
    # nn.Linear(512, 10)
    LinearModule(512, 10)
)
model.to(DEVICE)
model

Sequential(
  (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU()
  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
  (4): ReLU()
  (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (6): Flatten(start_dim=1, end_dim=-1)
  (7): LinearModule()
  (8): ReLU()
  (9): Dropout(p=0.25, inplace=False)
  (10): LinearModule()
)

In [None]:
for name, param in model.named_parameters():
    print(name, param.shape)

0.weight torch.Size([32, 1, 5, 5])
0.bias torch.Size([32])
3.weight torch.Size([64, 32, 5, 5])
3.bias torch.Size([64])
7.W torch.Size([1024, 512])
7.b torch.Size([512])
10.W torch.Size([512, 10])
10.b torch.Size([10])


In [None]:
num_params = 0
for parameter in model.parameters():
    num_params += np.prod(parameter.shape)
num_params

np.int64(582026)

In [None]:
class DatasetWrapper:
    def __init__(self, dataset, batch_size):
        self.ds = dataset.to_iterable_dataset()
        self.batch_size = batch_size
        self.dataset_length = len(dataset)
        self.iterations = (len(dataset) + batch_size - 1) // batch_size

    def shuffle(self):
        self.ds = self.ds.shuffle()

In [None]:
def softmax_reference(logits):
    return F.softmax(logits, dim=-1)

def softmax(logits):
    return logits.exp() / (logits.exp().sum(dim=-1, keepdim=True))

In [None]:
class AdamW:
    def __init__(self,
                 params,
                 lr=0.001,
                 betas=(0.9, 0.999),
                 eps=1e-08,
                 weight_decay=0.01):
        self.params = [param for param in params]
        self.lr = lr
        self.b1, self.b2 = betas
        self.eps = eps
        self.weight_decay = weight_decay

        self.ms = [torch.zeros_like(param) for param in self.params]
        self.vs = [torch.zeros_like(param) for param in self.params]

    def step(self):
        for param, m, v in zip(self.params, self.ms, self.vs):
            param.data.sub_(self.lr * self.weight_decay * param.data)
            m.mul_(self.b1).add_((1 - self.b1) * param.grad)
            v.mul_(self.b2).add_((1 - self.b2) * param.grad**2)
            m_corr = m / (1 - self.b1)
            v_corr = v / (1 - self.b2)
            param.data.sub_(self.lr * m_corr / (v_corr**0.5 + self.eps))

    def zero_grad(self):
        for param in self.params:
            param.grad = None


In [None]:
def cross_entropy_reference(logits, labels):
    return F.cross_entropy(logits, labels, reduction="mean")

def cross_entropy(logits, labels):
    pred_logits = logits[torch.arange(logits.shape[0]), labels]
    logprobs = pred_logits - torch.logsumexp(logits, dim=-1)
    return -logprobs.mean()

def images_to_tensor_batch(images):
    return torch.cat([
        image_to_tensor(image)
        for image in images
    ]).unsqueeze(1)

def labels_to_tensor_batch(labels):
    return torch.LongTensor(labels).to(DEVICE)

def train_one_batch(model, optimizer, batch):
    images = images_to_tensor_batch(batch["image"])
    labels = labels_to_tensor_batch(batch["label"])

    logits = model(images)

    loss = cross_entropy(logits, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    return loss.item()

def train_one_epoch(model, optimizer, wrapped_ds):
    model.train()
    losses = []
    for batch in tqdm(wrapped_ds.ds.iter(batch_size=wrapped_ds.batch_size),
                      total=wrapped_ds.iterations):
        loss = train_one_batch(model, optimizer, batch)
        losses.append(loss)
    return losses

@torch.no_grad()
def evaluate(model, wrapped_ds):
    model.eval()
    losses = []
    accuracies = []
    for batch in tqdm(wrapped_ds.ds.iter(batch_size=wrapped_ds.batch_size),
                    total=wrapped_ds.iterations):
        images = images_to_tensor_batch(batch["image"])
        labels = labels_to_tensor_batch(batch["label"])

        logits = model(images)
        pred = logits.argmax(dim=-1)

        loss = cross_entropy(logits, labels)
        accuracy = (pred == labels).sum().item() / len(labels)
        losses.append(loss.item())
        accuracies.append(accuracy)
    return losses, accuracies

wrapped_train = DatasetWrapper(ds["train"], batch_size=64)
wrapped_test = DatasetWrapper(ds["test"], batch_size=64)
# optimizer = torch.optim.AdamW(model.parameters())
optimizer = AdamW(model.parameters())
for epoch in range(2):
    print(f"Epoch {epoch + 1}")
    losses = train_one_epoch(model, optimizer, wrapped_train)
    print("Average train loss: ", np.mean(losses))
    test_losses, test_accuracies = evaluate(model, wrapped_test)
    print("Average eval loss: ", np.mean(test_losses))
    print("Average eval accuracy: ", np.mean(test_accuracies))
    wrapped_train.shuffle()
    print()

Epoch 1


100%|██████████| 938/938 [00:13<00:00, 67.43it/s]


Average train loss:  2.6954755883640065


100%|██████████| 157/157 [00:02<00:00, 73.60it/s]


Average eval loss:  0.3678508244312493
Average eval accuracy:  0.9670581210191083

Epoch 2


100%|██████████| 938/938 [00:16<00:00, 58.01it/s]


Average train loss:  0.7219443577550241


100%|██████████| 157/157 [00:01<00:00, 98.97it/s] 

Average eval loss:  0.21402439593699327
Average eval accuracy:  0.9783041401273885




