In [4]:
import random
import wandb
import numpy         as np

import mlx
import mlx.core       as mx
import mlx.nn         as nn
import mlx.optimizers as optim
import mlx.data       as data

In [2]:
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    mx.random.seed(seed)

set_seed(42)

### Data

### ResNet18 Implementation

In [23]:
class ResNetBuildingBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, decrease_dim: bool=False):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), padding=1)

        self.decrease_dim = decrease_dim
        if decrease_dim:
            self.conv2         = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), stride=2, padding=1)
            self.decrease_conv = nn.Conv2d(in_channels=in_channels,  out_channels=out_channels, kernel_size=(1, 1), stride=2, padding=0)
        else: 
            self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), padding=1) 

    def __call__(self, x: mx.array) -> mx.array:
        
        out = self.conv1(x)
        out = nn.relu(out)
        out = self.conv2(out)

        if self.decrease_dim:
            x = self.decrease_conv(x)
            
        out = out + x

        return out


class ResNet18(nn.Module):
    def __init__(self,in_channels:int, num_classes: int):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=(7, 7), stride=2, padding=1)
        self.max_pool = nn.MaxPool2d(kernel_size=(2, 2), stride=2, padding=1)

        self.base_blocks = [
            ResNetBuildingBlock(in_channels=64, out_channels=64),
            ResNetBuildingBlock(in_channels=64, out_channels=64),
            ResNetBuildingBlock(in_channels=64, out_channels=128, decrease_dim=True),
            ResNetBuildingBlock(in_channels=128, out_channels=128),
            ResNetBuildingBlock(in_channels=128, out_channels=256, decrease_dim=True),
            ResNetBuildingBlock(in_channels=256, out_channels=256),
            ResNetBuildingBlock(in_channels=256, out_channels=512, decrease_dim=True),
            ResNetBuildingBlock(in_channels=512, out_channels=512)
        ]

        self.average_pooling = nn.AvgPool2d(kernel_size=(2, 2))
        self.classifier      = nn.Linear(input_dims=4608, output_dims=num_classes)

    def __call__(self, x: mx.array):
        
        out = self.conv1(x)
        out = self.max_pool(out)
        out = nn.relu(out)

        for block in self.base_blocks[:-1]:
            out = block(out)
            out = nn.relu(out)
            
        out = self.base_blocks[-1](out)    
        out = self.average_pooling(out)
        out = mx.flatten(out)
        
        logits = self.classifier(out)

        return logits

### W&B Setup

Great tutorial from W&B [here](https://wandb.ai/byyoung3/ML_NEWS3/reports/Getting-started-with-Apple-MLX--Vmlldzo5Njk5MTk1)

In [5]:
use_wandb = False

if use_wandb:
    wandb.init(project="MLX_QUICK_AND_DRAW")

### Training Loop

Let's initialize model, optimizer and loss function.

In [8]:
device = mx.gpu

In [20]:
model = ResNet18(in_channels=1, num_classes=345)
mx.eval(model.parameters())

In [21]:
batch_size = 256
optimizer = optim.SGD(learning_rate=0.1, momentum=0.9, weight_decay=0.0001)

def loss_fn(model, X: mx.array, y: mx.array):
    return nn.losses.cross_entropy(model(X), y, reduction="mean")

loss_and_grad_fn = mx.value_and_grad(loss_fn)

In [7]:
epoches = 20

In [8]:
for epoch in range(epoches):
    print(epoch)
    break

0
