In [1]:
import scipy as sp
import scipy.optimize as optimize
import numpy as np
import pandas as pd

In [2]:
def get_f(
    data,
    num_cat,
    delta = 0.1,
    train = True,
):
    """ 
    :param data: each row is errors + Ns (total len 2 * num_cat )
    """
    def f(x):
        """ 
        :param x: alpha, rho_ij
        """
        loss = 0
        alpha = x[0]
        rho = x[1:].reshape((num_cat, num_cat))
        # rho = np.ones((num_cat, num_cat)) * rho[0,0]
        for d in data:
            Ns = rho @ d[num_cat:]
            log_error_hats = - alpha * np.log(Ns)
            err = log_error_hats - np.log(d[:num_cat])
            if train:
                loss += np.sum(sp.special.huber(delta, err,))
            else: 
                loss += np.sum(np.abs(np.exp(err) - 1))
        return loss
    return f

In [3]:
num_cat = 6
x0 = np.array([0.5] + [0.001] * (num_cat ** 2))
bounds = [(0, 10)] + [(0.00000001, None)] * (num_cat ** 2)
data = pd.read_csv("domainnet_transfer.csv").to_numpy()
f = get_f( 
    data, 
    num_cat,
    delta = 2,
)

In [4]:
f(x0)

1070.1566971383922

In [5]:
result = optimize.minimize( 
    fun=f, 
    x0 = x0,
    bounds = bounds,
    tol = 0.0001,
)
print(result)

      fun: 136.32765745072825
 hess_inv: <37x37 LbfgsInvHessProduct with dtype=float64>
      jac: array([ 6.20380774e+00, -1.51733776e+05,  6.37212812e+05,  5.55450113e+05,
        9.14484583e+05,  1.67602150e+05, -1.39786798e+05,  5.19907915e+05,
        2.47205100e+05, -4.50109733e+04,  7.84245353e+06,  8.44318806e+05,
        1.16975529e+06,  4.01848668e+05, -1.99089551e+05,  5.75655274e+04,
        1.00803836e+07,  2.84027708e+04,  6.06722602e+05, -1.37614192e+05,
        6.71986531e+06,  9.31790682e+06,  1.38063790e+05,  4.79464431e+06,
        4.95036044e+05,  1.32060534e+05,  1.36060941e+05,  1.33216225e+05,
        6.69114067e+06, -2.82267155e+05, -2.56748167e+05,  1.17335293e+05,
        8.31455788e+04, -9.45480084e+05, -4.20340705e+04, -1.84889371e+05,
        1.39782052e+05])
  message: b'ABNORMAL_TERMINATION_IN_LNSRCH'
     nfev: 3192
      nit: 35
     njev: 84
   status: 2
  success: False
        x: array([2.20272583e-01, 2.60692110e-06, 7.50644833e-08, 1.70157280e-07,


In [157]:
alpha = result.x[0]
rho = result.x[1:].reshape((num_cat, num_cat))
print(alpha)
print(rho / np.sum(rho) * num_cat)

0.22029656826013871
[[0.9442291  0.02721714 0.06163647 0.00672209 0.08916062 0.0726613 ]
 [0.00656354 0.0423537  0.00559854 0.00362038 0.00672576 0.00362038]
 [0.05046871 0.02059828 0.70689877 0.00362038 0.08197648 0.02204907]
 [0.00574867 0.00362038 0.00362038 0.9261102  0.00362038 0.00720927]
 [0.36581148 0.07496495 0.40150777 0.00362038 1.2069686  0.05867445]
 [0.05900094 0.01568705 0.01777999 0.00564163 0.02192233 0.66277045]]


In [160]:
print(rho / np.diagonal(rho).reshape((1, -1)))

[[1.         0.64261547 0.08719278 0.00725841 0.07387153 0.10963268]
 [0.00695121 1.         0.00791985 0.00390923 0.00557244 0.00546249]
 [0.05344965 0.4863397  1.         0.00390923 0.06791932 0.03326803]
 [0.00608822 0.08547961 0.00512149 1.         0.00299956 0.01087748]
 [0.38741814 1.76997422 0.56798482 0.00390923 1.         0.08852906]
 [0.06248583 0.37038215 0.0251521  0.00609175 0.01816313 1.        ]]


In [155]:
test_f = get_f( 
    data, 
    num_cat,
    train=False,
)
test_f(result.x) / data.shape[0] / data.shape[1]

0.046340293287311436

In [143]:
np.exp(np.sqrt(0.0947))

1.3603386159437398