In [1]:
import tcellmatch.api as tm
import torch as tc
import torch.nn as nn
from math import exp, log

In [2]:
saved_model_path = '../tutorial_data/saved_model-8-2--18:56'
ffn = tm.models.EstimatorFfn()
ffn.load_model_full(fn=saved_model_path, load_train_data=False)

device = 'cuda' if tc.cuda.is_available() else 'cpu'
ffn.model = ffn.model.to(device) 
for param in ffn.model.parameters():
    param = param.to(device) 

In [3]:
ffn.predict()

In [None]:
#     def summand(lmd, y_hat_j, y_j):
#         return exp(-1/(2 * lmd) * (log(max(y_hat_j - y_j + 1, 1)) - log(lmd + 1)) ** 2)
#     def p_lmd(lmd, obs):
#         return nn.PoissonNLLLoss(log_input=False, full=True)(lmd, obs)
#     def calc_norm(y_hat, y):
#         n, n_antigens = y_hat.shape
#         expectations = tc.zeros(n_antigens)
#         for col in range(n_antigens):
#             for i in range(n):
#                 for j in range(i, n):
#                     E = 0
#                     mu = y_hat[i, col]
#                     for lmd in range(1, max(mu ** 2.5, 5)):
#                         E += summand(lmd, y_hat[j, col], y[j, col])
#                     expectations[col] += E
#         return expectations

In [7]:
def mmd_pois(tau, y_hat, y):
    # acct for log lmd
    y_hat, y = tc.Tensor(y_hat).to(device), tc.Tensor(y).to(device)
    y_hat = tc.exp(y_hat)
    # broadcast to vectorize pairwise output
    y_hat_i = y_hat[:, None, :]  # Add a new dimension to treat i and j differently
    y_hat_j = y_hat[None, :, :]  # Add a new dimension to treat i and j differently
    
    # Broadcasting will take care of the pairwise subtraction
    dist_vals = 2 * tc.square(y_hat_i - y_hat_j)
    def summand(lmd, y_hat, y):
        y_tild = y_hat - y + 1
        y_tild[y_tild <= 0] = 1e-6
        y_tild = tc.log(y_tild)
        return tc.exp(-1/(2*lmd) * (y_tild - tc.log(lmd + 1))**2)
    def p_lmd(lmd, obs):
        return tc.nn.PoissonNLLLoss(log_input=False, full=True)(lmd, obs)
    def calc_norm(y_hat, y):
        y_hat, y = tc.Tensor(y_hat).to(device), tc.Tensor(y).to(device)
        n, n_antigens = y_hat.shape
        max_lambda = max(y_hat.max()**2.5, 5)
        lmd = tc.arange(1, max_lambda+1, device=device)
        lmd = lmd.unsqueeze(-1) # add singleton dimension
        y_hat_j = y_hat.unsqueeze(1).repeat(1, lmd.size(0), 1)
        y_j = y.unsqueeze(1).repeat(1, lmd.size(0), 1)

        summand_values = summand(lmd, y_hat_j, y_j)
        obs = y.unsqueeze(1).repeat(1, lmd.size(0), 1) 
        p_vals = p_lmd(lmd, obs)
        E = summand_values * p_vals
        E = E.sum(dim=1).squeeze(1).to(dtype=tc.float64)
        return E.sum(dim=0)    
    return tc.sum(exp(-1/(2 * tau ** 2)) * dist_vals) + calc_norm(y_hat, y)

In [8]:
mmd_pois(1, ffn.predictions, ffn.y_test)

tensor([1.1806e+08, 1.1997e+08, 1.2093e+08, 1.2189e+08, 1.2065e+08, 1.2057e+08,
        1.2116e+08, 1.1806e+08, 1.2145e+08, 1.1861e+08, 1.2212e+08, 1.2221e+08,
        1.2112e+08, 1.2149e+08, 1.2635e+08, 1.2296e+08, 1.1958e+08, 1.2068e+08,
        1.2043e+08, 1.1996e+08, 1.2150e+08, 1.2060e+08, 1.2027e+08, 1.2081e+08,
        1.2175e+08, 1.2116e+08, 1.2113e+08, 1.2094e+08, 1.2078e+08, 1.0743e+08,
        1.0283e+08, 1.1026e+08, 1.0644e+08, 1.0544e+08, 1.2049e+08, 1.2000e+08,
        1.2009e+08, 1.2029e+08, 1.2096e+08, 1.2053e+08, 1.2121e+08, 1.2420e+08,
        1.2026e+08, 1.2087e+08, 1.2122e+08, 1.1769e+08, 1.2033e+08, 1.2035e+08,
        1.2068e+08, 1.2068e+08], device='cuda:0', dtype=torch.float64)

In [6]:
ffn.y_test.shape

(8979, 50)