In [4]:
class Splitter:
    
    def __init__(self, dataset, shuffle=True, 
                 train_rate=0.8, valid_rate=0.1, dataset_share=1,
                 batch_size=2048, dataloader_workers_count=0):        
                   
        dataset_length = len(dataset)

        train_length = int(dataset_length * train_rate)
        valid_length = int(dataset_length * valid_rate)

        test_length = dataset_length - train_length - valid_length

        dataset_indices = np.arange(dataset_length, dtype=int)
        
        if shuffle:        
            np.random.shuffle(dataset_indices)
        
        train_end = int(train_length * dataset_share)
        valid_end = int(train_end + valid_length * dataset_share)
        test_end = int(valid_end + test_length * dataset_share)
        
        index_ranges = np.split(dataset_indices, (train_end, valid_end, test_end))[:3]     
            
        train_sampler, valid_sampler, test_sampler = map(SequentialValueSampler, index_ranges)
        
        self.dataset = dataset
        self.batch_size = batch_size
        self.lengths = (train_length, valid_length, test_length)
        
        dataloader = torch.utils.data.DataLoader

        self.train_data_loader = dataloader(dataset,
                                            batch_size=batch_size,
                                            sampler=train_sampler,
                                            num_workers=dataloader_workers_count)
        
        self.valid_data_loader = dataloader(dataset,                                  
                                            batch_size=batch_size,
                                            sampler=valid_sampler,
                                            num_workers=dataloader_workers_count)
        
        self.test_data_loader = dataloader(dataset,
                                           batch_size=batch_size,
                                           sampler=test_sampler,
                                           num_workers=dataloader_workers_count)

        
class SequentialValueSampler(torch.utils.data.Sampler):

    def __init__(self, values):
        self.values = values

    def __iter__(self):
        return iter(self.values)

    def __len__(self):
        return len(self.values)

In [1]:
class Trainer:

    def __init__(self,
                 splitter,
                 model=None, 
                 metric=None, 
                 optimizer=None,
                 criterion=None,
                 learning_rate=0.001, 
                 weight_decay=1e-6,
                 embedding_dimensions=16,
                 device='cpu'): 
        
        field_dimensions = splitter.dataset.field_dimensions
        
        model = model or OneHotFactorizationMachine(field_dimensions=field_dimensions,
                                                    embedding_dimensions=embedding_dimensions)
        
        metric = metric or sklearn.metrics.r2_score

        optimizer = optimizer or torch.optim.Adam(params=model.parameters(), 
                                                  lr=learning_rate, 
                                                  weight_decay=weight_decay)

        criterion = criterion or torch.nn.MSELoss()  
        
        self.epochs_passed = 0
        
        self.device = torch.device(device)
        
        self.model = model.to(device)
        self.metric = metric
        self.splitter = splitter
        self.optimizer = optimizer
        self.criterion = criterion
        
        self.validation_scores = []
        
        # TODO: EarlyStopper
    
    
    def train(self,
              epochs=10, 
              validate=True,
              disable_progressbar_printout=False):        
        
        batch_size = self.splitter.batch_size
        
        train_batch_count = len(self.splitter.train_data_loader)
        
        records_count = epochs * train_batch_count * batch_size  
        
        fit_batch_tracker = tqdm.trange(
            records_count,
            unit=' records',
            unit_scale=True,
            # ascii=True,
            ncols=110,
            mininterval=1,
            disable=disable_progressbar_printout,
        )
        
        epochs_total = self.epochs_passed + epochs
        
        for epoch in range(epochs):         

            fit_batch_tracker.set_description(f"Epoch: {self.epochs_passed + 1}/{epochs_total}")
            
            
            # Train Part

            interval_loss = 0

            for batch, (fields, targets) in enumerate(self.splitter.train_data_loader):
                
                loss = self.fit(fields, targets)
                
                interval_loss += loss
                
                # print(batch)
                
                fit_batch_tracker.update(batch_size)                    
            
            
            if validate:  
                
                validation_predictions, validation_score = self.batch_predict(self.splitter.valid_data_loader)                
                
                fit_batch_tracker.set_postfix(v=f"{validation_score:.02f}")

                self.validation_scores.append(validation_score)
                
            self.epochs_passed += 1
         
        train_predictions, train_score = self.batch_predict(self.splitter.train_data_loader)
        test_predictions, test_score = self.batch_predict(self.splitter.test_data_loader)
        
        print(f"Train {self.metric.__name__}: {train_score:.02f}")            
        print(f"Test  {self.metric.__name__}: {test_score:.02f}")
    
    
    def fit(self, fields, targets):

        self.model.train()
            
        predictions = self.model(fields)
            
        loss = self.criterion(predictions, targets.float())
            
        self.model.zero_grad()

        loss.backward()

        self.optimizer.step()

        return loss.item()
                
                
    def batch_predict(self, data_loader):
    
        targets = []
        predictions = []

        for fields, target in data_loader:

            prediction = self.predict(fields)

            targets.extend(target.tolist())
            predictions.extend(prediction.tolist())

        score = self.metric(targets, predictions)
        
        return predictions, score
    
    
    def predict(self, fields):
    
        self.model.eval()

        with torch.no_grad():
            
            if not isinstance(fields, torch.Tensor):
                
                fields = torch.tensor(fields)

            predictions = self.model(fields).numpy()
        
        return predictions