In [None]:
import numpy as np
import torch
from torch.nn import TransformerEncoderLayer, TransformerEncoder
from torch.optim import AdamW

Helper functions

In [7]:
def buildContext(D, i, M):
    context = list()
    if len(D) < M:
        diff = M - len(D)
        for i in range(diff):
            context.append(np.zeros(100))
        context.extend(D)
    else:
        samples_before_i = D[:i]
        context = samples_before_i[(len(samples_before_i) - M):]
        if len(context) != M:
            raise Exception("buildContextError: Incorrect context length.")
    return context

def computeNLL(x, mu, sigma):
    nb = len(x)
    loss = 0
    for i in range(nb):
        term1 = ((x[i] - mu[i]) / sigma[i]) ** 2
        term2 = torch.log(sigma[i])
        bin_nll = term1 + term2
        loss += bin_nll
    
    loss = loss / (2 * nb)

    return loss

def chi2_stat(x, sigma_x, u, sigma_u):
        nb = len(x)
        stat = (1 / nb) * torch.sum((x - u) ** 2 / (sigma_x**2 + sigma_u**2))
        return stat

def pull_delta(x, sigma_x, u, sigma_u):
    stat = (x - u) / torch.sqrt(sigma_x**2 + sigma_u**2)
    return stat

def poisson_uncertainty(x, I):
    return torch.sqrt((x / I) - (x**2 / I))

Inputs

In [13]:
x = [] # Past histograms
y = [] # Ground truth

M = 5 # Buffer size
K = 3 # Batch size

Initialize:

In [None]:
D = []

enc_layer = TransformerEncoderLayer(
    d_model=100,
    nhead=10,
    dim_feedforward=100,
    dropout=0.15
)
enc = TransformerEncoder(
    enc_layer,
    num_layers=3
)

optim = AdamW(enc.parameters(), lr=5e-4, weight_decay=1e-4)
early_stopping_patience = 5



In [None]:
mu_hist = list()
sigma_hist = list()
chi2_hist = list()
pulld_hist = list()
early_stopping_counter = 0
past_loss = None
best_model_state = None

for i in range(len(x)):
    nb = len(x[i])

    # Normalize new histogram
    x_t = x[i] / np.sum(x[i])
    
    # Perform inference to obtain reference prediction
    context = buildContext(D, i, M)

    mu_hat_i, sigma_hat_i = enc(context)

    mu_hist.append(mu_hat_i)
    sigma_hist.append(sigma_hat_i)

    # Compute chi2 and pull delta
    chi2 = chi2_stat(x_t, poisson_uncertainty(x_t, torch.sum(x[i])), mu_hat_i, sigma_hat_i)
    pull = pull_delta(x_t, poisson_uncertainty(x_t, torch.sum(x[i])), mu_hat_i, sigma_hat_i)
    chi2_hist.append(chi2)
    pulld_hist.append(pull)

    if y[i] == 0:
        D.append(x_t)

        # Training step
        batch = D.copy() if len(D) <= K else D[(len(D) - K):]
        nll_batch = 0
        optim.zero_grad()
        for (x_k, i_k) in batch:
            context = buildContext(D, i_k, M)
            mu_hat, sigma_hat = enc(context)
            nll_k = computeNLL(x_k, mu_hat, sigma_hat)
            nll_batch += nll_k
        nll_batch = nll_batch / len(batch)

        # Backprop
        nll_batch.backward()
        optim.step()

        # Check early stopping
        if past_loss is None:
            past_loss = nll_batch
            best_model_state = enc.state_dict()
        elif past_loss <= nll_batch:
            early_stopping_counter += 1
            if early_stopping_counter >= early_stopping_patience:
                break
        elif past_loss > nll_batch:
            past_loss = nll_batch
            best_params = enc.state_dict()
            early_stopping_counter = 0

enc.load_state_dict(best_model_state)