In [46]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_unflatten
import numpy as np
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor, Compose, Normalize
from torch.utils.data import DataLoader, Subset
import time
import matplotlib.pyplot as plt
import random
import os

In [48]:


BATCH_SIZE = 512

# 1. 数据集基本信息
def create_dataset(sample_ratio=0.1):
# 加载数据集:训练集数据和测试数据
    transform = Compose([
    ToTensor(),
    Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])
    # 确保数据目录存在
    os.makedirs('data', exist_ok=True)
    train_data = CIFAR10(root='data', train=True, transform=transform, download=True)
    test_data = CIFAR10(root='data', train=False, transform=transform, download=True)
    total_len = len(train_data)
    sample_size = int(total_len * sample_ratio)
    indices = random.sample(range(total_len), sample_size)
    train_subset = Subset(train_data, indices)
# 返回数据集结果
    return train_subset,test_data
class ImageClassification(nn.Module):
    def __init__(self):
        super().__init__()
        # 简化版ResNet18结构适配MLX
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.conv5 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, 10)

    def __call__(self, x):
        x = nn.relu(self.conv1(x))
        x = nn.relu(self.conv2(x))
        x = nn.relu(self.conv3(x))
        x = nn.relu(self.conv4(x))
        x = nn.relu(self.conv5(x))
        x = self.pool(x)
        return self.fc(x.reshape(-1, 512))

In [50]:
def train(model, train_data):
    # MLX适配的训练函数
    loss_fn = nn.losses.CrossEntropy()
    optimizer = optim.AdamW(learning_rate=0.001)

    # 转换数据为MLX数组
    def convert_batch(batch):
        images, labels = batch
        return mx.array(images.numpy()), mx.array(labels.numpy())

    epochs = 10
    for epoch in range(epochs):
        data_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
        total_loss = 0
        start_time = time.time()

        for batch in data_loader:
            x, y = convert_batch(batch)

            def loss_fn(model, x, y):
                return mx.mean(nn.losses.cross_entropy(model(x), y))

            loss, grads = nn.value_and_grad(model, loss_fn)(x, y)
            optimizer.update(model, grads)
            mx.eval(model.parameters(), optimizer.state)

            total_loss += loss.item()

        avg_loss = total_loss / len(data_loader)
        print(f'epoch:{epoch+1:2d} loss:{avg_loss:.5f} time:{time.time()-start_time:.2f}s')

    # 保存模型
    mx.savez("model.npz", **tree_unflatten(list(model.parameters().items())))
if __name__ == '__main__':
    # 初始化模型
    model = ImageClassification()
    print("MLX模型结构:")
    for name, param in model.parameters().items():
        print(f"{name}: {param.shape}")

    # 加载数据
    train_data, test_data = create_dataset(sample_ratio=1.0)

    # 完整训练测试流程
    train(model, train_data)

AttributeError: module 'mlx.nn' has no attribute 'AdaptiveAvgPool2d'

In [54]:
def test(test_data):
    # 加载模型
    model = ImageClassification()
    model.load_weights("model.npz")

    # 数据转换函数
    def convert_batch(batch):
        images, labels = batch
        return mx.array(images.numpy()), mx.array(labels.numpy())

    data_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)
    correct = 0
    total = 0

    for batch in data_loader:
        x, y = convert_batch(batch)
        y_pred = model(x)
        correct += (mx.argmax(y_pred, axis=1) == y).sum().item()
        total += len(y)

    print(f'Test Accuracy: {100 * correct / total:.2f}%')