In [None]:

class BoosterTrainer(TrainerModule):
    '''run through the whole process of training'''
    def init_data(self , **kwargs): 
        self.data : BoosterDataModule = BoosterDataModule(self.config)

    def batch_forward(self) -> None: 
        if self.status.dataset == 'train': self.booster.fit(silence=True)
        self.batch_output = self.booster.predict(self.status.dataset)

    def batch_metrics(self) -> None:
        self.metrics.calculate_from_tensor(self.status.dataset , self.batch_output.other['label'] , self.batch_output.pred, assert_nan = True)
        self.metrics.collect_batch_metric()

    def batch_backward(self) -> None: ...

    def fit_model(self):
        self.on_fit_model_start()
        for self.batch_idx , self.batch_data in enumerate(zip(self.data.train_dataloader() , self.data.val_dataloader())):
            self.status.dataset_train()
            self.on_train_batch_start()
            self.on_train_batch()
            self.on_train_batch_end()

            self.status.dataset_validation()
            self.on_validation_batch_start()
            self.on_validation_batch()
            self.on_validation_batch_end()

        self.on_fit_model_end()

    def test_model(self):
        self.on_test_model_start()
        self.status.dataset_test()
        for self.status.model_type in self.model_types:
            self.on_test_model_type_start()
            for self.batch_idx , self.batch_data in enumerate(self.data.test_dataloader()):
                self.on_test_batch_start()
                self.on_test_batch()
                self.on_test_batch_end()
            self.on_test_model_type_end()
        self.on_test_model_end()

    def on_configure_model(self):  self.config.set_config_environment()
    def on_fit_model_start(self):
        self.data.setup('fit' , self.model_param , self.model_date)
        self.load_model(True)
    def on_fit_model_end(self): self.save_model()
    
    def on_train_batch_start(self):
        self.metrics.new_epoch_metric('train' , self.status)

    def on_train_batch_end(self): 
        self.metrics.collect_epoch_metric('train')

    def on_validation_batch_start(self):
        self.metrics.new_epoch_metric('valid' , self.status)

    def on_validation_batch_end(self): 
        self.metrics.collect_epoch_metric('valid')

    def on_test_model_type_start(self):
        self.load_model(False , self.model_type)
        self.metrics.new_epoch_metric('test' , self.status)

    def on_test_model_type_end(self): 
        self.metrics.collect_epoch_metric('test')

    def on_test_model_start(self):
        if not self.deposition.exists(self.model_date , self.model_num , self.model_type): self.fit_model()
        self.data.setup('test' , self.model_param , self.model_date)

    def on_test_batch(self):
        if self.batch_idx < self.batch_warm_up: return
        self.batch_forward()
        self.batch_metrics()
    
    def load_model(self , training : bool , *args , **kwargs):
        '''load model state dict, return net and a sign of whether it is transferred'''
        model_file = self.deposition.load_model(self.model_date , self.model_num)
        self.booster : BoosterModel = self.model.new_model(training , model_file).model()
        self.metrics.new_model(self.model_param)

    def stack_model(self):
        self.on_before_save_model()
        for model_type in self.model_types:
            model_dict = self.model.collect(model_type)
            self.deposition.stack_model(model_dict , self.model_date , self.model_num , model_type) 

    def save_model(self):
        self.stack_model()
        for model_type in self.model_types:
            self.deposition.dump_model(self.model_date , self.model_num , model_type) 
    
    def __call__(self , input : BoosterInput): raise Exception('Undefined call')



In [None]:

class NetTrainer(TrainerModule):
    '''run through the whole process of training'''
    def init_data(self , **kwargs): 
        self.data : NetDataModule = NetDataModule(self.config)
    def batch_forward(self) -> None: 
        self.batch_output = self(self.batch_data)
    def batch_metrics(self) -> None:
        if isinstance(self.batch_data , BatchData) and self.batch_data.is_empty: return
        self.metrics.calculate(self.status.dataset , self.batch_data, self.batch_output, self.net, assert_nan = True)
        self.metrics.collect_batch_metric()
    def batch_backward(self) -> None:
        if isinstance(self.batch_data , BatchData) and self.batch_data.is_empty: return
        assert self.status.dataset == 'train' , self.status.dataset
        self.on_before_backward()
        self.optimizer.backward(self.metrics.output)
        self.on_after_backward()

    def fit_model(self):
        self.on_fit_model_start()
        for self.batch_idx , self.batch_data in enumerate(zip(self.data.train_dataloader() , self.data.val_dataloader())):
            self.status.dataset_train()
            self.on_train_batch_start()
            self.on_train_batch()
            self.on_train_batch_end()

            self.status.dataset_validation()
            self.on_validation_batch_start()
            self.on_validation_batch()
            self.on_validation_batch_end()

        self.on_fit_model_end()

        self.status.fit_model_start()
        self.on_fit_model_start()
        while not self.status.end_of_loop:
            self.status.fit_epoch_start()
            self.on_fit_epoch_start()

            self.status.dataset_train()
            self.on_train_epoch_start()
            for self.batch_idx , self.batch_data in enumerate(self.dataloader):
                self.on_train_batch_start()
                self.on_train_batch()
                self.on_train_batch_end()
            self.on_train_epoch_end()

            self.status.dataset_validation()
            self.on_validation_epoch_start()
            for self.batch_idx , self.batch_data in enumerate(self.dataloader):
                self.on_validation_batch_start()
                self.on_validation_batch()
                self.on_validation_batch_end()
            self.on_validation_epoch_end()

            self.on_before_fit_epoch_end()
            self.status.fit_epoch_end()
            self.on_fit_epoch_end()
        self.on_fit_model_end()

    def test_model(self):
        self.on_test_model_start()
        self.status.dataset_test()
        for self.status.model_type in self.model_types:
            self.on_test_model_type_start()
            for self.batch_idx , self.batch_data in enumerate(self.dataloader):
                self.on_test_batch_start()
                self.on_test_batch()
                self.on_test_batch_end()
            self.on_test_model_type_end()
        self.on_test_model_end()

    def on_configure_model(self):  self.config.set_config_environment()
    def on_fit_model_start(self):
        self.data.setup('fit' , self.model_param , self.model_date)
        self.load_model(True)
    def on_fit_model_end(self): self.save_model()
    
    def on_train_epoch_start(self):
        self.net.train()
        torch.set_grad_enabled(True)
        self.dataloader = self.data.train_dataloader()
        self.metrics.new_epoch_metric('train' , self.status)
    
    def on_train_epoch_end(self): 
        self.metrics.collect_epoch_metric('train')
        self.optimizer.scheduler_step(self.status.epoch)
    
    def on_validation_epoch_start(self):
        self.net.eval()
        torch.set_grad_enabled(False)
        self.dataloader = self.data.val_dataloader()
        self.metrics.new_epoch_metric('valid' , self.status)
    
    def on_validation_epoch_end(self):
        self.metrics.collect_epoch_metric('valid')
        self.model.assess(self.status.epoch , self.metrics)
        torch.set_grad_enabled(True)
    
    def on_test_model_start(self):
        if not self.deposition.exists(self.model_date , self.model_num , self.model_type): self.fit_model()
        self.data.setup('test' , self.model_param , self.model_date)
        torch.set_grad_enabled(False)

    def on_test_model_end(self):
        torch.set_grad_enabled(True)
    
    def on_test_model_type_start(self):
        self.load_model(False , self.model_type)
        self.dataloader = self.data.test_dataloader()
        self.assert_equity(len(self.dataloader) , len(self.batch_dates))
        self.metrics.new_epoch_metric('test' , self.status)
    
    def on_test_model_type_end(self): 
        self.metrics.collect_epoch_metric('test')

    def on_test_batch(self):
        self.assert_equity(self.batch_dates[self.batch_idx] , self.data.y_date[self.batch_data.i[0,1]]) 
        self.batch_forward()
        self.model.override()
        # before this is warmup stage , only forward
        if self.batch_idx < self.batch_warm_up: return
        self.batch_metrics()

    def on_before_save_model(self):
        self.net = self.net.cpu()
    
    def load_model(self , training : bool , model_type = 'best' , lr_multiplier = 1.):
        '''load model state dict, return net and a sign of whether it is transferred'''
        model_date = (self.prev_model_date if self.if_transfer else 0) if training else self.model_date
        model_file = self.deposition.load_model(model_date , self.model_num , model_type)
        self.transferred = training and self.if_transfer and model_file.exists()
        self.model = self.model.new_model(training , model_file)
        self.net : torch.nn.Module = self.model.model(model_file['state_dict'])
        self.metrics.new_model(self.model_param)
        if training:
            self.optimizer : Optimizer = Optimizer(self.net , self.config , self.transferred , lr_multiplier ,
                                                   model_module = self)
            self.checkpoint.new_model(self.model_param , self.model_date)
        else:
            assert model_file.exists() , str(model_file)
            self.net.eval()

    def stack_model(self):
        self.on_before_save_model()
        for model_type in self.model_types:
            model_dict = self.model.collect(model_type)
            self.deposition.stack_model(model_dict , self.model_date , self.model_num , model_type) 

    def save_model(self):
        if self.metrics.better_attempt(self.status.best_attempt_metric): self.stack_model()
        [self.deposition.dump_model(self.model_date , self.model_num , model_type) for model_type in self.model_types]