In [24]:
from torch import nn,save,load
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor,transforms

In [25]:
import torch
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print("Using device:", device)

Using device: mps


In [26]:
transform = transforms.ToTensor()

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [27]:
def evaluate(model, dataloader, device):
    model.eval()  
    correct = 0
    total = 0

    with torch.no_grad():  
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            predictions = output.argmax(dim=1)
            correct += (predictions == target).sum().item()
            total += target.size(0)

    accuracy = 100 * correct / total
    return accuracy

In [28]:
class ImageClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
           nn.Flatten(),
           nn.Linear(784,256),
           nn.ReLU(),
           nn.Linear(256,128),
           nn.ReLU(),
           nn.Linear(128,64),
           nn.ReLU(),
           nn.Linear(64,10)
        )
    def forward(self,x):
        return self.model(x)


In [29]:
clf=ImageClassifier()
clf.to(device)

ImageClassifier(
  (model): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=784, out_features=256, bias=True)
    (2): ReLU()
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): ReLU()
    (5): Linear(in_features=128, out_features=64, bias=True)
    (6): ReLU()
    (7): Linear(in_features=64, out_features=10, bias=True)
  )
)

In [30]:
opt=Adam(clf.parameters(),lr=1e-3)
loss_fn=nn.CrossEntropyLoss()

clf.train() 
for epoch in range(10):
    for batch in train_loader:
        X,y=batch
        X,y=X.to(device),y.to(device)
        y_hat=clf(X)

        loss=loss_fn(y_hat,y)

            #backprop
        opt.zero_grad()
        loss.backward()
        opt.step()
    print(f"Epoch:{epoch} loss:{loss.item():.4f} accuracy:{evaluate(clf,test_loader,device)}")

Epoch:0 loss:0.0521 accuracy:95.81
Epoch:1 loss:0.2613 accuracy:97.24
Epoch:2 loss:0.0798 accuracy:97.57
Epoch:3 loss:0.0014 accuracy:97.88
Epoch:4 loss:0.0344 accuracy:97.81
Epoch:5 loss:0.0150 accuracy:97.8
Epoch:6 loss:0.0064 accuracy:97.71
Epoch:7 loss:0.0850 accuracy:97.83
Epoch:8 loss:0.0488 accuracy:98.26
Epoch:9 loss:0.0142 accuracy:97.75


In [31]:
with open('handwritten.pt','wb') as f:
    save(clf.state_dict(),f)