# CNN on CIFAR10 with different optimizers

## Check the GPU we got

In [1]:
%pip install wandb
!wandb login

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
[34m[1mwandb[0m: Currently logged in as: [33mkwang126[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
%nvidia-smi

Mon Nov 21 20:38:46 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   58C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Import libraries

In [3]:
import torch
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
from tqdm.auto import tqdm
import time
import matplotlib.pyplot as plt

import wandb

device = torch.device("mps" if getattr(torch,'has_mps',False) else "cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

BATCH_SIZE = 256

cuda:0


## Load dataset

In [4]:
def load_data():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    train_data = datasets.CIFAR10(root = 'data', train = True, download = True, transform = transform)
    test_data = datasets.CIFAR10(root = 'data', train = False, download = True, transform = transform)
    print('Number of training data:', len(train_data))
    print('Number of testing data:', len(test_data))

    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

    return train_data, test_data, train_loader, test_loader
  
train_data, test_data, train_loader, test_loader = load_data()

Files already downloaded and verified
Files already downloaded and verified
Number of training data: 50000
Number of testing data: 10000


## Build model

In [5]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(         
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.BatchNorm2d(32),
            nn.ReLU(),                      
            nn.MaxPool2d(2),    
        )
        self.conv2 = nn.Sequential(         
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.BatchNorm2d(64),
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )
        self.out = nn.Linear(4096, 100)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)       
        output = self.out(x)
        return output

## Define training and testing loop

In [6]:
def train(model, train_loader, optimizer, opt_name, loss_func, epochs=30):
    accuracy_lst = []
    loss_lst = []
    model.train()
    for epoch in tqdm(range(epochs), desc=f"Training progress "+opt_name, colour="#00ff00"):
        total_loss = 0
        correct = 0
        num_labels = 0
        counter = 0
        start_time = time.time()
        for batch_idx, (X, y) in enumerate(tqdm(train_loader, leave=False, desc=f"Epoch {epoch + 1}/{epochs}", colour="#005500")):
            X = X.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = model(X)
            loss = loss_func(output, y)
            total_loss += loss.item()
            loss.backward()
            optimizer.step()

            predicted = torch.max(output.data, 1)[1]
            correct += (predicted == y).sum()
            num_labels += len(y)
            counter += 1
        accuracy_lst.append((correct/num_labels).cpu().item())
        loss_lst.append(total_loss/counter)

        end_time = time.time()

        wandb.log({'Accuracy': accuracy_lst[-1], 'Loss': loss_lst[-1], 'Time': end_time-start_time})

        print('Epoch %d, Loss %4f, Accuracy %4f, finished in %.4f seconds' % (epoch+1, total_loss/counter, correct/num_labels, end_time-start_time))
    
    return accuracy_lst, loss_lst

In [7]:
def evaluate(model, test_loader, opt_name, loss_func):
    total_loss = 0
    correct = 0
    num_labels = 0
    counter = 0
    model.eval()
    for batch_idx, (X, y) in enumerate(train_loader):
        X = X.to(device)
        y = y.to(device)

        output = model(X)

        loss = loss_func(output, y)
        total_loss += loss.item()

        predicted = torch.max(output,1)[1]
        correct += (predicted == y).sum()
        num_labels += len(y)
        counter += 1
    print('Test Loss %4f, Test Accuracy %4f' % (total_loss/counter, correct/num_labels))

## Train with different optimizers

In [8]:
lr = 0.001
# NAdam
NAdam_run = wandb.init(project="CSI 5340 Project", entity="kwang126", name='NAdam-0.001')
model = CNN().to(device)
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.NAdam(model.parameters(), lr = lr)

accuracy_lst_NAdam, loss_lst_NAdam = train(model, train_loader, optimizer, 'NAdam', loss_func, 30)
evaluate(model, test_loader, 'NAdam', loss_func)
NAdam_run.finish()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkwang126[0m. Use [1m`wandb login --relogin`[0m to force relogin


Training progress NAdam:   0%|          | 0/30 [00:00<?, ?it/s]

Epoch 1/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 1, Loss 1.498732, Accuracy 0.471000, finished in 16.7414 seconds


Epoch 2/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 2, Loss 1.072439, Accuracy 0.623820, finished in 11.7665 seconds


Epoch 3/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 3, Loss 0.913082, Accuracy 0.681700, finished in 11.5887 seconds


Epoch 4/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 4, Loss 0.822887, Accuracy 0.714980, finished in 11.5682 seconds


Epoch 5/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 5, Loss 0.754297, Accuracy 0.740820, finished in 11.9058 seconds


Epoch 6/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 6, Loss 0.700864, Accuracy 0.757700, finished in 12.9045 seconds


Epoch 7/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 7, Loss 0.652224, Accuracy 0.775620, finished in 11.7853 seconds


Epoch 8/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 8, Loss 0.614407, Accuracy 0.788500, finished in 11.5888 seconds


Epoch 9/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 9, Loss 0.566742, Accuracy 0.806640, finished in 11.6849 seconds


Epoch 10/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 10, Loss 0.535085, Accuracy 0.817980, finished in 11.8652 seconds


Epoch 11/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 11, Loss 0.506884, Accuracy 0.826460, finished in 11.8723 seconds


Epoch 12/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 12, Loss 0.472713, Accuracy 0.838600, finished in 11.9014 seconds


Epoch 13/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 13, Loss 0.438640, Accuracy 0.852280, finished in 12.0184 seconds


Epoch 14/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 14, Loss 0.416985, Accuracy 0.859980, finished in 11.7528 seconds


Epoch 15/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 15, Loss 0.393296, Accuracy 0.867360, finished in 11.9239 seconds


Epoch 16/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 16, Loss 0.367161, Accuracy 0.877900, finished in 11.6219 seconds


Epoch 17/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 17, Loss 0.339917, Accuracy 0.886480, finished in 11.6909 seconds


Epoch 18/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 18, Loss 0.323716, Accuracy 0.892240, finished in 11.8014 seconds


Epoch 19/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 19, Loss 0.298850, Accuracy 0.902760, finished in 11.7814 seconds


Epoch 20/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 20, Loss 0.279749, Accuracy 0.908680, finished in 11.9813 seconds


Epoch 21/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 21, Loss 0.262542, Accuracy 0.914340, finished in 11.4837 seconds


Epoch 22/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 22, Loss 0.239555, Accuracy 0.923300, finished in 11.7138 seconds


Epoch 23/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 23, Loss 0.221349, Accuracy 0.930340, finished in 11.8196 seconds


Epoch 24/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 24, Loss 0.211381, Accuracy 0.933820, finished in 11.7347 seconds


Epoch 25/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 25, Loss 0.194515, Accuracy 0.940200, finished in 11.5161 seconds


Epoch 26/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 26, Loss 0.179377, Accuracy 0.945340, finished in 11.5030 seconds


Epoch 27/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 27, Loss 0.168221, Accuracy 0.950020, finished in 11.8486 seconds


Epoch 28/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 28, Loss 0.149900, Accuracy 0.956780, finished in 11.8811 seconds


Epoch 29/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 29, Loss 0.142823, Accuracy 0.958640, finished in 11.6871 seconds


Epoch 30/30:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 30, Loss 0.136464, Accuracy 0.961640, finished in 11.7562 seconds
Test Loss 0.549158, Test Accuracy 0.824080


0,1
Accuracy,▁▃▄▄▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇████████
Loss,█▆▅▅▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
Time,█▁▁▁▂▃▁▁▁▂▂▂▂▁▂▁▁▁▁▂▁▁▁▁▁▁▁▂▁▁

0,1
Accuracy,0.96164
Loss,0.13646
Time,11.75621


## Show results

In [None]:
epoch_lst = range(30)
plt.plot(epoch_lst, accuracy_lst_SGD, label = "SGD")
#plt.plot(epoch_lst, accuracy_lst_Adam, label = "Adam")
#plt.plot(epoch_lst, accuracy_lst_NAdam, label = "NAdam")
#plt.plot(epoch_lst, accuracy_lst_RMSprop, label = "RMSprop")
#plt.plot(epoch_lst, accuracy_lst_LBFGS, label = "LBFGS")

from matplotlib.pyplot import MultipleLocator
y = MultipleLocator(0.01)    # x轴每10一个刻度
# 设置刻度间隔
ax = plt.gca()
ax.yaxis.set_major_locator(y)

plt.legend()
plt.title('Training Accuracy on CIFAR10')
# plt.savefig(FILE_PATH + 'Training Accuracy on CIFAR10.png')
plt.show()

In [None]:
epoch_lst = range(30)
plt.plot(epoch_lst, loss_lst_SGD, label = "SGD")
plt.plot(epoch_lst, loss_lst_Adam, label = "Adam")
plt.plot(epoch_lst, loss_lst_NAdam, label = "NAdam")
plt.plot(epoch_lst, loss_lst_RMSprop, label = "RMSprop")
#plt.plot(epoch_lst, loss_lst_LBFGS, label = "LBFGS")
plt.legend()
plt.title('Training Loss on CIFAR10')
# plt.savefig(FILE_PATH + 'Training Loss on CIFAR10.png')
plt.show()