In [2]:
class PortfolioDataset(torch.utils.data.Dataset):

    def __init__(self, dataset_path):        
        
        # Data
        
        self.data = pd.read_csv(dataset_path)
        
        data = self.data.to_numpy()
                        
        self.fields = data[:, :-4].astype(int)
        self.targets = data[:, -1]
        
        self.field_dimensions = np.max(self.fields, axis=0).astype(int) + 1
        
        self.field_dimensions[self.field_dimensions < 2] = 2
        
        
    def __len__(self):
        
        return self.fields.shape[0]
    

    def __getitem__(self, index):
        
        fields = self.fields[index]
        target = self.targets[index].squeeze()
        
        return fields, target
    
    def plot(self):

        top_sharpe_index = self.data.sharpe.idxmax()

        top_sharpe_row = self.data.iloc[top_sharpe_index]

        top_return = top_sharpe_row['return']
        top_risk = top_sharpe_row['risk']

        plt.figure(figsize=(12, 8))

        plt.title('Markowitz portfolio (Combinations of all portfolio selections)')
        plt.xlabel('Volatility - standard deviation')
        plt.ylabel('Return')

        plt.scatter(self.data.risk, self.data['return'], c=self.data.sharpe, cmap='viridis')
        plt.scatter(top_risk, top_return, c='red', s=50, marker=5)  
        plt.colorbar(label='Sharpe Ratio')

        plt.show()

In [4]:
class Splitter:
    
    def __init__(self, dataset, shuffle=True, 
                 train_rate=0.8, valid_rate=0.1, dataset_share=1,
                 batch_size=2048, dataloader_workers_count=8):
        
        self.dataset = dataset
                    
        dataset_length = len(dataset)

        train_length = int(dataset_length * train_rate)
        valid_length = int(dataset_length * valid_rate)

        test_length = dataset_length - train_length - valid_length

        dataset_indices = np.arange(dataset_length, dtype=int)
        
        if shuffle:        
            np.random.shuffle(dataset_indices)
        
        train_end = int(train_length * dataset_share)
        valid_end = int(train_end + valid_length * dataset_share)
        test_end = int(valid_end + test_length * dataset_share)
        
        self.lengths = (train_end, valid_end, test_end)
        
        # print("train_length, train_end:", train_length, train_end)
        # print("valid_length, valid_end:", valid_length, valid_end)
        # print("test_length, test_end:", test_length, test_end)
        
        index_ranges = np.split(dataset_indices, (train_end, valid_end, test_end))[:3]     
            
        train_sampler, valid_sampler, test_sampler = map(SequentialValueSampler, index_ranges)

        self.train_data_loader = DataLoader(dataset,
                                            batch_size=batch_size,
                                            sampler=train_sampler,
                                            num_workers=dataloader_workers_count)
        
        self.valid_data_loader = DataLoader(dataset,                                  
                                            batch_size=batch_size,
                                            sampler=valid_sampler,
                                            num_workers=dataloader_workers_count)
        
        self.test_data_loader = DataLoader(dataset,
                                           batch_size=batch_size,
                                           sampler=test_sampler,
                                           num_workers=dataloader_workers_count)

        
class SequentialValueSampler(torch.utils.data.Sampler):

    def __init__(self, values):
        self.values = values

    def __iter__(self):
        return iter(self.values)

    def __len__(self):
        return len(self.values)

In [8]:
class Trainer:

    def __init__(self, splitter,
                 learning_rate=0.001, weight_decay=1e-6,
                 embedding_dimensions=16,
                 model=None, criterion=None, metric=None, optimizer=None,
                 device='cpu', batch_logging_interval=100): 
        
        criterion = criterion or torch.nn.MSELoss()   

        metric = metric or sklearn.metrics.r2_score
        
        field_dimensions = splitter.dataset.field_dimensions
        
        model = model or CustomFactorizationMachine(field_dimensions, 
                                                    embedding_dimensions=embedding_dimensions)

        optimizer = torch.optim.Adam(params=model.parameters(), 
                                     lr=learning_rate, 
                                     weight_decay=weight_decay)
        
        self.device = torch.device(device)
        
        self.model = model.to(device)
        self.criterion = criterion
        self.metric = metric
        self.optimizer = optimizer
        self.splitter = splitter
        
        self.batch_logging_interval = batch_logging_interval
        
        self.train_data_loader = self.splitter.train_data_loader
        self.valid_data_loader = self.splitter.valid_data_loader
        self.test_data_loader = self.splitter.test_data_loader
        
        self.epoch_scores = []
        
        # TODO: EarlyStopper
    
    
    def fit(self, epochs=10, 
            validate=True, test=True, 
            disable_progressbar_printout=False):
        
        epoch_offset = len(self.epoch_scores)
        
        batches_count = epochs * len(self.splitter.train_data_loader)       
        
        fit_batch_tracker = tqdm.trange(
            batches_count,
            unit=' batches',
            ncols=110,
            mininterval=1,
            disable=disable_progressbar_printout,
        ) 
        
        for epoch in range(epochs):
            
            current_epoch = epoch_offset + epoch + 1
            max_epochs = epoch_offset + epochs

            fit_batch_tracker.set_description(f"Epoch: {current_epoch}/{max_epochs}")     
            
            self.train(fit_batch_tracker)
            
            validation_score = 0
            
            if validate:                
                validation_score, validation_predictions = self.test(self.valid_data_loader)                
                fit_batch_tracker.set_postfix(r2=f"{validation_score:.02f}")

            self.epoch_scores.append(validation_score)
                
        if test:            
            total_score, total_predictions = self.test(self.test_data_loader)        
            print(f"Test {self.metric.__name__}: {total_score:.05f}")
    
    
    def train(self, fit_batch_tracker):

        self.model.train()

        interval_loss = 0

        for batch, (fields, target) in enumerate(self.splitter.train_data_loader):
            
            predictions = self.model(fields)
            
            loss = self.criterion(predictions, target.float())
            
            self.model.zero_grad()

            loss.backward()

            self.optimizer.step()

            interval_loss += loss.item()

            if not (batch + 1) % self.batch_logging_interval:

                average_loss = interval_loss / self.batch_logging_interval

                # fit_batch_tracker.set_postfix(loss=f"{average_loss:1.05f}")
                
                fit_batch_tracker.update(self.batch_logging_interval)
                
                interval_loss = 0
                
                
    def test(self, data_loader):
        
        self.model.eval()
    
        targets = []
        predictions = []

        with torch.no_grad():

            for fields, target in data_loader:

                prediction = self.model(fields)

                targets.extend(target.tolist())
                predictions.extend(prediction.tolist())

        score = self.metric(targets, predictions)
        
        return score, predictions
    

    def predict(self, indices):   
        
        encoded_fields = self.splitter.dataset.fields[indices]
        targets = self.splitter.dataset.targets[indices]

        self.model.eval()

        with torch.no_grad():

            predictions = self.model(torch.tensor(encoded_fields))
        
        return targets, predictions