In [77]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [78]:
import numpy as np
import pandas as pd
import warnings
import sys

sys.path.append("../src")
from margin_calibration import MarginCalibration

In [79]:
warnings.simplefilter("ignore")

# Dataset Generation

In [80]:
# Create a 100 random observations with their respective
# sampling weights, such that those weights sum to 20 %
n_obs = 100
sampling_probabilities = np.random.rand(n_obs, 1).flatten()
sampling_probabilities = sampling_probabilities / sampling_probabilities.sum() *.2

# Create a matrix of size n_obs * n_margins, with margins ranging from 0 to 1000
n_margins = 2 # Let say we have two variables
calibration_matrix = 1000*np.random.rand(n_obs, n_margins)

# Now we create the calibration target
# It is of sier n_margins * 1 and contains the 
# sums of the margins over all the population
calibration_target = np.array([np.sum(calibration_matrix[:, i])*100 for i in range(calibration_matrix.shape[1])])

# The costs should be of size of the number of margins
costs = (1,1) # We give here same costs to both variables

# Let say we want to work with want to work with pandas dataframes 
# instead of numpy arrays
sampling_probabilities = pd.DataFrame(sampling_probabilities, columns=["weights"])
calibration_matrix = pd.DataFrame(calibration_matrix, columns = ["margin1", "margin2"])
calibration_target = pd.DataFrame(calibration_target, columns = ["margin_sums"])

# Margin Calibration

In [84]:
mc=MarginCalibration()
mc_logit=MarginCalibration("logit", .5, 1.5)
mc_rr=MarginCalibration("raking_ratio")
mc_lt=MarginCalibration("truncated_linear", .5, 1.5)

In [82]:
%%time
mc.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 16.3 s, sys: 16.4 ms, total: 16.4 s
Wall time: 2.36 s


array([ 1.46538412e+02,  2.44840497e+01, -2.27143360e+02,  3.37677433e+02,
       -1.68444216e+01,  5.63995911e+01, -5.47284279e+01,  3.08495608e+02,
        3.54428940e+02,  2.22505460e+02,  1.43553669e+02, -6.78948529e+01,
        3.80153957e+02, -1.00475805e+02, -1.36923193e+02,  3.32614259e+02,
        3.71267545e+01,  9.09656289e+01, -3.54908153e+02,  1.38891971e+02,
        3.59528374e+00,  7.67050968e+01,  6.50140772e+01,  2.57844732e+02,
        3.26935789e+02,  5.28916874e+01,  5.74176837e+01, -5.75228318e+02,
        1.06446528e+02,  1.66390469e+02, -1.00338639e+02, -2.63550660e+01,
       -5.21683481e+00, -5.55190798e+01,  3.49629914e+01,  3.40342075e+02,
        1.57197528e+03, -1.63711861e+01,  8.01922631e+01, -1.24919533e+01,
        4.32402910e+02,  5.15360786e+02,  1.10167032e+02,  4.99572615e+02,
       -8.41231738e+01,  1.24088108e+02,  1.82120573e+03,  1.20207836e+02,
        2.80012960e+02,  1.36951460e+02, -1.18184465e+02,  6.96246037e+00,
        1.36410124e+02,  

In [51]:
%%time
mc_logit.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 7h 57min 27s, sys: 3min 11s, total: 8h 39s
Wall time: 31min 18s


array([ 4.91190636e+03, -3.74358928e+02,  1.67848488e+03,  2.99720746e+04,
       -8.27062289e+02,  1.17437908e+02,  2.12617212e+02,  1.72016302e+02,
        3.94493665e+02,  4.06994883e+02,  1.15840960e+02, -2.36344722e+03,
       -1.10726556e+03, -1.89942445e+03,  3.48288143e+02, -1.60931263e+03,
       -1.83742318e+02,  3.73759538e+02,  7.24158749e+02,  2.97092985e+02,
        2.52821583e+02,  2.97130319e+02, -9.80696220e+02,  3.13648492e+02,
        4.38680160e+02,  2.28862448e+02, -1.69982223e+03, -2.44735193e+03,
        3.40823026e+02,  2.22686086e+02,  9.36811585e+01, -1.84186868e+03,
       -2.58520025e+03,  2.28507485e+02, -7.88093279e+02,  2.79165003e+02,
       -1.54695065e+03,  1.86142021e+02, -3.80202500e+02, -1.92393439e+02,
       -8.20903606e+02, -1.39077675e+03, -2.98881300e+02,  2.78415062e+03,
        1.76267131e+02,  2.93549601e+02, -2.13170688e+03,  3.56186678e+02,
        1.58058421e+02,  1.57532497e+02,  1.59371481e+02,  4.60886094e+02,
        4.79517955e+02, -

In [85]:
%%time
mc_rr.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

ValueError: array must not contain infs or NaNs

In [None]:
%%time
mc_lt.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 15h 47min 40s, sys: 50.7 s, total: 15h 48min 31s
Wall time: 1h 18s


array([ 6.80730765e+03, -4.18204837e+02,  8.11482024e+02,  3.03671473e+04,
       -1.19086009e+03, -8.77316733e+02, -3.95039542e+02, -2.52948843e+02,
        5.90013121e+02,  1.17874564e+02, -6.78324983e+02, -1.09449555e+03,
       -5.71132142e+02, -5.54867365e+02, -3.28391263e+02, -1.09238972e+03,
       -8.34336356e+02, -3.19801908e+02, -2.12283057e+02,  2.22348989e+02,
        4.00663020e+02,  1.13384786e+02, -1.14391312e+03,  5.83799009e+02,
        3.45088480e+01, -9.12692121e+01, -5.76682186e+02, -5.90627508e+02,
        1.97084269e+02, -4.16988728e+02, -1.97589720e+02, -6.69607271e+02,
       -7.72179308e+02,  6.33674105e+02, -1.02778059e+03,  4.24139106e+02,
       -5.45921697e+02, -6.70816719e+01, -5.71783600e+02, -1.58422790e+02,
       -7.22773598e+02, -1.00638688e+03, -1.30494909e+03,  4.80863593e+03,
       -3.10855902e+02, -2.60608328e+02, -9.58916922e+02, -2.80738477e+02,
        1.05445288e+02, -3.62865664e+02, -3.88468826e+02, -3.20801872e+02,
       -2.75232095e+01, -

# Penalized Margin Calibration

In [None]:
mc_pen=MarginCalibration(penalty=.1, costs=costs)
mc_logit_pen=MarginCalibration("logit", .5, 1.5, penalty=.1, costs=costs)
mc_rr_pen=MarginCalibration("raking_ratio", penalty=.1, costs=costs)
mc_lt_pen=MarginCalibration("truncated_linear", .5, 1.5, penalty=.1, costs=costs)

In [None]:
%%time
mc_pen.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 2min 25s, sys: 170 ms, total: 2min 25s
Wall time: 1min 8s


array([ 8.74921073e+03,  2.35664135e+02,  7.53999287e+01,  3.29623027e+04,
       -4.32340219e+02, -1.02892193e+03, -9.09114259e+02, -1.73700063e+02,
        6.82001754e+02, -3.16844208e+02, -3.08586205e+03, -1.11812696e+03,
       -1.40310306e+03, -9.75123230e+02, -9.27518944e+02, -1.73472149e+02,
        2.34213809e+02, -9.63868221e+02, -1.16609860e+03,  2.19716554e+02,
        3.73281171e+02,  1.40162515e+02, -2.33421815e+02,  7.61804318e+02,
        5.16475245e+01, -3.62506694e+02, -1.42782672e+03, -1.81415169e+03,
        2.29878837e+01, -9.52666108e+02, -2.44145381e+02, -2.19417562e+03,
       -1.95458518e+03,  2.02399644e+02, -2.36414539e+02,  5.45398804e+02,
       -1.99159695e+03, -6.16484693e+01,  3.92046027e+02, -1.16722482e+02,
       -1.81819697e+03,  4.09757397e+02,  2.73863924e+02,  4.31430902e+03,
       -1.19837397e+03, -1.16854575e+03, -1.08147611e+03, -1.23596337e+03,
        1.99006028e+01, -3.90786992e+02, -6.82447176e+01, -1.04059340e+03,
       -2.56733755e+02, -

In [None]:
%%time
mc_logit_pen.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 27.3 s, sys: 37.7 ms, total: 27.3 s
Wall time: 19.1 s


array([ 4730.13743753,   126.3159581 ,  1694.90014335, 30366.2316765 ,
         178.61299218,   225.00647065,   252.11591023,   198.66538876,
         328.52641718,   406.9842149 ,   943.25909324,   238.54110517,
         286.25653019,   286.46154002,   378.51799534,   173.96541391,
         141.09033139,   419.92746326,   742.38071838,   296.92077387,
         251.15104302,   296.93881718,   176.95264691,   297.55636233,
         438.5765216 ,   228.69316591,   314.09334209,   291.053651  ,
         340.67607427,   253.35699849,   194.40083657,   453.08769289,
         276.50999776,   210.19925402,   184.64539791,   256.70914464,
         334.03595586,   186.00159935,   144.55824709,   147.56897117,
         932.15020764,   150.1464805 ,   134.71524059,  1935.76202912,
         307.26093522,   299.15393329,   240.7412601 ,   379.3686566 ,
         157.80958771,   185.422735  ,   159.14549008,   507.03659444,
         479.59936621,   219.0241214 ,   159.28870286,   244.47344115,
      

In [None]:
%%time
mc_rr_pen.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 4min 55s, sys: 196 ms, total: 4min 56s
Wall time: 2min 33s


array([ 6.43249228e+03,  4.06414647e+03, -8.02219790e+02,  1.96648781e+04,
       -1.18198960e+03, -5.12103302e+02,  1.14552327e+01,  4.52122393e+02,
        1.63367934e+03,  3.34518265e+03, -6.32726897e+02, -2.11262350e+03,
       -1.13171780e+03, -1.78065942e+03, -5.81045275e+02, -7.62337167e+02,
        1.76488226e+03, -5.39607894e+00,  1.43082293e+01, -4.46612279e+02,
        9.75739445e+00,  3.20289560e+01, -1.66622185e+02,  4.57258478e+02,
        1.76485755e+03,  1.90154765e+03, -1.98937633e+02, -2.76140609e+03,
       -1.20099617e+00,  1.70066079e+00,  1.60615652e+03, -5.60209300e+02,
       -1.79945599e+03,  5.24713271e+02, -1.00178019e+03,  5.61842324e+02,
       -1.20074002e+03,  2.72315075e+02, -2.46888450e+01,  6.01715865e+03,
       -3.26568849e+03, -1.33153035e+03, -9.59846100e+02,  3.50277308e+03,
       -6.02988966e+02,  1.82212136e+03, -1.42035104e+03, -4.20305554e+02,
        4.03572311e+01, -3.16668079e+00,  2.25281083e+03, -7.94640347e+01,
       -1.06624269e+03, -

In [None]:
%%time
mc_lt_pen.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 10.7 s, sys: 18.9 ms, total: 10.7 s
Wall time: 5.2 s


array([ 4730.13743753,   126.3159581 ,  1694.90014335, 30366.2316765 ,
         178.61299218,   225.00647065,   252.11591023,   198.66538876,
         328.52641718,   406.9842149 ,   943.25909324,   238.54110517,
         286.25653019,   286.46154002,   378.51799534,   173.96541391,
         141.09033139,   419.92746326,   742.38071838,   296.92077387,
         251.15104302,   296.93881718,   176.95264691,   297.55636233,
         438.5765216 ,   228.69316591,   314.09334209,   291.053651  ,
         340.67607427,   253.35699849,   194.40083657,   453.08769289,
         276.50999776,   210.19925402,   184.64539791,   256.70914464,
         334.03595586,   186.00159935,   144.55824709,   147.56897117,
         932.15020764,   150.1464805 ,   134.71524059,  1935.76202912,
         307.26093522,   299.15393329,   240.7412601 ,   379.3686566 ,
         157.80958771,   185.422735  ,   159.14549008,   507.03659444,
         479.59936621,   219.0241214 ,   159.28870286,   244.47344115,
      