In [52]:
# %load_ext autoreload
# %autoreload 2
import os
import sys
from pathlib import Path
import uuid
import torch
from nnunet.training.model_restore import restore_model
from batchgenerators.utilities.file_and_folder_operations import join

sys.path.insert(0, '../')
sys.path.insert(0, '/executables')
from kaapana_federated.kaapana_federated import KaapanaFederatedTrainingBase, requests_retry_session


In [57]:
class nnUNetFederatedTraining(KaapanaFederatedTrainingBase):

    @staticmethod
    def get_network_trainer(folder):
        checkpoint = join(folder, "model_final_checkpoint.model")
        pkl_file = checkpoint + ".pkl"
        return restore_model(pkl_file, checkpoint, False)

    def __init__(self, run_id=None, workflow_dir=None, federated_operators=None, skip_operators=None):
        dag_id = 'nnunet-training'
        run_id = run_id or os.getenv("RUN_ID", dag_id + str(uuid.uuid4()))
        workflow_dir = workflow_dir or os.getenv('WORKFLOW_DIR', f'/appdata/dev/federated-local-workspace/{run_id}')
        federated_operators = federated_operators or ['nnunet-training']
        skip_operators = skip_operators or ["zip-unzip-training", "model2dicom", "dcmsend", "pdf2dcm-training", "dcmsend-pdf", "workflow-cleaner"]
        conf_data = KaapanaFederatedTrainingBase.get_conf(dag_id, run_id, workflow_dir, federated_operators, skip_operators)
        
        super().__init__(dag_id, conf_data, workflow_dir)
        
        # only for nnunet training!
        self.train_max_epochs_increment = self.remote_conf_data['workflow_form']['train_max_epochs']
        print(f'Training increment {self.train_max_epochs_increment}')
        
    def update_data(self, tmp_federated_site_info, federated_round):     
        trainer_dirs = []
        network_trainers = []
        print(Path(os.path.join(self.fl_working_dir, str(federated_round))))
        
        for path in Path(os.path.join(self.fl_working_dir, str(federated_round))).rglob('model_final_checkpoint.model'):
            print(path)
            trainer_dir = path.parents[0]
            trainer_dirs.append(trainer_dir)
            network_trainers.append(nnUNetFederatedTraining.get_network_trainer(trainer_dir))
        averaged_state_dict = collections.OrderedDict()
        # Average all parameters
        for key, value in network_trainers[0].network.state_dict().items():
            averaged_state_dict[key] = value
            for network_trainer in network_trainers[1:]:
                averaged_state_dict[key] =  (averaged_state_dict[key] + network_trainer.network.state_dict()[key]) / 2.
        for trainer_dir, network_trainer in zip(trainer_dirs, network_trainers):
            network_trainer.network.load_state_dict(averaged_state_dict)
            print(f"Updating model {str(trainer_dir / 'model_final_checkpoint.model')}")
            network_trainer.save_checkpoint(str(trainer_dir / 'model_final_checkpoint.model'))
        self.remote_conf_data['workflow_form']['train_continue'] = True
        self.remote_conf_data['workflow_form']['train_max_epochs'] = self.train_max_epochs_increment * (federated_round + 2)

        print(federated_round, self.remote_conf_data['federated_form']['federated_total_rounds'])
        print(type(federated_round), type(self.remote_conf_data['federated_form']['federated_total_rounds']))
            
kaapana_ft = nnUNetFederatedTraining()
kaapana_ft.train()

['dkfz', 'hamburg']
