In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torch.nn.functional as F

class DifferentiableDecisionTree(nn.Module):
    def __init__(self, input_dim, output_dim, max_depth=3):
        super(DifferentiableDecisionTree, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.max_depth = max_depth

        # Parameters for decision rules
        self.feature_threshold = nn.Parameter(torch.randn(max_depth, input_dim))
        self.left_weights = nn.Parameter(torch.randn(max_depth, output_dim, input_dim))
        self.right_weights = nn.Parameter(torch.randn(max_depth, output_dim, input_dim))


    def forward(self, x):
        # Initialize leaf node values to zero
        leaf_values = torch.zeros(x.shape[0], self.output_dim)

        for d in range(self.max_depth):
            # Compute decision rule
            decision = x[:, None, :] < self.feature_threshold[d]
            # Apply decision rule to update leaf node values
            left_values = torch.mul(decision.float(), self.left_weights[d]).sum(dim=-1)
            right_values = torch.mul((1 - decision.float()), self.right_weights[d]).sum(dim=-1)
            leaf_values = leaf_values + left_values + right_values

        return leaf_values



    # def forward(self, x):
    #     # Initialize leaf node values to zero
    #     batch_size = x.shape[0]
    #     leaf_values = torch.zeros(batch_size, self.output_dim)

    #     for d in range(self.max_depth):
    #         # Compute decision rule
    #         decision = torch.sigmoid(x @ self.feature_threshold[d])

    #         # Apply decision rule to update leaf node values
    #         left_values = torch.matmul(decision.unsqueeze(1), self.left_weights[d]).squeeze(1)
    #         right_values = torch.matmul((1 - decision).unsqueeze(1), self.right_weights[d]).squeeze(1)
    #         leaf_values = leaf_values + decision * left_values + (1 - decision) * right_values

    #     return leaf_values



In [21]:
np.random.seed(42)
X = np.random.rand(1000, 4)
y = np.random.randint(0, 2, size=1000)

# Convert data to PyTorch tensors
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.long)

# Convert the target tensor to torch.long data type
# y_onehot = torch.zeros(len(y_tensor), 4)
# y_onehot.scatter_(1, y_tensor.view(-1, 1), 1)

# Define model, loss function, and optimizer
model = DifferentiableDecisionTree(input_dim=4, output_dim=2)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)



In [24]:

# Training loop
num_epochs = 1000
for epoch in range(num_epochs):
    # Forward pass
    outputs = model(X_tensor)
    loss = criterion(outputs, y_tensor)
    
    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (epoch+1) % 100 == 0:
        with torch.no_grad():
            outputs = model(X_tensor)
            _, predicted = torch.max(outputs, 1)
            accuracy = (predicted == y_tensor).sum().item() / y_tensor.size(0)
            # print()
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}', f'Accuracy: {accuracy:.2f}')

Epoch [100/1000], Loss: 0.7491 Accuracy: 0.49
Epoch [200/1000], Loss: 0.7268 Accuracy: 0.49
Epoch [300/1000], Loss: 0.7139 Accuracy: 0.53
Epoch [400/1000], Loss: 0.7065 Accuracy: 0.53
Epoch [500/1000], Loss: 0.7022 Accuracy: 0.52
Epoch [600/1000], Loss: 0.6995 Accuracy: 0.52
Epoch [700/1000], Loss: 0.6978 Accuracy: 0.52
Epoch [800/1000], Loss: 0.6967 Accuracy: 0.52
Epoch [900/1000], Loss: 0.6958 Accuracy: 0.52
Epoch [1000/1000], Loss: 0.6951 Accuracy: 0.52
