# Analyze Job Result

## Download Job Result

In [None]:
JOB_ID="<Replace with your job id>"

In [None]:
import boto3
import os
from pathlib import Path
from tqdm import tqdm

sts_client = boto3.client('sts')
account_info = sts_client.get_caller_identity()
account_id = account_info['Account']

bucket_name = f"flare-provision-bucket-{account_id}"

local_dir = Path('outputs') / JOB_ID

s3 = boto3.resource('s3')
bucket = s3.Bucket(bucket_name)

# Create local directory if it doesn't exist
if local_dir and not os.path.exists(local_dir):
    os.makedirs(local_dir)

job_key = f'outputs/{JOB_ID}'
# Download each object
for obj in tqdm(bucket.objects.filter(Prefix=job_key)):
    if obj.key.endswith('/'):
        continue

    os.makedirs(os.path.dirname(obj.key), exist_ok=True)
    bucket.download_file(obj.key, obj.key)
print('Download Complete')

## Visualize Cross-Eval Result

In [None]:
import json
import os
from pprint import pprint

with open(os.path.join(local_dir, 'workspace', 'cross_site_val', 'cross_val_results.json'), 'r') as f:
    cross_val_result = json.loads(f.read())

pprint(cross_val_result)

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

for site in cross_val_result.values():
    for model in site:
        if type(site[model]) is not float:
            site[model] = site[model]['accuracy']

df = pd.DataFrame(cross_val_result)[['site-1', 'site-2','site-3']]
df = df.reindex([
    'SRV_FL_global_model.pt',
    'SRV_best_FL_global_model.pt',
    'site-1',
    'site-2',
    'site-3'
])

# Create the heatmap
fig = plt.figure(figsize=(10, 8))
sns.heatmap(
    df,
    annot=True,  # Show numbers in cells
    fmt='.2f',   # Format numbers to 4 decimal places
    cmap='Blues',  # Color scheme
    vmin=0,   # Minimum value for color scaling
    vmax=1,   # Maximum value for color scaling
    cbar_kws={'label': 'Accuracy'}
)

plt.title('Cross-validation Accuracy Heatmap')
plt.xlabel('Test Dataset')
plt.ylabel('Model Source')

# Adjust layout to prevent label cutoff
plt.tight_layout()
plt.show()