In [3]:
import torch
import torch.nn as nn
from torchvision.transforms import transforms, ToTensor
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
import torchvision.models as models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((224,224)),
    ToTensor()
])

train_dataset = CIFAR10(
    root="../data",
    train = True,
    transform=transform,
)

test_dataset = CIFAR10(
    root = "../data",
    train =  True,
    transform= transform
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size = 32, shuffle=False, num_workers=0)



In [11]:
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

for param in model.parameters():
    param.requires_grad = False

num_classes = 10
model.fc =  nn.Linear(model.fc.in_features, num_classes)

model = model.to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.fc.parameters(), lr=0.001)

In [None]:
epochs = 5

def train(dataloader, model, loss_fn, optimizer):
    epoch_loss = 0
    total_correct = 0
    total_samples = 0
    
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()

        logits = model(X)
        loss = loss_fn(logits, y)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item() * y.size(0)
        preds = (torch.argmax(logits, dim = 1))
        total_correct += (preds == y).sum().item()
        total_samples += y.size(0)

        if batch % 100 == 0:
            print("---------------------------------")
            print(f"Batch : {batch}")
            print(f"Loss : {epoch_loss/total_samples}, Accuracy={total_correct/total_samples:.4f}")

    avg_loss = epoch_loss / total_samples
    accuracy = total_correct / total_samples
    return avg_loss, accuracy

model.train()
num_batches = len(train_loader)

for epoch in range(epochs):
    avg_loss, avg_accuracy = train(dataloader=train_loader, model=model, loss_fn=loss_fn, optimizer=optimizer)
    print(f"avg loss: {avg_loss}, avg accurracy: {avg_accuracy}")




---------------------------------
Batch : 0
Loss : 2.575059413909912, Accuracy=0.0625
---------------------------------
Batch : 100
Loss : 1.6056577852456877, Accuracy=0.4824
