# Use Adam till convergence point then use SGDM

In [13]:
import torch
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
from tqdm.auto import tqdm
import time
import matplotlib.pyplot as plt
from PIL import Image

# import wandb

device = torch.device("mps" if getattr(torch,'has_mps',False) else "cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

BATCH_SIZE = 256

mps


## Import dataset

In [14]:
def load_data():
    img_shape = (32, 32)
    transform_train = transforms.Compose([
        transforms.Resize(img_shape, Image.BILINEAR),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])
    transform_test = transforms.Compose([
        transforms.Resize(img_shape, Image.BILINEAR),
        transforms.ToTensor(),
    ])
    train_data = datasets.CIFAR10(root = 'data', train = True, download = True, transform = transform_train)
    test_data = datasets.CIFAR10(root = 'data', train = False, download = True, transform = transform_test)
    print('Number of training data:', len(train_data))
    print('Number of testing data:', len(test_data))

    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

    return train_data, test_data, train_loader, test_loader
  
train_data, test_data, train_loader, test_loader = load_data()

  transforms.Resize(img_shape, Image.BILINEAR),
  transforms.Resize(img_shape, Image.BILINEAR),


Files already downloaded and verified
Files already downloaded and verified
Number of training data: 50000
Number of testing data: 10000


## Build model

In [15]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(         
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.BatchNorm2d(32),
            nn.ReLU(),                      
            nn.MaxPool2d(2),    
        )
        self.conv2 = nn.Sequential(         
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.BatchNorm2d(64),
            nn.ReLU(),                      
            nn.MaxPool2d(2),                
        )
        self.out = nn.Linear(4096, 100)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)       
        output = self.out(x)
        return output

In [16]:
def train(model, train_loader, optimizer, loss_func, epochs=30):
    accuracy_lst = []
    loss_lst = []
    model.train()
    for epoch in tqdm(range(epochs), desc=f"Training progress", colour="#00ff00"):
        total_loss = 0
        correct = 0
        num_labels = 0
        counter = 0
        start_time = time.time()
        for batch_idx, (X, y) in enumerate(tqdm(train_loader, leave=False, desc=f"Epoch {epoch + 1}/{epochs}", colour="#005500")):
            X = X.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = model(X)
            loss = loss_func(output, y)
            total_loss += loss.item()
            loss.backward()
            optimizer.step()

            predicted = torch.max(output.data, 1)[1]
            correct += (predicted == y).sum()
            num_labels += len(y)
            counter += 1
        accuracy_lst.append((correct/num_labels).cpu().item())
        loss_lst.append(total_loss/counter)

        end_time = time.time()

        # wandb.log({'Accuracy': accuracy_lst[-1], 'Loss': loss_lst[-1], 'Time': end_time-start_time})

        print('Epoch %d, Loss %4f, Accuracy %4f, finished in %.4f seconds' % (epoch+1, total_loss/counter, correct/num_labels, end_time-start_time))
    
    return accuracy_lst, loss_lst

In [17]:
def evaluate(model, test_loader, loss_func):
    total_loss = 0
    correct = 0
    num_labels = 0
    counter = 0
    model.eval()
    for batch_idx, (X, y) in enumerate(test_loader):
        X = X.to(device)
        y = y.to(device)

        output = model(X)

        loss = loss_func(output, y)
        total_loss += loss.item()

        predicted = torch.max(output,1)[1]
        correct += (predicted == y).sum()
        num_labels += len(y)
        counter += 1
    print('Test Loss %4f, Test Accuracy %4f' % (total_loss/counter, correct/num_labels))

## Training

In [18]:
lr = 0.001
# NAdam
# NAdam_run = wandb.init(project="CSI 5340 Project", entity="kwang126", name='Adam')
model = CNN().to(device)
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = lr, weight_decay=5e-4)

accuracy_lst_NAdam, loss_lst_NAdam = train(model, train_loader, optimizer, loss_func, 100)
evaluate(model, test_loader, loss_func)
# NAdam_run.finish()

Training progress:   0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1/100:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 1, Loss 1.571526, Accuracy 0.436940, finished in 9.5242 seconds


Epoch 2/100:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 2, Loss 1.252128, Accuracy 0.555360, finished in 9.3219 seconds


Epoch 3/100:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 3, Loss 1.117868, Accuracy 0.605340, finished in 9.6804 seconds


Epoch 4/100:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 4, Loss 1.041975, Accuracy 0.636120, finished in 10.3401 seconds


Epoch 5/100:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 5, Loss 1.001487, Accuracy 0.648540, finished in 9.9824 seconds


Epoch 6/100:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 6, Loss 0.957515, Accuracy 0.667720, finished in 10.7398 seconds


Epoch 7/100:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 7, Loss 0.926388, Accuracy 0.679940, finished in 10.1472 seconds


Epoch 8/100:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 8, Loss 0.903981, Accuracy 0.688820, finished in 9.6908 seconds


Epoch 9/100:   0%|          | 0/196 [00:00<?, ?it/s]

Epoch 9, Loss 0.885048, Accuracy 0.692600, finished in 9.6198 seconds


Epoch 10/100:   0%|          | 0/196 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [8]:
from tqdm.notebook import tqdm_notebook
import time

In [2]:
for i in tqdm(range(10), desc=f"Training progress", colour="#00ff00"):
    print(i)
    time.sleep(1)

Training progress:   0%|          | 0/10 [00:00<?, ?it/s]

0
1
2
3
4
5
6
7
8
9
