# Notebook to debug Hyper Parameter Optimization

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

import argparse

In [None]:
train_suffix ='train'
val_suffix = 'val'
MAX_SAMPLES_PROPORTION = 0.01

def log_metrics(loss, running_corrects, running_samples, total_samples):
    accuracy = running_corrects / running_samples
    print("Images [{}/{} ({:.0f}%)] Loss: {:.2f} Accuracy: {}/{} ({:.2f}%)".format(
        running_samples,
        total_samples,
        100.0 * (running_samples / total_samples),
        loss.item(),
        running_corrects,
        running_samples,
        100.0 * accuracy,
        )
    )
    
def test(model, test_loader, criterion):
    print("Testing Model on Whole Testing Dataset")
    model.eval()
    running_loss = 0
    running_corrects = 0
    running_samples = 0
    total_samples = len(test_loader.dataset)

    for inputs, labels in test_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data).item()
        running_samples += len(inputs)
        
        #NOTE: Comment lines below to train and test on whole dataset
        if running_samples > (MAX_SAMPLES_PROPORTION * total_samples):
            break

    total_loss = running_loss / len(test_loader.dataset)
    total_acc = running_corrects / len(test_loader.dataset)
    print(f"Testing Accuracy: {100*total_acc}, Testing Loss: {total_loss}")


def validate(model, validation_loader, criterion):
    model.eval()
    running_loss = 0
    running_corrects = 0
    running_samples = 0
    total_samples = len(validation_loader.dataset)

    for inputs, labels in validation_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data).item()
        running_samples += len(inputs)
        if running_samples % 2000 == 0:
            log_metrics(loss, running_corrects, running_samples, total_samples)
            
        #NOTE: Comment lines below to train and test on whole dataset
        if running_samples > (MAX_SAMPLES_PROPORTION * total_samples):
            break

    epoch_loss = running_loss / running_samples
    epoch_acc = running_corrects / running_samples
    print(f"Phase validation, Epoc loss {epoch_loss}, Epoc accuracy {epoch_acc}")
    return epoch_loss


def train(model, train_loader, criterion, optimizer):
    model.train()

    running_loss = 0.0
    running_corrects = 0
    running_samples = 0
    total_samples = len(train_loader.dataset)

    for inputs, labels in train_loader:
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data).item()
        running_samples += len(inputs)
        if running_samples % 2000 == 0:
            log_metrics(loss, running_corrects, running_samples, total_samples)

        #NOTE: Comment lines below to train and test on whole dataset
        if running_samples > (MAX_SAMPLES_PROPORTION * total_samples ):
            break

    epoch_loss = running_loss / running_samples
    epoch_acc = running_corrects / running_samples
    print(f"Phase training, Epoc loss {epoch_loss}, Epoc accuracy {epoch_acc}")
    return epoch_loss

def train_with_early_stopping(model, datasets_loader, epochs, loss_criterion, optimizer):
    for epoch in range(1, epochs + 1):
        best_loss = 1e6
        print(f"Epoch {epoch}")
        _ = train(model, datasets_loader[train_suffix], loss_criterion, optimizer, epoch)
        validate_epoch_loss = validate(model, datasets_loader[val_suffix], loss_criterion, epoch)
        if validate_epoch_loss < best_loss:
            break

In [None]:
def net(num_classes: int):
    '''Initializes a pretrained model'''
    model = models.resnet50(pretrained=True)

    # Freeze training of the convolutional layers
    for param in model.parameters():
        param.requires_grad = False   

    # Override the last layer to adjust it to our problem
    num_features=model.fc.in_features
    model.fc = nn.Sequential(nn.Linear(num_features, num_classes))
    
    return model

def create_data_loaders(train_data_dir: str, valid_data_dir: str, batch_size: int):
    '''Create pytorch data loaders'''
     
    data_dir = { train_suffix: train_data_dir, val_suffix: valid_data_dir}
    
    data_transforms = {
        train_suffix: transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        val_suffix: transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
   
    image_datasets = {}
    dataloaders = {}
    for x in [train_suffix, val_suffix]:
        
        image_datasets[x] = datasets.ImageFolder(data_dir[x], data_transforms[x])
        dataloaders[x] = torch.utils.data.DataLoader(
            image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4)
        
    return dataloaders

def get_num_classes(dataloader) -> int:
        return len(dataloader[train_suffix].dataset.classes)

In [None]:
def main(args):

    dataset_loaders = create_data_loaders(ags.data_dir, ags.batch_size, 'train', 'validation')
    num_classes = get_num_classes(dataset_loaders)
    
    model=net(num_classes)
    
    loss_criterion = nn.CrossEntropyLoss()
    optimizer =  optim.Adam(model.fc.parameters(), lr=args.lr)

    train_with_early_stopping(model, dataset_loaders, args.epochs, loss_criterion, optimizer)

    torch.save(model.state_dict(), PATH) #TODO  torch.save(model.state_dict(), "mnist_cnn.pt")
    
# def model_fn(model_dir):
#     model = Net()
#     with open(os.path.join(model_dir, "model.pth"), "rb") as f:
#         model.load_state_dict(torch.load(f))
#     return model


# def save_model(model, model_dir):
#     logger.info("Saving the model.")
#     path = os.path.join(model_dir, "model.pth")
#     torch.save(model.cpu().state_dict(), path)

In [None]:
if __name__=='__main__':
    parser=argparse.ArgumentParser(description="Training Job for Hyperparameter tuning")
    
    parser.add_argument(
        "--batch-size",
        type=int,
        default=64,
        metavar="N",
        help="input batch size for training (default: 64)",
    )

    parser.add_argument(
        "--epochs",
        type=int,
        default=2,
        metavar="N",
        help="number of epochs to train (default: 14)",
    )
    parser.add_argument(
        "--lr", type=float, default=1.0, metavar="LR", help="learning rate (default: 1.0)"
    )

    
    # Container environment
#     parser.add_argument("--hosts", type=list, default=json.loads(os.environ["SM_HOSTS"]))
#     parser.add_argument("--current-host", type=str, default=os.environ["SM_CURRENT_HOST"])
#     parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"])
    parser.add_argument("--data-dir", type=str, default=os.environ["SM_CHANNEL_TRAINING"])
#     parser.add_argument("--num-gpus", type=int, default=os.environ["SM_NUM_GPUS"])
    
    args = parser.parse_args()
    
    main(args)

## TEST

### Arguments

In [None]:
import sys
import os

sys.argv = ["hpo.py", "--batch-size", "32", "--learning-rate", "0.001"]

os.environ['SM_CHANNEL_TRAINING'] = "./dogImages/train"
os.environ['SM_CHANNEL_VALID'] = "./dogImages/valid"
os.environ['SM_CHANNEL_TEST'] = "./dogImages/test"
os.environ['SM_CHANNEL_DIR'] = "/opt/ml/model"

! sudo mkdir /opt/ml/model
! sudo chown -R ec2-user:ec2-user /opt/ml/model # give ownership to the current EC2 user

### Data loaders

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt

data_loader, _ = create_data_loaders("./dogImages/test", 30)

In [None]:
# Can we do this in a different way?
for batch in data_loader:
    break

In [None]:
batch[0].shape, batch[1].shape

In [None]:
batch[0][0].shape, batch[0][0].permute(1, 2, 0).shape

In [None]:
image = batch[0][0].permute(1, 2, 0)
breed = batch[1][0]

plt.imshow(image);

In [None]:
!ls './dogImages/test/{breed + 1}*'

### Model 

In [None]:
model = net(42) #If it complains check the weights part on the webpage
type(model)

In [None]:
model