In [32]:
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import random_split

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                         std=[0.2023, 0.1994, 0.2010])
])

train_dataset = CIFAR10(
    root="../data",
    train=True,
    transform=transform
)

test_dataset = CIFAR10(
    root="../data",
    train=False,
    transform= transform
)

train_size = int(0.8 * len(train_dataset))
valid_size = len(train_dataset) - train_size
train_dataset, valid_dataset = random_split(train_dataset, [train_size, valid_size])

train_dataloader = DataLoader(train_dataset, 32, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, 32, shuffle=False)
test_dataloader = DataLoader(test_dataset, 32)


In [33]:
import torch.nn as nn
class SimpleCNN(nn.Module):

    def __init__(self, in_channels, num_classes=10):
        super().__init__()
        self.layers = nn.Sequential(
            # Block 1
            nn.Conv2d(in_channels=in_channels, out_channels=16, kernel_size=3, padding=1, stride=2), 
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels= 32, kernel_size=3 , padding=1),
            nn.MaxPool2d(2, 2),

            # Block 2
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size= 3, padding=1),
            nn.ReLU(),
            # nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.MaxPool2d(2, 2),

            # FC
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 256), 
            nn.ReLU(),
            nn.Dropout2d(p=0.2),
            nn.Linear(256, num_classes)
        )

    def forward(self, X):
        preds = self.layers(X)
        return preds




In [None]:
from  torch.optim import AdamW
from torchmetrics import Precision, Recall

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN(3, 10)
model.to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=0.001, weight_decay=0.01)

precision = Precision(task="multiclass", num_classes=10, average='macro').to(device)
recall = Recall(task="multiclass", num_classes=10, average='macro').to(device)

model.train()
def train(model, dataloader, loss_fn, optimizer):
    expected_loss, samples, num_correct = 0,0,0

    for batch, (X,y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        optimizer.zero_grad()

        preds = model(X)
        loss = loss_fn(preds, y)

        loss.backward()
        optimizer.step()

        samples += y.size(0)
        expected_loss += loss.item() * y.size(0)

        pred_classes = torch.argmax(preds, dim=1)
        num_correct += (pred_classes == y).sum().item()

        precision.update(pred_classes, y)
        recall.update(pred_classes, y)
        # if batch % 500 == 0:
        #     print(f"batch : {batch} loss: {expected_loss/ samples}")

    avg_loss = expected_loss/ samples
    accuracy = num_correct / samples
    return avg_loss, accuracy

def validate(model, dataloader, loss_fn):
    model.eval()
    with torch.no_grad():
        total_loss = 0
        num_correct = 0
        num_samples = 0
        for X,y in dataloader:
            preds = model(X)
            loss = loss_fn(preds, y)

            num_samples += y.size(0)
            total_loss += loss.item() * y.size(0)
            pred_classes = torch.argmax(preds, dim=1)
            num_correct += (pred_classes == y).sum().item()
        print(f"Validation Loss: {total_loss/num_samples} Validation Accuracy: {num_correct/num_samples}")


epochs= 10
for i in range(epochs):
    avg_loss, accuracy = train(model, train_dataloader, loss_fn, optimizer)
    validate(model, valid_dataloader, loss_fn)
    print(f"Epoch {i}: avg_loss :{ avg_loss}, accuracy :{accuracy}")

    print(f"Final Precision: {precision.compute()}")
    print(f"Final Recall: {recall.compute()}")
    precision.reset()
    recall.reset()





Validation Loss: 1.1138532507896424 Validation Accuracy: 0.6046
Epoch 0: avg_loss :1.4060361347198487, accuracy :0.492
Final Precision: 0.48588529229164124
Final Recall: 0.49191686511039734
Validation Loss: 0.9558962268829345 Validation Accuracy: 0.6611
Epoch 1: avg_loss :0.9967478182554245, accuracy :0.646125
Final Precision: 0.6434494256973267
Final Recall: 0.6460368037223816
Validation Loss: 0.8907347256660462 Validation Accuracy: 0.6885
Epoch 2: avg_loss :0.8186102046489716, accuracy :0.71195
Final Precision: 0.7099288702011108
Final Recall: 0.7118694186210632
Validation Loss: 0.8594109941482544 Validation Accuracy: 0.699
Epoch 3: avg_loss :0.6973984870672226, accuracy :0.75245
Final Precision: 0.7510395646095276
Final Recall: 0.7523576021194458
Validation Loss: 0.8623089943885803 Validation Accuracy: 0.711
Epoch 4: avg_loss :0.5841926406264305, accuracy :0.793125
Final Precision: 0.79215407371521
Final Recall: 0.7930319905281067
Validation Loss: 0.9418686904907226 Validation Accur

In [28]:

model.eval()
with torch.no_grad():
    total_loss = 0
    num_correct = 0
    num_samples = 0
    for X,y in test_dataloader:
        preds = model(X)
        loss = loss_fn(preds, y)

        num_samples += y.size(0)
        total_loss += loss.item() * y.size(0)
        pred_classes = torch.argmax(preds, dim=1)
        num_correct += (pred_classes == y).sum().item()

    print(precision(pred_classes, y).item())
    print(recall(pred_classes,y).item())
    
    print(f"Loss : {total_loss/num_samples}, Accuracy : {num_correct/num_samples}")

        






0.6388888955116272
0.5925925970077515
Loss : 0.9322456899166107, Accuracy : 0.7251


In [29]:
torch.save(model.state_dict(), "simple_cnn.pth")