In [1]:
from typing import Tuple, Callable
import torch as t
from NewTaylorAnalysisBase import model_extension
from tayloranalysis.cls import TaylorAnalysis
import numpy as np
from itertools import product as prod
from pprint import pprint

In [2]:
class SimpleModel(t.nn.Module):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        hidden_nodes: Tuple[int] = (100,),
        layer_activation: Callable = t.nn.ReLU,
        final_activation: Callable = t.nn.Sigmoid,
    ) -> None:
        super(SimpleModel, self).__init__()

        self.layers = t.nn.ModuleList([t.nn.Linear(input_dim, hidden_nodes[0])])
        self.layers.extend([t.nn.Linear(hidden_nodes[i], n) for i, n in enumerate(hidden_nodes[1:])])
        self.layers.append(t.nn.Linear(hidden_nodes[-1], output_dim))

        self.layer_activation = layer_activation()
        self.final_activation = final_activation()

    def forward(self, x: t.Tensor) -> t.Tensor:
        for layer in self.layers[:-1]:
            x = self.layer_activation(layer(x))
        x = self.final_activation(self.layers[-1](x))
        return x

In [3]:
t.manual_seed(19)
model = SimpleModel(3, 1)

In [4]:
# currently used wrapper for the extension
benchmark_model = TaylorAnalysis(model, apply_abs=True, reduction="mean")

# new approach: Add only tc_nth_order, tc_{first, second, third}_order and tc_reduce to model 
model_extension(model, reduction="abs mean")

In [5]:
# Compare results on:
t.manual_seed(19)
X = t.rand((10, 3))
X_idx = [0, 1, 2]

In [6]:
first_order_benchmark = benchmark_model.first_order(X)
first_order = model.tc_first_order(X, X_idx)
for i in X_idx:
    assert np.all(np.isclose(first_order_benchmark[i], first_order[(i,)]))
pprint(first_order)

{(0,): array([0.03742455], dtype=float32),
 (1,): array([0.03745043], dtype=float32),
 (2,): array([0.04474961], dtype=float32)}


In [7]:
second_order = model.tc_second_order(X, X_idx, X_idx)
for i, j in {tuple(sorted(list(it))) for it in prod(X_idx, X_idx)}:
    second_order_benchmark = benchmark_model.second_order(X, i)[j]
    assert np.all(np.isclose(second_order_benchmark, second_order[(i, j)]))
pprint(second_order)

{(0, 0): array([0.00013891], dtype=float32),
 (0, 1): array([0.00023917], dtype=float32),
 (0, 2): array([0.00030635], dtype=float32),
 (1, 1): array([0.00012863], dtype=float32),
 (1, 2): array([0.00033262], dtype=float32),
 (2, 2): array([0.00022999], dtype=float32)}


In [8]:
third_order = model.tc_third_order(X, X_idx, X_idx, X_idx)
not_equal_idx = []
for i, j, k in {tuple(sorted(list(it))) for it in prod(X_idx, X_idx, X_idx)}:
    third_order_benchmark = benchmark_model.third_order(X, i, j)[k]
    try:
        assert np.all(np.isclose(third_order_benchmark, third_order[(i, j, k)]))
    except AssertionError:
        not_equal_idx.append((i, j, k))
print(f"Not equal results in: {not_equal_idx}")
pprint(third_order)

Not equal results in: [(0, 1, 2)]
{(0, 0, 0): array([8.924161e-05], dtype=float32),
 (0, 0, 1): array([0.00023364], dtype=float32),
 (0, 0, 2): array([0.00022128], dtype=float32),
 (0, 1, 1): array([0.00024527], dtype=float32),
 (0, 1, 2): array([0.00046847], dtype=float32),
 (0, 2, 2): array([0.00029219], dtype=float32),
 (1, 1, 1): array([9.5188596e-05], dtype=float32),
 (1, 1, 2): array([0.00027656], dtype=float32),
 (1, 2, 2): array([0.00034676], dtype=float32),
 (2, 2, 2): array([0.00016722], dtype=float32)}


In [9]:
pprint(
    {
        idx: tc 
        for idx, tc in model.tc_third_order(X, 0, 1, 2, combine_permutations=False).items()
        if all(sub_idx in idx for sub_idx in (0, 1, 2))
    }
)

{(0, 1, 2): array([7.807805e-05], dtype=float32),
 (0, 2, 1): array([7.807803e-05], dtype=float32),
 (1, 0, 2): array([7.807805e-05], dtype=float32),
 (1, 2, 0): array([7.807803e-05], dtype=float32),
 (2, 0, 1): array([7.807803e-05], dtype=float32),
 (2, 1, 0): array([7.807803e-05], dtype=float32)}


In [10]:
assert np.isclose(
    model.tc_third_order(X, 0, 1, 2, combine_permutations=True)[(0, 1, 2)] / 6,
    benchmark_model.third_order(X, 0, 1)[2],
)

assert np.isclose(
    model.tc_third_order(X, 0, 1, 2, combine_permutations=False)[(0, 1, 2)],
    benchmark_model.third_order(X, 0, 1)[2],
)