In [1]:
%load_ext autoreload
%autoreload 2
from argparse import Namespace
from cryptography.fernet import Fernet
import requests
import time
import datetime as datetime
from datetime import timedelta
import numpy as np
import json
import os
import tarfile
import glob
from minio import Minio
from minio.error import (InvalidResponseError, S3Error)
import sys
from abc import ABC, abstractmethod

sys.path.insert(0, '/appdata/dev')

# from kaapana_federated.utils import get_auth_headers, get_minio_client, get_presigend_url, get_remote_header

In [6]:
class KaapanaFederatedTrainingBase(ABC):
    
    @staticmethod
    def fernet_encryptfile(filepath, key):
        if key == 'deactivated':
            return
        fernet = Fernet(key.encode())
        with open(filepath, 'rb') as file:
            original = file.read()
        encrypted = fernet.encrypt(original)
        with open(filepath, 'wb') as encrypted_file:
            encrypted_file.write(encrypted)
    
    @staticmethod
    def fernet_decryptfile(filepath, key):
        if key == 'deactivated':
            return
        fernet = Fernet(key.encode())
        with open(filepath, 'rb') as enc_file:
            encrypted = enc_file.read()
        decrypted = fernet.decrypt(encrypted)
        with open(filepath, 'wb') as dec_file:
            dec_file.write(decrypted)

    @staticmethod
    def apply_tar_action(dst_filename, src_dir):
        print(f'Tar {src_dir} to {dst_filename}')
        with tarfile.open(dst_filename, "w:gz") as tar:
            tar.add(src_dir, arcname=os.path.basename(src_dir))

    @staticmethod
    def apply_untar_action(src_filename, dst_dir):
        print(f'Untar {src_filename} to {dst_dir}')
        with tarfile.open(src_filename, "r:gz")as tar:
            tar.extractall(dst_dir)

    @staticmethod
    def raise_kaapana_connection_error(r):
        if r.history:
            raise ConnectionError('You were redirect to the auth page. Your token is not valid!')
        try:
            r.raise_for_status()
        except:
            raise ValueError(f'Something was not okay with your request code {r}: {r.text}!')
    
        
    def __init__(self, meta_data, job_data, local_data, fl_working_dir,
                 access_key='kaapanaminio',
                 secret_key='Kaapana2020',
                 minio_host='minio-service.store.svc',
                 minio_port='9000',
                 dry_run=False
                ):
        
        self.meta_data = meta_data
        self.job_data = job_data
        self.local_data = local_data
        self.fl_working_dir = fl_working_dir
        self.client_url = 'http://federated-backend-service.base.svc:5000/client'
        r = requests.get(f'{self.client_url}/client-kaapana-instance')
        KaapanaFederatedTrainingBase.raise_kaapana_connection_error(r)
        self.client_network = r.json()
        if 'node_ids' in self.job_data['federated']:
            node_ids = self.job_data['federated']['node_ids']
        else:
            node_ids = []
        print(type(node_ids))
        r = requests.post(f'{self.client_url}/get-remote-kaapana-instances', json={'node_ids': node_ids})
        KaapanaFederatedTrainingBase.raise_kaapana_connection_error(r)
        self.remote_sites = r.json()

        self.minioClient = Minio(
            minio_host+":"+minio_port,
            access_key=access_key,
            secret_key=secret_key,
            secure=False)
        self.dry_run = dry_run
    
    @abstractmethod
    def update_data(self, fl_round):
        pass
    
    def train(self):
        for fl_round in range(0, 2):
            updated = {site['node_id']: False for site in self.remote_sites}
            self.job_data['federated']['fl_round'] = fl_round
            # Starting round!
            for site_info in self.remote_sites:
                if fl_round == 0:
                    self.job_data['federated']['from_previous_dag_run'] =  None
                else:
                    self.job_data['federated']['from_previous_dag_run'] = site_info['parking']['from_previous_dag_run']

                r = requests.post(f'{self.client_url}/job', json={
                    "dry_run": self.dry_run,
                    "conf_data": self.meta_data,
                    "job_data": self.job_data,
                    "local_data": self.local_data,
                    "status": "queued",
                    "addressed_kaapana_node_id": self.client_network['node_id'],
                    "kaapana_instance_id": site_info['id']}, verify=self.client_network['ssl_check'])
                
                KaapanaFederatedTrainingBase.raise_kaapana_connection_error(r)
                job = r.json()
                print('Created Job')
                print(job)
                site_info['parking'] = {
                    'job_id': job['id']
                }
            if self.dry_run is True:
                print(r.text)
                break

            # Waiting for updated files
            print('Waiting for updates')
            for idx in range(10000):
                if idx%60 == 0:
                    print(f'{60*(idx+1)} seconds')

                time.sleep(8) 
                for site_info in self.remote_sites:
                    r = requests.get(f'{self.client_url}/job', params={
                        "job_id": site_info['parking']["job_id"]
                    },  verify=self.client_network['ssl_check'])
                    job = r.json()
                    print(job['status'])
                    print(job['run_id'])
                    print(job['description'])
                    print(job['job_data']['federated']['from_previous_dag_run'])
                    if job['status'] == 'finished':
                        updated[site_info['node_id']] = True
                        site_info['parking']['from_previous_dag_run'] = job['run_id']
                if np.sum(list(updated.values())) == len(self.remote_sites):
                    break

            if bool(np.sum(list(updated.values())) == len(self.remote_sites)) is False:
                print('Update list')
                for k, v in updated.items():
                    print(k, v)
                raise ValueError('There are lacking updates, please check what is going on!')

        #         federated_bucket, os.path.join(federated_dir, fl_round, node_id, f'{federated_operator}.tar.gz')

        #     # Downloading all objects
            file_path_object_name_tuples = [] 
            for site_info in self.remote_sites:
                current_federated_round_dir = os.path.join(local_data['federated']['federated_dir'], str(fl_round))
                next_federated_round_dir =  os.path.join(local_data['federated']['federated_dir'], str(fl_round+1))
                objects = self.minioClient.list_objects(local_data['federated']['federated_bucket'], os.path.join(current_federated_round_dir, site_info['node_id']), recursive=True)
                # assert len(list(objects)) != 0, 'There seems to be an error somewhere, we can not find the files on minio!' # Seems to destroy the objects object!
                for obj in objects:
                    print(obj)
                    # https://github.com/minio/minio-py/blob/master/minio/datatypes.py#L103
                    if obj.is_dir:
                        pass
                    else:
                        file_path = os.path.join(self.fl_working_dir, obj.object_name)
                        file_dir = os.path.dirname(file_path)
                        os.makedirs(file_dir, exist_ok=True)
                        self.minioClient.fget_object(local_data['federated']['federated_bucket'], obj.object_name, file_path)
                        KaapanaFederatedTrainingBase.fernet_decryptfile(file_path, site_info['fernet_key'])
                        KaapanaFederatedTrainingBase.apply_untar_action(file_path, file_dir)
                        file_path_object_name_tuples.append((file_path, obj.object_name.replace(current_federated_round_dir, next_federated_round_dir)))

            # Working with downloaded objects
            self.update_data(fl_round)

            # Push objects:
            for file_path, next_object_name in file_path_object_name_tuples:
                KaapanaFederatedTrainingBase.apply_tar_action(file_path, file_dir)
                KaapanaFederatedTrainingBase.fernet_encryptfile(file_path, site_info['fernet_key'])

                next_object_name = obj.object_name.replace(current_federated_round_dir, next_federated_round_dir)
                self.minioClient.fput_object(local_data['federated']['federated_bucket'], next_object_name, file_path)

            print('Finished round', fl_round)


In [3]:
def get_init_conf():
    return {
        "query": {
            "bool": {
                "must": [
                    {
                        "match_all": {}
                    },
                    {
                        "match_all": {}
                    },
                    {
                        "match_phrase": {
                            "00120020 ClinicalTrialProtocolID_keyword.keyword": {
                                "query": "satori-ana"
                            }
                        }
                    },
                    {
                        "match_phrase": {
                            "00080060 Modality_keyword.keyword": {
                                "query": "SEG"
                            }
                        }
                    }
                ],
                "filter": [],
                "should": [],
                "must_not": []
            }
        },
        "index": "meta-index",
        "dag": "dev-federated",
        "cohort_limit": 2,
        "form_data": {
            "task": "Task136_RACOON_260122-Federated",
            "model": "3d_lowres",
            "train_network_trainer": "nnUNetTrainerV2",
            "prep_modalities": "CT",
            "node_uid": "node_uid_125497293966312",
            "shuffle_seed": 0,
            "test_percentage": 0,
            "training_description": "nnUnet Segmentation",
            "body_part": "N/A",
            "train_max_epochs": 50,
            "input": "SEG"
        }
    }


In [None]:
def get_init_conf():
    return {
        "query": {
            "bool": {
                "must": [
                    {
                        "match_all": {}
                    },
                    {
                        "match_all": {}
                    },
                    {
                        "match_phrase": {
                            "00120020 ClinicalTrialProtocolID_keyword.keyword": {
                                "query": "satori-ana"
                            }
                        }
                    },
                    {
                        "match_phrase": {
                            "00080060 Modality_keyword.keyword": {
                                "query": "SEG"
                            }
                        }
                    }
                ],
                "filter": [],
                "should": [],
                "must_not": []
            }
        },
        "index": "meta-index",
        "dag": "nnunet-training",
        "cohort_limit": 4,
        "form_data": {
            "task": "Task136_RACOON_310122-Federated",
            "model": "3d_lowres",
            "train_network_trainer": "nnUNetTrainerV2",
            "prep_modalities": "CT",
            "node_uid": "node_uid_125497293966312",
            "shuffle_seed": 0,
            "test_percentage": 0,
            "training_description": "nnUnet Segmentation",
            "body_part": "N/A",
            "train_max_epochs": 2,
            "input": "SEG"
        }
    }


In [7]:
class KaapanaFederatedTraining(KaapanaFederatedTrainingBase):
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
            
    def update_data(self, fl_round):
        
        print(fl_round)

conf = {
    'fl_working_dir': '/appdata/dev/federated-local-workspace',
    'job_data': {
        "federated": {
            "federated_operators": ['unet-ct', 'local-dev'], 
            "skip_operators": ["workflow-cleaner"],
            "node_ids": ['berlin'],
        }
    },
    'local_data': {
        "federated": {
            "federated_bucket": "federateddata",
            "federated_dir": "123"
        }
    },
    'meta_data': {
        'conf': get_init_conf()
    }
}

kaapana_ft = KaapanaFederatedTraining(**conf)
kaapana_ft.train()

<class 'list'>
Created Job
{'dry_run': False, 'status': 'queued', 'run_id': None, 'description': None, 'external_job_id': None, 'addressed_kaapana_node_id': 'dkfz', 'id': 29, 'conf_data': {'conf': {'query': {'bool': {'must': [{'match_all': {}}, {'match_all': {}}, {'match_phrase': {'00120020 ClinicalTrialProtocolID_keyword.keyword': {'query': 'satori-ana'}}}, {'match_phrase': {'00080060 Modality_keyword.keyword': {'query': 'SEG'}}}], 'filter': [], 'should': [], 'must_not': []}}, 'index': 'meta-index', 'dag': 'dev-federated', 'cohort_limit': 2, 'form_data': {'task': 'Task136_RACOON_260122-Federated', 'model': '3d_lowres', 'train_network_trainer': 'nnUNetTrainerV2', 'prep_modalities': 'CT', 'node_uid': 'node_uid_125497293966312', 'shuffle_seed': 0, 'test_percentage': 0, 'training_description': 'nnUnet Segmentation', 'body_part': 'N/A', 'train_max_epochs': 50, 'input': 'SEG'}}}, 'job_data': {'federated': {'federated_operators': ['unet-ct', 'local-dev'], 'skip_operators': ['workflow-cleaner

KeyboardInterrupt: 