In [104]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [105]:
import numpy as np
import pandas as pd
import warnings
import sys

sys.path.append("../src")
from margin_calibration import MarginCalibration

In [106]:
warnings.simplefilter("ignore")

# Dataset Generation

In [107]:
# Create a 100 random observations with their respective
# sampling weights, such that those weights sum to 20 %
n_obs = 100
sampling_probabilities = np.random.rand(n_obs, 1).flatten()
sampling_probabilities = sampling_probabilities / sampling_probabilities.sum() *.2

# Create a matrix of size n_obs * n_margins, with margins ranging from 0 to 1000
n_margins = 2 # Let say we have two variables
calibration_matrix = 1000*np.random.rand(n_obs, n_margins)

# Now we create the calibration target
# It is of sier n_margins * 1 and contains the 
# sums of the margins over all the population
calibration_target = np.array([np.sum(calibration_matrix[:, i])*100 for i in range(calibration_matrix.shape[1])])

# The costs should be of size of the number of margins
costs = (1,1) # We give here same costs to both variables

# Margin Calibration

In [111]:
mc=MarginCalibration()
mc_logit=MarginCalibration("logit", .5, 1.5)
mc_rr=MarginCalibration("raking_ratio")
mc_lt=MarginCalibration("truncated_linear", .5, 1.5)

In [112]:
%%time
mc.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 1.4 s, sys: 3.99 ms, total: 1.4 s
Wall time: 206 ms


array([ 6.41993384e+02, -6.26411876e+01,  6.29861122e+01,  9.04233130e+01,
        1.38880822e+02,  1.30876907e+02,  1.66642258e+02, -6.82007061e+01,
        6.72556071e+01,  9.02531876e+01, -2.93658653e+01,  2.47540207e+02,
       -2.29235151e+00,  1.23162780e+01,  4.70909172e+01,  1.94769341e+02,
        1.48597255e+02,  2.70992960e+03, -1.66489334e+01,  1.23109464e+02,
       -3.53442141e+01, -3.02547074e+02,  1.92399334e+02,  1.72954844e+02,
        9.84236657e+02,  6.72711492e+02,  2.89791050e+01, -9.08261903e+00,
        3.44584961e+02,  8.39170668e+00,  4.81061807e+02, -3.05862089e+02,
        4.76139920e+01, -3.66502695e+01,  3.05543549e+02,  1.76347570e+02,
        5.34485465e+02, -1.44834218e+00, -1.59841552e+02, -1.27166061e+02,
        6.46487338e+01, -3.55393195e+02,  6.99995218e+01, -1.04608590e+03,
       -3.68531293e+02,  8.05957120e+01,  5.08287478e+02, -2.07646677e+02,
        9.96291241e+00,  2.91250487e+02,  3.90439131e+01,  2.12951503e+02,
        1.47067283e+02,  

In [113]:
%%time
mc_logit.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 1h 56min 3s, sys: 3.64 s, total: 1h 56min 6s
Wall time: 3min 52s


array([  682.54946855,  -733.37695001,  -152.80896794,  -495.27711301,
          47.56328237,  -323.91692863,    52.66145658, -1285.92776754,
        -160.33727359,   213.62910182,  -536.28069437,  -236.48018788,
        -323.36554819,  -818.91189128,  -661.48149874,   256.31565643,
        -264.94134165, 18789.48441823,  -821.08312918,  -168.95061671,
       -1123.67524075,   626.57597723,   762.20439626,   155.80162486,
        1335.48114788,  1869.34350174,  -203.22363837, -1048.39582431,
         175.72540171,  -285.93856641,   412.32817853,  3535.19031689,
        -329.90609515,  -841.47432283,   -26.72785974,   -98.52874271,
        1354.75327251,  -824.85745865, -1154.24964735, -1288.22930918,
         -41.94934404,  -607.89079691,   349.4302462 ,  1042.9736612 ,
        5272.74445884,  -163.88831399,  1403.62622408, -1283.44601342,
         709.81978825,   498.10727035,  -514.81306692,   -40.2656144 ,
        -181.24478925,  1418.39853492, -1097.59377471,  -106.04287365,
      

In [114]:
%%time
mc_rr.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 31min 35s, sys: 1.1 s, total: 31min 36s
Wall time: 1min 3s


array([ 214.26287207,   24.27600319,   31.16406382,   29.48799086,
         60.18863029,   41.02290592,   71.35509641,    5.98789362,
         35.79043153,   52.31751776,   28.74658774,  100.54208239,
         37.78136139,   14.03177503,   17.94160365,   87.45108986,
         63.13443553, 1879.01584606,   10.4733563 ,   51.85373749,
          7.24571442,  130.02527344,  109.22215281,   89.38928005,
        468.37019151,  391.43113674,   24.14071308,   10.01975354,
        125.65705581,   20.5450812 ,  316.36661407,  316.37764822,
         23.58636319,    8.68944889,  169.6046775 ,   90.75888184,
        350.85527905,   11.78196675,   16.69091495,    7.16899152,
         31.07243426,   42.62687274,   78.42398323,   87.33804233,
        321.9892762 ,   40.36318402,  197.5471232 ,    6.16670039,
        117.21886478,  151.94919993,   23.6501358 ,   81.82294579,
         50.08612881,  874.75632674,    7.83197373,   43.10581016,
         17.31441782,    9.80998054,  288.87248322,   64.01307

In [115]:
%%time
mc_lt.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 1h 23min 9s, sys: 4.85 s, total: 1h 23min 14s
Wall time: 2min 46s


array([ 2.08586817e+01, -4.49273388e+02, -1.07029486e+02, -4.41258681e+02,
       -3.46187783e+02, -4.16877909e+02, -1.82821895e+02, -7.64176330e+02,
       -1.92534901e+02,  1.43628491e+02, -4.72274759e+02, -1.72668666e+02,
       -4.19313385e+02, -5.74741994e+02, -8.45597263e+02, -8.53046686e+00,
       -7.97038673e+01,  1.89133081e+04, -5.53061084e+02, -3.46991158e+01,
       -8.16960466e+02,  8.70442460e+02,  2.82660208e+02,  1.11261986e+02,
        1.33026964e+03,  1.23828076e+03, -1.72483484e+02, -7.79714471e+02,
       -1.21441689e+02, -1.78915786e+02,  4.53272684e+02,  3.95141901e+03,
       -2.52888669e+02, -4.91226771e+02,  8.16278972e+01,  5.96795822e+01,
        3.71176732e+02, -6.83327431e+02, -6.65956708e+02, -7.69682136e+02,
       -1.42226075e+02, -6.40629508e+02, -4.93288133e+02,  1.04903117e+03,
        5.51787356e+03, -2.16925474e+02,  7.27128037e+02, -6.36007391e+02,
        1.97839029e+02,  3.30279950e+02, -3.42079515e+02, -1.39158594e+02,
       -3.27028448e+02,  

# Penalized Margin Calibration

In [127]:
mc_pen=MarginCalibration(penalty=.1, costs=costs)
mc_logit_pen=MarginCalibration("logit", .5, 1.5, penalty=.1, costs=costs)
mc_rr_pen=MarginCalibration("raking_ratio", penalty=.1, costs=costs)
mc_lt_pen=MarginCalibration("truncated_linear", .5, 1.5, penalty=.1, costs=costs)

In [128]:
%%time
mc_pen.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 21.5 ms, sys: 6 μs, total: 21.5 ms
Wall time: 20.5 ms


array([-2.01058342e+02, -1.32154687e+03, -2.48694878e+02, -6.11241431e+02,
       -1.12795179e+03, -5.38575598e+02, -6.81466944e+02, -5.57309921e+02,
       -5.17295540e+02,  5.98718406e+01, -1.41867714e+03, -4.95033089e+02,
       -1.19866417e+03, -4.97594929e+02, -1.65677517e+03, -2.20981814e+02,
       -7.68499696e+01,  3.46644572e+04, -3.36959723e+02, -1.07167525e+02,
       -4.03233120e+02,  8.33863254e+02,  2.63435053e+02,  6.16863216e+01,
        2.13703318e+03,  1.45336714e+03, -3.78031708e+02, -5.60546365e+02,
       -6.29681339e+02, -3.76777463e+02,  4.05192924e+02,  4.92734403e+03,
       -3.34002388e+02, -2.25234310e+02, -1.03779101e+02,  3.83827772e+01,
        2.72304313e+02, -6.17101551e+02, -1.89109905e+03, -1.70714219e+03,
       -3.20067795e+02, -1.58484655e+03, -1.21933079e+03,  7.94773128e+02,
        6.73837010e+03, -6.86796663e+02,  3.68021065e+02, -1.93475556e+03,
        1.28838171e+01,  3.95250262e+02, -6.84010683e+02, -5.06501799e+02,
       -4.02004704e+02,  

In [129]:
%%time
mc_logit_pen.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 596 ms, sys: 11.7 ms, total: 608 ms
Wall time: 604 ms


array([ 1018.7949904 ,   338.35844878,   160.00629474,   161.03319896,
         425.61288589,   157.80723684,   373.26137727,   138.83191405,
         212.14050834,   196.06288646,   364.15400587,   200.76478558,
         405.7576084 ,   143.69820745,   210.26928743,   301.3289038 ,
         138.09619639, 18913.27169843,   133.93433377,   147.85114548,
         130.26115985,  1771.49009669,   192.82841705,   152.59743591,
        1557.05656896,   765.66594436,   170.76719754,   139.4089952 ,
         460.06359786,   168.63211718,   328.21985306,  3867.31039226,
         149.47366841,   128.59834853,   208.51135901,   139.14654431,
         341.42746523,   146.24614851,   375.20294763,   218.45425073,
         165.34277423,   955.79921606,   911.54138272,  2087.36495761,
        5471.60267102,   248.25999089,  1837.26315849,   252.680389  ,
        1220.57543907,   497.1922944 ,   184.06022057,   222.03612175,
         152.96698265,  1686.11143517,   142.93572669,   184.49883891,
      

In [130]:
%%time
mc_rr_pen.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 851 ms, sys: 3.95 ms, total: 855 ms
Wall time: 853 ms


array([3.12191099e-05, 2.00550214e-05, 2.78359078e-05, 2.91609289e-05,
       2.57386341e-05, 3.50307156e-05, 2.83485449e-05, 1.52590256e-05,
       2.64448601e-05, 3.04411661e-05, 2.08781953e-05, 5.17144734e-05,
       2.18302332e-05, 2.24285424e-05, 2.28562407e-05, 3.34194471e-05,
       4.63570800e-05, 4.62450963e+03, 2.05139153e-05, 3.78651234e-05,
       1.88883120e-05, 2.05904627e-05, 4.76606789e-05, 5.23134807e-05,
       3.42439481e-05, 4.37855953e-05, 2.44389894e-05, 2.05752383e-05,
       3.43681971e-05, 2.30190993e-05, 8.76089529e-05, 2.14796871e-05,
       2.60816749e-05, 1.93736127e-05, 7.80636829e-05, 5.97525228e-05,
       9.87404383e-05, 2.12265353e-05, 1.78131524e-05, 1.65652665e-05,
       2.76229597e-05, 1.78183531e-05, 2.23625476e-05, 1.77570687e-05,
       2.08057430e-05, 2.63532444e-05, 2.46839890e-05, 1.53014974e-05,
       2.21053533e-05, 3.31771392e-05, 2.43606874e-05, 4.08635128e-05,
       3.95095623e-05, 5.17583630e-05, 1.85043214e-05, 3.09044234e-05,
      

In [132]:
%%time
mc_lt_pen.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 560 ms, sys: 3.94 ms, total: 564 ms
Wall time: 561 ms


array([ 1018.7949904 ,   338.35844878,   160.00629474,   161.03319896,
         425.61288589,   157.80723684,   373.26137727,   138.83191405,
         212.14050834,   196.06288646,   364.15400587,   200.76478558,
         405.7576084 ,   143.69820745,   210.26928743,   301.3289038 ,
         138.09619639, 18913.27169843,   133.93433377,   147.85114548,
         130.26115985,  1771.49009669,   192.82841705,   152.59743591,
        1557.05656896,   765.66594436,   170.76719754,   139.4089952 ,
         460.06359786,   168.63211718,   328.21985306,  3867.31039226,
         149.47366841,   128.59834853,   208.51135901,   139.14654431,
         341.42746523,   146.24614851,   375.20294763,   218.45425073,
         165.34277423,   955.79921606,   911.54138272,  2087.36495761,
        5471.60267102,   248.25999089,  1837.26315849,   252.680389  ,
        1220.57543907,   497.1922944 ,   184.06022057,   222.03612175,
         152.96698265,  1686.11143517,   142.93572669,   184.49883891,
      