In [30]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [31]:
import numpy as np
import pandas as pd
import warnings
import sys

sys.path.append("../src")
from margin_calibration import MarginCalibration

In [32]:
warnings.simplefilter("ignore")

# Dataset Generation

In [33]:
# Create a 100 random observations with their respective
# sampling weights, such that those weights sum to 20 %
n_obs = 100
sampling_probabilities = np.random.rand(n_obs, 1).flatten()
sampling_probabilities = sampling_probabilities / sampling_probabilities.sum() *.2

# Create a matrix of size n_obs * n_margins, with margins ranging from 0 to 1000
n_margins = 2 # Let say we have two variables
calibration_matrix = 1000*np.random.rand(n_obs, n_margins)

# Now we create the calibration target
# It is of sier n_margins * 1 and contains the 
# sums of the margins over all the population
calibration_target = np.array([np.sum(calibration_matrix[:, i])*100 for i in range(calibration_matrix.shape[1])])

# The costs should be of size of the number of margins
costs = (1,1) # We give here same costs to both variables

# Let say we want to work with want to work with pandas dataframes 
# instead of numpy arrays
sampling_probabilities = pd.DataFrame(sampling_probabilities, columns=["weights"])
calibration_matrix = pd.DataFrame(calibration_matrix, columns = ["margin1", "margin2"])
calibration_target = pd.DataFrame(calibration_target, columns = ["margin_sums"])

# Margin Calibration

In [34]:
mc=MarginCalibration()
mc_logit=MarginCalibration("logit", .5, 1.5)
mc_rr=MarginCalibration("raking_ratio")
mc_lt=MarginCalibration("truncated_linear", .5, 1.5)

In [35]:
%%time
mc.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 1.62 s, sys: 3.98 ms, total: 1.63 s
Wall time: 238 ms


array([-1.25572742e+02,  2.21626049e+02,  3.22360845e+01, -4.47394883e+00,
        2.36097273e+02,  2.32785362e+02, -6.83061549e+01, -3.12691515e+01,
        1.10267598e+03,  8.25296362e+02, -7.10453722e+01, -7.93847287e+01,
       -9.36758783e+02,  8.04953353e+01,  2.23875118e+02,  4.30768011e+01,
       -1.01928969e+02,  2.77164176e+02,  1.96849892e+02, -7.42232124e+02,
        8.26093700e+02,  6.52426430e+01, -1.07564782e+02,  7.84984583e+01,
        2.48716902e+02,  3.21891205e+02,  5.10655135e+03, -8.20524244e+02,
        5.21529581e+01, -2.21915387e+02,  4.95868563e+02,  1.64332957e+02,
       -1.21129743e+02, -2.91032901e+02, -2.27172815e+02, -8.80735931e+01,
       -6.35823473e+02, -3.67872358e+02,  1.42122462e+03,  4.02054723e+01,
       -1.19977034e+01,  2.70704331e+02, -1.95796379e+01,  1.69042217e+02,
        1.80725213e+02,  3.20752492e+02,  3.42268947e+02,  1.12008283e+04,
       -9.29952287e+01,  6.67570775e+02,  1.32868597e+03, -3.20570805e+01,
        1.77176774e+02,  

In [36]:
%%time
mc_logit.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 2h 8min 52s, sys: 13.4 s, total: 2h 9min 6s
Wall time: 4min 18s


array([ 5.81285693e+02,  3.95872510e+02, -7.51582235e+01, -2.13354856e+02,
        8.35270818e+01,  5.72230263e+01, -4.33345857e+02,  4.48408775e+02,
        9.66810748e+02,  8.61482631e+02, -3.85711412e+02, -1.69128486e+01,
       -2.69173806e+02,  1.41109382e+02,  2.17626532e+02,  2.15232678e+02,
       -4.25307525e+02,  2.75482097e+02,  2.68548407e+01,  3.07953616e+02,
        1.02484273e+03, -6.20800352e+00, -3.86379812e+02, -6.07811879e+01,
        1.28599674e+02,  1.13864270e+02,  4.40213895e+03, -3.02774075e+02,
        1.08383754e+02, -7.22991787e+02,  5.40490518e+02,  1.14973677e+02,
       -8.30266411e+01,  3.95384339e+02, -5.81790277e+02, -4.30710230e+02,
       -2.89695820e+02, -5.87289949e+02,  1.03656358e+03,  3.82635361e+01,
       -1.73005197e+01,  8.33793979e+01, -1.61426525e+02,  6.79353522e+01,
        1.42149483e+02,  1.34185113e+02,  1.10215022e+02,  1.04309660e+04,
       -2.42907050e+02,  6.15374626e+02,  7.80148626e+02,  7.70542183e+02,
        1.04052935e+01,  

In [37]:
%%time
mc_rr.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 35min 30s, sys: 2.58 s, total: 35min 33s
Wall time: 1min 11s


array([-1.45919611e-02,  1.11015913e-02, -1.55864013e+01, -2.00169494e+02,
        5.66329507e-02,  6.93919782e+00, -5.38866796e+01, -1.69821145e-03,
        9.65484222e-02,  6.66884420e-01, -2.41717538e+00, -5.78175389e-04,
       -5.59988606e+02,  2.06619702e-02,  3.58452071e-03, -1.38291407e-02,
       -1.94016023e+01, -3.25892234e-03,  2.80379968e+00, -4.21179637e+01,
        4.57605334e-03,  1.94201207e-02, -1.42447140e+02,  1.05503222e-02,
        3.69955823e-02,  7.77077291e+01,  6.12944726e+03, -5.32560176e+02,
        1.16703886e-03, -3.23141429e+02,  1.13506677e-01,  8.23653172e-03,
       -1.10236512e-02, -1.49442065e+02, -3.51902389e+02, -2.48620493e+02,
       -4.10007388e+02, -2.62266736e+02,  3.60557213e+02,  2.24464363e-03,
       -5.14446281e-01,  2.95706459e+00, -2.37572381e+02,  2.42357250e+01,
        7.46220477e-03,  1.72027755e+00,  3.88491335e+01,  1.66431804e+04,
       -2.64054867e+02,  1.99205930e-02,  3.55197798e+02,  2.19897498e-02,
        2.67654137e+00, -

In [38]:
%%time
mc_lt.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 1h 37min 58s, sys: 8.05 s, total: 1h 38min 6s
Wall time: 3min 16s


array([ 1.88773781e+02,  8.23650451e+01, -2.20074958e+01, -5.65406199e+01,
        1.98493053e+02,  2.23642156e+02, -2.13113863e+02,  5.66475045e+01,
        9.14800170e+02,  7.58473568e+02, -1.73417013e+02, -2.11988884e+02,
       -3.65793793e+02, -5.52729156e+01,  1.02173147e+02, -1.14877554e+02,
       -2.11603094e+02,  1.95120389e+02,  1.71412216e+02,  2.07892291e+02,
        8.66027053e+02, -1.93327350e+01, -2.14288676e+02,  3.20105749e+01,
        1.81807437e+02,  3.15514462e+02,  4.30549985e+03, -3.52626053e+02,
       -8.87780380e+01, -3.59995419e+02,  3.45160250e+02,  6.76725582e+01,
       -2.10234629e+02,  1.10994552e+02, -3.59745367e+02, -1.70204575e+02,
       -2.63118973e+02, -4.84409691e+02,  1.21159731e+03, -8.55318940e+01,
       -1.17218539e+02,  2.29507082e+02, -7.89253446e+01,  1.38913206e+02,
        9.32843884e+01,  2.63205844e+02,  3.16612497e+02,  1.04541445e+04,
       -2.02779085e+02,  5.14309547e+02,  1.29533667e+03,  4.14477845e+02,
        1.42655785e+02,  

# Penalized Margin Calibration

In [39]:
mc_pen=MarginCalibration(penalty=.1, costs=costs)
mc_logit_pen=MarginCalibration("logit", .5, 1.5, penalty=.1, costs=costs)
mc_rr_pen=MarginCalibration("raking_ratio", penalty=.1, costs=costs)
mc_lt_pen=MarginCalibration("truncated_linear", .5, 1.5, penalty=.1, costs=costs)

In [40]:
%%time
mc_pen.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 288 ms, sys: 3.94 ms, total: 292 ms
Wall time: 44.1 ms


array([ 7.55973101e+02,  4.47750834e+02, -5.29368963e+02, -6.57985392e+02,
        3.06746977e+01,  1.22786756e+02, -8.63986792e+02,  4.70610086e+02,
        1.77777998e+03,  1.53524566e+03, -7.24466602e+02, -1.92347618e+02,
       -2.09439192e+02, -9.56769190e+01,  1.44501763e+02,  2.74867180e+01,
       -7.91724424e+02,  2.40698635e+02,  8.32753350e+00,  9.01536287e+02,
        1.78632910e+03, -3.48031123e+02, -7.44126867e+02, -3.44732959e+02,
        6.91443814e+01,  2.70851318e+02,  8.50167778e+03, -2.06316941e+02,
       -1.71652909e+02, -9.95664086e+02,  8.48063519e+02, -5.44143762e+01,
       -1.26937134e+02,  6.69755409e+02, -1.15275537e+03, -8.69835274e+02,
       -5.17793312e+01, -6.54408027e+02,  2.01036460e+03, -3.04030850e+02,
       -4.25692595e+02,  1.10011136e+02, -6.36581546e+02, -4.80455601e+01,
       -1.61416110e+01,  1.87688155e+02,  2.44689891e+02,  2.04911322e+04,
       -5.70190100e+02,  1.03091570e+03,  1.44539679e+03,  1.13185874e+03,
       -4.98626993e+01,  

In [41]:
%%time
mc_logit_pen.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 3.44 s, sys: 4.06 ms, total: 3.45 s
Wall time: 494 ms


array([  846.41124513,   580.34401223,   124.1242419 ,   125.63062522,
         214.24487101,   142.13115919,   125.70276675,   681.98181888,
        1112.33518013,  1026.54514349,   166.34106944,   385.04001228,
         648.7858291 ,   337.85720183,   382.00060676,   431.90459159,
         170.71899633,   421.62896074,   146.35515361,  1042.49495151,
        1181.55810245,   176.82032439,   191.64084249,   124.1406335 ,
         275.75159635,   178.00121451,  4435.6152767 ,   623.64539839,
         313.44617673,   193.06416747,   706.10903658,   282.16226543,
         431.97757421,   852.84571734,   164.10870059,   142.52758681,
         632.54805149,   344.53845334,  1178.08848795,   243.51304181,
         238.43779656,   212.73001   ,   155.33539632,   133.67571687,
         299.92738867,   270.28294783,   221.82924068, 10452.21313829,
         243.3169085 ,   765.32368802,   798.25100334,  1006.93599619,
         147.98174673,   508.18165433,   296.36615351,   122.93803724,
      

In [42]:
%%time
mc_rr_pen.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 6.19 s, sys: 4.15 ms, total: 6.19 s
Wall time: 884 ms


array([7.71287982e-06, 3.97927616e-05, 3.39263967e-05, 2.90671974e-05,
       6.93549520e-05, 5.17606011e-05, 2.60813239e-05, 3.33367636e-05,
       6.17647489e-05, 6.02624564e-05, 2.85923343e-05, 3.06220373e-05,
       1.95545038e-05, 3.80163260e-05, 4.15467828e-05, 3.41604814e-05,
       2.64095666e-05, 4.24865578e-05, 1.07662552e-04, 7.10379794e-06,
       4.76525017e-05, 4.63394456e-05, 2.52913581e-05, 5.23204238e-05,
       5.80460749e-05, 3.19214397e-04, 7.37656621e+02, 1.96998717e-05,
       3.79283965e-05, 2.10026586e-05, 5.44685301e-05, 4.82473384e-05,
       3.09059787e-05, 8.00008283e-06, 1.98768919e-05, 2.42644017e-05,
       2.19701647e-05, 2.04917794e-05, 3.48032425e-05, 4.04309363e-05,
       3.15804155e-05, 9.85190913e-05, 2.82569747e-05, 9.81624201e-05,
       4.59556983e-05, 5.65579385e-05, 1.50826119e-04, 1.09499782e+04,
       2.58912314e-05, 5.51822656e-05, 6.90331317e-06, 3.51294001e-05,
       8.96278758e-05, 4.82354109e-05, 3.73678383e-05, 2.75220730e-05,
      

In [43]:
%%time
mc_lt_pen.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 3.05 s, sys: 105 μs, total: 3.05 s
Wall time: 433 ms


array([  846.41124513,   580.34401223,   124.1242419 ,   125.63062522,
         214.24487101,   142.13115919,   125.70276675,   681.98181888,
        1112.33518013,  1026.54514349,   166.34106944,   385.04001228,
         648.7858291 ,   337.85720183,   382.00060676,   431.90459159,
         170.71899633,   421.62896074,   146.35515361,  1042.49495151,
        1181.55810245,   176.82032439,   191.64084249,   124.1406335 ,
         275.75159635,   178.00121451,  4435.6152767 ,   623.64539839,
         313.44617673,   193.06416747,   706.10903658,   282.16226543,
         431.97757421,   852.84571734,   164.10870059,   142.52758681,
         632.54805149,   344.53845334,  1178.08848795,   243.51304181,
         238.43779656,   212.73001   ,   155.33539632,   133.67571687,
         299.92738867,   270.28294783,   221.82924068, 10452.21313829,
         243.3169085 ,   765.32368802,   798.25100334,  1006.93599619,
         147.98174673,   508.18165433,   296.36615351,   122.93803724,
      