In [None]:
from src.model.classifier_model import ClassifierModel
from src.data.datamodule import AnimalDataModule

In [None]:
data_config = dict(
    image_dir = 'dataset/',
    train_val_test_split = (0.75, 0.15, 0.10),
    batch_size = 32,
    num_workers = 0,
    pin_memory = False
)

model_config = dict(
    pretrained = True,
    freeze_features = True,
    num_classes = 3,
    learning_rate = 0.001,
    optimizer = 'adam',
    beta_1 = 0.9,
    beta_2 = 0.999,
    momentum = 0.9,
    weight_decay = 0.0001
)

experiment_name = 'three_animals_test'

In [None]:
data = AnimalDataModule(**data_config)
model = ClassifierModel(**model_config)

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from src.misc.misc import datetime_now

callbacks = ModelCheckpoint(
    dirpath = 'logs/'+datetime_now(),
    mode = 'min',
    monitor = 'val_loss',
    save_weights_only = True,
    save_top_k = 1,
)


trainer = Trainer(
    max_epochs = 1,
    callbacks = [callbacks],
    checkpoint_callback  = True
)

In [None]:
import mlflow
from src.logger.utils import experiment_id

mlflow.pytorch.autolog()
experiment_id_ = experiment_id(experiment_name = experiment_name)
with mlflow.start_run(experiment_id = experiment_id_) as run:
    trainer.fit(model, data)
    mlflow.log_param('pretrained', pretrained)
    mlflow.log_param('num_classes', num_classes)
    mlflow.log_param('freeze_features', freeze_features)