In [52]:
# %load_ext autoreload
# %autoreload 2
import os
import sys
from pathlib import Path
import uuid
import torch
import collections
from nnunet.training.model_restore import restore_model
from batchgenerators.utilities.file_and_folder_operations import join

sys.path.insert(0, '../')
sys.path.insert(0, '/executables')
from kaapana_federated.kaapana_federated import KaapanaFederatedTrainingBase, requests_retry_session


In [57]:
class nnUNetFederatedTraining(KaapanaFederatedTrainingBase):

    @staticmethod
    def get_network_trainer(folder):
        checkpoint = join(folder, "model_final_checkpoint.model")
        pkl_file = checkpoint + ".pkl"
        return restore_model(pkl_file, checkpoint, False)

    def __init__(self, run_id=None, workflow_dir=None, federated_operators=None, skip_operators=None):
        dag_id = 'nnunet-training'
        run_id = run_id or os.getenv("RUN_ID", dag_id + str(uuid.uuid4()))
        workflow_dir = workflow_dir or os.getenv('WORKFLOW_DIR', f'/appdata/dev/federated-local-workspace/{run_id}')
        federated_operators = federated_operators or ['nnunet-training']
        skip_operators = skip_operators or ["zip-unzip-training", "model2dicom", "dcmsend", "pdf2dcm-training", "dcmsend-pdf", "workflow-cleaner"]
        conf_data = KaapanaFederatedTrainingBase.get_conf(dag_id, run_id, workflow_dir, federated_operators, skip_operators)
        
        super().__init__(dag_id, conf_data, workflow_dir)
        
        if self.remote_conf_data['workflow_form']['train_max_epochs'] % self.remote_conf_data['federated_form']['federated_total_rounds'] != 0:
            raise ValueError('train_max_epochs has to be multiple of federated_total_rounds')
        else:
            self.remote_conf_data['workflow_form']['epochs_per_round'] = int(self.remote_conf_data['workflow_form']['train_max_epochs'] / self.remote_conf_data['federated_form']['federated_total_rounds'])

        print(f"Epochs per round {self.remote_conf_data['workflow_form']['epochs_per_round']}")
        
    def update_data(self, tmp_federated_site_info, federated_round):     
        print(Path(os.path.join(self.fl_working_dir, str(federated_round))))
        models_path = Path(os.path.join(self.fl_working_dir, str(federated_round)))
        averaged_state_dict = collections.OrderedDict()
        averaged_amp_grad_scaler = dict()
        print('Loading averaged checkpoints')
        for idx, fname in enumerate(models_path.rglob('model_final_checkpoint.model')):
            print(fname)
            checkpoint = torch.load(fname, map_location=torch.device('cpu'))
            if idx==0:
                for key, value in checkpoint['state_dict'].items():
                    averaged_state_dict[key] = value
                if 'amp_grad_scaler' in checkpoint.keys():
                    for key, value in checkpoint['amp_grad_scaler'].items():
                        averaged_amp_grad_scaler[key] = value 
            else:
                for key, value in checkpoint['state_dict'].items():
                    averaged_state_dict[key] =  (averaged_state_dict[key] + checkpoint['state_dict'][key]) / 2.
                if 'amp_grad_scaler' in checkpoint.keys():
                    for key, value in checkpoint['amp_grad_scaler'].items():
                        averaged_amp_grad_scaler[key] = (averaged_amp_grad_scaler[key] + checkpoint['amp_grad_scaler'][key]) / 2.

        print('Saving averaged checkpoints')
        for idx, fname in enumerate(models_path.rglob('model_final_checkpoint.model')):
            print(fname)
            checkpoint['state_dict'] = averaged_state_dict
            if 'amp_grad_scaler' in checkpoint.keys():
                checkpoint['amp_grad_scaler'] = averaged_amp_grad_scaler
            torch.save(checkpoint, fname)

        self.remote_conf_data['workflow_form']['train_continue'] = True
        print(federated_round, self.remote_conf_data['federated_form']['federated_total_rounds'])
            
kaapana_ft = nnUNetFederatedTraining()
kaapana_ft.train()

['dkfz', 'hamburg']
