In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.optim import AdamW

Helper functions

In [2]:
def buildContext(D, i, M):
    context = list()
    if len(D) < M:
        diff = M - len(D)
        for i in range(diff):
            context.append(np.zeros(100))
        context.extend(D)
    else:
        samples_before_i = D[:i]
        context = samples_before_i[(len(samples_before_i) - M):]
        if len(context) != M:
            raise Exception("buildContextError: Incorrect context length.")
    return context

def computeNLL(x, mu, sigma):
    nb = len(x)
    loss = 0
    for i in range(nb):
        term1 = ((x[i] - mu[i]) / sigma[i]) ** 2
        term2 = torch.log(sigma[i])
        bin_nll = term1 + term2
        loss += bin_nll
    
    loss = loss / (2 * nb)

    return loss

def chi2_stat(x, sigma_x, u, sigma_u):
        nb = len(x)
        stat = (1 / nb) * torch.sum((x - u) ** 2 / (sigma_x**2 + sigma_u**2))
        return stat

def pull_delta(x, sigma_x, u, sigma_u):
    stat = (x - u) / torch.sqrt(sigma_x**2 + sigma_u**2)
    return stat

def poisson_uncertainty(x, I):
    return torch.sqrt((x / I) - (x**2 / I))

Inputs

In [13]:
x = [] # Past histograms
y = [] # Ground truth

M = 5 # Buffer size
K = 3 # Batch size

Initialize:

In [None]:
class DinamoML(nn.Module):
    def __init__(self, n_bins=100, d_model=100, nhead=10, dim_feedforward=100, dropout=0.15, num_layers=3):
        super().__init__()

        self.n_bins = n_bins

        self.x_embed = nn.Sequential(
            nn.Linear(n_bins, d_model),
            nn.LayerNorm(d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )

        self.t_embed = nn.Sequential(
            nn.Linear(1, d_model),
            nn.LayerNorm(d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation="relu",
            batch_first=True
        )

        self.enc = nn.TransformerEncoder(
            encoder_layer=enc_layer,
            num_layers=num_layers
        )

        self.mu_out = nn.Sequential(
            nn.Linear(d_model, n_bins),
            nn.LayerNorm(n_bins),
            nn.ReLU(),
            nn.Linear(n_bins, n_bins),
            nn.Softmax(dim=1)
        )

        self.sigma_out = nn.Sequential(
            nn.Linear(d_model, n_bins),
            nn.LayerNorm(n_bins),
            nn.ReLU()
            nn.Linear(n_bins, n_bins),
            nn.Softplus()
        )
    
    def forward(self, x, t):

        """
        x: (K, M, n_bins)
        t: (K, M)
        mu_out: (K, n_bins)
        sigma_out: (K, n_bins)
        """

        x_embed = self.x_embed(x) # (K, M, n_bins) -> (K, M, d_model)
        t = t.unsqueeze(2) # (K, M,) -> (K, M, 1)
        t_embed = self.t_embed(t) # (K, M, n_bins) -> (K, M, d_model)

        xt = x_embed + t_embed # (K, M, d_model)

        xt_e = self.enc(xt) # (K, M, d_model)

        xt_e_pooled = xt_e.mean(dim=1) # (K, d_model)

        mu_out = self.mu_out(xt_e_pooled) # (K, d_model) -> (K, n_bins)
        sigma_out = self.sigma_out(xt_e_pooled) # (K, d_model) -> (K, n_bins)

        return mu_out, sigma_out



In [None]:
D = []

model = DinamoML()
optim = AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)
early_stopping_patience = 5

In [None]:
mu_hist = list()
sigma_hist = list()
chi2_hist = list()
pulld_hist = list()
early_stopping_counter = 0
past_loss = None
best_model_state = None

for i in range(len(x)):
    nb = len(x[i])

    # Normalize new histogram
    x_t = x[i] / np.sum(x[i])
    
    # Perform inference to obtain reference prediction
    context = buildContext(D, i, M)

    mu_hat_i, sigma_hat_i = model(context)

    mu_hist.append(mu_hat_i)
    sigma_hist.append(sigma_hat_i)

    # Compute chi2 and pull delta
    chi2 = chi2_stat(x_t, poisson_uncertainty(x_t, torch.sum(x[i])), mu_hat_i, sigma_hat_i)
    pull = pull_delta(x_t, poisson_uncertainty(x_t, torch.sum(x[i])), mu_hat_i, sigma_hat_i)
    chi2_hist.append(chi2)
    pulld_hist.append(pull)

    if y[i] == 0:
        D.append((x_t, i))

        # Training step
        batch = D.copy() if len(D) <= K else D[(len(D) - K):]
        nll_batch = 0
        optim.zero_grad()
        for (x_k, i_k) in batch:
            context = buildContext(D, i_k, M)
            mu_hat, sigma_hat = model(context)
            nll_k = computeNLL(x_k, mu_hat, sigma_hat)
            nll_batch += nll_k
        nll_batch = nll_batch / len(batch)

        # Backprop
        nll_batch.backward()
        optim.step()

        # Check early stopping (Checks the previous param updates performance not the current one's)
        if past_loss is None:
            past_loss = nll_batch
            best_model_state = model.state_dict()
        elif past_loss <= nll_batch:
            early_stopping_counter += 1
            if early_stopping_counter >= early_stopping_patience:
                break
        elif past_loss > nll_batch:
            past_loss = nll_batch
            best_params = model.state_dict()
            early_stopping_counter = 0

model.load_state_dict(best_model_state)