# ResMLP
#### Feedforward networks for image classification written in PyTorch

### Import and install extra libraries



In [None]:
import torch
import torch.backends.cudnn as cudnn
import csv

!pip install timm einops
from timm import optim
from timm import models

import dataset
from model import ResMLP
import learning_utils

### Setting device (CPU or GPU)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

### Build model

In [None]:
# ResNet18 Timm implementation
teacher_model = models.resnet.ResNet(block=models.resnet.BasicBlock, layers=[2, 2, 2, 2], num_classes=10)

# CaiT Timm implementation
#transformer_model = models.cait.Cait(img_size=96, num_classes=10)

# ResMLP12 Timm implementation
#student_model = models.mlp_mixer.MlpMixer(num_classes=10, img_size=96, patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=models.mlp_mixer.ResBlock, norm_layer=models.mlp_mixer.Affine)

# ResMLP12 local implementation
student_model = ResMLP(in_channels=3, image_size=96, patch_size=16, num_classes=10, dim=384, depth=12).to(device)

if device == 'cuda':
        student_model = torch.nn.DataParallel(student_model)
        teacher_model = torch.nn.DataParallel(teacher_model)
        cudnn.benchmark = True

### Loss and optimizer

In [None]:
# Loss function
loss_fn = torch.nn.CrossEntropyLoss()

# Adam optimizer for teacher convnet 
teacher_optimizer = torch.optim.AdamW(teacher_model.parameters(), lr=1e-3, weight_decay=0.05)

# Lamb optimizer for ResMLP12
student_optimizer = optim.Lamb(student_model.parameters(), lr=5e-3, weight_decay=0.2)

In [None]:
num_epochs = 50

with open('performance_without_distillation.csv', 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['epoch', 'train loss', 'test loss',
                     'train accuracy', 'test accuracy'])
train_loss, train_accuracy, test_loss, test_accuracy = 0,0,0,0
for t in range(num_epochs):
    print(f'Epoch {t+1}\n-------------------------------')
    train_loss, train_accuracy = learning_utils.train_teacher(dataset.train_dataloader, student_model, device, loss_fn, student_optimizer)
    test_loss, test_accuracy = learning_utils.test(dataset.test_dataloader, student_model, device, loss_fn)
    with open('performance_without_distillation.csv', 'a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([t, f'{train_loss:f}', f'{test_loss:f}',
                         f'{train_accuracy:f}', f'{test_accuracy:f}'])

### Execute with train distillation

In [None]:
num_epochs = 50

for t in range(num_epochs):
    print(f'Epoch {t+1}\n-------------------------------')
    train_loss, train_accuracy = learning_utils.train(dataset.train_dataloader, teacher_model, device, loss_fn, teacher_optimizer)
    test_loss, test_accuracy = learning_utils.test(dataset.test_dataloader, teacher_model, device, loss_fn)

with open('performance_with_distillation.csv', 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['epoch', 'train loss', 'test loss',
                     'train accuracy', 'test accuracy'])
train_loss, train_accuracy, test_loss, test_accuracy = 0,0,0,0
for t in range(num_epochs):
    print(f'Epoch {t+1}\n-------------------------------')
    train_loss, train_accuracy = learning_utils.train_student(dataset.train_dataloader, student_model, teacher_model, device, loss_fn, student_optimizer)
    test_loss, test_accuracy = learning_utils.test(dataset.test_dataloader, student_model, device, loss_fn)
    with open('performance_with_distillation.csv', 'a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([t, f'{train_loss:f}', f'{test_loss:f}',
                         f'{train_accuracy:f}', f'{test_accuracy:f}'])

### Save the model


In [None]:
torch.save(student_model.state_dict(), 'saved_model.pth')