# Install dependencies

In [None]:
! pip install git+https://github.com/karthikrangasai/efficient_face.git@master
! pip install torchtext==0.14

# Imports

In [None]:
from pytorch_lightning import Trainer, seed_everything
import pytorch_lightning.callbacks as plcb
from pytorch_lightning.loggers.wandb import WandbLogger
from efficient_face.data import ciFAIRDataModule
from efficient_face.models import SoftmaxBasedModel, TripletLossBasedModel

In [None]:
from torch.optim import Adam, Adadelta, Adagrad, RMSprop
from torch_optimizer import Ranger, Lookahead, SGDW
from torch.optim.lr_scheduler import (
    ConstantLR,
    CosineAnnealingWarmRestarts,
    CyclicLR,
    StepLR,
)

# Configure the parameters

In [None]:
RANDOM_SEED = 1234

# Optimizer Params
LEARNING_RATE = 1e-3
OPTIMIZER_CLS = Adam
OPTIMIZER_KWARGS = dict()  # Don't add `params` and `lr` arguments here
LR_SCHEDULER_CLS = None
LR_SCHEDULER_KWARGS = dict(
    num_steps_arg=None,  # Change this value to the argument name when changing LR Scheduler
    num_steps_factor=1.0,
)

# DataModule Params
BATCH_SIZE = 16
NUM_EPOCHS = 2
NUM_WORKERS = 2

# Model Params
MODEL_NAME = "mobilenetv3_small_100"
EMBEDDING_SIZE = 128

# Loss Function Params
DISTANCE_METRIC = "L2"
TRIPLET_STRATEGY = "VANILLA"
MINER_KWARGS = dict()
LOSS_FUNC_KWARGS = dict()

# Trainer Params
ACCELERATOR = "gpu"  # or "cpu"
NUM_DEVICES = 1

# Setup SEED for Random generators

In [None]:
seed_everything(RANDOM_SEED)

# Setup the DataModule

In [None]:
datamodule = ciFAIRDataModule(
    batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, model_name=MODEL_NAME
)

# Setup the Model

In [None]:
model = SoftmaxBasedModel(
    model_name=MODEL_NAME,
    embedding_size=EMBEDDING_SIZE,
    distance_metric=DISTANCE_METRIC,
    triplet_strategy=TRIPLET_STRATEGY,
    miner_kwargs=MINER_KWARGS,
    loss_func_kwargs=LOSS_FUNC_KWARGS,
    learning_rate=LEARNING_RATE,
    optimizer=OPTIMIZER_CLS,
    optimizer_kwargs=OPTIMIZER_KWARGS,
    lr_scheduler=LR_SCHEDULER_CLS,
    lr_scheduler_kwargs=LR_SCHEDULER_KWARGS,
)

# Training

In [None]:
model_summary = plcb.RichModelSummary()
progress_bar = plcb.RichProgressBar()
lr_monitor = plcb.LearningRateMonitor(logging_interval="step")
checkpoint = plcb.ModelCheckpoint(
    dirpath="",
    filename="{epoch}--{val_loss:.3f}",
    monitor="val_loss",
    save_last=True,
    save_top_k=2,
    mode="min",
    auto_insert_metric_name=True,
    every_n_epochs=2,
)

CALLBACKS = [model_summary, progress_bar, lr_monitor, checkpoint]

In [None]:
LOGGER = WandbLogger(
    project="efficient_face",
    log_model=True,
    group=MODEL_NAME,
    id=None,  # Change when a run has failed to auto-resume it.
)

In [None]:
trainer = Trainer(
    num_sanity_val_steps=0,
    check_val_every_n_epoch=2,
    detect_anomaly=True,
    max_epochs=NUM_EPOCHS,
    accelerator=ACCELERATOR,
    devices=NUM_DEVICES,
    logger=LOGGER,
    callbacks=CALLBACKS,
)

In [None]:
trainer.fit(model, datamodule=datamodule)

In [None]:
print(checkpoint.best_model_path)