**Check Usage of GPUs**

In [None]:
!nvidia-smi

## 0. Imports

In [None]:
import utils

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchrs.datasets import RESISC45, AID

from sklearn import metrics
from sklearn import decomposition
from sklearn import manifold
from tqdm.notebook import trange, tqdm
import matplotlib.pyplot as plt
import numpy as np

import copy
import random
import time

import sys
import os
import requests

sys.path.append('./mae')
sys.path.append('/data/ek58_data/')
!pip3 install timm==0.4.5  # 0.3.2 does not work in Colab
    
import models_mae
import cls_mlp

In [None]:
torch.cuda.set_device(0)

In [None]:
torch.cuda.current_device()

## 1. Data Processing

In [None]:
## CHOOSE DATASET
# dataset = 'MNIST'
# dataset = 'CIFAR10'
dataset = 'RESISC45'
# dataset = 'AID'

In [None]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

ROOT = '.data'

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
     transforms.Resize([224, 224])])

if dataset == 'MNIST':
    train_data = datasets.MNIST(root=ROOT,
                              train=True,
                              download=True,
                              transform=transforms)

    test_data = datasets.MNIST(root=ROOT,
                            train=False,
                            download=True,
                            transform=transform)
elif dataset == 'CIFAR10':
    train_data = datasets.CIFAR10(root=ROOT,
                              train=True,
                              download=True,
                              transform=transform)
    
    test_data = datasets.CIFAR10(root=ROOT,
                            train=False,
                            download=True,
                            transform=transform)
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
elif dataset == 'RESISC45':
    ROOT = '/data/ek58_data/data/NWPU-RESISC45'
    resisc45_data = RESISC45(root=ROOT,
                        transform=transform)
    classes = resisc45_data.classes
    train_data, test_data = torch.utils.data.random_split(resisc45_data, [27000, 4500])
elif dataset == 'AID':
    ROOT = '/data/ek58_data/data/AID'
    aid_data = AID(root=ROOT,
                        transform=transform)
    classes = aid_data.classes
    train_data, test_data = torch.utils.data.random_split(aid_data, [8500, 1500]) 

# print(f'Number of training examples: {len(train_data)}')
# print(f'Number of testing examples: {len(test_data)}')

BATCH_SIZE = 64
# BATCH_SIZE = 16
# BATCH_SIZE = 1

train_iterator = data.DataLoader(train_data,
                                 shuffle=True,
                                 batch_size=BATCH_SIZE)

test_iterator = data.DataLoader(test_data,
                                batch_size=BATCH_SIZE)

### Load a pre-trained MAE model

In [None]:
def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'):
    # build model
    model = getattr(models_mae, arch)()
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)
    return model

In [None]:
# # download checkpoint if not exist
# !wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_base.pth

In [None]:
# This is an MAE model trained with pixels as targets for visualization (ViT-Large, training mask ratio=0.75)

# chkpt_dir = 'pre_trained/mae_visualize_vit_large.pth'
# chkpt_dir = 'pre_trained/mae_visualize_vit_base.pth'
# chkpt_dir = 'pre_trained/mae_pretrain_vit_large_full.pth'
chkpt_dir = 'pre_trained/mae_pretrain_vit_base_full.pth'

# model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16')
model_mae = prepare_model(chkpt_dir, 'mae_vit_base_patch16')
print('Model loaded.')

In [None]:
# mask_ratio = 0
# mask_ratio = 0.1
# mask_ratio = 0.2
# mask_ratio = 0.25
# mask_ratio = 0.4
# mask_ratio = 0.5
# mask_ratio = 0.7
# mask_ratio = 0.75
mask_ratio = 0.8
# mask_ratio = 0.9
# mask_ratio = 0.99

input_shape = train_data[0][0].shape
x, ids_restore, num_unmasked = cls_mlp.get_decoders_input(model_mae, torch.zeros(input_shape).unsqueeze(dim=0), mask_ratio, 'cpu')
input_dim = x.numel()
output_dim = len(classes)

model = cls_mlp.CLS_MLP(input_dim, output_dim, 'cpu')

# print(f'The model has {utils.count_parameters(model):,} trainable parameters')

In [None]:
print('input dim:', input_dim)
print('output dim:', output_dim)
print('MLP Model:', model)

## 3. Training the Model

In [None]:
optimizer = optim.Adam(model.parameters())
# optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model_mae = model_mae.to(device)
criterion = criterion.to(device)

In [None]:
# EPOCHS = 300
# EPOCHS = 100
EPOCHS = 30
# EPOCHS = 10
# EPOCHS = 3

model_str = 'CLS_MLP_' + dataset + '_mask_' + str(mask_ratio) + '_ep_' + str(EPOCHS)
save_model_name = 'models/CLS_MLP/model_' + model_str + '.pt'
best_valid_loss = float('inf')

In [None]:
train_loss_vec = []; train_acc_vec = []
test_loss_vec = []; test_acc_vec = []
for epoch in trange(EPOCHS):

    start_time = time.monotonic()

    train_loss, train_acc = cls_mlp.train(model, model_mae, train_iterator, optimizer, criterion, mask_ratio, device)
    test_loss, test_acc = cls_mlp.evaluate(model, model_mae, test_iterator, criterion, mask_ratio, device)

    train_loss_vec.append(train_loss); train_acc_vec.append(train_acc)
    test_loss_vec.append(test_loss); test_acc_vec.append(test_acc)

    end_time = time.monotonic()

    epoch_mins, epoch_secs = utils.epoch_time(start_time, end_time)

    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
#     print(f'\tMLP Time: {mlp_time} ')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\tTest  Loss: {test_loss:.3f} | Test  Acc: {test_acc*100:.2f}%')

## 4. Examining the Model

### ***Save Model Logs***

In [None]:
np.savetxt('logs/' + str(mask_ratio) + '/' + model_str + '_train_loss.log', train_loss_vec, fmt='%1.4f')
np.savetxt('logs/' + str(mask_ratio) + '/' + model_str + '_test_loss.log', test_loss_vec, fmt='%1.4f')
np.savetxt('logs/' + str(mask_ratio) + '/' + model_str + '_train_acc.log', train_acc_vec, fmt='%1.4f')
np.savetxt('logs/' + str(mask_ratio) + '/' + model_str + '_test_acc.log', test_acc_vec, fmt='%1.4f')