In [1]:
import tcellmatch.api as tm
import torch as tc
import torch.nn as nn
from math import exp, log

In [2]:
saved_model_path = '../tutorial_data/saved_model-8-2--18:56'
ffn = tm.models.EstimatorFfn()
ffn.load_model_full(fn=saved_model_path, load_train_data=False)

device = 'cuda' if tc.cuda.is_available() else 'cpu'
# device='cpu'
ffn.model = ffn.model.to(device) 
for param in ffn.model.parameters():
    param = param.to(device) 

In [3]:
ffn.predict()

In [6]:
def mmd_pois(tau, y_hat, y):
    y_hat, y = tc.Tensor(y_hat).to(device), tc.Tensor(y).to(device)
    num_samples = 1000
    samples_idx = tc.randperm(y_hat.size(0))[:num_samples]
    y_hat, y = y_hat[samples_idx], y[samples_idx]
    # acct for log lmd
    y_hat = tc.exp(y_hat)
    # broadcast to vectorize pairwise output
    y_hat_i = y_hat[:, None, :]  # Add a new dimension to treat i and j differently
    y_hat_j = y_hat[None, :, :]  # Add a new dimension to treat i and j differently
    p_lmd = tc.nn.PoissonNLLLoss(log_input=False, full=True)
    
    dist_vals = 2 * tc.square(y_hat_i - y_hat_j)

    def k(y, y_):
        # take gss to be y
        return tc.exp(-1/(2 * y) * (tc.log(y + 1) - tc.log(y_ + 1)) ** 2)

    def calc_norm(y_hat, y):
        n, n_antigens = y_hat.shape
        max_lambdas = y_hat.max(dim=0)[0]
        max_lambdas[max_lambdas < 5] = 5
        max_lambdas = max_lambdas**2.
        out = tc.zeros(n_antigens, device=device)

        for col in range(n_antigens):
            max_lambda = int(max_lambdas[col])
            # Prepare a tensor of λ values
            lambda_values = tc.arange(1, max_lambda+1).float().to(device)

            # Compute the loss for all combinations of y and λ
            
            losses = tc.zeros((n, max_lambda))
            for i in range(n):
                for j in range(max_lambda):
                    losses[i,j] = p_lmd(lambda_values[j], y[i, col])

            # Reshape to (n, 1, Λ, 1) and (1, n, 1, Λ) to use broadcasting
            # i_mtx = losses[:, None, :, None]
            i_mtx = losses.reshape(n, 1, max_lambda, 1)
            # j_mtx = losses[None, :, None, :]
            j_mtx = losses.reshape(1, n, 1, max_lambda)

            # the ijklth element is L(y_i | λ = λ̂_k) * L(y_j | λ = λ̂_l)
            A = (i_mtx * j_mtx).to(device)

            # k(λ̂_i, λ̂_j) for all (i, j)
            K = k(lambda_values.reshape(max_lambda, 1), lambda_values.reshape(1, max_lambda))

            # broadcast K to shape n x n x lmd x lmd
            K_broadcasted = K[None, None, :, :].to(device=device)
            E_both = (A * K_broadcasted).sum()

            # p_lmd(y_i | λ = j)
            losses = p_lmd(lambda_values[None, :], y[:, col, None])

            # one-d E_{y\sim\mu(X_i)}k(y, y_j)
            k_Pois = k(y[:, col, None], lambda_values[None, :])
            one_d_E = (losses * k_Pois).sum()

            last_term = K.sum()
            # one_d_E looks at each pair twice
            # and we are looking at the same column, so this is the same as 
            # -(E_{y\sim\mu(X_i)}k(y, y_j)+E_{y'\sim\mu(X_j)}k(y_i, y'))
            out[col] = E_both - one_d_E + last_term
        return out.sum()
    return exp(-1/(2 * tau ** 2)) * dist_vals.sum() + calc_norm(y_hat, y)

In [7]:
mmd_pois(1, ffn.predictions, ffn.y_test)

tensor(7.2482e+12, device='cuda:0')