# Meta Learning with Pytorch Lightning

## Imports

In [None]:
from data_loaders.mnist_loaders import get_mnist_loaders
from models.meta.MetaModel import MetaModel

## Load training data and validation sets

In [None]:
train_loader, val_loader = get_mnist_loaders(batch_size=64, num_workers=0)[:2]

## Define I/O shapes and config

In [None]:
input_shape = (1, 28, 28)
output_shape = 10
max_layers = 5
n_calls = 5

## Instantiate MetaModel

In [None]:
model = MetaModel(input_shape, output_shape, train_loader, val_loader, max_layers=max_layers, n_calls=n_calls)

# Optimize the model's architecture

In [None]:
best_params = model.optimize()

# Create the best model from the best_params

In [None]:
best_model = model.generate_model(best_params)

# Train the best model

In [None]:
model.train_model(best_model)

# Save the best metamodel checkpoint

In [None]:
import os

if not os.path.exists("checkpoints"):
    os.makedirs("checkpoints")
model.save("checkpoints/best_metamodel.pth")