In [8]:
import torch
import torchvision
import torchvision.transforms as transforms

In [9]:
config = {
    "EPOCH" : 300,
    "BATCH_SIZE": 50,
    "VALIDATION_BATCH_SIZE": 5000,
    "DEVICE" : torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "NUM_STEPS" : 1000
}

In [10]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 5)
        self.conv2 = nn.Conv2d(64, 64, 5)
        #size = 16 // 4 - 3
        self.fc1 = nn.Linear(64 * 5 * 5, 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        #x = self.pool(F.relu(self.conv1(x)))
        #x = self.pool(F.relu(self.conv2(x)))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
net = net.to(config["DEVICE"])

In [11]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])
])

train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, 
                                        transform=transform)

# train_set, val_set = torch.utils.data.random_split(train_set, [35000, 15000])

train_loader = torch.utils.data.DataLoader(train_set, batch_size=config["BATCH_SIZE"],
                                         shuffle=True, num_workers=1)

#val_loader = torch.utils.data.DataLoader(val_set, batch_size=config["VALIDATION_BATCH_SIZE"],
#                                         shuffle=False, num_workers=1)

test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, 
                                        transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=config["VALIDATION_BATCH_SIZE"],
                                         shuffle=False, num_workers=1)


In [12]:
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-3)

scheduler = ReduceLROnPlateau(optimizer, "min", verbose=True)

In [13]:
def compute_loss_accuracy(net, dataloader):
    correct = 0
    total = 0

    with torch.no_grad():
        loss = 0
        for data in dataloader:
            images, labels = data
            images = images.to(config["DEVICE"])
            labels = labels.to(config["DEVICE"])

            outputs = net(images)
            loss += criterion(outputs, labels)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    return loss, 100 * correct / total

In [14]:
from tqdm.notebook import tqdm

#running_loss = 0.0
for j in tqdm(range(config["EPOCH"])):
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.to(config["DEVICE"])
        labels = labels.to(config["DEVICE"])
        
        optimizer.zero_grad()

        outputs = net(inputs)
        train_loss = criterion(outputs, labels)
        
        train_loss.backward()
        optimizer.step()
    
    # val_loss, val_acc = compute_loss_accuracy(net, val_loader)
    test_loss, test_acc = compute_loss_accuracy(net, test_loader)
    
    scheduler.step(test_loss)
    # print(f"Accuracy after {j + 1} steps on {total} images = {100 * correct / total} (train_loss={train_loss:.4f}, val_loss={val_loss:.4f}, test_loss={test_loss:.4f})")
    # print(f"Epoch {j:3d}: train_loss={train_loss:.4f}, val_loss={val_loss:.4f}, test_loss={test_loss:.4f}, val_acc={val_acc:.2f}, test_acc={test_acc:.2f}")
    print(f"Epoch {j:3d}: train_loss={train_loss:.4f}, test_loss={test_loss:.4f}, test_acc={test_acc:.2f}")

In [None]:
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images = images.to(config["DEVICE"])
        labels = labels.to(config["DEVICE"])
        
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy on {total} images = {100 * correct // total} % with {correct} images classified")