In [None]:
%load_ext autoreload
%autoreload 2
from argparse import Namespace
from cryptography.fernet import Fernet
import requests
import time
import datetime as datetime
from datetime import timedelta
import numpy as np
import json
import os
import tarfile
import glob
from minio import Minio
from minio.error import (InvalidResponseError, S3Error)
import sys

sys.path.insert(0, '/appdata/dev')

from kaapana_federated.utils import get_auth_headers, get_minio_client, get_presigend_url, get_remote_header

In [None]:
def get_init_conf():
    return {
        "query": {
            "bool": {
                "must": [
                    {
                        "match_all": {}
                    },
                    {
                        "match_all": {}
                    },
                    {
                        "match_phrase": {
                            "00120020 ClinicalTrialProtocolID_keyword.keyword": {
                                "query": "satori-ana"
                            }
                        }
                    },
                    {
                        "match_phrase": {
                            "00080060 Modality_keyword.keyword": {
                                "query": "SEG"
                            }
                        }
                    }
                ],
                "filter": [],
                "should": [],
                "must_not": []
            }
        },
        "index": "meta-index",
        "dag": "dev-federated",
        "cohort_limit": 2,
        "form_data": {
            "task": "Task136_RACOON_260122-Federated",
            "model": "3d_lowres",
            "train_network_trainer": "nnUNetTrainerV2",
            "prep_modalities": "CT",
            "node_uid": "node_uid_125497293966312",
            "shuffle_seed": 0,
            "test_percentage": 0,
            "training_description": "nnUnet Segmentation",
            "body_part": "N/A",
            "train_max_epochs": 50,
            "input": "SEG"
        },
        "rest_call": {
        },
        "federated": {
            "federated_operators": ['unet-ct', 'local-dev'], 
            "skip_operators": ["workflow-cleaner"]
        }
    }


In [None]:
def get_init_conf():
    return {
        "query": {
            "bool": {
                "must": [
                    {
                        "match_all": {}
                    },
                    {
                        "match_all": {}
                    },
                    {
                        "match_phrase": {
                            "00120020 ClinicalTrialProtocolID_keyword.keyword": {
                                "query": "satori-ana"
                            }
                        }
                    },
                    {
                        "match_phrase": {
                            "00080060 Modality_keyword.keyword": {
                                "query": "SEG"
                            }
                        }
                    }
                ],
                "filter": [],
                "should": [],
                "must_not": []
            }
        },
        "index": "meta-index",
        "dag": "nnunet-training",
        "cohort_limit": 4,
        "form_data": {
            "task": "Task136_RACOON_310122-Federated",
            "model": "3d_lowres",
            "train_network_trainer": "nnUNetTrainerV2",
            "prep_modalities": "CT",
            "node_uid": "node_uid_125497293966312",
            "shuffle_seed": 0,
            "test_percentage": 0,
            "training_description": "nnUnet Segmentation",
            "body_part": "N/A",
            "train_max_epochs": 2,
            "input": "SEG"
        },
        "rest_call": {
        },
        "federated": {
            "federated_operators": ['nnunet-training'], 
            "skip_operators": ["model2dicom", "dcmsend", "dcmsend-pdf", "zip-unzip-training", "workflow-cleaner"]
        }
    }


In [None]:
def fernet_encryptfile(filepath, key):
    if key == 'deactivated':
        return
    fernet = Fernet(key.encode())
    with open(filepath, 'rb') as file:
        original = file.read()
    encrypted = fernet.encrypt(original)
    with open(filepath, 'wb') as encrypted_file:
        encrypted_file.write(encrypted)
        
def fernet_decryptfile(filepath, key):
    if key == 'deactivated':
        return
    fernet = Fernet(key.encode())
    with open(filepath, 'rb') as enc_file:
        encrypted = enc_file.read()
    decrypted = fernet.decrypt(encrypted)
    with open(filepath, 'wb') as dec_file:
        dec_file.write(decrypted)

def apply_tar_action(dst_filename, src_dir):
    print(f'Tar {src_dir} to {dst_filename}')
    with tarfile.open(dst_filename, "w:gz") as tar:
        tar.add(src_dir, arcname=os.path.basename(src_dir))

def apply_untar_action(src_filename, dst_dir):
    print(f'Untar {src_filename} to {dst_dir}')
    with tarfile.open(src_filename, "r:gz")as tar:
        tar.extractall(dst_dir)

def raise_kaapana_connection_error(r):
    if r.history:
        raise ConnectionError('You were redirect to the auth page. Your token is not valid!')
    try:
        r.raise_for_status()
    except:
        raise ValueError(f'Something was not okay with your request code {r}: {r.text}!')
    print('Response', r)

def update_conf(data, site, fl_run_id, minioClient, federated_bucket):
    data['conf']['federated']['site'] = site
    minio_urls = {}
    minio_urls['conf'] = {
        'GET': get_presigend_url(minioClient, 'GET', federated_bucket, os.path.join(fl_run_id, site, 'conf.tar.gz')),
        'PUT': get_presigend_url(minioClient, 'PUT', federated_bucket,  os.path.join(fl_run_id, site,'conf.tar.gz'))
    }
    for federated_operator in data['conf']['federated']['federated_operators']:
        minio_urls[federated_operator] = {
            'GET': get_presigend_url(minioClient, 'GET', federated_bucket, os.path.join(fl_run_id, site, f'{federated_operator}.tar.gz')),
            'PUT': get_presigend_url(minioClient, 'PUT', federated_bucket,  os.path.join(fl_run_id, site, f'{federated_operator}.tar.gz'))
        }
    data['conf']['federated']['minio_urls'] = minio_urls
        
minioClient = get_minio_client('kaapanaminio', 'Kaapana2020')

fl_working_dir = '/appdata/dev/federated-local-workspace'
FEDERATED_BUCKET = 'january'
fl_run_id = '123'
dry_run = False
sites = {
    'dkfz': {
        'token': '7bfb0941-b3d6-466a-9c17-b5a6e3fe8b5e',
        'protocol': 'https',
        'host': '10.133.28.53',
        'port': '443',
        'ssl_check': False,
        'fernet_key': 'LL9zxylY0AvcFsYIvIEWtRszKrEIcjXlrVsL7HGW7-8='
#         'username': 'kaapana',
#         'password': 'admin',
#         'client_id': 'kaapana',
#         'client_secret': '1c4645f0-e654-45a1-a8b6-cf28790104ea'
    },
    'essen':  {
        'token': '7bfb0941-b3d6-466a-9c17-b5a6e3fe8b5e',
        'protocol': 'https',
        'host': '10.133.28.53',
        'port': '443',
        'ssl_check': False,
        'fernet_key': 'LL9zxylY0AvcFsYIvIEWtRszKrEIcjXlrVsL7HGW7-8='
#         'username': 'kaapana',
#         'password': 'admin',
#         'client_id': 'kaapana',
#         'client_secret': '1c4645f0-e654-45a1-a8b6-cf28790104ea'
    },
}

r = requests.get('http://federated-backend-service.base.svc:5000/federated-backend/get-client-network')
raise_kaapana_connection_error(r)
client_network = r.json()
    
if not minioClient.bucket_exists(FEDERATED_BUCKET):
    minioClient.make_bucket(FEDERATED_BUCKET)


for rn in range(0, 2):
    updated = {site: False for site in sites.keys()}
    fl_working_dir_round = os.path.join(fl_working_dir, str(rn))
    # Starting round!
    now = datetime.datetime.now(tz=datetime.timezone.utc)
    for site, site_info in sites.items():
        if rn == 0:
            conf = get_init_conf()
            conf['federated']['from_previous_dag_run'] =  None
            conf['federated']['rounds'] = [0] 
        else:
            with open(os.path.join(os.path.join(fl_working_dir, str(rn-1)), fl_run_id, site, 'conf', 'conf.json'), "r", encoding='utf-8') as jsonData:
                conf = json.load(jsonData)
        meta_data = {
            'conf': conf
        }
        update_conf(meta_data, site, fl_run_id, minioClient, FEDERATED_BUCKET)
        print(f'Sending data to {site_info["protocol"]}://{site_info["host"]}:{site_info["port"]}')
        remote_backend_url = f'{site_info["protocol"]}://{site_info["host"]}:{site_info["port"]}/federated-backend/remote'
        r = requests.get(f'{remote_backend_url}/health-check', verify=False, headers=get_remote_header(site_info["token"]))
        raise_kaapana_connection_error(r)
        r = requests.post(f'{remote_backend_url}/trigger-workflow', params={'dry_run': dry_run}, json=meta_data,  verify=False,  headers=get_remote_header(site_info["token"]))
        raise_kaapana_connection_error(r)
    if dry_run is True:
        print(r.text)
        break
        
    # Waiting for updated files
    print('Waiting for updates')
    for idx in range(10000):
        if idx%60 == 0:
            print(f'{60*(idx+1)} seconds')
            
        time.sleep(1) 
        for site, site_info in sites.items():
            for obj in minioClient.list_objects(FEDERATED_BUCKET, os.path.join(fl_run_id, site, 'conf')):
                if now < obj.last_modified:
                    updated[site] = True
        if np.sum(list(updated.values())) == len(sites):
            break
            
    if bool(np.sum(list(updated.values())) == len(sites)) is False:
        print('Update list')
        for k, v in updated.items():
            print(k, v)
        raise ValueError('There are lacking updates, please check what is going on!')

    # Downloading all objects
    for site, site_info in sites.items():
        objects = minioClient.list_objects(FEDERATED_BUCKET, os.path.join(fl_run_id, site), recursive=True)
        assert len(list(objects)) != 0, 'There seems to be an error somewhere, we can not find the files on minio!'
        for obj in objects:
            # https://github.com/minio/minio-py/blob/master/minio/datatypes.py#L103
            if obj.is_dir:
                pass
            else:
#                 file_dir = os.path.join(fl_working_dir_round, os.path.dirname(obj.object_name))
                file_path = os.path.join(fl_working_dir_round, obj.object_name)
                file_dir = os.path.dirname(file_path)
                os.makedirs(file_dir, exist_ok=True)
                minioClient.fget_object(FEDERATED_BUCKET,  obj.object_name, file_path)
#                 dst_dir = os.path.join(file_dir, os.path.basename(obj.object_name).replace('.tar.gz', ''))
                fernet_decryptfile(file_path, site_info['fernet_key'])
                apply_untar_action(file_path, file_dir)
     # Working with downloaded objects

    print('Finished round', rn)