Skip to content

Commit

Permalink
Adapting AlexNet to handle MNIST dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
jmaczan committed Feb 21, 2024
1 parent c307d38 commit 21b3f09
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 3 deletions.
6 changes: 6 additions & 0 deletions configs/data/mnist_alexnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_target_: src.data.mnist_alexnet_datamodule.MNISTAlexNetDataModule
data_dir: ${paths.data_dir}
batch_size: 64 # Needs to be divisible by the number of devices (e.g., if in a distributed setup)
train_val_test_split: [55_000, 5_000, 10_000]
num_workers: 0
pin_memory: False
32 changes: 32 additions & 0 deletions configs/experiment/mnist_alexnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# @package _global_

# to execute this experiment run:
# python train.py experiment=example

defaults:
- override /data: mnist_alexnet
- override /model: mnist_alexnet
- override /callbacks: default
- override /trainer: default

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["mnist", "alexnet"]

seed: 12345

trainer:
min_epochs: 10
max_epochs: 10
gradient_clip_val: 0.5

data:
batch_size: 64

logger:
wandb:
tags: ${tags}
group: "mnist"
aim:
experiment: "mnist"
22 changes: 22 additions & 0 deletions configs/model/mnist_alexnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
_target_: src.models.mnist_module.MNISTLitModule

optimizer:
_target_: torch.optim.Adam
_partial_: true
lr: 0.001
weight_decay: 0.0

scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
_partial_: true
mode: min
factor: 0.1
patience: 10

net:
_target_: src.models.components.alexnet.AlexNet
channels: 1
first_fc_in_features: 1024

# compile model for faster training with pytorch 2.0
compile: false
205 changes: 205 additions & 0 deletions src/data/mnist_alexnet_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
from typing import Any, Dict, Optional, Tuple

import torch
from lightning import LightningDataModule
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import transforms


class MNISTAlexNetDataModule(LightningDataModule):
"""`LightningDataModule` for the MNIST dataset, adapted for original AlexNet.
The MNIST database of handwritten digits has a training set of 60,000 examples, and a test set of 10,000 examples.
It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a
fixed-size image. The original black and white images from NIST were size normalized to fit in a 20x20 pixel box
while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing
technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of
mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.
A `LightningDataModule` implements 7 key methods:
```python
def prepare_data(self):
# Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP).
# Download data, pre-process, split, save to disk, etc...
def setup(self, stage):
# Things to do on every process in DDP.
# Load data, set variables, etc...
def train_dataloader(self):
# return train dataloader
def val_dataloader(self):
# return validation dataloader
def test_dataloader(self):
# return test dataloader
def predict_dataloader(self):
# return predict dataloader
def teardown(self, stage):
# Called on every process in DDP.
# Clean up after fit or test.
```
This allows you to share a full dataset without explaining how to download,
split, transform and process the data.
Read the docs:
https://lightning.ai/docs/pytorch/latest/data/datamodule.html
"""

def __init__(
self,
data_dir: str = "data/",
train_val_test_split: Tuple[int, int, int] = (55_000, 5_000, 10_000),
batch_size: int = 64,
num_workers: int = 0,
pin_memory: bool = False,
) -> None:
"""Initialize a `MNISTDataModule`.
:param data_dir: The data directory. Defaults to `"data/"`.
:param train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`.
:param batch_size: The batch size. Defaults to `64`.
:param num_workers: The number of workers. Defaults to `0`.
:param pin_memory: Whether to pin memory. Defaults to `False`.
"""
super().__init__()

# this line allows to access init params with 'self.hparams' attribute
# also ensures init params will be stored in ckpt
self.save_hyperparameters(logger=False)

# data transformations
self.transforms = transforms.Compose(
[
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
]
)

self.data_train: Optional[Dataset] = None
self.data_val: Optional[Dataset] = None
self.data_test: Optional[Dataset] = None

self.batch_size_per_device = batch_size

@property
def num_classes(self) -> int:
"""Get the number of classes.
:return: The number of MNIST classes (10).
"""
return 10

def prepare_data(self) -> None:
"""Download data if needed. Lightning ensures that `self.prepare_data()` is called only
within a single process on CPU, so you can safely add your downloading logic within. In
case of multi-node training, the execution of this hook depends upon
`self.prepare_data_per_node()`.
Do not use it to assign state (self.x = y).
"""
MNIST(self.hparams.data_dir, train=True, download=True)
MNIST(self.hparams.data_dir, train=False, download=True)

def setup(self, stage: Optional[str] = None) -> None:
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
`trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
`self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
`self.setup()` once the data is prepared and available for use.
:param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
"""
# Divide batch size by the number of devices.
if self.trainer is not None:
if self.hparams.batch_size % self.trainer.world_size != 0:
raise RuntimeError(
f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
)
self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size

# load and split datasets only if not loaded already
if not self.data_train and not self.data_val and not self.data_test:
trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms)
testset = MNIST(self.hparams.data_dir, train=False, transform=self.transforms)
dataset = ConcatDataset(datasets=[trainset, testset])
self.data_train, self.data_val, self.data_test = random_split(
dataset=dataset,
lengths=self.hparams.train_val_test_split,
generator=torch.Generator().manual_seed(42),
)

def train_dataloader(self) -> DataLoader[Any]:
"""Create and return the train dataloader.
:return: The train dataloader.
"""
return DataLoader(
dataset=self.data_train,
batch_size=self.batch_size_per_device,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=True,
)

def val_dataloader(self) -> DataLoader[Any]:
"""Create and return the validation dataloader.
:return: The validation dataloader.
"""
return DataLoader(
dataset=self.data_val,
batch_size=self.batch_size_per_device,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
)

def test_dataloader(self) -> DataLoader[Any]:
"""Create and return the test dataloader.
:return: The test dataloader.
"""
return DataLoader(
dataset=self.data_test,
batch_size=self.batch_size_per_device,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
)

def teardown(self, stage: Optional[str] = None) -> None:
"""Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`,
`trainer.test()`, and `trainer.predict()`.
:param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
Defaults to ``None``.
"""
pass

def state_dict(self) -> Dict[Any, Any]:
"""Called when saving a checkpoint. Implement to generate and save the datamodule state.
:return: A dictionary containing the datamodule state that you want to save.
"""
return {}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""Called when loading a checkpoint. Implement to reload datamodule state given datamodule
`state_dict()`.
:param state_dict: The datamodule state returned by `self.state_dict()`.
"""
pass


if __name__ == "__main__":
_ = MNISTAlexNetDataModule()
15 changes: 12 additions & 3 deletions src/models/components/alexnet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
from torch import nn


Expand All @@ -6,13 +7,13 @@ class AlexNet(nn.Module):
Paper: https://proceedings.neurips.cc/paper_files/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf
"""

def __init__(self):
def __init__(self, channels=3, first_fc_in_features=9216):
super().__init__()

self.model = nn.Sequential(
# 1st conv layer
nn.Conv2d(
in_channels=3,
in_channels=channels,
out_channels=96,
kernel_size=(11, 11),
stride=4,
Expand All @@ -37,7 +38,7 @@ def __init__(self):
nn.ReLU(),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
# 1st fc layer with dropout
nn.Linear(in_features=9216, out_features=4096),
nn.Linear(in_features=first_fc_in_features, out_features=4096),
nn.Dropout(p=0.5),
nn.ReLU(),
# 2nd fc layer with dropout
Expand All @@ -47,3 +48,11 @@ def __init__(self):
# 3rd fc layer
nn.Linear(in_features=4096, out_features=1000),
)

def forward(self, x):
for i, layer in enumerate(self.model):
x = layer(x)
print(f"Layer {i}: {x.size()}")

x = torch.flatten(x, start_dim=1)
return x

0 comments on commit 21b3f09

Please sign in to comment.