In [24]:
import torch, sys, time, os

print("")
print("Python version :", sys.version)
print("torch version : ", torch.__version__)
print("cuda available : ", torch.cuda.is_available())
print(torch.cuda.get_device_name(0))
print("cudnn ver : ", torch.backends.cudnn.version())
print("cudnn enabled:", torch.backends.cudnn.enabled)
device = str(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
print(f"Using device: {device}")
print("")


Python version : 3.12.1 | packaged by Anaconda, Inc. | (main, Jan 19 2024, 15:51:05) [GCC 11.2.0]
torch version :  2.2.1
cuda available :  True
NVIDIA GeForce RTX 3090
cudnn ver :  8902
cudnn enabled: True
Using device: cuda:0



In [25]:
NUM_EPOCHS = 100
BATCH = 128
SAVE_MODEL_EVERY_N_EPOCH = 5
CONT_FROM_SAVE = "latest"  # 0, "latest", "epoch_100

In [46]:
# use save?

if CONT_FROM_SAVE == 0:
    print("Starting from scratch")
elif CONT_FROM_SAVE == "latest":

    folder_path = "resnet18_cifar10"
    file_extension = ".pth"

    # Get a list of all pth files in the folder
    pth_files = [file for file in os.listdir(folder_path) if file.endswith(file_extension)]
    # Sort the pth files based on the epoch number
    sorted_files = sorted(
        pth_files, key=lambda x: int(x.split("_epoch")[1].split(".pth")[0])
    )

    # Get the file with the highest epoch number
    latest_file = sorted_files[-1]

    # Extract the epoch number from the file name
    latest_epoch = int(latest_file.split("_epoch")[1].split(".pth")[0])

    print(f"Continuing from latest save {latest_epoch} epochs")

else:
    print(f"Continuing from {CONT_FROM_SAVE} epochs")

Continuing from latest save 15 epochs


In [27]:
import torch
import torchvision
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152
import torchvision.models.resnet as resnet
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms


# Load the ResNet-50 model
model = resnet18(weights=resnet.ResNet18_Weights.DEFAULT).to(device)

# Set up training and evaluation processes
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Load CIFAR-10 dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

trainset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
    trainset,
    batch_size=BATCH,
    shuffle=True,
    num_workers=8,
    pin_memory=True,
    pin_memory_device=device,
)

testset = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(
    testset,
    batch_size=BATCH,
    shuffle=False,
    num_workers=8,
    pin_memory=True,
    pin_memory_device=device,
)

classes = (
    "plane",
    "car",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
)

Files already downloaded and verified
Files already downloaded and verified


In [28]:
def SingleEpochTrain(
    model, trainloader, criterion, optimizer, device, verb=True
) -> tuple:
    # Training loop
    model.train()

    running_loss = 0.0
    correct = 0
    total = 0
    start_time = time.time()
    for i, (images, labels) in enumerate(trainloader, start=0):

        images, labels = images.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        if i % 100 == 0 and verb == True:
            now_time = time.time()
            print(
                f"Batch {str(i).rjust(len(str(len(trainloader))))}/{len(trainloader)} ({now_time - start_time:.2f}s) | train_loss: {loss.item():.4f} | train_acc: {correct/total*100:.2f}%"
            )

    train_loss = running_loss / len(trainloader)
    train_acc = 100.0 * correct / total

    return train_loss, train_acc

In [29]:
def SingleEpochEval(
    model, testloader, criterion, device
) -> tuple:
    # Evaluation loop
    model.eval()

    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for i, (images, labels) in enumerate(trainloader, start=0):

            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Print statistics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    eval_loss = running_loss / len(testloader)
    eval_acc = 100.0 * correct / total

    return eval_loss, eval_acc

In [30]:
# Training loop
model.train()
for epoch in range(NUM_EPOCHS):  # Change the number of epochs as needed
    start_time = time.time()
    train_loss, train_acc = SingleEpochTrain(
        model=model,
        trainloader=trainloader,
        criterion=criterion,
        optimizer=optimizer,
        device=device,
        verb=False,
    )
    eval_loss, eval_acc = SingleEpochEval(model=model, testloader=testloader, criterion=criterion, device=device)
    end_time = time.time()
    print(
        f"Epoch {str(epoch + 1).rjust(len(str(NUM_EPOCHS)))}/{NUM_EPOCHS} ({end_time - start_time:.2f}s) | train_loss: {train_loss:.4f} | train_acc: {train_acc:.2f}% | eval_loss: {eval_loss:.4f} | eval_acc: {eval_acc:.2f}%"
    )
    if (epoch + 1) % SAVE_MODEL_EVERY_N_EPOCH == 0:
        torch.save(
            model.state_dict(), f"resnet18_cifar10/resnet18_cifar10_epoch{epoch+1}.pth"
        )

print("Finished training")

Epoch   1/100 (4.19s) | train_loss: 1.5430 | train_acc: 56.61% | eval_loss: 3.7743 | eval_acc: 73.74%
Epoch   2/100 (4.17s) | train_loss: 0.7532 | train_acc: 73.88% | eval_loss: 2.7610 | eval_acc: 80.70%
Epoch   3/100 (4.24s) | train_loss: 0.5934 | train_acc: 79.29% | eval_loss: 2.1294 | eval_acc: 85.30%
Epoch   4/100 (4.37s) | train_loss: 0.4801 | train_acc: 83.25% | eval_loss: 1.6696 | eval_acc: 88.78%
Epoch   5/100 (4.23s) | train_loss: 0.3911 | train_acc: 86.37% | eval_loss: 1.1826 | eval_acc: 92.45%
Epoch   6/100 (4.23s) | train_loss: 0.3082 | train_acc: 89.30% | eval_loss: 0.8549 | eval_acc: 94.81%
Epoch   7/100 (4.21s) | train_loss: 0.2395 | train_acc: 91.73% | eval_loss: 0.6339 | eval_acc: 96.13%
Epoch   8/100 (4.15s) | train_loss: 0.1816 | train_acc: 93.72% | eval_loss: 0.4057 | eval_acc: 97.68%
Epoch   9/100 (4.23s) | train_loss: 0.1420 | train_acc: 95.03% | eval_loss: 0.2725 | eval_acc: 98.53%
Epoch  10/100 (4.13s) | train_loss: 0.1161 | train_acc: 96.04% | eval_loss: 0.1974

KeyboardInterrupt: 