In [4]:
from image_classification.models.mlp import MLPClassifier
from image_classification.models.resnet import ResNetClassifier
from layers import ResNetBlock
import lightning.pytorch as pl
import lightning.pytorch.loggers as pl_loggers
import lightning.pytorch.callbacks as pl_callbacks

from util import set_seed
from weight_init import init_for_relu

from ray.train.lightning import (
    RayDDPStrategy,
    RayLightningEnvironment,
    RayTrainReportCallback,
    prepare_trainer,
)

from ray import tune


In [2]:
from datasets import FashionMNIST

In [3]:
pl.seed_everything(42, True)

train_dl, valid_dl, test_dl = FashionMNIST.get_dataloaders(batch_size=16,
                                                           pin_memory=True,
                                                           num_workers=4,
                                                           persistent_workers=True)

n_classes = len(train_dl.dataset.classes)

# model = ResNetClassifier(
#     n_classes=n_classes,
#     opt='AdamW',
#     lr=1e-1,
#     wd=1e-4, ).apply(init_for_relu)

model = MLPClassifier(
    input_sz=28 * 28,
    n_classes=n_classes,
    n_features=(16, 32, 64, 128,)
).apply(init_for_relu)

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    save_top_k=1,
    monitor="accuracy/val",
    mode="max",
    filename="best-{epoch:02d}",
    save_last=True,
)

tb_logger = pl_loggers.TensorBoardLogger(save_dir='./results',
                                         name='fashion_mnist',
                                         sub_dir=model.__class__.__name__)

trainer = pl.Trainer(max_epochs=5,
                     limit_train_batches=10,
                     limit_val_batches=10,
                     callbacks=[
                         # pl.callbacks.LearningRateFinder(),
                         # RayTrainReportCallback(),
                         checkpoint_callback,
                     ],
                     logger=tb_logger,
                     # devices="auto",
                     # accelerator="auto",
                     # strategy=RayDDPStrategy(),
                     # plugins=[RayLightningEnvironment()],
                     )

trainer.fit(model=model,
            train_dataloaders=train_dl,
            val_dataloaders=valid_dl,
            # ckpt_path='./results/fashion_mnist/version_5/checkpoints/best-epoch=04.ckpt'
            )

Seed set to 42
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

KeyboardInterrupt



In [5]:
pl.seed_everything(42, True)

search_space = {
    "layer_1_size": tune.choice([32, 64, 128]),
    "layer_2_size": tune.choice([64, 128, 256]),
    "lr": tune.loguniform(1e-4, 1e-1),
    "batch_size": tune.choice([16, 32]),
}


def train_func(config):
    train_dl, valid_dl, test_dl = FashionMNIST.get_dataloaders(batch_size=config['batch_size'],
                                                               pin_memory=True,
                                                               num_workers=4,
                                                               persistent_workers=True)

    n_classes = len(train_dl.dataset.classes)

    # model = ResNetClassifier(
    #     n_classes=n_classes,
    #     opt='AdamW',
    #     lr=1e-1,
    #     wd=1e-4, ).apply(init_for_relu)

    model = MLPClassifier(
        input_sz=28 * 28,
        n_classes=n_classes,
        n_features=(config['layer_1_size'], config['layer_2_size'],),
        lr=config['lr'],
    ).apply(init_for_relu)

    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        save_top_k=1,
        monitor="accuracy/val",
        mode="max",
        filename="best-{epoch:02d}",
        save_last=True,
    )

    tb_logger = pl_loggers.TensorBoardLogger(save_dir='./results',
                                             name='fashion_mnist',
                                             sub_dir=model.__class__.__name__)

    trainer = pl.Trainer(max_epochs=5,
                         limit_train_batches=10,
                         limit_val_batches=10,
                         callbacks=[
                             # pl.callbacks.LearningRateFinder(),
                             RayTrainReportCallback(),
                             checkpoint_callback,
                         ],
                         logger=tb_logger,
                         devices="auto",
                         accelerator="auto",
                         strategy=RayDDPStrategy(),
                         plugins=[RayLightningEnvironment()],
                         )

    trainer = prepare_trainer(trainer)
    trainer.fit(model=model,
                train_dataloaders=train_dl,
                val_dataloaders=valid_dl,
                # ckpt_path='./results/fashion_mnist/version_5/checkpoints/best-epoch=04.ckpt'
                )

Seed set to 42
