# Multi-Layer Perceptron Binary Classifier using PyTorch

In [None]:
from math import isclose
from pathlib import Path
from warnings import filterwarnings

import matplotlib.pyplot as plt
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.model_summary import summarize
from pytorch_lightning.utilities.warnings import PossibleUserWarning


from shipsnet.data import ShipsDataModule
from shipsnet.models import MLPClassifier
from shipsnet.viz import array_to_rgb_image

filterwarnings("ignore", category=PossibleUserWarning)

# Train an ensemble of classifiers

Use the two cells below to train an ensemble of MLP classifiers with different hidden shapes and activation functions.

In [None]:
datamodule = ShipsDataModule(
    batch_size=32,
    train_frac=0.75,
    random_split_seed=12345,
)

seed = seed_everything()

model = MLPClassifier(
    hidden_shape=[20],
    activation="relu",
)

# This seems to be broken on older PyTorch Lightning installs
#model.save_hyperparameters({"seed": seed})
summarize(model, max_depth=2)

In [None]:
early_stopping = EarlyStopping(monitor="val/loss", patience=5, verbose=True)
checkpoints = ModelCheckpoint(monitor="val/loss", filename="{epoch:d}")

trainer = Trainer(
    logger=TensorBoardLogger(".", default_hp_metric=False),
    callbacks=[early_stopping, checkpoints],
    enable_model_summary=False,
)
trainer.fit(model, datamodule)

# So we can easily see where to load the checkpoint from later
model.logger.experiment.add_text(
    "checkpoint_path", str(Path(checkpoints.best_model_path).resolve())
)


# Evaluate the best model

Use Tensorboard to compare different models. When you've found the best model, load it up and run it on the test set.

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs