In [None]:
import os
import boto3
import botocore
import logging
from functools import lru_cache
import multiprocess as mp
import time

from botocore.credentials import RefreshableCredentials
from botocore.session import get_session

In [None]:
def aws_session(aws_profile='sg_stage'):
    """Create a a boto3 session.
    Params:
        (string): credentials profile name
    Returns:
        (boto3 client object)
    """
    session = boto3.Session()
    # If the session is run on a local machine, with AWS credentials fetched
    # from a shared file, use the DataScience role profile.
    if session.get_credentials().method == 'shared-credentials-file':
        session = boto3.Session(profile_name=aws_profile)
    creds = session.get_credentials()
    result = {
        'access_key': creds.access_key,
        'secret_key': creds.secret_key,
        'token': creds.token,
        'expiry_time': creds._expiry_time.isoformat()
    }
    return result

In [None]:
# automatically refreshes credentials to avoid client error due to expired session
# may not work on local machine due to MFA
CREDS = RefreshableCredentials.create_from_metadata(
    metadata=aws_session(),
    refresh_using=aws_session,
    method="sts-assume-role",
)

In [None]:
SESSION = get_session()
SESSION._credentials = CREDS
SESSION.set_config_variable("region", 'ap-southeast-1')
AUTO_SESSION = boto3.Session(botocore_session=SESSION)

In [None]:
@lru_cache()
def s3_client():
    """Cache a boto3 client with credentias and MFA token."""
    return AUTO_SESSION.client('s3')

In [None]:
class Migration:

    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    logging.basicConfig(filename='s3_transfer.log', level=logging.INFO)

    def __init__(self, src, dest):
        self.src = src
        self.dest = dest

    def generate_all_keys(self):
        '''
        Generates and returns list of all keys in src bucket.
        Outputs keys into a text file locally.
        '''
        s3 = self.s3_client()
        s3_paginator = s3.get_paginator('list_objects_v2')
        s3_object_keys = []

        key_file = open('keys.txt', 'w+')

        for page in s3_paginator.paginate(Bucket=self.src):
            for content in page.get('Contents', ()):
                key = content['Key']
                if not key[-1] == '/':
                    s3_object_keys += [key]
                    key_file.write(f'{key}\n')

        key_file.close()
        return s3_object_keys

    def multi(self, function, arguments):
        '''Run a given function across multiple processes.'''
        pool = mp.Pool(processes=mp.cpu_count() - 1)

        pool_results = [
            pool.apply_async(function, args=(argument, ))
            for argument in arguments
        ]

        pool.close()
        pool.join()
        return pool_results

    def download(self, key):
        '''Downloads file from src bucket.'''
        for i in range(2): 
            try:
                logging.info("%s is being downloaded", key)
                local_file = os.path.abspath(key)
                current_dir = os.path.split(local_file)[0]

                if not os.path.exists(current_dir):
                    os.makedirs(current_dir)

                s3 = s3_client()
                s3.download_file(self.src, key, local_file)
                logging.info("%s downloaded successfully", key)

            except Exception as ex:
                logging.exception(ex)

            else:
                break

    def upload(self, key):
        '''Uploads file to dest bucket.'''
        for i in range(2): 
            try:
                logging.info("%s is being uploaded", key)
                local_file = os.path.abspath(key)

                s3 = s3_client()
                s3.upload_file(local_file, self.dest, key)
                logging.info("%s uploaded successfully", key)

            except Exception as ex:
                logging.exception(ex)

            else:
                break

    def transfer(self, key):
        '''
        Transfers file from src bucket to dest bucket.
        Deletes file locally after transfer.
        '''
        self.download(key)
        self.upload(key)
        os.remove(key)

    def transfer_all(self):
        '''
        Copies file from src bucket to dest bucket using multiprocessing.
        Prints to log total time taken for this process.
        '''
        start_time = time.time()

        logging.info('Transferring...')

        key_file = open('keys.txt', 'r')
        keys = [key.rstrip('\n') for key in key_file]
        self.multi(self.transfer, keys)
        key_file.close()

        logging.info('Transfer completed')
        time_taken = time.time() - start_time
        logging.info("took %s to run", time_taken)

    def transfer_trial(self):
        '''
        Extracts the first 1000 keys to test and time the transfer of
        a small sample.
        '''
        with open("keys.txt") as keyfile:
            head = [next(keyfile).strip() for x in range(1000)]

        start_time = time.time()

        logging.info('Transferring...')

        self.multi(self.transfer, head)

        logging.info('Transfer completed')
        time_taken = time.time() - start_time
        logging.info("took %s to run", time_taken)

In [None]:
src = "yara-sh-dads-scd"
dest = "yara-sh-dads-scd-stage"

migrate = Migration(src, dest)
migrate.generate_all_keys()
# migrate.transfer_trial()

migrate.transfer_all()