In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import torch

if torch.cuda.is_available():
    device = torch.device("cuda:0")
    accelerator = "gpu"
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    accelerator = "mps"
else:
    device = torch.device("cpu")


from main_code.dataset import FashionDataset
from main_code.nn_definition import FashionAutoEncoder
from main_code.evaluation import evaluate_model
from main_code.visualization import visualizer, visualize_dataset

from utils.save_load_model import save_model, load_model

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [None]:
#Training Dataset and Dataloader
batch_size = 1
train_data = FashionDataset('dataset/woman_25_34_caucasian_frontal_standing/trainset.txt')
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False)
print("Trainset Size:", len(train_data))
print(train_data.image_names)

num_batches = len(train_dataloader)
print("Number of Batches:", num_batches)

In [None]:
inputs, targets = next(iter(train_dataloader))
inputs.shape, targets.shape

In [None]:
visualize_dataset(train_data, num_imgs = 1)

In [None]:
#Model Initialization
hparams = {
    'batch_size': batch_size,
    'num_batches': num_batches,
    'num_resnet_trainable':0,
    'optimizer': 'Adam',
    'learning_rate': 0.005,
    'momentum': 0,
}
model = FashionAutoEncoder(hparams = hparams)

In [None]:
#Initialize Tensorboard Logger
logger = pl.loggers.TensorBoardLogger(save_dir='lightning_logs', name="logs")

#Initialize Model Trainer 
trainer = pl.Trainer(
    max_epochs=1000,
    logger=logger,
    log_every_n_steps=1,
    callbacks=[EarlyStopping(monitor='val_loss', patience = 50, mode='min')],
    accelerator=accelerator,
    devices= 1 if accelerator == "mps" else None,
)

In [None]:
#Train Model
trainer.fit(model, train_dataloaders = train_dataloader, val_dataloaders = train_dataloader)

In [None]:
evaluate_model(model, train_dataloader) #latest best for max num train images: 99.07% for 120 images

In [None]:
visualizer(model, train_data, num_imgs = 5)

In [None]:
save_model(model, 'model')

model_loaded = load_model(FashionAutoEncoder, file_name='model')

In [None]:
%load_ext tensorboard