# Tutorial on Training a Model with the Masked Autoencoder (MAE) Framework using Lightning (PyTorch)

## Training

In [None]:
import torch
from lightning import Trainer
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from mae import MaskedAutoencoderLIT

# Define the transform to normalize the data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the training dataset
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=512, shuffle=True)

# Initialize the Masked Autoencoder model with the specified parameters
model = MaskedAutoencoderLIT(
    size='base',
    in_chans=1,
    base_lr=3e-5,
    num_gpus=1,
    batch_size=512,
    warmup_epochs=1,
    weight_decay=0.05,
    betas=(0.9, 0.95)
)

# Train the model
trainer = Trainer(max_epochs=10, gpus=1)
trainer.fit(model, train_dataloader)

# Save the trained model
torch.save(model.state_dict(), 'masked_autoencoder.pt')

## Convert to ViT

In [None]:
from mae_to_vit import get_vit_from_mae

# Load the trained model to get ViT, Set Global Pooling to False for Linear Probing
vit_for_linear_probe = get_vit_from_mae(pretrained_model=model.state_dict(), global_pool=False)

# Load the trained model to get ViT, Set Global Pooling to True for Fine-Tuning
vit_for_linear_probe = get_vit_from_mae(pretrained_model=model.state_dict(), global_pool=True)