# Sanity checks

In [None]:
from math import isclose
from pathlib import Path
from warnings import filterwarnings

import matplotlib.pyplot as plt
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning

from shipsnet.data import ShipsDataModule
from shipsnet.models import MLPClassifier
from shipsnet.viz import array_to_rgb_image

filterwarnings("ignore", category=PossibleUserWarning)

## Check the data is being loaded correctly

In [None]:
datamodule = ShipsDataModule()

datamodule.prepare_data()
datamodule.setup()

inputs, labels = next(iter(datamodule.train_dataloader()))

fig, axes = plt.subplots(3, 4)

for tensor, ax in zip(inputs, axes.flatten()):
    
    # You may need to undo the data standardisation transformation here...
    
    ax.imshow(array_to_rgb_image(tensor))
    ax.set_axis_off()

fig.tight_layout()
plt.show()

## Reproducibility check

In [None]:
def train_and_eval():
    """Quickly trains model and returns validation metrics."""
    datamodule = ShipsDataModule()
    model = MLPClassifier([10], "relu")
    trainer = Trainer(
        max_epochs=5,
        logger=False,
        enable_checkpointing=False,
        enable_model_summary=False,
        enable_progress_bar=False,
    )
    trainer.fit(model, datamodule)
    (metrics,) = trainer.validate(model, datamodule)
    return metrics


seed = seed_everything()

metrics_1 = train_and_eval()

# Reset the seed and retrain - should get same results
seed_everything(seed)
metrics_2 = train_and_eval()
assert all([isclose(metrics_1[k], metrics_2[k]) for k in metrics_1])

# Don't reset the seed - should get different results
metrics_3 = train_and_eval()
assert not all([isclose(metrics_1[k], metrics_3[k]) for k in metrics_1])

print("Reproducibility check passed!")