# Load Best Checkpoint from S3

This notebook downloads the best trained model checkpoint from the S3 bucket `4k-woody-btt` and saves it under `data/checkpoint/` in this repository.

S3 path structure:
- `s3://4k-woody-btt/4k/data/t15_pretrained_rnn_baseline/checkpoint/best_checkpoint`

Local save path:
- `data/checkpoint/best_checkpoint.pth`


In [None]:
# Install required packages
%pip install boto3 torch omegaconf


In [None]:
import os
import boto3
import torch
import tempfile
from pathlib import Path

# Configure S3
S3_BUCKET = '4k-woody-btt'
S3_KEY = '4k/data/t15_pretrained_rnn_baseline/checkpoint/best_checkpoint'  # exact key provided

# Local path
LOCAL_DIR = Path('data/checkpoint')
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
LOCAL_CHECKPOINT_PATH = LOCAL_DIR / 'best_checkpoint.pth'

s3 = boto3.client('s3')
print(f"S3 bucket: {S3_BUCKET}")
print(f"S3 key: {S3_KEY}")
print(f"Local save path: {LOCAL_CHECKPOINT_PATH}")


In [None]:
# Download from S3 and save locally
with tempfile.NamedTemporaryFile(delete=False) as tmp:
    tmp_path = tmp.name

try:
    s3.download_file(S3_BUCKET, S3_KEY, tmp_path)
    # If source file lacks .pth extension, still save as .pth locally for consistency
    os.replace(tmp_path, LOCAL_CHECKPOINT_PATH)
    print(f"Downloaded and saved to {LOCAL_CHECKPOINT_PATH}")
except Exception as e:
    try:
        os.unlink(tmp_path)
    except Exception:
        pass
    raise e


In [None]:
# Optional: verify that checkpoint is loadable with torch
try:
    ckpt = torch.load(LOCAL_CHECKPOINT_PATH, map_location='cpu')
    if isinstance(ckpt, dict) and ('model_state_dict' in ckpt or 'state_dict' in ckpt):
        print("Checkpoint structure looks valid.")
        print("Keys:", list(ckpt.keys())[:5])
    else:
        print("Loaded object is not a standard torch checkpoint dict. Displaying type:", type(ckpt))
except Exception as e:
    print("Warning: could not load checkpoint with torch:", str(e))
