In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn, optim
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, Subset
from torchvision import models

from sklearn.metrics import accuracy_score

from datetime import datetime
from time import time

  from .autonotebook import tqdm as notebook_tqdm


## Check CUDA

In [4]:
if torch.cuda.is_available():
    device = "cuda"
    print(torch.cuda.get_device_name())
else:
    device = "cpu"


Quadro RTX 3000 with Max-Q Design


## Load data

In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (0.5, ))])

# Create dataset(use 100 data for my laptop)
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
valid_set = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
train_set = Subset(train_set, list(range(100)))
valid_set = Subset(valid_set, list(range(100, 110)))

# Create data loaders for our datasets
train_loader = DataLoader(train_set, batch_size=5, shuffle=False)
valid_loader = DataLoader(valid_set, batch_size=5, shuffle=False)

print(f'## Training set has {len(train_set)} instances.')
print(f'## Validation set has {len(valid_set)} instances.')

Files already downloaded and verified
Files already downloaded and verified
## Training set has 100 instances.
## Validation set has 10 instances.


## Build model

In [6]:
model = models.resnet18(weights="IMAGENET1K_V1", progress=True).to(device)

## Loss function(Criterion) & Optimizer

In [7]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

## Training

In [13]:
def train_one_epoch(epoch_index, tb_writer):
    sum_loss, sum_acc = 0.0, 0.0
    running_loss,running_acc = 0.0, 0.0
    last_loss, last_acc = 0.0, 0.0

    START_TIME = time()
    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero your gradients for every batch!
        optimizer.zero_grad()
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()
        
        running_loss += loss.item()
        sum_loss += loss.item()
        running_acc += accuracy_score(labels.cpu(), outputs.argmax(dim=1).cpu())
        sum_acc += accuracy_score(labels.cpu(), outputs.argmax(dim=1).cpu())
        if i % 10 == 9:
            last_loss = running_loss/10
            last_acc = running_acc/10
            # print(f' - Batch {i+1} loss: {last_loss:.4f} / accuracy: {last_acc:.4f}')

            tb_x = epoch_index * len(train_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            tb_writer.add_scalar('Accuracy/train', last_acc, tb_x)
            running_loss = 0.0
            running_acc = 0.0
    END_TIME = time()

    count = len(train_loader.dataset)
    return sum_loss/count, sum_acc/count, (END_TIME-START_TIME)

In [15]:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/resnet18_trainer_{}'.format(timestamp))
epoch_number = 0

total_time = 0
EPOCHS = 5
best_vloss = 1_000_000.0

# warm up 
print(f'## start warm up')
dummy_data = torch.randn(1, 3, 32, 32).to(device)
for _ in range(1000):
    _ = model(dummy_data)
print(f'## finished warm up')

for epoch in range(EPOCHS):
    print(f'EPOCH {epoch_number+1}: ', end="")

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss, avg_acc, train_time = train_one_epoch(epoch_number, writer)
    total_time += train_time

    # Set the model to evaluation mode
    model.eval()
    running_vloss = 0.0
    running_vacc = 0.0

    with torch.no_grad():
        for i, vdata in enumerate(valid_loader):
            vinputs, vlabels = vdata
            vinputs, vlabels = vinputs.to(device), vlabels.to(device)
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels) # current batch valid loss
            vacc = accuracy_score(vlabels.cpu(), voutputs.argmax(dim=1).cpu()) # current batch valid accuracy
            running_vloss += vloss.item()
            running_vacc += vacc

    avg_vloss = running_vloss / (i + 1)
    avg_vacc = running_vacc / (i + 1)
    print(f'Train Loss: {avg_loss:.4f} / Valid Loss: {avg_vloss:.4f} / '
      f'Train Accuracy: {avg_acc:.4f} / Valid Accuracy: {avg_vacc:.4f} ({train_time:.4f} sec)')

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_model = model
        best_vloss = avg_vloss

    epoch_number += 1
print(f'== Total time: {total_time:.4f} sec ==')
model_path = f'models/model_renet18_{timestamp}_{epoch_number}.pth'
torch.save(best_model.state_dict(), model_path)

## start warm up
## finished warm up
EPOCH 1: Train Loss: 0.0014 / Valid Loss: 3.3157 / Train Accuracy: 0.2000 / Valid Accuracy: 0.4000 (0.3890 sec)
EPOCH 2: Train Loss: 0.0011 / Valid Loss: 3.3276 / Train Accuracy: 0.2000 / Valid Accuracy: 0.4000 (0.3820 sec)
EPOCH 3: Train Loss: 0.0009 / Valid Loss: 3.3384 / Train Accuracy: 0.2000 / Valid Accuracy: 0.4000 (0.3810 sec)
EPOCH 4: Train Loss: 0.0008 / Valid Loss: 3.3499 / Train Accuracy: 0.2000 / Valid Accuracy: 0.4000 (0.3820 sec)
EPOCH 5: Train Loss: 0.0007 / Valid Loss: 3.3612 / Train Accuracy: 0.2000 / Valid Accuracy: 0.4000 (0.3850 sec)
== Total time: 1.9190 sec ==


In [None]:
# 打不開阿!!!
# !tensorboard --logdir runs/resnet18_trainer_20241222_125827

## Load a saved version of the model

In [14]:
PATH = r"models\model_renet18_20241222_141355_4.pth"
saved_model = models.resnet18()
saved_model.load_state_dict(torch.load(PATH))

<All keys matched successfully>