From 21b3f09c6947265ce403705ce0722f9271ca8553 Mon Sep 17 00:00:00 2001 From: jmaczan Date: Wed, 21 Feb 2024 14:52:16 +0100 Subject: [PATCH] Adapting AlexNet to handle MNIST dataset --- configs/data/mnist_alexnet.yaml | 6 + configs/experiment/mnist_alexnet.yaml | 32 ++++ configs/model/mnist_alexnet.yaml | 22 +++ src/data/mnist_alexnet_datamodule.py | 205 ++++++++++++++++++++++++++ src/models/components/alexnet.py | 15 +- 5 files changed, 277 insertions(+), 3 deletions(-) create mode 100644 configs/data/mnist_alexnet.yaml create mode 100644 configs/experiment/mnist_alexnet.yaml create mode 100644 configs/model/mnist_alexnet.yaml create mode 100644 src/data/mnist_alexnet_datamodule.py diff --git a/configs/data/mnist_alexnet.yaml b/configs/data/mnist_alexnet.yaml new file mode 100644 index 0000000..786e4fd --- /dev/null +++ b/configs/data/mnist_alexnet.yaml @@ -0,0 +1,6 @@ +_target_: src.data.mnist_alexnet_datamodule.MNISTAlexNetDataModule +data_dir: ${paths.data_dir} +batch_size: 64 # Needs to be divisible by the number of devices (e.g., if in a distributed setup) +train_val_test_split: [55_000, 5_000, 10_000] +num_workers: 0 +pin_memory: False diff --git a/configs/experiment/mnist_alexnet.yaml b/configs/experiment/mnist_alexnet.yaml new file mode 100644 index 0000000..a23ea93 --- /dev/null +++ b/configs/experiment/mnist_alexnet.yaml @@ -0,0 +1,32 @@ +# @package _global_ + +# to execute this experiment run: +# python train.py experiment=example + +defaults: + - override /data: mnist_alexnet + - override /model: mnist_alexnet + - override /callbacks: default + - override /trainer: default + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["mnist", "alexnet"] + +seed: 12345 + +trainer: + min_epochs: 10 + max_epochs: 10 + gradient_clip_val: 0.5 + +data: + batch_size: 64 + +logger: + wandb: + tags: ${tags} + group: "mnist" + aim: + experiment: "mnist" diff --git a/configs/model/mnist_alexnet.yaml b/configs/model/mnist_alexnet.yaml new file mode 100644 index 0000000..653c505 --- /dev/null +++ b/configs/model/mnist_alexnet.yaml @@ -0,0 +1,22 @@ +_target_: src.models.mnist_module.MNISTLitModule + +optimizer: + _target_: torch.optim.Adam + _partial_: true + lr: 0.001 + weight_decay: 0.0 + +scheduler: + _target_: torch.optim.lr_scheduler.ReduceLROnPlateau + _partial_: true + mode: min + factor: 0.1 + patience: 10 + +net: + _target_: src.models.components.alexnet.AlexNet + channels: 1 + first_fc_in_features: 1024 + +# compile model for faster training with pytorch 2.0 +compile: false diff --git a/src/data/mnist_alexnet_datamodule.py b/src/data/mnist_alexnet_datamodule.py new file mode 100644 index 0000000..d270685 --- /dev/null +++ b/src/data/mnist_alexnet_datamodule.py @@ -0,0 +1,205 @@ +from typing import Any, Dict, Optional, Tuple + +import torch +from lightning import LightningDataModule +from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split +from torchvision.datasets import MNIST +from torchvision.transforms import transforms + + +class MNISTAlexNetDataModule(LightningDataModule): + """`LightningDataModule` for the MNIST dataset, adapted for original AlexNet. + + The MNIST database of handwritten digits has a training set of 60,000 examples, and a test set of 10,000 examples. + It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a + fixed-size image. The original black and white images from NIST were size normalized to fit in a 20x20 pixel box + while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing + technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of + mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field. + + A `LightningDataModule` implements 7 key methods: + + ```python + def prepare_data(self): + # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP). + # Download data, pre-process, split, save to disk, etc... + + def setup(self, stage): + # Things to do on every process in DDP. + # Load data, set variables, etc... + + def train_dataloader(self): + # return train dataloader + + def val_dataloader(self): + # return validation dataloader + + def test_dataloader(self): + # return test dataloader + + def predict_dataloader(self): + # return predict dataloader + + def teardown(self, stage): + # Called on every process in DDP. + # Clean up after fit or test. + ``` + + This allows you to share a full dataset without explaining how to download, + split, transform and process the data. + + Read the docs: + https://lightning.ai/docs/pytorch/latest/data/datamodule.html + """ + + def __init__( + self, + data_dir: str = "data/", + train_val_test_split: Tuple[int, int, int] = (55_000, 5_000, 10_000), + batch_size: int = 64, + num_workers: int = 0, + pin_memory: bool = False, + ) -> None: + """Initialize a `MNISTDataModule`. + + :param data_dir: The data directory. Defaults to `"data/"`. + :param train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`. + :param batch_size: The batch size. Defaults to `64`. + :param num_workers: The number of workers. Defaults to `0`. + :param pin_memory: Whether to pin memory. Defaults to `False`. + """ + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False) + + # data transformations + self.transforms = transforms.Compose( + [ + transforms.Resize((64, 64)), + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)), + ] + ) + + self.data_train: Optional[Dataset] = None + self.data_val: Optional[Dataset] = None + self.data_test: Optional[Dataset] = None + + self.batch_size_per_device = batch_size + + @property + def num_classes(self) -> int: + """Get the number of classes. + + :return: The number of MNIST classes (10). + """ + return 10 + + def prepare_data(self) -> None: + """Download data if needed. Lightning ensures that `self.prepare_data()` is called only + within a single process on CPU, so you can safely add your downloading logic within. In + case of multi-node training, the execution of this hook depends upon + `self.prepare_data_per_node()`. + + Do not use it to assign state (self.x = y). + """ + MNIST(self.hparams.data_dir, train=True, download=True) + MNIST(self.hparams.data_dir, train=False, download=True) + + def setup(self, stage: Optional[str] = None) -> None: + """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`. + + This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and + `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after + `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to + `self.setup()` once the data is prepared and available for use. + + :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``. + """ + # Divide batch size by the number of devices. + if self.trainer is not None: + if self.hparams.batch_size % self.trainer.world_size != 0: + raise RuntimeError( + f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})." + ) + self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size + + # load and split datasets only if not loaded already + if not self.data_train and not self.data_val and not self.data_test: + trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms) + testset = MNIST(self.hparams.data_dir, train=False, transform=self.transforms) + dataset = ConcatDataset(datasets=[trainset, testset]) + self.data_train, self.data_val, self.data_test = random_split( + dataset=dataset, + lengths=self.hparams.train_val_test_split, + generator=torch.Generator().manual_seed(42), + ) + + def train_dataloader(self) -> DataLoader[Any]: + """Create and return the train dataloader. + + :return: The train dataloader. + """ + return DataLoader( + dataset=self.data_train, + batch_size=self.batch_size_per_device, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader[Any]: + """Create and return the validation dataloader. + + :return: The validation dataloader. + """ + return DataLoader( + dataset=self.data_val, + batch_size=self.batch_size_per_device, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=False, + ) + + def test_dataloader(self) -> DataLoader[Any]: + """Create and return the test dataloader. + + :return: The test dataloader. + """ + return DataLoader( + dataset=self.data_test, + batch_size=self.batch_size_per_device, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=False, + ) + + def teardown(self, stage: Optional[str] = None) -> None: + """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, + `trainer.test()`, and `trainer.predict()`. + + :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. + Defaults to ``None``. + """ + pass + + def state_dict(self) -> Dict[Any, Any]: + """Called when saving a checkpoint. Implement to generate and save the datamodule state. + + :return: A dictionary containing the datamodule state that you want to save. + """ + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Called when loading a checkpoint. Implement to reload datamodule state given datamodule + `state_dict()`. + + :param state_dict: The datamodule state returned by `self.state_dict()`. + """ + pass + + +if __name__ == "__main__": + _ = MNISTAlexNetDataModule() diff --git a/src/models/components/alexnet.py b/src/models/components/alexnet.py index 0b1a56c..76d1300 100644 --- a/src/models/components/alexnet.py +++ b/src/models/components/alexnet.py @@ -1,3 +1,4 @@ +import torch from torch import nn @@ -6,13 +7,13 @@ class AlexNet(nn.Module): Paper: https://proceedings.neurips.cc/paper_files/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf """ - def __init__(self): + def __init__(self, channels=3, first_fc_in_features=9216): super().__init__() self.model = nn.Sequential( # 1st conv layer nn.Conv2d( - in_channels=3, + in_channels=channels, out_channels=96, kernel_size=(11, 11), stride=4, @@ -37,7 +38,7 @@ def __init__(self): nn.ReLU(), nn.MaxPool2d(kernel_size=(3, 3), stride=2), # 1st fc layer with dropout - nn.Linear(in_features=9216, out_features=4096), + nn.Linear(in_features=first_fc_in_features, out_features=4096), nn.Dropout(p=0.5), nn.ReLU(), # 2nd fc layer with dropout @@ -47,3 +48,11 @@ def __init__(self): # 3rd fc layer nn.Linear(in_features=4096, out_features=1000), ) + + def forward(self, x): + for i, layer in enumerate(self.model): + x = layer(x) + print(f"Layer {i}: {x.size()}") + + x = torch.flatten(x, start_dim=1) + return x