# Deploy Model

## Download Job Result

In [None]:
JOB_ID="<Replace with your job id>"

You can skip this part if you have already downloaded the artifact in `Analyze_Result.ipynb`

In [None]:
import boto3
import os
from pathlib import Path
from tqdm import tqdm

sts_client = boto3.client('sts')
account_info = sts_client.get_caller_identity()
account_id = account_info['Account']

bucket_name = f"flare-provision-bucket-{account_id}"

local_dir = Path('outputs') / JOB_ID

s3 = boto3.resource('s3')
bucket = s3.Bucket(bucket_name)

# Create local directory if it doesn't exist
if local_dir and not os.path.exists(local_dir):
    os.makedirs(local_dir)

job_key = f'outputs/{JOB_ID}'
# Download each object
for obj in tqdm(bucket.objects.filter(Prefix=job_key)):
    if obj.key.endswith('/'):
        continue

    os.makedirs(os.path.dirname(obj.key), exist_ok=True)
    bucket.download_file(obj.key, obj.key)
print('Download Complete')

## Upload Artifact to SageMaker Bucket

In [None]:
from torch import nn
from torch.functional import F


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
        self.activation = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return self.activation(x)

In [None]:
import torch
import os
from pathlib import Path

local_dir = Path('outputs') / JOB_ID
model_path = local_dir / 'workspace' / 'app_server' / 'best_FL_global_model.pt'

model = Net()
model.load_state_dict(torch.load(model_path, weights_only=True)['model'])
model.eval()

torch.jit.script(model).save('model.pt')

In [None]:
import tarfile

with tarfile.open("model.tar.gz", "w:gz") as tar:
    tar.add('model.pt')

In [None]:
import sagemaker
import boto3

sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()
model_artifact = sagemaker_session.upload_data(
    "model.tar.gz",
    bucket=bucket,
    key_prefix=f"flare-model/{JOB_ID}"
)
role = sagemaker.get_execution_role()

In [None]:
from sagemaker.pytorch import PyTorchModel
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

pytorch_model = PyTorchModel(
    model_data=model_artifact,
    role=role,
    framework_version='2.3.0',
    py_version="py311",
    entry_point="inference.py",
)

predictor = pytorch_model.deploy(
    initial_instance_count=1,
    instance_type="ml.m5.large",
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer(),
)

In [None]:
import torchvision

test_data = torchvision.datasets.MNIST(
    './data',
    train=False,
    download=False,
    transform=torchvision.transforms.ToTensor(),
)

In [None]:
import json
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, accuracy_score
import seaborn as sns
from torch.utils.data import DataLoader


test_dataloader = DataLoader(test_data, batch_size=100, shuffle=False)
# Lists to store predictions and true labels
all_predictions = []
all_labels = []

# Iterate through the test data
for image, label in tqdm(test_dataloader):
    # Get model predictions
    response = predictor.predict({
        'inputs': image.numpy().tolist(),
    })

    # Assuming response contains predictions, adjust this based on your response format
    batch_predictions = np.argmax(response, axis=1)

    # Store predictions and labels
    all_predictions.extend(batch_predictions)
    all_labels.extend(label.numpy())

# Convert lists to numpy arrays
all_predictions = np.array(all_predictions)
all_labels = np.array(all_labels)

# Calculate accuracy
accuracy = accuracy_score(all_labels, all_predictions)
print(f"\nTest Accuracy: {accuracy:.4f}")

# Create confusion matrix
cm = confusion_matrix(all_labels, all_predictions)

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()

# Print classification report
from sklearn.metrics import classification_report
print("\nClassification Report:")
print(classification_report(all_labels, all_predictions))

# Clean Up

In [None]:
predictor.delete_endpoint()
pytorch_model.delete_model()