In [None]:
import boto3
# the code below assumes that you configure boto3 with your AWS account
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html
ec2 = boto3.resource('ec2')
client = boto3.client('ec2')

In [None]:
use_internal_routing = True

# ^-- must be unique per experiment
coordinator_type = "r5.large"
dht_port = 31337
num_workers = 16
check_proba = 0.5 * 1 / 16
print(f'check_proba = {check_proba}')

image_id = "ami-0db67995cd75f5a9f"
aws_key_name = ""  ## update with your aws key name
subnet = ""  ## update with your subnet name or skip entirely
security_group = ""  ## you guessed it
data_path = ""  ## path to an archive with wikitext103 dataset
state_url = "http://"  ## prefix for URLs where you will upload states for steps 950 and 4950

In [None]:
def kill_instances(experiment_name):
    existing_instances = ec2.instances.filter(Filters=[
        {'Name': 'instance-state-name', 'Values': ['running']},
        {'Name': 'tag:experiment', 'Values': [experiment_name]},
    ])
    ins = list(existing_instances)
    private_ips = []
    if ins:
        print(f"Already running {experiment_name}: {ins}")
        print(len(ins))
        for i in ins:
            private_ips.append(i.private_ip_address)
            print(i.public_ip_address, i.private_ip_address)
    
    # to remove all instances and spot requests, run this:
    existing_instances.terminate()
    requests_to_shutdown = []
    for request in client.describe_spot_instance_requests()['SpotInstanceRequests']:
        if request['State'] == 'active' and \
                any(tag['Key'] == 'experiment' and tag['Value'] == experiment_name for tag in request['Tags']):
            requests_to_shutdown.append(request['SpotInstanceRequestId'])
    if requests_to_shutdown:
        client.cancel_spot_instance_requests(
            SpotInstanceRequestIds=requests_to_shutdown)
    print('Instances stopped')

### Stage 1: run coordinator

Coordinator is an instance that welcomes new peers into a decentralized training run. If coordinator is down, new peers can still join by initializing with one of the existing peers.

In [None]:
WandB_API_key = ""  ## Your key

In [None]:
get_ip_cmd = "export IP=$(ifconfig eth0 | grep -Eo 'inet (addr:)?([0-9]*\.){3}[0-9]*' | grep -Eo '([0-9]*\.){3}[0-9]*')"

In [None]:
coordinator_script = f'''#!/bin/bash -ex
exec > >(tee /var/log/user-command.log|logger -t user-data -s 2>/dev/console) 2>&1

# note: we configure rsyslog to forward logs from all trainers
sudo sh -c 'cat <<"EOF" >> /etc/rsyslog.conf
$ModLoad imudp
$UDPServerRun 514

$ModLoad imtcp
$InputTCPServerRun 514

$FileCreateMode 0644
$DirCreateMode 0755
$Umask 0022

$template RemoteLogs,"/var/log/rsyslog/%HOSTNAME%.log"
*.*  ?RemoteLogs
& ~
EOF'
sudo systemctl restart rsyslog

{get_ip_cmd if use_internal_routing else ''}
git clone https://github.com/neurips-paper/BTARD
cd BTARD/albert

pip install -e .
pip install transformers==4.5.1
cd experiments


ulimit -n 4096


sh -c 'cat <<"EOF" >> ~/.netrc
machine api.wandb.ai
  login user
  password {WandB_API_key}
EOF'

wget -q {state_url}/%initial_state%/averager_state_%initial_step%.pickle -O initial_state.pickle

CCLIP_TAU=%tau% TOTAL_THREADS=256 python ./run_first_peer.py --dht_listen_on [::]:{dht_port} {'--address $IP' if use_internal_routing else ''} \
 --experiment_prefix %experiment_name% --wandb_project Runs \
 --compression NONE --metadata_expiration 180 --averaging_timeout 60 --averaging_expiration 10 \
 --initial_state_path initial_state.pickle
'''

In [None]:
def create_coordinator(tau, experiment_name, initial_step):
    coordinator, = ec2.create_instances(
        ImageId=image_id, InstanceType=coordinator_type,
        MinCount=1, MaxCount=1,
        SecurityGroupIds=[security_group], SubnetId=subnet,
        KeyName=aws_key_name,
        UserData=coordinator_script
            .replace('%tau%', str(tau))
            .replace('%initial_state%', 'tau_' + str(tau).replace('.', '_'))
            .replace('%initial_step%', str(initial_step))
            .replace('%experiment_name%', experiment_name),
        TagSpecifications=[{'ResourceType': 'instance', 'Tags': [
            {'Key':'experiment', 'Value': experiment_name},
            {'Key':'role', 'Value': 'first_peer'}
        ]}]
    )
    coordinator.wait_until_running()
    coordinator, = list(ec2.instances.filter(InstanceIds=[coordinator.id]))

    print('Created coordinator:', coordinator.private_ip_address, coordinator.public_ip_address)

    if use_internal_routing:
        coordinator_ip = coordinator.private_ip_address
    else:
        coordinator_ip = coordinator.public_ip_address

    coordinator_endpoint = f"{coordinator_ip}:{dht_port}"
    print('coordinator_endpoint =', coordinator_endpoint)
    
    return {'ip': coordinator_ip, 'endpoint': coordinator_endpoint}

### Stage 2: run workers

Workers are preemptible GPU instances that run compute gradients and perform averaging. In this example, each worker is a single tesla T4 instance.

In [None]:
worker_script = f'''#!/bin/bash -ex
exec > >(tee /var/log/user-command.log|logger -t user-data -s 2>/dev/console) 2>&1

set -euxo pipefail
cd ~

sudo sh -c 'cat <<"EOF" >> /etc/rsyslog.conf

user.* @@%coordinator_ip%:514

EOF'
sudo systemctl restart rsyslog


{get_ip_cmd if use_internal_routing else ''}
git clone https://github.com/neurips-paper/BTARD
cd BTARD/albert

pip install -e .
pip install transformers==4.5.1
cd experiments


mkdir -p ~/data
wget -qO- {data_path} | tar xzf -


sh -c 'cat <<"EOF" >> ~/.netrc
machine api.wandb.ai
  login user
  password {WandB_API_key}
EOF'


ulimit -n 4096

CCLIP_TAU=%tau% ATTACK_TYPE=%attack_type% ATTACK_START=%attack_start% CHECK_PROBA={check_proba} \
  DIRECTION_SEED=%seed% \
  WANDB_PROJECT=%experiment_name% WANDB_WATCH=false \
  TOTAL_THREADS=256 python run_trainer.py \
  --output_dir ./outputs --overwrite_output_dir \
  {'--endpoint $IP'+':*' if use_internal_routing else ''} \
  --logging_dir ./logs --logging_first_step --logging_steps 100 \
  --initial_peers %coordinator_endpoint%  --run_name aws_worker \
  --experiment_prefix %experiment_name% --seed %seed% --compression NONE --metadata_expiration 180 \
  --averaging_timeout 60 --averaging_expiration 10 --statistics_expiration 60
'''

In [None]:
def create_instance(worker_type, attack_type, tau, experiment_name, coordinator_ip, coordinator_endpoint, seed,
                    attack_start):
    new_worker, = ec2.create_instances(
    ImageId=image_id, InstanceType=worker_type,
    MinCount=1, MaxCount=1,
    UserData=worker_script
        .replace('%attack_type%', attack_type)
        .replace('%attack_start%', str(attack_start))
        .replace('%tau%', str(tau))
        .replace('%tau_underscore%', str(tau).replace('.', '_'))
        .replace('%experiment_name%', experiment_name)
        .replace('%coordinator_ip%', coordinator_ip)
        .replace('%coordinator_endpoint%', coordinator_endpoint)
        .replace('%seed%', str(seed)),
    SecurityGroupIds=[security_group], SubnetId=subnet, 
    KeyName=aws_key_name,
    InstanceMarketOptions={
        "MarketType": "spot",
        "SpotOptions": {
            "SpotInstanceType": "one-time",
            "InstanceInterruptionBehavior": "terminate"
        }
    },
    TagSpecifications=[{'ResourceType': 'instance', 'Tags': [
        {'Key':'experiment', 'Value': experiment_name},
        {'Key':'role', 'Value': 'gpu_worker'}
    ]}, {'ResourceType': 'spot-instances-request', 'Tags': [
        {'Key':'experiment', 'Value': experiment_name},
        {'Key':'role', 'Value': 'gpu_worker'}
    ]}],)
    return new_worker

In [None]:
import time
import traceback

def run_workers(tau, experiment_name, coordinator_ip, coordinator_endpoint,
                n_attackers, time_limit=3 * 3600, intended_attack='NONE', **kwargs):
    stop_time = time.time() + time_limit
    while time.time() < stop_time:
        existing_instances = list(ec2.instances.filter(Filters=[
            {'Name': 'instance-state-name', 'Values': ['running']},
            {'Name': 'tag:experiment', 'Values': [experiment_name]},
        ]))

        count_needed = num_workers + 1 - len(existing_instances)
        if count_needed > 0:
            attack_type = intended_attack if count_needed > num_workers - n_attackers else 'NONE'
            
            print(f"Need {count_needed} more workers. Trying to spawn one")
            instance_types = ['g4dn.2xlarge']
            for i, worker_type in enumerate(instance_types):
                try:
                    new_worker = create_instance(
                        worker_type, attack_type, tau, experiment_name, coordinator_ip, coordinator_endpoint, **kwargs)
                    new_worker.wait_until_running()
                    new_worker, = list(ec2.instances.filter(InstanceIds=[new_worker.id]))
                    print("CREATED ONE WORKER!", worker_type, attack_type,
                          new_worker.public_ip_address, new_worker.private_ip_address)
                    break
                except Exception as e:
                    print('Failed:', worker_type, e)
                    traceback.print_exc()
                    
        time.sleep(30)

Run the training without attacks to collect initial states for steps 950 and 4950:

In [None]:
seed = 0
tau = 0.125
initial_step = 0
attack_start = 0
intended_attack = 'NONE'

experiment_name = "baseline"

try:
    print(f'\n[*] {experiment_name}: Creating coordinator...')
    while True:
        try:
            coordinator_info = create_coordinator(tau, experiment_name, initial_step=initial_step)
            break
        except Exception as e:
            print('[-] Failed to create coordinator:', e)
            traceback.print_exc()
            time.sleep(30)
    time.sleep(5 * 60)

    print(f'\n[*] {experiment_name}: Running workers...')
    run_workers(tau, experiment_name, coordinator_info['ip'], coordinator_info['endpoint'], n_attackers,
                time_limit=4 * 3600, intended_attack=intended_attack, seed=seed, attack_start=attack_start)
finally:
    print(f'\n[*] {experiment_name}: Stopping instances...')
    kill_instances(experiment_name)

Collect the states by manually connecting to a worker machine.

Upload the collected steps to an URL and go on to testing the attacks:

In [None]:
n_attackers = 7

for seed in range(3):
    for tau in [0.125, 0.5]:
        for initial_step in [950, 4950]:    
            attack_start = initial_step + 50

            for intended_attack in ['NONE', 'SIGN_FLIPPING', 'LABEL_FLIPPING', 'CONSTANT_DIRECTION']:
                experiment_name = f"{intended_attack.lower()}_tau_{str(tau).replace('.', '_')}_step_{initial_step}_seed_{seed}"

                try:
                    print(f'\n[*] {experiment_name}: Creating coordinator...')
                    while True:
                        try:
                            coordinator_info = create_coordinator(tau, experiment_name, initial_step=initial_step)
                            break
                        except Exception as e:
                            print('[-] Failed to create coordinator:', e)
                            traceback.print_exc()
                            time.sleep(30)
                    time.sleep(5 * 60)

                    print(f'\n[*] {experiment_name}: Running workers...')
                    run_workers(tau, experiment_name, coordinator_info['ip'], coordinator_info['endpoint'], n_attackers,
                                time_limit=4 * 3600, intended_attack=intended_attack, seed=seed, attack_start=attack_start)
                finally:
                    print(f'\n[*] {experiment_name}: Stopping instances...')
                    kill_instances(experiment_name)