In [None]:
import boto3
from mlrun.artifacts import get_model
import pandas as pd
import torch

def get_metrics_model_info_df(model_spec):
    metrics = model_spec.metrics
    parameters = model_spec.parameters
    labels = model_spec.labels
    model_file = {"model_state_dict" : "model_state_dict.pth"}
    
    results_dict = {}
    for d in [metrics, parameters, labels, model_file]:
        results_dict.update(d)
    return pd.DataFrame([results_dict])
    
def get_prep_model(layer_size):
    return f"""import torch
from torch import nn

def prep_model(model_state_dict, device, num_classes=2):
    model_resnet50 = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=True)
    
    for name, param in model_resnet50.named_parameters():
        if "bn" not in name:
            param.requires_grad = False
        
    model_resnet50.fc = nn.Sequential(nn.Linear(model_resnet50.fc.in_features, {layer_size}),
                                      nn.ReLU(),
                                      nn.Dropout(),
                                      nn.Linear({layer_size}, num_classes))
    model_resnet50.to(device)
    model_resnet50.load_state_dict(model_state_dict)
    return model_resnet50"""
    
def handler(context, event):
    context.logger.info("Loading Model Metrics")
    model_file, model_spec, _ = get_model(context.inputs['model'].url, suffix='.pth')
    
    context.logger.info("Creating Metrics / Model Info CSV")
    df = get_metrics_model_info_df(model_spec)
    df.to_csv("results.csv", index=False)
    
    context.logger.info("Creating prep_model.py script")
    layer_size = model_spec.parameters['layer_size']
    with open("prep_model.py", "w") as f:
        f.write(get_prep_model(layer_size))
        
    context.logger.info("Creating model_state_dict.pth file")
    model_state_dict = torch.load(open(model_file, "rb"))
    torch.save(model_state_dict, "model_state_dict.pth")
    
    context.logger.info("Initializing S3")
    s3 = boto3.resource('s3')
    bucket = s3.Bucket(str(context.inputs['bucket_name']))
    results_upload_path = str(context.inputs['results_upload_path'])
    
    context.logger.info("Uploading Files to S3")
    upload_files = ["results.csv", "prep_model.py", "model_state_dict.pth"]
    for file in upload_files:
        bucket.upload_file(file, f"{results_upload_path}/{file}")