In [183]:
import mlx
import mlx.core       as mx
import mlx.nn         as nn
import mlx.optimizers as optim
import mlx.data       as data

In [174]:
class ResNetBuildingBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, decrease_dim: bool=False):
        super(ResNetBuildingBlock).__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), padding=1)

        self.decrease_dim = decrease_dim
        if decrease_dim:
            self.conv2         = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), stride=2, padding=1)
            self.decrease_conv = nn.Conv2d(in_channels=in_channels,  out_channels=out_channels, kernel_size=(1, 1), stride=2)
        else: 
            self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), padding=1) 

    def __call__(self, x: mx.array) -> mx.array:
        
        out = self.conv1(x)
        out = nn.relu(out)
        out = self.conv2(out)

        if self.decrease_dim:
            x = self.decrease_conv(x)
            
        out = out + x

        return out


class ResNet18(nn.Module):
    def __init__(self,in_channels:int, num_classes: int):
        super(ResNet18).__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=(7, 7), stride=2, padding=1)
        self.max_pool = nn.MaxPool2d(kernel_size=(2, 2), stride=2, padding=1)
        
        self.bb1 = ResNetBuildingBlock(in_channels=64, out_channels=64)
        self.bb2 = ResNetBuildingBlock(in_channels=64, out_channels=64)

        self.bb3 = ResNetBuildingBlock(in_channels=64, out_channels=128, decrease_dim=True)
        self.bb4 = ResNetBuildingBlock(in_channels=128, out_channels=128)

        self.bb5 = ResNetBuildingBlock(in_channels=128, out_channels=256, decrease_dim=True)
        self.bb6 = ResNetBuildingBlock(in_channels=256, out_channels=256)

        self.bb7 = ResNetBuildingBlock(in_channels=256, out_channels=512, decrease_dim=True)
        self.bb8 = ResNetBuildingBlock(in_channels=512, out_channels=512)

        self.ap = nn.AvgPool2d(kernel_size=(2, 2))
        self.linear = nn.Linear(input_dims=4608, output_dims=num_classes)
        

    def __call__(self, x: mx.array):
        
        out = self.conv1(x)
        out = self.max_pool(out)

        out = self.bb1(out)
        out = nn.relu(out)

        out = self.bb2(out)
        out = nn.relu(out)

        out = self.bb3(out)
        out = nn.relu(out)

        out = self.bb4(out)
        out = nn.relu(out)

        out = self.bb5(out)
        out = nn.relu(out)

        out = self.bb6(out)
        out = nn.relu(out)

        out = self.bb7(out)
        out = nn.relu(out)

        out = self.bb8(out)
        out = nn.relu(out)

        out = self.ap(out)
        out = mx.flatten(out)
        
        logits = self.linear(out)

        return logits

In [175]:
batch_size = 256
optimizer = optim.SGD(learning_rate=0.1, momentum=0.9, weight_decay=0.0001)

In [176]:
import numpy as np

arr = np.random.random((1, 224,224, 1))
pic = mx.array(arr)
pic.shape

(1, 224, 224, 1)

In [177]:
model = ResNet18(in_channels=1, num_classes=345)
mx.eval(model.parameters())

In [179]:
logits = model(pic)

### Training Loop

In [184]:
epoches = 20

In [None]:
for epoch in range(epoches):
    break