In [23]:
from torch import nn,save,load
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor,transforms

In [24]:
import torch
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print("Using device:", device)

Using device: mps


In [25]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                         std=[0.2023, 0.1994, 0.2010])
])


train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [26]:
def evaluate(model, dataloader, device):
    model.eval()  
    correct = 0
    total = 0

    with torch.no_grad():  
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            predictions = output.argmax(dim=1)
            correct += (predictions == target).sum().item()
            total += target.size(0)

    accuracy = 100 * correct / total
    return accuracy

In [None]:
class ImageClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),    
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.MaxPool2d(2, 2),                          

            nn.Conv2d(32, 64, kernel_size=3, padding=1),   
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.MaxPool2d(2, 2),                            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),  
            nn.BatchNorm2d(128),                           
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.MaxPool2d(2, 2),                            

            nn.Flatten(),                                  
            nn.Linear(2048, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 10)                              
        )

    def forward(self,x):
        return self.model(x)


In [28]:
clf=ImageClassifier()
clf.to(device)

ImageClassifier(
  (model): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.5, inplace=False)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): ReLU()
    (8): Dropout(p=0.5, inplace=False)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU()
    (13): Dropout(p=0.5, inplace=False)
    (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (15): Flatten(start_dim=1, end_dim=-1)
  

In [29]:
opt=Adam(clf.parameters(),lr=1e-3)
loss_fn=nn.CrossEntropyLoss()

clf.train() 
for epoch in range(10):
    for batch in train_loader:
        X,y=batch
        X,y=X.to(device),y.to(device)
        y_hat=clf(X)

        loss=loss_fn(y_hat,y)

            #backprop
        opt.zero_grad()
        loss.backward()
        opt.step()
    print(f"Epoch:{epoch} loss:{loss.item():.4f} accuracy:{evaluate(clf,test_loader,device)}")

Epoch:0 loss:1.5003 accuracy:22.45
Epoch:1 loss:1.8313 accuracy:64.83
Epoch:2 loss:0.3883 accuracy:68.88
Epoch:3 loss:0.3128 accuracy:73.52
Epoch:4 loss:0.9699 accuracy:75.32
Epoch:5 loss:0.4563 accuracy:75.54
Epoch:6 loss:0.1743 accuracy:76.65
Epoch:7 loss:0.2168 accuracy:75.86
Epoch:8 loss:0.4353 accuracy:76.08
Epoch:9 loss:0.3968 accuracy:75.39


In [31]:
with open('cifar10_cnn.pt','wb') as f:
    save(clf.state_dict(),f)