In [None]:
%load_ext lab_black

In [None]:
from liltab.data.datasets import PandasDataset
from liltab.data.dataloaders import FewShotDataLoader, ComposedDataLoader
from pathlib import Path

In [None]:
paths = [
    Path("data/openml_original/ICU.csv"),
    # Path("data/openml_original/autoPrice.csv"),
    # Path("data/openml_original/fri_c3_250_10.csv"),
]
pd_datasets = [PandasDataset(path) for path in paths]
few_shot_datasets = [
    FewShotDataLoader(pd_dataset, 5, 27, 100) for pd_dataset in pd_datasets
]
composed_few_shot_dataset = ComposedDataLoader(
    dataloaders=few_shot_datasets, n_episodes=100
)

In [None]:
from torch import nn, Tensor
from typing import Callable


class NetworkBlock(nn.Module):
    def __init__(
        self,
        input_size: int,
        output_size: int,
        activation_function: Callable,
        dropput_rate: float,
    ):
        super().__init__()

        self.block = nn.Sequential(
            nn.Linear(input_size, output_size),
            nn.Dropout(dropput_rate),
            activation_function,
        )

    def forward(self, X: Tensor) -> Tensor:
        return self.block(X)


class FeedForwardNetwork(nn.Module):
    def __init__(
        self,
        input_size: int,
        output_size: int,
        n_hidden_layers: int,
        hidden_size: int,
        dropout_rate: float,
        inner_activation_function: Callable = nn.ReLU(),
        output_activation_function: Callable = nn.Identity(),
    ):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.n_hidden_layers = n_hidden_layers
        self.hidden_size = hidden_size
        self.dropout_rate = dropout_rate
        self.inner_activation_function = inner_activation_function
        self.output_activation_function = output_activation_function

        self.input_layer = NetworkBlock(
            self.input_size,
            self.hidden_size,
            self.inner_activation_function,
            self.dropout_rate,
        )
        self.hidden_layers = nn.ModuleList(
            [
                NetworkBlock(
                    self.hidden_size,
                    self.hidden_size,
                    self.inner_activation_function,
                    self.dropout_rate,
                )
                for _ in range(self.n_hidden_layers)
            ]
        )
        self.output_layer = NetworkBlock(
            self.hidden_size,
            self.output_size,
            self.output_activation_function,
            self.dropout_rate,
        )

    def forward(self, X: Tensor) -> Tensor:
        X = self.input_layer(X)
        for layer in self.hidden_layers:
            X = layer(X)
        X = self.output_layer(X)
        return X

In [None]:
X_support = next(composed_few_shot_dataset)[0]
X_support.shape

In [None]:
block = NetworkBlock(19, 32, nn.ReLU(), 0.1)
block(X_support).shape

In [None]:
network = FeedForwardNetwork(19, 20, 16, 32, 0.1)
network(X_support).shape

In [None]:
import torch

torch.manual_seed(0)


class HeterogenousInferenceNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.initial_support_encoding_network = FeedForwardNetwork(1, 32, 3, 32, 0.1)
        self.initial_support_representation_network = FeedForwardNetwork(
            32, 32, 3, 32, 0.1
        )

        self.interaction_encoding_network = FeedForwardNetwork(33, 32, 3, 32, 0.1)
        self.interaction_representation_network = FeedForwardNetwork(32, 32, 3, 32, 0.1)

        self.attributes_encoding_network = FeedForwardNetwork(33, 32, 3, 32, 0.1)
        self.attributes_representation_network = FeedForwardNetwork(32, 32, 3, 32, 0.1)

        self.responses_encoding_network = FeedForwardNetwork(33, 32, 3, 32, 0.1)
        self.responses_representation_network = FeedForwardNetwork(32, 32, 3, 32, 0.1)

    def forward(self, X_support: Tensor, y_support: Tensor, X_query: Tensor) -> Tensor:
        attributes_initial_representation = self.create_initial_features_representation(
            self.initial_encoding_network,
            self.initial_representation_network,
            X_support,
        )
        responses_initial_representation = self.create_initial_features_representation(
            self.initial_encoding_network,
            self.initial_representation_network,
            y_support,
        )
        support_set_representation = self.create_support_set_representation(
            self.interaction_encoding_network,
            self.interaction_representation_network,
            X_support,
            attributes_initial_representation,
            y_support,
            responses_initial_representation,
        )

        return support_set_representation

    def create_initial_features_representation(
        self,
        encoder_network: FeedForwardNetwork,
        representation_network: FeedForwardNetwork,
        X: Tensor,
    ) -> Tensor:
        initial_tensor_shape = X.shape

        X = X.reshape(-1, 1)
        X = encoder_network(X)
        X = X.reshape(*initial_tensor_shape, 32)
        X = X.mean(axis=0)

        return representation_network(X)

    def create_support_set_representation(
        self,
        interaction_encoding_network: FeedForwardNetwork,
        interaction_representation_network: FeedForwardNetwork,
        X: Tensor,
        attributes_initial_representation: Tensor,
        y: Tensor,
        responses_initial_representation: Tensor,
    ) -> Tensor:
        network_input = attributes_initial_representation.repeat([5, 1]).reshape(
            -1, 19, 32
        )
        network_input = torch.concat([network_input, torch.unsqueeze(X, 2)], 2)
        network_input = network_input.reshape(-1, 33)
        atrributes_encoded = interaction_encoding_network(network_input)
        atrributes_encoded = atrributes_encoded.reshape(5, -1, 32).mean(axis=1)

        network_input = responses_initial_representation.repeat([5, 1]).reshape(
            -1, 1, 32
        )
        network_input = torch.concat([network_input, torch.unsqueeze(y, 2)], 2)
        network_input = network_input.reshape(-1, 33)
        responses_encoded = interaction_encoding_network(network_input)
        responses_encoded = responses_encoded.reshape(5, -1, 32).mean(axis=1)

        return interaction_representation_network(
            atrributes_encoded + responses_encoded
        )

    def create_features_representation(
        self,
        features_encoding_network: FeedForwardNetwork,
        features_representation_network: FeedForwardNetwork,
        set_: Tensor,
        set_representation: Tensor,
    ):
        n_examples, n_features = set_.shape
        network_input = torch.concat(
            [
                set_representation.repeat(n_features, 1).reshape(
                    n_features, n_examples, -1
                ),
                torch.unsqueeze(set_.T, 2),
            ],
            axis=2,
        )
        attributes_encoded = features_encoding_network(network_input.reshape(-1, 33))
        attributes_encoded = attributes_encoded.reshape(n_features, -1, 32).mean(axis=1)
        return features_representation_network(attributes_encoded)

In [None]:
X_support, y_support, X_query, y_query = next(composed_few_shot_dataset)

network = HeterogenousInferenceNetwork()

attributes_initial_representation = network.create_initial_features_representation(
    network.initial_support_encoding_network,
    network.initial_support_representation_network,
    X_support,
)

responses_initial_representation = network.create_initial_features_representation(
    network.initial_support_encoding_network,
    network.initial_support_representation_network,
    y_support,
)
support_set_representation = network.create_support_set_representation(
    network.interaction_encoding_network,
    network.interaction_representation_network,
    X_support,
    attributes_initial_representation,
    y_support,
    responses_initial_representation,
)
attributes_representation = network.create_features_representation(
    network.attributes_encoding_network,
    network.attributes_representation_network,
    X_support,
    support_set_representation,
)
responses_representation = network.create_features_representation(
    network.responses_encoding_network,
    network.responses_representation_network,
    y_support,
    support_set_representation,
)

In [None]:
inference_encoding_network = FeedForwardNetwork(33, 32, 3, 32, 0.1)
inference_representation_network = FeedForwardNetwork(32, 32, 3, 32, 0.1)
inference_network = FeedForwardNetwork(64, 1, 3, 32, 0.1)

In [None]:
network_input = attributes_representation.repeat([X_query.shape[0], 1]).reshape(
    -1, 19, 32
)
network_input = torch.concat([network_input, torch.unsqueeze(X_query, 2)], 2)
network_input = network_input.reshape(-1, 33)
X_query_inference_embedding = inference_encoding_network(network_input)
X_query_inference_embedding = X_query_inference_embedding.reshape(
    X_query.shape[0], -1, 32
).mean(axis=1)

In [None]:
inference_network_input = X_query_inference_embedding.repeat(
    responses_representation.shape[0], 1
).reshape(responses_representation.shape[0], X_query.shape[0], 32)
inference_network_input = torch.concat(
    [
        inference_network_input,
        responses_representation.unsqueeze(1).repeat(1, X_query.shape[0], 1),
    ],
    axis=2,
)
inference_network_input = inference_network_input.reshape(-1, 64)
response = inference_network(inference_network_input)
response = response.reshape(27, -1)

In [None]:
response.shape