In [1]:
import numpy as np
import optuna
import torch

from image_classification.models.mlp import MLPClassifier
from image_classification.models.resnet import ResNetClassifier_fastai, Resnet
import lightning.pytorch as pl
import lightning.pytorch.loggers as pl_loggers
import lightning.pytorch.callbacks as pl_callbacks
from image_classification.models.classification_model import ClassificationModel

from util import set_seed
from weight_init import init_for_relu
from sklearn.model_selection import ParameterGrid
from datasets import FashionMNIST
from experiment import train
import sklearn.metrics as skm
import torch
from image_classification.models.cnn import CNN
from weight_init import generic_init


  warn(


In [2]:
def show_classification_results(test_dl, model, classnames):
    ypred, ytrue = [], []

    model.eval()
    with torch.inference_mode():
        for x, y in test_dl:
            pred = model(x.to(model.device))
            ypred.extend(pred.argmax(-1).cpu().numpy().flatten())
            ytrue.extend(y.cpu().numpy().flatten())

    print(skm.classification_report(ytrue, ypred, target_names=classnames))
    print(skm.confusion_matrix(ytrue, ypred))

# Param Grid using sklearn

In [3]:
torch.set_float32_matmul_precision("medium")

In [4]:
param_grid = {
    'n_features': [
        # (4, 8, 16, 32, 64),
        # (8, 16, 32, 64, 128),
        (16, 32, 64, 128, 256),
    ],
    'lr': [1e-3],
    'wd': [1e-5],
    'batch_sz': [32]
}

for params in ParameterGrid(param_grid):
    pl.seed_everything(42, True)

    train_dl, valid_dl, test_dl = FashionMNIST.get_dataloaders(batch_size=params['batch_sz'],
                                                               pin_memory=True,
                                                               num_workers=4,
                                                               persistent_workers=True)
    n_classes = len(train_dl.dataset.classes)

    # Select classification model:

    # classifier = MLPClassifier(
    #     input_sz=28 * 28,
    #     n_classes=n_classes,
    #     n_features=(16, 32,),
    # ).apply(init_for_relu)

    # classifier = ResNetClassifier_fastai(
    #     n_classes=n_classes,
    #     n_features=params["n_features"]
    # ).apply(generic_init)

    # classifier = CNN(
    #     n_input_channels=1,
    #     n_classes=n_classes,
    #     n_features=params["n_features"],
    #     n_hidden_layers=128,
    #     init_weights=True,
    #     use_sepconv=False,
    # )

    classifier = Resnet(
        n_input_channels=1,
        n_classes=n_classes,
        n_features=params["n_features"],
        n_hidden_layers=128,
        use_fft_input=False,
        init_weights=True,
    )

    model = ClassificationModel(
        model=classifier,
        n_classes=n_classes,
        opt="AdamW",
        lr=params["lr"],
        wd=params["wd"]
    )

    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        save_top_k=1,
        monitor="accuracy/val",
        mode="max",
        filename="best-{epoch:02d}",
        save_last=True,
    )

    # experiment_name = f"Resnet-batch_sz-{params['batch_sz']}-layers"
    experiment_name = f"Testing-Refactor"
    tb_logger = pl_loggers.TensorBoardLogger(save_dir='./results',
                                             name=experiment_name,
                                             sub_dir=model.classifier.__class__.__name__)
    trainer = pl.Trainer(max_epochs=10,
                         # limit_train_batches=10,
                         # limit_val_batches=10,
                         callbacks=[
                             checkpoint_callback,
                         ],
                         logger=tb_logger,
                         )
    trainer.fit(model=model,
                train_dataloaders=train_dl,
                val_dataloaders=valid_dl,
                # ckpt_path='./results/fashion_mnist/version_5/checkpoints/best-epoch=04.ckpt'
                )

    show_classification_results(test_dl, model, test_dl.dataset.classes)


Global seed set to 42
  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name       | Type                | Params
---------------------------------------------------
0 | classifier | Resnet              | 1.3 M 
1 | accuracy   | MulticlassAccuracy  | 0     
2 | precision  | MulticlassPrecision | 0     
3 | recall     | MulticlassRecall    | 0     
---------------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.003     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


              precision    recall  f1-score   support

 T-shirt/top       0.86      0.83      0.85      1000
     Trouser       0.99      0.98      0.98      1000
    Pullover       0.83      0.87      0.85      1000
       Dress       0.87      0.90      0.88      1000
        Coat       0.80      0.85      0.83      1000
      Sandal       0.97      0.96      0.97      1000
       Shirt       0.73      0.66      0.69      1000
     Sneaker       0.94      0.94      0.94      1000
         Bag       0.99      0.97      0.98      1000
  Ankle boot       0.95      0.96      0.95      1000

    accuracy                           0.89     10000
   macro avg       0.89      0.89      0.89     10000
weighted avg       0.89      0.89      0.89     10000

[[831   1  17  39   5   1 100   0   6   0]
 [  1 982   1  13   2   0   0   0   1   0]
 [ 15   0 872  10  57   0  45   0   1   0]
 [  5  10   5 896  49   0  34   0   1   0]
 [  1   0  66  25 854   0  53   0   1   0]
 [  0   0   0   0   0 956 

In [None]:
def objective(trial: optuna.trial.Trial) -> float:
    pl.seed_everything(42, True)

    params = {
        "dropout": trial.suggest_float("dropout", 0.1, 0.5, log=True),
        "lr": trial.suggest_float("lr", 1e-5, 1e-1, log=True),
        "wd": trial.suggest_float("wd", 1e-5, 1e-1, log=True),
        "opt": trial.suggest_categorical("opt", ["SGD", "RAdam"]),
    }

    data_module = FashionMNIST.FashionMNISTDataModule()
    n_classes = data_module.n_classes
    model = ResNetClassifier(
        n_classes=n_classes,
        **params).apply(init_for_relu)

    trainer = train(model,
                    data_module,
                    monitored_metric='accuracy/val',
                    mode='max',
                    max_epochs=10,
                    limit_train_batches=200)

    return trainer.callback_metrics["accuracy/val"].item()


pruner = optuna.pruners.HyperbandPruner()
study = optuna.create_study(direction="maximize", pruner=pruner)
study.optimize(objective, n_trials=100, timeout=60 * 60, n_jobs=1)

print("Number of finished trials: {}".format(len(study.trials)))

print("Best trial:")
trial = study.best_trial

print("  Value: {}".format(trial.value))

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

In [None]:
data_module = FashionMNIST.FashionMNISTDataModule()
ckpt_path = "results/FashionMNIST/version_2/checkpoints/last.ckpt"
model = ResNetClassifier.load_from_checkpoint(ckpt_path)

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    save_top_k=1,
    monitor="accuracy/val",
    mode="max",
    filename="best-{epoch:02d}",
    save_last=True,
)

tb_logger = pl_loggers.TensorBoardLogger(save_dir='./results',
                                         name=data_module.dataset_name,
                                         sub_dir=model.__class__.__name__)

trainer = pl.Trainer(max_epochs=40,
                     callbacks=[
                         checkpoint_callback,
                     ],
                     logger=tb_logger,
                     )

trainer.fit(model,
            data_module,
            ckpt_path=ckpt_path,
            )