In [1]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt
import torchvision

In [2]:
# Check if a GPU is available and use it
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
# Define data preprocessing and augmentation for CIFAR-10 dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # EfficientNet expects 224x224 images
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [4]:
# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
# Load pre-trained EfficientNetB0 from torchvision
model = models.efficientnet_b0(pretrained=True)



In [6]:
# Modify the classifier to fit the number of classes in CIFAR-10 (10 classes)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 10)

In [7]:
# Move the model to the GPU (if available)
model = model.to(device)

In [8]:
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [13]:
# Save training state
def save_checkpoint(model, optimizer, epoch, batch_idx, file_path="checkpoint.pth"):
    state = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'batch_idx': batch_idx
    }
    torch.save(state, file_path)
    print(f"Checkpoint saved at epoch {epoch}, batch {batch_idx}")

In [14]:
# Load training state
def load_checkpoint(model, optimizer, file_path="checkpoint.pth"):
    if os.path.exists(file_path):
        checkpoint = torch.load(file_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        batch_idx = checkpoint['batch_idx']
        print(f"Checkpoint loaded: epoch {epoch}, batch {batch_idx}")
        return epoch, batch_idx
    else:
        print("No checkpoint found. Starting training from scratch.")
        return 0, 0

In [17]:
# Load checkpoint if it exists
start_epoch, start_batch = load_checkpoint(model, optimizer)

  checkpoint = torch.load(file_path)


Checkpoint loaded: epoch 0, batch 294


In [16]:
# Train the model and display progress for each batch
num_epochs = 10
for epoch in range(start_epoch, num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")

    # Set model to training mode
    model.train()
    running_loss = 0.0
    total_correct = 0
    total = 0

    # Training loop
    for batch_idx, (images, labels) in enumerate(train_loader):
        # Skip already completed batches if resuming
        if epoch == start_epoch and batch_idx < start_batch:
            continue

        images, labels = images.to(device), labels.to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Accumulate metrics
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total_correct += (predicted == labels).sum().item()
        total += labels.size(0)

        # Save checkpoint after each batch
        save_checkpoint(model, optimizer, epoch, batch_idx)

        # Print statistics every 100 batches
        if (batch_idx + 1) % 100 == 0:
            print(f'Batch [{batch_idx + 1}/{len(train_loader)}], '
                  f'Loss: {running_loss / (batch_idx + 1):.4f}, '
                  f'Accuracy: {total_correct / total:.4f}')

    # After each epoch, evaluate on test data
    model.eval()
    test_loss = 0.0
    total_test_correct = 0
    total_test = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total_test_correct += (predicted == labels).sum().item()
            total_test += labels.size(0)

    train_loss = running_loss / len(train_loader)
    train_acc = total_correct / total
    test_loss = test_loss / len(test_loader)
    test_acc = total_test_correct / total_test

    # Print epoch summary
    print(f"Epoch [{epoch + 1}/{num_epochs}], "
          f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}, "
          f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

    # Save checkpoint at the end of the epoch
    save_checkpoint(model, optimizer, epoch + 1, 0)  # Set batch_idx to 0 for next epoch


Epoch 1/10
Checkpoint saved at epoch 0, batch 151
Checkpoint saved at epoch 0, batch 152
Checkpoint saved at epoch 0, batch 153
Checkpoint saved at epoch 0, batch 154
Checkpoint saved at epoch 0, batch 155
Checkpoint saved at epoch 0, batch 156
Checkpoint saved at epoch 0, batch 157
Checkpoint saved at epoch 0, batch 158
Checkpoint saved at epoch 0, batch 159
Checkpoint saved at epoch 0, batch 160
Checkpoint saved at epoch 0, batch 161
Checkpoint saved at epoch 0, batch 162
Checkpoint saved at epoch 0, batch 163
Checkpoint saved at epoch 0, batch 164
Checkpoint saved at epoch 0, batch 165
Checkpoint saved at epoch 0, batch 166
Checkpoint saved at epoch 0, batch 167
Checkpoint saved at epoch 0, batch 168
Checkpoint saved at epoch 0, batch 169
Checkpoint saved at epoch 0, batch 170
Checkpoint saved at epoch 0, batch 171
Checkpoint saved at epoch 0, batch 172
Checkpoint saved at epoch 0, batch 173
Checkpoint saved at epoch 0, batch 174
Checkpoint saved at epoch 0, batch 175
Checkpoint sa

KeyboardInterrupt: 