In [26]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
B = np.array([ 
    [1.0, 0.8, 0.2, 0.0],
    [0.8, 1.0, 0.3, 0.1],
    [0.2, 0.3, 1.0, 0.6],
    [0.0, 0.1, 0.6, 1.0]
])

In [4]:
T = B.shape[0]     # Longueur des séquences
D = 50         # Dimension des vecteurs X^t
N = 1000           # Nombre d'échantillons

In [5]:
def factorise(B, tol=1e-8):
    # Vérification symétrie
    if not np.allclose(B, B.T, atol=tol):
        raise ValueError("La matrice B n'est pas symétrique.")

    # Décomposition spectrale
    eigvals, eigvecs = np.linalg.eigh(B)

    # Vérification semi-définie positive
    if np.any(eigvals < -tol):
        raise ValueError("La matrice B n'est pas semi-définie positive.")

    # Construction de C = U Λ^{1/2}
    sqrt_eigvals = np.sqrt(np.clip(eigvals, 0, None))
    C = eigvecs @ np.diag(sqrt_eigvals)

    return C

In [15]:

def generate_X(C, N, D):
    T = C.shape[0]
    Z = np.random.randn(N, D, T)  # Z ~ N(0, I)
    X = np.einsum('ndt,ts->nds', Z, C.T)  # équivalent à batch matmul Z @ C.T
    return X

In [None]:
C = factorise(B)
X_samples = generate_X(C, N=1000, D=10)

# Estimation de E[X^T X]
B_est = np.mean([X.T @ X for X in X_samples], axis=0) / 10
print(B_est)

[[ 0.98324317  0.77676622  0.19439341 -0.01276158]
 [ 0.77676622  0.97135208  0.28575321  0.09112401]
 [ 0.19439341  0.28575321  0.98336761  0.58428362]
 [-0.01276158  0.09112401  0.58428362  0.99345742]]


In [23]:
def compute_moments(B, T_max):
    """
    Calcule (B, B^2, ..., B^T_max) et les empile dans un tenseur (T_max, T, T)
    """
    moments = [B]
    for _ in range(1, T_max):
        moments.append(moments[-1] @ B)
    return np.stack(moments, axis=0)

In [24]:
print(compute_moments(B, T_max=4))

[[[1.     0.8    0.2    0.    ]
  [0.8    1.     0.3    0.1   ]
  [0.2    0.3    1.     0.6   ]
  [0.     0.1    0.6    1.    ]]

 [[1.68   1.66   0.64   0.2   ]
  [1.66   1.74   0.82   0.38  ]
  [0.64   0.82   1.49   1.23  ]
  [0.2    0.38   1.23   1.37  ]]

 [[3.136  3.216  1.594  0.75  ]
  [3.216  3.352  1.902  1.046 ]
  [1.594  1.902  2.602  2.206 ]
  [0.75   1.046  2.206  2.146 ]]

 [[6.0276 6.278  3.636  2.028 ]
  [6.278  6.6    4.1784 2.5224]
  [3.636  4.1784 4.815  3.9574]
  [2.028  2.5224 3.9574 3.5742]]]


In [19]:

def generate_dataset(N, D, T, seed=None):
    if seed is not None:
        np.random.seed(seed)
    
    # Choix aléatoire d'une matrice B (SDP)
    A = np.random.randn(T, T)
    B_true = A @ A.T  # symétrique SDP
    B_true /= np.trace(B_true)  # normalisation optionnelle

    # Factorisation pour générer X
    C = factorise(B_true)
    X_samples = generate_X(C, N, D)

    # Création de (X, y)
    X_data = X_samples
    y_data = np.array([compute_moments(X.T @ X / D, T) for X in X_samples])  # normalisé

    return X_data, y_data

In [25]:

X, y = generate_dataset(N, D, T)

print("X.shape :", X.shape)  # (N, D, T)
print("y.shape :", y.shape)  # (N, T * T * T)


X.shape : (1000, 50, 4)
y.shape : (1000, 4, 4, 4)


In [27]:
class MomentTransformer(nn.Module):
    def __init__(self, d_model, nhead, num_layers, dim_feedforward, T):
        super().__init__()
        self.T = T

        # Embedding linéaire initial (projeté à d_model)
        self.input_proj = nn.Linear(d_model, d_model)

        # Encodage de position appris
        self.pos_embedding = nn.Parameter(torch.randn(1, T, d_model))

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Projection vers une matrice T x T pour chaque moment (T moments au total)
        self.output_proj = nn.Linear(d_model, T * T * T)

    def forward(self, x):
        # x: (batch_size, D, T)
        x = x.permute(0, 2, 1)  # -> (batch_size, T, D)
        x = self.input_proj(x)  # -> (batch_size, T, d_model)

        # Ajout du positional embedding
        x = x + self.pos_embedding  # broadcasting sur batch

        # Passage dans le Transformer
        x = self.transformer(x)  # (batch_size, T, d_model)

        # Moyenne globale ou moyenne sur les positions
        x_pooled = x.mean(dim=1)  # (batch_size, d_model)

        # Projection finale
        y_hat = self.output_proj(x_pooled)  # (batch_size, T*T*T)
        y_hat = y_hat.view(-1, self.T, self.T, self.T)  # reshape en tenseur de moments

        return y_hat
