In [None]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from cloudpickle import loads, dumps

def prep_model(device, layer_size, num_classes=2):
    model_resnet50 = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=True)
    
    for name, param in model_resnet50.named_parameters():
        if "bn" not in name:
            param.requires_grad = False
        
    model_resnet50.fc = nn.Sequential(nn.Linear(model_resnet50.fc.in_features, layer_size),
                                      nn.ReLU(),
                                      nn.Dropout(),
                                      nn.Linear(layer_size, num_classes))
    
    return model_resnet50.to(device)

def train(model, optimizer, loss_fn, train_loader, val_loader, epochs, device):
    for epoch in range(epochs):
        context.logger.info(f"Epoch: {epoch + 1}")
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        i = 0
        for i, batch in enumerate(train_loader):
            context.logger.debug(f"Batch: {i + 1}")
            optimizer.zero_grad()
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss = loss_fn(output, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item() * inputs.size(0)
            i += 1
        training_loss /= len(train_loader.dataset)
        
        model.eval()
        num_correct = 0 
        num_examples = 0
        for batch in val_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            output = model(inputs)
            targets = targets.to(device)
            loss = loss_fn(output,targets) 
            valid_loss += loss.data.item() * inputs.size(0)
                        
            correct = torch.eq(torch.max(F.softmax(output, dim=1), dim=1)[1], targets).view(-1)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        valid_loss /= len(val_loader.dataset)

        context.logger.info('Epoch: {}, Training Loss: {:.4f}, Validation Loss: {:.4f}, accuracy = {:.4f}'.format(epoch, training_loss,
        valid_loss, num_correct / num_examples))
    return training_loss, valid_loss, num_correct / num_examples

def handler(context, event):
    context.logger.info("Loading DataLoaders")
    train_data_loader = loads(context.inputs['train_data_loader'].get())
    validation_data_loader = loads(context.inputs['validation_data_loader'].get())
    
    epochs = int(str(context.inputs['epochs']))
    device = str(context.inputs['device'])
    batch_size = int(str(context.inputs['batch_size']))
    lr = float(str(context.parameters['lr']))
    layer_size = int(str(context.parameters['layer_size']))
    
    context.logger.info("Initializing Model")
    model = prep_model(device=device, layer_size=layer_size)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    context.logger.info("Training Model")
    training_loss, validation_loss, validation_accuracy = train(model=model,
                                                                optimizer=optimizer,
                                                                loss_fn=nn.CrossEntropyLoss(),
                                                                train_loader=train_data_loader,
                                                                val_loader=validation_data_loader,
                                                                epochs=epochs,
                                                                device=device)
    
    context.logger.info("Logging Results")
    results = {"training_loss" : training_loss,
               "validation_loss" : validation_loss,
               "validation_accuracy" : validation_accuracy}
    context.log_results(results)
    
    context.logger.info("Logging Model")
    torch.save(model.state_dict(), "model.pth")
    context.log_model(key="model",
                      model_file="model.pth",
                      artifact_path=context.artifact_path,
                      labels={'framework': 'pytorch',
                              'category': 'cv',
                              'action': 'dogs_vs_cats'},
                      metrics=context.results, 
                      parameters={'batch_size': batch_size,
                                  'epochs': epochs,
                                  'lr' : lr,
                                  'layer_size' : layer_size})