In [None]:
!nvidia-smi

In [None]:
pip -q install einops

In [None]:
import numpy as np
from matplotlib import pyplot as plt
from seaborn
seaborn.set()

from tqdm.notebook import trange, tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

import einops
import pickle
import os
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
root_folder = "..."
os.makedirs(root_folder, exist_ok=True)
os.chdir(root_folder)

In [None]:
!wget -O '/content/autograder_student.pt' '...'
!wget -O '/content/test_reference.pt' '...'

test_data = torch.load('/content/test_reference.pt')
auto_grader_data = torch.load('/content/autograder_student.pt')
auto_grader_data['output'] = {}

In [None]:
def save_auto_grader_data():
  torch.save(
      {'output': auto_grader_data['output']},
      'autograder.pt'
  )

def rel_error(x, y):
  return torch.max(
      torch.abs(x - y)
       / (torch.maximum(torch.tensor(1e-8), torch.abs(x) + torch.abs(y)))
  ).item()

def check_error(name, x, y, tol=1e-3):
  error = rel_error(x, y)
  if error > tol:
    print(f'The relative error for {name} is {error}, should be smaller than {tol}')
  else:
    print(f'The relative error for {name} is {error}')

def check_acc(acc, threshold):
  if acc > threshold:
    print(f'The accuracy {acc} should >= threshold accuracy {threshold}')
  else:
    print(f'The accuracy is {acc} is better than threshold accuracy {threshold}')

def patchify(images, patch_size=4):
  #...
  raise NotImplementedError

def unpatchify(patches, patch_size=4):
  #...
  raise NotImplementedError

In [None]:
x = test_data['input']['patchify']
y = test_data['output']['patchify']
check_error('patchify', patchify(x), y)

x = auto_grader_data['input']['patchify']
auto_grader_data['output']['patchify'] = patchify(x)
save_auto_grader_data()

x = test_data['input']['unpatchify']
y = test_data['output']['unpatchify']
check_error('unpatchify', unpatchify(x), y)

x = auto_grader_data['input']['unpatchify']
auto_grader_data['output']['unpatchify'] = unpatchify(x)

save_auto_grader_data()

In [None]:
class Transformer(nn.Module):
  def __init__(self, embedding_dim=256, n_heads=4, n_layers=4, feedforward_dim=1024):
    super().__init__()
    self.embedding_dim =embedding_dim
    self.n_layers = n_layers
    self.n_heads = n_heads
    self.feedforward_dim = feedforward_dim
    self.transformer = nn.TransformerEncoder(
        nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead= self.n_heads,
            dim_feedforward=self.feedforward_dim,
            activation= F.gelu,
            batch_first=True,
            dropout=0.0
        ),
        num_layers=n_layers,
    )

    def forward(self, x):
      return self.transformer(x)

class ClassificationViT(nn.Module):
  def __init__(self, n_classes, embedding_dim=256, patch_size=4, num_patches=8):
    super().__init__()
    self.patch_size = patch_size
    self.num_patches = num_patches
    self.embedding_dim = embedding_dim

    self.transformer = Transformer(embedding_dim)
    self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim) * 0.02)
    self.position_encoding = nn.Parameter(
        torch.randn(1, num_patches * num_patches + 1, embedding_dim) * 0.02
        )
    self.patch_projection = nn.Linear(patch_size * patch_size * 3, embedding_dim)

    self.output_head = nn.Sequential(
        nn.LayerNorm(embedding_dim), nn.Linear(embedding_dim, n_classes)
    )

  def forward(self, images):
    # ...
    raise NotImplementedError



In [None]:
model = ClassificationViT(10)
model.load_state_dict(test_data['weights']['ClassificationViT'])
x = test_data['input']['ClassificationViT.forward']
y = model.forward(x)
check_error('ClassificationViT.forward', y, test_data['output']['ClassificationViT.forward'])

model.load_state_dict(auto_grader_data['weights']['ClassificationViT'])
x = auto_grader_data['input']['ClassificationViT.forward']
y = model.forward(x)
check_error('ClassificationViT.forward', y, test_data['output']['ClassificationViT.forward'])

model.load_state_dict(auto_grader_data['weights']['ClassificationViT'])
x = auto_grader_data['input']['ClassificationViT.forward']
y = model.forward(x)
auto_grader_data['output']['ClassificationViT.forward'] = y
save_auto_grader_data()

In [None]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding = 4),
    transforms.Resize(32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

batch_size = 128

trainset = torchvision.datasets.CIFAR10(root='/content/data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='/content/data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

In [None]:
model = ClassificationViT(10)
model.to(torch_device)

optimizer = optim.AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.95), weight_decay=1e-9)

total_steps = 0
num_epochs = 10
train_logfreq = 100
losses = []
train_acc = []
all_val_acc = []
best_val_acc = 0

epoch_iteration = trange(num_epochs)
for epoch in epoch_iteration:
  data_iterator = tqdm(trainloader)
  for x, y in data_iterator:
    total_steps += 1
    x, y = x.to(torch_device), y.to(torch_device)
    logits = model(x)
    loss = torch.mean(F.cross_entropy(logits, y))
    accuracy = torch.mean((torch.argmax(logits, dim=-1) == y).float())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    data_iterator.set_postfix(loss=loss.item(), train_acc = accuracy.item())

    if total_steps % train_logfreq == 0:
      losses.append(loss.item())
      train_acc.append(accuracy.item())

  val_acc = []
  model.eval()
  for x, y in testloader:
    x, y = x.to(torch_device), y.to(torch_device)
    with torch.no_grad():
      logits = model(x)
    accuracy = torch.mean((torch.argmax(logits, dim=-1) == y).float())
    val_acc.append(accuracy.item())
  model.train()

  all_val_acc.append(np.mean(val_acc))
  if np.mean(val_acc) > best_val_acc:
    best_val_acc = np.mean(val_acc)

  epoch_iterator.set_postfix(val_acc=np.mean(val_acc), best_val_acc=best_val_acc)

plt.plot(losses)
plt.title('Train Loss')
plt.figure()
plt.plot(train_acc)
plt.title('Train Accuracy')
plt.figure()
plt.plot(all_val_acc)
plt.title('Val Accuracy')

In [None]:
auto_grader_data['output']['vit_acc'] = best_val_acc
save_auto_grader_data()
check_acc(best_val_acc, threshold=0.65)

In [None]:
def index_sequence(x, ids):
  if len(x.shape) == 3:
    ids = ids.unsqueeze(-1).expand(-1, -1, x.shape[-1])
  return torch.take_along_dim(x, ids, dim=1)

def random_masking(x, keep_length, ids_shuffle):
  # ...
  raise NotImplementedError

def restore_masked(kept_x, masked_x, ids_restore):
  # ...
  raise NotImplementedError

In [None]:
x, ids_shulle = test_data['input']['random_masking']
kept, mask, ids_restore = random_masking(x, 4, ids_shuffle)
kept_t, mask_t, ids_restore_t = test_data['output']['random_masking']
check_error('random_masking: kept', kept, kept_t)
check_error('random_masking: mask', mask, mask_t)
check_error('random_masking: ids_restore', ids_restore, ids_restore_t)

x, ids_shuffle = auto_grader_data['input']['random_masking']
kept, mask, ids_restore = random_masking(x, 4, ids_shuffle)
auto_grader_data['output']['random_masking'] = (kept, mask, ids_restore)
save_auto_grader_data()

kept_x, masked_x, ids_restore = test_data['input']['restore_masked']
restored = restore_masked(kept_x, masked_x, ids_restore)
check_error('restore_masked', restored, test_data['output']['restore_masked'])

kept_x, masked_x, ids_restore = auto_grader_data['input']['restore_masked']
restored = restore_masked(kept_x, masked_x, ids_restore)
auto_grader_data['output']['restore_masked'] = (kept, mask, ids_restore)
save_auto_grader_data()

In [None]:
class MaskedAutoEncoder(nn.Module):
  def __init__(self, encoder, decoder, encoder_embedding_dim=256,
               decoder_embedding_dim=128, patch_size = 4, num_patches=8,
               mask_ratio=0.75):
    super().__init__()
    self.encoder_embedding_dim = encoder_embedding_dim
    self.decoder_embedding_dim = decoder_embedding_dim
    self.patch_size = patch_size
    self.num_patches = num_patches
    self.mask_ratio = mask_ratio

    self.masked_length = int(num_patches * num_patches * mask_ratio)
    self.keep_length = num_patches * num_patches - self.masked_length

    self.encoder = encoder
    self.decoder = decoder

    self.encoder_input_projection = nn.Linear(patch_size * patch_size * 3, encoder_embedding_dim)
    self.decoder_input_projection = nn.Linear(encoder_embedding_dim, decoder_embedding_dim)
    self.decoder_output_projection = nn.Linear(decoder_embedding_dim, patch_size * patch_size * 3)
    self.cls_token = nn.Parameter(torch.randn(1,1, encoder_embedding_dim) * 0.02)
    self.encoder_position_encoding = nn.Parameter(torch.randn(1, num_patches * num_patches, encoder_embedding_dim) * 0.02)
    self.decoder_position_encoding = nn.Parameter(torch.randn(1, num_patches * num_patches, decoder_embedding_dim) * 0.02)
    self.masked_tokens = nn.Parameter(torch.randn(1, 1, decoder_embedding_dim) * 0.02)

  def forward_encoder(self, images, ids_shuffle=None):
    batch_size = images.shape[0]
    if ids_shuffle is None:
      ids_shuffle = torch.argsort(
          torch.rand(
              (batch_size, self.num_patches * self.num_patches),
              device=images.device
          ),
          dim= 1
      )
      # ...
      raise NotImplementedError

def forward_decoder(self, encoder_embeddings, ids_restore):
  # ...
  raise NotImplementedError

def forward(self, images):
  encoder_output, mask, ids_restore = self.forward_encoder(images)
  decoder_output = self.forward_decoder(encoder_output, ids_restore)
  return decoder_output, mask

def forward_encoder_representation(self, images):
  #...
  raise NotImplementedError

In [None]:
model = MaskedAutoEncoder(
    Transformer(embedding_dim = 256, n_layers= 4),
    Transformer(embedding_dim = 128, n_layers= 2),
)

model.load_state_dict(test_data['weights']['MaskedAutoEncoder'])
images, ids_shuffle = test_data['input']['MaskedAutoEncoder.forward_encoder']
encoder_embeddings_t, mask_t, ids_restore_t = test_data['output']['MaskedAutoEncoder.forward_encoder']
encoder_embeddings, mask, ids_restore = model.forward_encoder(
    images, ids_shuffle
    )

check_error(
    'MaskedAutoEncoder.forward_encoder: encoder_embeddings',
    encoder_embeddings, encoder_embeddings_t
    )
check_error(
    'MaskedAutoEncoder.forward_encoder: mask',
    mask, mask_t
    )
check_error(
    'MaskedAutoEncoder.forward_encoder: ids_restore',
    ids_restore, ids_restore_t
    )

encoder_embeddings, ids_restore = test_data['input']['MaskedAutoEncoder.forward_decoder']
decoder_output_t = test_data['output']['MaskedAutoEncoder.forward_decoder']
decoder_output = model.forward_decoder(encoder_embeddings, ids_restore)
check_error(
    'MaskedAutoEncoder.forward_decoder',
    decoder_output,
    decoder_output_t
)

images = test_data['input']['MaskedAutoEncoder.forward_encoder_representation']
encoder_representations_t = test_data['output']['MaskedAutoEncoder.forward_encoder_representation']
encoder_representations = model.forward_encoder_representation(images)
check_error(
    'MaskedAutoEncoder.forward_encoder_representation',
    encoder_representations,
    encoder_representations_t
)

model = MaskedAutoEncoder(
    Transformer(embedding_dim = 256, n_layers= 4),
    Transformer(embedding_dim = 128, n_layers= 2),
)
model.load_state_dict(auto_grader_data['weights']['MaskedAutoEncoder'])
images, ids_shuffle = auto_grader_data['input']['MaskedAutoEncoder.forward_encoder']
auto_grader_data['output']['MaskedAutoEncoder.forward_encoder'] = model.forward_encoder(
    images, ids_shuffle
    )

encoder_embeddings, ids_restore = auto_grader_data['input']['MaskedAutoEncoder.forward_decoder']
auto_grader_data['output']['MaskedAutoEncoder.forward_encoder'] = model.forward_decoder(encoder_embeddings, ids_restore)

images = auto_grader_data['input']['MaskedAutoEncoder.forward_encoder_representation']
auto_grader_data['output']['MaskedAutoEncoder.forward_encoder_representation'] = model.forward_encoder_representation(images)
save_auto_grader_data()

In [None]:
model = MaskedAutoEncoder(
    Transformer(embedding_dim = 256, n_layers= 4),
    Transformer(embedding_dim = 128, n_layers= 2),
)
model.to(torch_device)

optimizer = optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95), weight_decay=0.05)

total_steps = 0
num_epochs = 20
train_logfreq = 100

losses = []

epoch_iterator = trange(num_epochs)
for epoch in epoch_iterator:
  data_iterator = tqdm(trainloader)
  for x, y in data_iterator:
    total_steps += 1
    x = x.to(torch_device)
    image_patches = patchify(x)
    predicted_patches, mask = model(x)
    loss = torch.sum(torch.mean(torch.square(image_patches - predicted_patches), dim=-1)* mask) / mask.sum()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    data_iterator.set_postfix(loss=loss.item())
    if total_steps % train_logfreq == 0:
      losses.append(loss.item())

  torch.save(model.state_dict(), os.path.join(root_folder, "mae_pretrained.pt"))

plt.plot(losses)
plt.title('MAE Train Loss')

In [None]:
class ClassificationMAE(nn.Module):
  def __init__(self, n_classes, mae, embedding_dim=256, detach=False):
    super().__init__()
    self.embedding_dim = embedding_dim
    self.mae = mae
    self.output_head = nn.Sequential(
        nn.LayerNorm(embedding_dim), nn.Linear(embedding_dim, n_classes)
    )
    self.detach = detach

  def forward(self, images):
    #...
    raise NotImplementedError

In [None]:
model = ClassificationMAE(
    10,
    MaskedAutoEncoder(
        Transformer(embedding_dim = 256, n_layers= 4),
        Transformer(embedding_dim = 128, n_layers= 2),
    )
  )

model.load_state_dict(test_data['weights']['ClassificationMAE'])

check_error(
    'ClassificationMAE.forward',
    model(test_data['input']['ClassificationMAE.forward']),
    test_data['output']['ClassificationMAE.forward']
)

model = ClassificationMAE(
    10,
    MaskedAutoEncoder(
        Transformer(embedding_dim = 256, n_layers= 4),
        Transformer(embedding_dim = 128, n_layers= 2),
    )
)

model.load_state_dict(auto_grader_data['weights']['ClassificationMAE'])
auto_grader_data['output']['ClassificationMAE.forward'] = model(
    auto_grader_data['input']['ClassificationMAE.forward']
    )
save_auto_grader_data()

In [None]:
mae = MaskedAutoEncoder(
    Transformer(embedding_dim = 256, n_layers= 4),
    Transformer(embedding_dim = 128, n_layers= 2),
)
mae.load_state_dict(torch.load(os.path.join(root_folder, "mae_pretrained.pt")))

In [None]:
model = ClassificationMAE(10, mae, detach=True)
model.to(torch_device)

optimizer = optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95), weight_decay=1e-9)

total_steps = 0
num_epochs = 20
train_logfreq = 100
losses = []
train_acc = []
all_val_acc = []
best_val_acc = 0

epoch_iterator = trange(num_epochs)
for epoch in epoch_iterator:
  data_iterator = tqdm(trainloader)
  for x, y in data_iterator:
    total_steps =+ 1
    x, y = x.to(torch_device), y.to(torch_device)
    logits = model(x)
    loss = torch.mean(F.cross_entropy(logits, y))
    accuracy = torch.mean((torch.argmax(logits, dim=-1) == y).float())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    data_iterator.set_postfix(loss=loss.item(), train_acc = accuracy.item())

    if total_steps % train_logfreq == 0:
      losses.append(loss.item())
      train_acc.append(accuracy.item())

  val_acc = []
  medal.eval()
  for x, y in testloader:
    x, y = x.to(torch_device), y.to(torch_device)
    with torch.no_grad():
      logits = model(x)
    accuracy = torch.mean((torch.argmax(logits, dim=-1) == y).float())
    val_acc.append(accuracy.item())

  model.train()

  all_val_acc.append(np.mean(val_acc))

  if np.mean(val_acc) > best_val_acc:
    best_val_acc = np.mean(val_acc)

  epoch_iterator.set_postfix(val_acc=np.mean(val_acc), best_val_acc=best_val_acc)

plt.plot(losses)
plt.title('Linear Classification Train Loss')
plt.figure()
plt.plot(train_acc)
plt.title('Linear Classification Train Accuracy')
plt.figure()
plt.plot(all_val_acc)
plt.title('Linear Classification Val Accuracy')

In [None]:
auto_grader_data['output']['mae_linear_acc'] = best_val_acc
save_auto_grader_data()
check_acc(best_val_acc, threshold=0.30)

In [None]:
model = ClassificationMAE(10, mae, detach=False)
model.to(torch_device)

optimizer = optim.AdamW(model.parameters(), lr=1e-4, betas=(0.9, 0.95), weight_decay=1e-9)

total_steps = 0
num_epochs = 20
train_logfreq = 100
losses = []
train_acc = []
all_val_acc = []
best_val_acc = 0

epoch_iterator = trange(num_epochs)
for epoch in epoch_iterator:
  data_iterator = tqdm(trainloader)
  for x, y in data_iterator:
    total_steps += 1
    x, y = x.to(torch_device), y.to(torch_device)
    logits = model(x)
    loss = torch.mean(F.cross_entropy(logits, y))
    accuracy = torch.mean((torch.argmax(logits, dim=-1) ==y).float())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    data_iterator.set_postfix(loss=loss.item(), train_acc=accuracy.item())

    if total_steps % train_logfreq == 0:
      losses.append(loss.item())
      train_acc.append(accuracy.item())

  vall_acc = []
  model.eval()
  for x, y in testloader:
    x, y = x.to(torch_device), y.to(torch_device)
    with torch.no_grad():
      logits = model(x)
    accuracy = torch.mean((torch.argmax(logits, dim=-1) ==y).float())
    val_acc.append(accuracy.item())
  model.train()

  all_val_acc.append(np.mean(val_acc))

  if np.mean(val_acc) > best_val_acc:
    best_val_acc = np.mean(val_acc)

  epoch_iterator.set_postfix(val_acc = np.mean(val_acc), best_val_acc = best_val_acc)

plt.plot(losses)
plt.title('Finetune Classification Train Loss')
plt.figure()
plt.plot(train_acc)
plt.title('Finetune Classification Train Accuracy')
plt.figure()
plt.plot(all_val_acc)
plt.title('Finetune Classification Val Accuracy')

In [None]:
auto_grader_data['output']['mae_finetune_acc'] = best_val_acc
save_auto_grader_data()
check_acc(best_val_acc, threshold=0.70)