In [4]:
# Define model
import torch
import torch.nn as nn

NUM_FEATURES = 0


# Network that takes in all the input features and outputs a vector of qualities
class QualityPredictor(nn.Module):
    def __init__(self, input_size, output_size, hidden_layer_sizes):
        super(QualityPredictor, self).__init__()

        # Create input layer
        self.input_layer = nn.Linear(input_size, hidden_layer_sizes[0])

        # Create hidden layers dynamically based on args
        hidden_layers = []
        for i in range(1, len(hidden_layer_sizes)):
            hidden_size = hidden_layer_sizes[i]
            hidden_layers.append(nn.Linear(hidden_layer_sizes[i - 1], hidden_size))
        self.hidden_layers = nn.ModuleList(hidden_layers)
        self.output_layer = nn.Linear(hidden_layer_sizes[-1], output_size)

    def forward(self, x):
        x = torch.relu(self.input_layer(x))
        for hidden_layer in self.hidden_layers:
            x = torch.relu(hidden_layer(x))
        x = self.output_layer(x)
        return x


class SatisfactionPredictor(nn.Module):
    def __init__(self, hidden_layer_sizes):
        super(SatisfactionPredictor, self).__init__()

        # Create input layer
        self.input_layer = nn.Linear(2, hidden_layer_sizes[0])

        # Create hidden layers dynamically based on args
        hidden_layers = []
        for i in range(1, len(hidden_layer_sizes)):
            hidden_size = hidden_layer_sizes[i]
            hidden_layers.append(nn.Linear(hidden_layer_sizes[i - 1], hidden_size))
        self.hidden_layers = nn.ModuleList(hidden_layers)
        self.output_layer = nn.Linear(hidden_layer_sizes[-1], 1)

    def forward(self, x):
        x = torch.relu(self.input_layer(x))
        for hidden_layer in self.hidden_layers:
            x = torch.relu(hidden_layer(x))
        x = self.output_layer(x)
        return x


class CausalModel(nn.Module):
    def __init__(self, input_size, num_quality_scores, quality_predictor_hidden_layer_sizes,
                 satisfaction_predictor_hidden_layer_sizes, activation_on_quality):
        super(CausalModel, self).__init__()
        self.num_quality_scores = num_quality_scores

        self.quality_predictor_net = QualityPredictor(input_size, num_quality_scores,
                                                      quality_predictor_hidden_layer_sizes)

        self.activation_on_quality = activation_on_quality

        self.satisfaction_predictors = []
        for i in range(num_quality_scores):
            self.satisfaction_predictors.append(SatisfactionPredictor(satisfaction_predictor_hidden_layer_sizes))

    def forward(self, x):
        qualities = self.quality_predictor_net(x)
        qualities = self.activation_on_quality(qualities)

        satisfactions = torch.zeros(self.num_quality_scores, 1)
        for i, sat_predictor in enumerate(self.satisfaction_predictors):
            satisfactions[i] = sat_predictor(qualities[i])

        return satisfactions