In [None]:
import os
import boto3
import botocore
import logging
from functools import lru_cache
import multiprocess as mp
from datetime import datetime

from botocore.credentials import RefreshableCredentials
from botocore.session import get_session

In [None]:
def aws_session(aws_profile='sg_stage'):
    """Create a a boto3 session.
    Params:
        (string): credentials profile name
    Returns:
        (boto3 client object)
    """
    session = boto3.Session()
    # If the session is run on a local machine, with AWS credentials fetched
    # from a shared file, use the DataScience role profile.
    if session.get_credentials().method == 'shared-credentials-file':
        session = boto3.Session(profile_name=aws_profile)
    creds = session.get_credentials()
    result = {
        'access_key': creds.access_key,
        'secret_key': creds.secret_key,
        'token': creds.token,
        'expiry_time': creds._expiry_time.isoformat()
    }
    return result

CREDS = RefreshableCredentials.create_from_metadata(
    metadata=aws_session(),
    refresh_using=aws_session,
    method="sts-assume-role",
)

SESSION = get_session()
SESSION._credentials = CREDS
SESSION.set_config_variable("region", 'ap-southeast-1')
AUTO_SESSION = boto3.Session(botocore_session=SESSION)

@lru_cache()
def s3_client():
    """Cache a boto3 client with credentias and MFA token."""
    return AUTO_SESSION.client('s3')

In [None]:
class Check:

    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    logging.basicConfig(filename='check.log', level=logging.INFO)

    def __init__(self, src, dest):
        self.src = src
        self.dest = dest

    def multi(self, function, arguments):
        """Run a given function across multiple processes."""
        pool = mp.Pool(processes=mp.cpu_count() - 1)

        pool_results = [
            pool.apply_async(function, args=(argument, ))
            for argument in arguments
        ]

        pool.close()
        pool.join()
        return pool_results

    def reupload(self, key):
        '''
        Takes in object as argument and copies it over from src bucket to dest bucket.
        '''
        logging.info("%s is being reuploaded", key)

        local_file = os.path.abspath(key)
        current_dir = os.path.split(local_file)[0]

        if not os.path.exists(current_dir):
            os.makedirs(current_dir)

        s3 = s3_client()
        s3.download_file(self.src, key, local_file)
        s3.upload_file(local_file, self.dest, key)

        logging.info("%s has been reuploaded", key)

    def is_uploaded(self, key):
        '''
        Takes in an object key as argument and checks if it exists in dest bucket.
        If not, it copies the object from src bucket over to dest bucket.
        '''
        try:
            response = s3_client().head_object(Bucket=self.dest, 
                                              Key=key, 
                                              IfModifiedSince=datetime(2021, 6, 1))
            # logging.info("%s present", key)
        except botocore.exceptions.ClientError as err:
            if err.response['Error']['Code'] == '304':
                logging.info("%s not present", key)
                self.reupload(key)
            else:
                logging.info("%s encountered error", key)
                logging.error(err)
                self.reupload(key)

    def check_all(self):
        '''
        Checks that all files in src bucket exist in dest bucket using multiprocessing.
        Uses keys.txt generated from the transfer_multi notebook.
        '''
        logging.info('Checking...')

        keys = []
        with open('keys.txt', 'r') as key_file:
            keys = [key.rstrip('\n') for key in key_file]

        self.multi(self.is_uploaded, keys)

        logging.info('Check completed')

In [None]:
src = "yara-sh-dads-scd"
dest = "yara-sh-dads-scd-stage"
check_files = Check(src, dest)
check_files.check_all()