# ResMLP
#### Feedforward networks for image classification written in PyTorch

### Import and install extra libraries



In [None]:
import torch
import torch.backends.cudnn as cudnn
import csv

!pip install timm einops
from timm import optim
from timm import models

import dataset
from model import ResMLP
import learning_utils

### Setting device (CPU or GPU)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

### Build model

In [None]:
# RegNet Timm implementation
#model = models.regnet.RegNet(cfg=models.regnet.model_cfgs['regnetx_002'], num_classes=10, output_stride=16)

# CaiT Timm implementation
#model = models.cait.Cait(img_size=96, num_classes=10)

# ResMLP Timm implementation
#model = models.mlp_mixer.MlpMixer(num_classes=10, img_size=96, patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=models.mlp_mixer.ResBlock, norm_layer=models.mlp_mixer.Affine)

# ResMLP Timm pretrained implementation
#model = models.mlp_mixer.resmlp_12_224(pretrained = True)

# ResMLP local implementation
model = ResMLP(in_channels=3, image_size=96, patch_size=16, num_classes=10, dim=384, depth=12, mlp_dim=384*4).to(device)

if device == 'cuda':
        model = torch.nn.DataParallel(model)
        cudnn.benchmark = True

### Loss and optimizer

In [None]:
# Hyperparameters
learning_rate = 5e-3
weight_decay = 0.2

# Loss function
loss_fn = torch.nn.CrossEntropyLoss()

# Lamb optimizer
optimizer = optim.Lamb(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Adam optimizer
#optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

### Execution

In [None]:
num_epochs = 100

with open('performance.csv', 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['epoch', 'train loss', 'test loss',
                     'train accuracy', 'test accuracy'])
train_loss, train_accuracy = 0,0
for t in range(num_epochs):
    print(f'Epoch {t+1}\n-------------------------------')
    train_loss, train_accuracy = learning_utils.train(dataset.train_dataloader, model, device, loss_fn, optimizer)
    test_loss, test_accuracy = learning_utils.test(dataset.test_dataloader, model, device, loss_fn)
    with open('performance.csv', 'a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([t, f'{train_loss:f}', f'{test_loss:f}',
                         f'{train_accuracy:f}', f'{test_accuracy:f}'])

### Save the model


In [None]:
torch.save(model.state_dict(), 'saved_model.pth')