# CNN Part 2: Building a CNN Classifier with PyTorch

In [None]:
import os

import torch
import torch.nn as nn
from torchvision import datasets, models, transforms

from datetime import datetime

import matplotlib.pyplot as plt
%matplotlib inline

torch.hub.set_dir(os.environ['WORK'])

### Downloading dataset

Download the dataset to `$WORK`

In [None]:
! cp /work2/10000/zw427/data.tar.gz $WORK
! tar zxf $WORK/data.tar.gz -C $WORK
! ls $WORK/Dataset_2
! rm $WORK/data.tar.gz

### Hyperparameters

This notebook will use the following hyperparameters:

In [None]:
hp = {"lr":1e-4, "batch_size":16, "epochs":5}

## Dataset Loaders and Transforms

Define the path to our train and validation sets.

In [None]:
train_path = os.path.join(os.environ['WORK'], "Dataset_2/Train/")
val_path   = os.path.join(os.environ['WORK'], "Dataset_2/Validation/")
test_path  = None

Define a dataset loader. The transformation is different from part 1. 

In [None]:
def load_datasets(train_path, val_path, test_path):
    val_img_transform = transforms.Compose([transforms.Resize((244,244)),
                                         transforms.ToTensor()])
    
    #################################################
    ##################### TODO ######################  

    #  Main Modification: Additional transformation
    train_img_transform =
 
    #################################################
    #################### ANSWER #####################
    
    # train_img_transform = transforms.Compose([transforms.AutoAugment(),
    #                                            transforms.Resize((244,244)),
    #                                            transforms.ToTensor()])
    
    #################################################

    train_dataset = datasets.ImageFolder(train_path, transform=train_img_transform)
    val_dataset = datasets.ImageFolder(val_path, transform=val_img_transform) 
    test_dataset = datasets.ImageFolder(test_path, transform=val_img_transform) if test_path is not None else None
    print(f"Train set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")
    return train_dataset, val_dataset, test_dataset

In [None]:
train_set, val_set, test_set = load_datasets(train_path, val_path, test_path)

## Construct Dataloaders 
Define a dataloader constructor in the same way as part 1.

In [None]:
def construct_dataloaders(train_set, val_set, test_set, batch_size, shuffle=True):
    train_dataloader = torch.utils.data.DataLoader(train_set, batch_size, shuffle)
    val_dataloader = torch.utils.data.DataLoader(val_set, batch_size) 
    test_dataloader = torch.utils.data.DataLoader(test_set, batch_size) if test_path is not None else None
    return train_dataloader, val_dataloader, test_dataloader

In [None]:
train_dataloader, val_dataloader, test_dataloader = construct_dataloaders(train_set, val_set, test_set, hp["batch_size"], True)

## Visualizing the Design Safe Dataset

Before moving on to building CNN models, visualize the dataset first.

In [None]:
fig,axs = plt.subplots(3,3,figsize=(8, 8))
label_map={0: 'low damage', 1:'medium damage', 2:'high damage'}
for ax in axs.ravel():
    sample_idx = torch.randint(len(train_set), size=(1,)).item()
    img, label = train_set[sample_idx]
    ax.imshow(img.permute(1, 2, 0)) #.reshape((244,244,3)))
    ax.set_title(label_map[label])
fig.tight_layout()

## Building the Neural Network
### ResNet and Transfer Learning
Instantiate a model with resnet34's pretrained weights and create a new fully connected final layer.

In [None]:
def getResNet():
    resnet = models.resnet34(weights='IMAGENET1K_V1')

    # Fix the conv layers parameters
    for conv_param in resnet.parameters():
        conv_param.require_grad = False

    # get the input dimension for this layer
    num_ftrs = resnet.fc.in_features
    
    # build the new final mlp layers of network
    fc = nn.Sequential(
        nn.Linear(num_ftrs, num_ftrs),
        nn.ReLU(),
        nn.Linear(num_ftrs, 3)
    )
    
    # replace final fully connected layer
    resnet.fc = fc
    return resnet

In [None]:
resnet = getResNet()   

### Check for GPU and move model to correct device 

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

Pass resnet model to gpu (or cpu if gpu is not found).

In [None]:
resnet.to(device);

## Training the Neural Network

### Define Loss Function, Optimizer, and Label smoothing
Same optimizer and loss functions as part 1, but add label smoothing.

In [None]:
opt = torch.optim.Adam(resnet.parameters(),lr=hp["lr"])
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)

### Reduced learning rate on plateau

In [None]:
#################################################
##################### TODO ######################
    
# Add a learning rate scheduler so that the learning rate can change throughout the optimization procedure

scheduler = 
    
#################################################
############### POSSIBLE ANSWER #################
    
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, 
#                                                        mode='min',
#                                                        factor=0.1, 
#                                                        patience=2,
#                                                        min_lr=1e-8)
    
#################################################

### Setting up Checkpoints

In [None]:
def load_checkpoint(checkpoint_path, DEVICE):
    checkpoint = torch.load(checkpoint_path, map_location=torch.device(DEVICE))
    return checkpoint

Create a directory to store models and define a file name for the best model.

In [None]:
# For saving the trained model
model_folder_path = os.path.join(os.environ['WORK'], "cnn2_output_model") 
os.makedirs(model_folder_path,exist_ok=True)

# filename for the best model
checkpoint_file = os.path.join(model_folder_path, "best_model.pt")

To resume the training process, run this code to load the best previous accuracy, if it exists.

In [None]:
# load the checkpoint that has the best performance in previous experiments
prev_best_val_acc = None
checkpoint_file = os.path.join(model_folder_path, "best_model.pt")
if os.path.exists(checkpoint_file):
    checkpoint = load_checkpoint(checkpoint_file, device)
    prev_best_val_acc = checkpoint['accuracy']

### Train and Model Evaluation Functions

In [None]:
@torch.no_grad()
def eval_model(data_loader, model, loss_fn, DEVICE):
    model.eval()
    loss, accuracy = 0.0, 0.0
    n = len(data_loader)

    for i, data in enumerate(data_loader):
        x,y = data
        x,y = x.to(DEVICE), y.to(DEVICE)
        pred = model(x)
        loss += loss_fn(pred, y)/len(x)
        pred_label = torch.argmax(pred, axis = 1)
        accuracy += torch.sum(pred_label == y)/len(x)

    return loss/n, accuracy/n 

def train(train_loader, val_loader, model, opt, scheduler, loss_fn, epochs, DEVICE, checkpoint_file, prev_best_val_acc):
    n = len(train_loader)
  
    best_val_acc = torch.tensor(0.0).to(DEVICE) if prev_best_val_acc is None else prev_best_val_acc
    
    for epoch in range(epochs):
        model.train(True)
    
        avg_loss, val_loss, val_acc, avg_acc  = 0.0, 0.0, 0.0, 0.0
    
        start_time = datetime.now()
    
        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            pred = model(x)
            loss = loss_fn(pred,y)

            opt.zero_grad()
            loss.backward()
            opt.step()

            avg_loss += loss.item()/len(x)
            pred_label = torch.argmax(pred, axis=1)
            avg_acc += torch.sum(pred_label == y)/len(x)

        val_loss, val_acc = eval_model(val_loader, model, loss_fn, DEVICE)
    
        end_time = datetime.now()
    
        total_time = torch.tensor((end_time-start_time).seconds).to(DEVICE)
    
        #################################################
        ##################### TODO ######################

        # Learning rate reducer takes action

        #################################################
        #################### ANSWER #####################

        # scheduler.step(val_loss)

        #################################################
        
        print(f'lr for this epoch is {scheduler.get_last_lr()}')
    
        avg_loss, avg_acc = avg_loss/n, avg_acc/n
    
    
        if val_acc.item() > best_val_acc.item():
            print(f"\nPrev Best Val Acc: {best_val_acc} < Cur Val Acc: {val_acc}")
            print("Saving the new best model...")
            
            #################################################
            ##################### TODO ######################

            # Save the best model that has the highest val accuracy

            torch.save({})
            
            #################################################
            #################### ANSWER #####################

            # torch.save({
            #    'epoch':epoch,
            #    'model_state_dict':model.state_dict(),
            #    'accuracy':val_acc,
            #    'loss':val_loss
            # }, checkpoint_file)

            #################################################
    
            best_val_acc = val_acc
            print("Finished saving model\n")
        
        # Print the metrics (should be same on all machines)
        print(f"\n(Epoch {epoch+1}/{epochs}) Time: {total_time}s")
        print(f"(Epoch {epoch+1}/{epochs}) Average train loss: {avg_loss}, Average train accuracy: {avg_acc}")
        print(f"(Epoch {epoch+1}/{epochs}) Val loss: {val_loss}, Val accuracy: {val_acc}")  
        print(f"(Epoch {epoch+1}/{epochs}) Current best val acc: {best_val_acc}\n")  


### Train Model 
Task: Monitor Val accuracy vs. Train accuracy and check if overfitting exists.

In [None]:
train(train_dataloader, val_dataloader, resnet, opt, scheduler,loss_fn, hp["epochs"], device, checkpoint_file, prev_best_val_acc)

###  Optional Exercise
Above, you trained a ResNet34 model with data augmentation, label smoothing, and learning rate reducer. Try to train the model without these techniques, and compare the training speed and performance. 

In [None]:
#################################################
##################### TODO ######################
    
# Do not use torch.save
# It may override your previous model 
    
#################################################
############### POSSIBLE ANSWER #################

# def load_datasets(train_path, val_path, test_path):
#     val_img_transform = transforms.Compose([transforms.Resize((244, 244)),
#                                             transforms.ToTensor()])
#     train_img_transform = transforms.Compose([transforms.Resize((244, 244)),
#                                               transforms.ToTensor()])
#     train_dataset = datasets.ImageFolder(train_path, transform=train_img_transform)
#     val_dataset = datasets.ImageFolder(val_path, transform=val_img_transform)
#     test_dataset = datasets.ImageFolder(test_path, transform=val_img_transform) if test_path is not None else None
#     print(f"Train set size: {len(train_dataset)}, Validation set size: {len(val_dataset)}")
#     return train_dataset, val_dataset, test_dataset


# train_set, val_set, test_set = load_datasets(train_path, val_path, test_path)

# def construct_dataloaders(train_set, val_set, test_set, batch_size, shuffle=True):
#     train_dataloader = torch.utils.data.DataLoader(train_set, batch_size, shuffle)
#     val_dataloader = torch.utils.data.DataLoader(val_set, batch_size)
#     test_dataloader = torch.utils.data.DataLoader(test_set, batch_size) if test_set is not None else None
#     return train_dataloader, val_dataloader, test_dataloader


# train_dataloader, val_dataloader, test_dataloader = construct_dataloaders(train_set, val_set, test_set, hp["batch_size"], True)

# resnet = getResNet()
# resnet.to(device)

# opt = torch.optim.Adam(resnet.parameters(), lr=hp["lr"])

# loss_fn = nn.CrossEntropyLoss()

# def train(train_loader, val_loader, model, opt, loss_fn, epochs, DEVICE):
#     n = len(train_loader)
    
#     best_val_acc = torch.tensor(0.0).to(DEVICE)
    
#     for epoch in range(epochs):
#         model.train(True)
        
#         avg_loss, val_loss, val_acc, avg_acc = 0.0, 0.0, 0.0, 0.0
        
#         start_time = datetime.now()
        
#         for x, y in train_loader:
#             x, y = x.to(DEVICE), y.to(DEVICE)
#             pred = model(x)
#             loss = loss_fn(pred, y)
            
#             opt.zero_grad()
#             loss.backward()
#             opt.step()
            
#             avg_loss += loss.item() / len(x)
#             pred_label = torch.argmax(pred, axis=1)
#             avg_acc += torch.sum(pred_label == y) / len(x)
        
#         val_loss, val_acc = eval_model(val_loader, model, loss_fn, DEVICE)
        
#         end_time = datetime.now()
        
#         total_time = torch.tensor((end_time - start_time).seconds).to(DEVICE)
        
#         avg_loss, avg_acc = avg_loss / n, avg_acc / n
        
#         # Print the metrics (should be same on all machines)
#         print(f"\n(Epoch {epoch + 1}/{epochs}) Time: {total_time}s")
#         print(f"(Epoch {epoch + 1}/{epochs}) Average train loss: {avg_loss}, Average train accuracy: {avg_acc}")
#         print(f"(Epoch {epoch + 1}/{epochs}) Val loss: {val_loss}, Val accuracy: {val_acc}")

# train(train_dataloader, val_dataloader, resnet, opt, loss_fn, hp["epochs"], device)
    
#################################################

## Load the Best Model and Explore Performance

### Read the model checkpoint

In [None]:
def load_checkpoint(checkpoint_path, DEVICE):
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    return checkpoint

def load_model_fm_checkpoint(checkpoint, primitive_model, DEVICE):
    primitive_model.load_state_dict(checkpoint['model_state_dict'])
    return primitive_model.to(DEVICE)

def getResNet():
    resnet = models.resnet34(weights='IMAGENET1K_V1')

    # Fix the conv layers parameters
    for conv_param in resnet.parameters():
        conv_param.require_grad = False

    # get the input dimension for this layer
    num_ftrs = resnet.fc.in_features
    
    # build the new final mlp layers of network
    fc = nn.Sequential(
        nn.Linear(num_ftrs, num_ftrs),
        nn.ReLU(),
        nn.Linear(num_ftrs, 3)
    )
    
    # replace final fully connected layer
    resnet.fc = fc
    return resnet

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

model_dump_dir = checkpoint_file
model = None

try:
    ckpt = load_checkpoint(model_dump_dir, DEVICE)
    model = load_model_fm_checkpoint(ckpt, getResNet(), DEVICE)
except FileNotFoundError: 
    print(f"{model_dump_dir} does not exist, please first train the model before performing inference!") 

###  Load in the dataset


In [None]:
def load_test_datasets(test_path):
    img_transform = transforms.Compose([transforms.Resize((244,244)),transforms.ToTensor()])
    try:
        test_dataset = datasets.ImageFolder(test_path, transform=img_transform) 
    except:
        print(f"test_path: {test_path} does not exist!")
    print(f"Test set size: {len(test_dataset)}")
    return test_dataset

In [None]:
test_path   = os.path.join(os.environ['WORK'], "Dataset_2/Validation/")
test_set = load_test_datasets(test_path)

### Perform Inference on a Random Image

Tasks:
1. See if predictions match labels
2. Randomly choose images and run predictions

In [None]:
random_idx = torch.randint(0, len(test_set), size=(1,1))
sample_image, label = test_set[random_idx]
plt.imshow(sample_image.permute(1,2,0))
plt.show()
print(f"label: {label} for image_idx: {random_idx}")

sample = sample_image.unsqueeze(0).to(DEVICE)

#################################################
##################### TODO ######################
    
# Make predictions with the model

prediction = 
    
#################################################
#################### ANSWER #####################
    
# prediction = torch.argmax(model(sample))
    
#################################################

print(f"prediction result: {prediction} actual result: {label}")