In [10]:
import torch
from torchvision.transforms import ToTensor
from torchvision import datasets
from torch.utils.data import DataLoader


In [11]:
batch_size = 8

training_data = datasets.CIFAR10(
    root='../datasets/cifar_10',
    train=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size=batch_size)

In [12]:
from matplotlib import pyplot as plt
# Get cpu or gpu device for training.
import torch.nn.functional as F
from torch import nn
import numpy as np

from densenet import DenseNet

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")


def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    if one_channel:
        plt.imshow(npimg, cmap="Greys")
    else:
        plt.imshow(np.transpose(npimg, (1, 2, 0)))


model = DenseNet(growthRate=12,
                 depth=100,
                 reduction=0.5,
                 bottleneck=True,
                 nClasses=10).to(device)

loss_fn = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = torch.optim.SGD(model.parameters(),
                            lr=0.001,
                            momentum=0.9,
                            nesterov=True,
                            weight_decay=5e-4)


Using cuda device


In [13]:
from torch.utils.tensorboard import SummaryWriter

experiment_name = "cifar_densenet"
# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter(f'runs/{experiment_name}')

In [14]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

train(train_dataloader, model, loss_fn=loss_fn, optimizer=optimizer)

loss: 2.376384  [    0/50000]
loss: 2.417169  [  800/50000]
loss: 2.136257  [ 1600/50000]
loss: 2.156937  [ 2400/50000]
loss: 2.287895  [ 3200/50000]
loss: 2.095054  [ 4000/50000]
loss: 1.782100  [ 4800/50000]
loss: 1.629167  [ 5600/50000]
loss: 1.897258  [ 6400/50000]
loss: 2.181595  [ 7200/50000]
loss: 2.001462  [ 8000/50000]
loss: 2.183216  [ 8800/50000]
loss: 1.811096  [ 9600/50000]
loss: 1.528380  [10400/50000]
loss: 1.840214  [11200/50000]
loss: 1.811746  [12000/50000]
loss: 1.991841  [12800/50000]
loss: 2.019451  [13600/50000]
loss: 1.767434  [14400/50000]
loss: 2.031518  [15200/50000]
loss: 1.794912  [16000/50000]
loss: 1.809017  [16800/50000]
loss: 1.542620  [17600/50000]
loss: 1.568574  [18400/50000]
loss: 1.674643  [19200/50000]
loss: 2.338620  [20000/50000]
loss: 1.824275  [20800/50000]
loss: 1.665568  [21600/50000]
loss: 1.663399  [22400/50000]
loss: 1.746360  [23200/50000]
loss: 1.634168  [24000/50000]
loss: 1.827178  [24800/50000]
loss: 1.903072  [25600/50000]
loss: 1.57