# model_training_o2mnist
This notebook is an example of loading data and building a model from a config file. It shows how a model is trained using similar code to `run.py`. There is no automatic logging. 


To train a model without a notebook, check the repo's README.

In [None]:
import sys
sys.path.append("../")
import importlib
import train_loops
import run
import torch
from utils import utils
import wandb
import logging
import os 
from pathlib import Path
from configs.config_o2mnist import config

device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Load dataset and view sample data 
*** First you need to build the o2-vae datset: Navigate to `data/` and run `python generate_o2mnist.py`.  

In [None]:
importlib.reload(utils)
# get datasets specified by config.data
dset, loader, dset_test, loader_test = run.get_datasets_from_config(config)
print("sample train data")
f, axs = utils.plot_sample_data(loader)
f

## Load the model from config parameters

In [None]:
config.model.encoder.n_channels=dset[0][0].shape[0]  # image channels
model = run.build_model_from_config(config)

# optimizer - by default, no lr scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=config.optimizer.lr)

print(model.model_details())

## Model training
To load a pretrained model for O2-mnist, set `TRAIN_MODEL=True` in the next cell (which will only work if the model config ia still the default
Set that 

In [None]:

try: 
    for epoch in range(config.run.epochs):
        train_loops.train(epoch, model, loader, optimizer, do_progress_bar=config.logging.do_progress_bar,
               do_wandb=0, device=device)

        if config.run.do_validation and epoch%config.run.valid_freq==0:
            train_loops.valid(epoch, model, loader_test, do_progress_bar=config.logging.do_progress_bar,
                    do_wandb=0, device=device)
except KeyboardInterrupt: 
    print("Keyboard interrupt")


## Save a model

In [None]:
fname_model=None
if fname_model:
    model.train()
    torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),}, fname)

## Some sanity checks
#### Reconstruction quality and the 

In [None]:
from utils import eval_utils
import matplotlib.pyplot as plt
importlib.reload(eval_utils)
model.eval().cpu() 

x,y = next(iter(loader_test))
reconstruct_grid = eval_utils.reconstruction_grid(model, x, align=False)
reconstruct_grid_aligned = eval_utils.reconstruction_grid(model, x, align=True)
f,axs = plt.subplots(1,2, figsize=(10,10))
axs[0].imshow(reconstruct_grid)
axs[1].imshow(reconstruct_grid_aligned)
print("Left: reconstructions.")
print("Right: reconstructions where output is re-algined")
axs[0].set_axis_off();  axs[1].set_axis_off()

## Extract features to an array

In [None]:
device='cuda'
embeddings, labels = utils.get_model_embeddings_from_loader(model, loader, return_labels=True)
embeddings_test, labels_test = utils.get_model_embeddings_from_loader(model, loader_test, return_labels=True)