In [2]:
import random
import wandb
import numpy         as np

import mlx
import mlx.core       as mx
import mlx.nn         as nn
import mlx.optimizers as optim
import mlx.data       as dx

In [2]:
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    mx.random.seed(seed)

set_seed(42)

### Data

In [52]:
train_images = '../quick-draw-challenge/train_images.npy'
train_labels = '../quick-draw-challenge/train_labels.npy'

val_images = '../quick-draw-challenge/val_images.npy'
val_labels = '../quick-draw-challenge/val_labels.npy'

In [53]:
def load_images_and_labels(images_path: str, labels_path: str):
    train_images = np.load(images_path)
    train_labels = np.load(labels_path)

    images_labels = []
    for image, label in zip(train_images, train_labels):
        images_labels.append(dict(image=image, label=label))

    return images_labels

In [138]:
train_dataset = (
    dx.buffer_from_vector(load_images_and_labels(images_path=train_images, 
                                                 labels_path=train_labels))
        .shuffle()
        .to_stream()
        .key_transform("image", lambda x: x.reshape(28, 28)) # transform flatten array of size 729 to 2-dim array of size 28x28
        .key_transform("image", lambda x: np.expand_dims(x, axis=-1)) # Transform HxW image to HxWxC image with one color channel
        .image_resize("image", w=224, h=224)
        .key_transform("image", lambda x: x.astype("float32"))
        .batch(256)
        .prefetch(4, 2)
)

In [141]:
validation_dataset = (
    dx.buffer_from_vector(load_images_and_labels(images_path=val_images,
                                                 labels_path=val_labels))
        .to_stream()
        .key_transform("image", lambda x: x.reshape(28, 28))
        .key_transform("image", lambda x: np.expand_dims(x, axis=-1))  
        .image_resize("image", w=224, h=224) 
        .key_transform("image", lambda x: x.astype("float32"))
        .batch(256)
        .prefetch(4, 2)
)

### ResNet18 Implementation

<img src="https://raw.githubusercontent.com/mikheevshow/mlx-convolutional-classifier/refs/heads/master/resources/resnet18_arc.png" alt="resnet18-arch" width="500"/>

In [180]:
class ResNetBuildingBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, decrease_dim: bool=False):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), padding=1)

        self.decrease_dim = decrease_dim
        if decrease_dim:
            self.conv2         = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), stride=2, padding=1)
            self.decrease_conv = nn.Conv2d(in_channels=in_channels,  out_channels=out_channels, kernel_size=(1, 1), stride=2, padding=0)
        else: 
            self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), padding=1) 

    def __call__(self, x: mx.array) -> mx.array:
        
        out = self.conv1(x)
        out = nn.relu(out)
        out = self.conv2(out)

        if self.decrease_dim:
            x = self.decrease_conv(x)
            
        out = out + x

        return out


class ResNet18(nn.Module):
    def __init__(self,in_channels:int, num_classes: int):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=(7, 7), stride=2, padding=1)
        self.max_pool = nn.MaxPool2d(kernel_size=(2, 2), stride=2, padding=1)

        self.base_blocks = [
            ResNetBuildingBlock(in_channels=64, out_channels=64),
            ResNetBuildingBlock(in_channels=64, out_channels=64),
            ResNetBuildingBlock(in_channels=64, out_channels=128, decrease_dim=True),
            ResNetBuildingBlock(in_channels=128, out_channels=128),
            ResNetBuildingBlock(in_channels=128, out_channels=256, decrease_dim=True),
            ResNetBuildingBlock(in_channels=256, out_channels=256),
            ResNetBuildingBlock(in_channels=256, out_channels=512, decrease_dim=True),
            ResNetBuildingBlock(in_channels=512, out_channels=512)
        ]

        self.average_pooling = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        self.classifier      = nn.Linear(input_dims=1179648, output_dims=num_classes)

    def __call__(self, x: mx.array):
        
        out = self.conv1(x)
        out = self.max_pool(out)
        out = nn.relu(out)

        for block in self.base_blocks[:-1]:
            out = block(out)
            out = nn.relu(out)
            
        out = self.base_blocks[-1](out)    

        print(out.shape)

        out = self.average_pooling(out)
        out = mx.flatten(out)

        print(out.shape)
        
        logits = self.classifier(out)

        return logits

### W&B Setup

Great tutorial from W&B [here](https://wandb.ai/byyoung3/ML_NEWS3/reports/Getting-started-with-Apple-MLX--Vmlldzo5Njk5MTk1)

In [181]:
use_wandb = False

if use_wandb:
    wandb.init(project="MLX_QUICK_AND_DRAW")

### Training Loop

Let's initialize model, optimizer and loss function.

In [182]:
device = mx.gpu

In [183]:
model = ResNet18(in_channels=1, num_classes=345)
mx.eval(model.parameters())

In [184]:
batch_size = 256
optimizer = optim.SGD(learning_rate=0.1, momentum=0.9, weight_decay=0.0001)

def loss_fn(model, X: mx.array, y: mx.array):
    return nn.losses.cross_entropy(model(X), y, reduction="mean")

loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

In [185]:
epoches = 20

In [186]:
for epoch in range(epoches):
    for batch in train_dataset:

        X = mx.array(batch['image'])
        y = mx.array(batch['label'])

        print(X.shape)
        print(model(X))

        loss, grads = loss_and_grad_fn(model, X, y)
        # optimizer.update(model, grads)

        # mx.eval(model.parameters(), optimizer.state)

        # print(model.parameters())

        break
        
    for batch in validation_dataset:
        break
        X = batch   

    break    
        

(256, 224, 224, 1)
(256, 7, 7, 512)
(1179648,)
array([-0.42911, 0.288466, 0.77963, ..., 2.19505, -2.78398, 1.33222], dtype=float32)
(256, 7, 7, 512)
(1179648,)


ValueError: Targets shape (256,) does not match logits shape (345,).