In [83]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [84]:
import numpy as np
import pandas as pd
import warnings
import sys

sys.path.append("../src")
from margin_calibration import MarginCalibration

In [85]:
warnings.simplefilter("ignore")

# Dataset Generation

In [86]:
# Create a 100 random observations with their respective
# sampling weights, such that those weights sum to 20 %
n_obs = 100
sampling_probabilities = np.random.rand(n_obs, 1).flatten()
sampling_probabilities = sampling_probabilities / sampling_probabilities.sum() *.2

# Create a matrix of size n_obs * n_margins, with margins ranging from 0 to 1000
n_margins = 2 # Let say we have two variables
calibration_matrix = 1000*np.random.rand(n_obs, n_margins)

# Now we create the calibration target
# It is of sier n_margins * 1 and contains the 
# sums of the margins over all the population
calibration_target = np.array([np.sum(calibration_matrix[:, i])*100 for i in range(calibration_matrix.shape[1])])

# The costs should be of size of the number of margins
costs = (1,1) # We give here same costs to both variables

# Margin Calibration

In [87]:
mc=MarginCalibration()
mc_logit=MarginCalibration("logit", .5, 1.5)
mc_rr=MarginCalibration("raking_ratio")
mc_lt=MarginCalibration("truncated_linear", .5, 1.5)

In [88]:
%%time
mc.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 27.3 s, sys: 173 ms, total: 27.5 s
Wall time: 17.9 s


array([-9.88407057e+01,  8.65676777e+01,  6.84494297e+02, -8.02749095e+01,
        1.13362526e+02, -2.05318967e+02, -1.10744576e+02,  5.24532205e+02,
        2.15043556e+02, -9.46270213e+01,  4.43366095e+02,  6.58744384e+01,
       -6.63548539e+01,  3.15863266e+02,  2.84206483e+02, -2.38275285e+01,
        5.22459599e+01, -6.47286621e+01, -1.24396622e+02, -2.29585762e+02,
       -2.64969843e+02,  2.80545117e+02,  3.56421840e+02,  2.07093707e+02,
        2.00939275e+02,  1.08154313e+02,  1.90849260e+02,  5.40685887e+03,
        5.95198100e+00,  1.66950049e+01,  1.62960203e+01, -2.01923755e+02,
        1.97796563e+02,  1.26956847e+02,  6.12037607e+01, -4.81428666e+00,
       -1.14616196e+02,  1.85676869e+02, -6.73737615e+01,  1.00755848e+02,
       -2.88637907e+02, -2.66238303e+01,  3.36289420e+02,  1.02697070e+03,
        1.05558326e+02,  2.74124370e+02,  1.80920550e+02,  6.55185116e+02,
        1.11236988e+02,  9.04300302e+03,  2.16956919e+02,  1.45866747e+02,
        1.15476239e+03, -

In [89]:
%%time
mc_logit.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 8h 47min 36s, sys: 3min 33s, total: 8h 51min 10s
Wall time: 33min 17s


array([ 1.59579727e+02,  5.41619696e+02,  2.71946366e+02, -8.12477364e+02,
        4.06449769e+02,  2.24884865e+02, -2.28627916e+02,  4.16813236e+02,
        3.56076963e+01, -2.52597452e+02,  2.94727732e+02, -6.20698617e+01,
       -2.56531753e+02,  7.43373156e+02,  6.62221529e+02, -4.17218885e+02,
       -9.27099911e+01, -1.39819448e+02, -8.05849829e+02, -7.23126462e+02,
       -7.63588696e+02,  1.95917211e+02,  2.48282050e+02,  1.23630660e+02,
        1.66056559e+02,  5.05805558e+00,  2.22314467e+02,  4.43853014e+03,
        1.06342648e+02, -4.85651956e+01, -2.10480900e+02, -8.20842921e+02,
       -5.10207794e+01, -1.47084001e+02, -3.21954577e+01, -3.65302355e+02,
        1.08512659e+02,  6.01950638e+01, -5.19500509e+02, -2.02349795e+00,
       -4.45950731e+02, -4.32184642e+02,  1.78602923e+02,  6.22526748e+02,
        9.08864293e+01,  1.44391032e+02,  1.81580051e+02,  8.25750565e+02,
        1.26715305e+02,  1.97096549e+04,  9.15521299e+00,  5.41137856e+01,
        1.42824973e+03, -

In [90]:
%%time
mc_rr.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 5min 20s, sys: 1.14 s, total: 5min 21s
Wall time: 3min 24s


array([ 8.42811028e-01, -3.15573161e-01,  4.51178094e+01, -3.30933536e+00,
       -2.02598051e-01, -1.12196119e+02, -1.93178963e+01,  3.74863176e+00,
        3.20428298e+00,  1.16238968e+00,  2.61550031e+00,  6.75848902e-01,
       -2.85651460e+00,  4.36372572e+02,  3.33666410e+02, -1.15666231e+00,
       -2.59292024e-01, -2.31410399e+00, -4.96984808e+02, -4.39365515e+02,
       -1.67136704e+03,  2.16747439e+02, -1.77254537e+00,  7.80076342e+00,
        4.30417883e+01,  3.18081216e+00,  1.87784849e+00,  3.05980811e+03,
        1.02001945e-02,  8.96696036e-01, -7.83903807e-01, -8.92462385e+02,
        8.65816733e+00,  1.98119340e+00,  3.56561214e+00,  1.85647506e-01,
        9.20169979e-01,  5.70287369e+00, -2.09079562e+02,  4.10384081e-01,
        5.00218868e-01, -1.14461514e-01,  1.05236133e+00,  4.96701468e+02,
        2.56589529e+00,  6.67387404e+00,  8.89269833e+01, -1.94743806e-01,
        5.43449847e-01,  2.69894372e+04,  2.35227563e+01,  4.49948644e-01,
        1.78154367e+00, -

In [None]:
%%time
mc_lt.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 16h 25min 59s, sys: 47.5 s, total: 16h 26min 46s
Wall time: 1h


array([-2.39550968e+02, -3.12849070e+01,  3.45850361e+02, -4.72431144e+02,
        8.05146043e+01, -3.40519732e+02, -2.93461264e+02,  2.68965034e+02,
        3.94412330e-01, -2.68325775e+02,  1.92660244e+02, -2.24652111e+02,
       -2.62098590e+02,  6.80992377e+02,  6.05619130e+02, -4.21402030e+02,
       -1.69392417e+02, -2.30447159e+02, -4.95235395e+02, -3.81964865e+02,
       -4.14765574e+02,  2.74118890e+02,  1.09446731e+02,  8.41735618e+01,
        1.88146339e+02, -2.39047455e+01, -1.33173335e+01,  4.62237147e+03,
       -1.86926491e+02, -1.59415617e+02, -1.91443391e+02, -3.67746491e+02,
        2.62921031e+01, -1.18656196e+02, -8.15514990e+01, -3.57378004e+02,
       -2.80084037e+02,  3.16084679e+01, -4.33706971e+02, -5.15943160e+01,
       -3.50748397e+02, -3.80305974e+02,  9.38623365e+01,  9.27413069e+02,
       -6.99986083e+01,  1.13022870e+02,  2.10353580e+02,  4.56522763e+02,
       -8.54255351e+01,  1.97379693e+04,  5.47835128e+01, -3.61797215e+01,
        1.13377119e+03, -

# Penalized Margin Calibration

In [None]:
mc_pen=MarginCalibration(penalty=.1, costs=costs)
mc_logit_pen=MarginCalibration("logit", .5, 1.5, penalty=.1, costs=costs)
mc_rr_pen=MarginCalibration("raking_ratio", penalty=.1, costs=costs)
mc_lt_pen=MarginCalibration("truncated_linear", .5, 1.5, penalty=.1, costs=costs)

In [None]:
%%time
mc_pen.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 1min 53s, sys: 141 ms, total: 1min 53s
Wall time: 16.3 s


array([-4.10499544e+02, -7.81233389e+02,  4.34438148e+02, -7.08219895e+01,
       -1.36403601e+02, -1.09667235e+03, -5.85888792e+02,  4.91255758e+01,
        2.00025345e+02, -2.62862804e+02,  4.68519657e+01, -8.58608419e+01,
       -2.87980684e+02,  2.23328318e+02,  3.16646129e+02, -1.52356361e+02,
       -1.27256822e+02, -2.25808926e+02, -9.59553648e+01, -1.95135169e+02,
       -1.49458992e+02,  4.54384439e+02, -3.45444957e+02,  2.93857331e+02,
        7.46933860e+00, -1.52546249e+02, -1.30765423e+02,  7.32717255e+03,
       -6.12757840e+02, -1.29183411e+02, -5.40856966e+01, -3.52975396e+02,
        5.03151143e+02,  2.25485416e+02,  2.27589632e+02, -1.15528740e+02,
       -7.03222155e+02, -3.42745341e+02, -1.30934364e+02,  2.06021641e+02,
       -8.70462173e+02, -1.04513783e+02, -3.13011814e+02,  1.03583335e+03,
       -3.61232769e+02,  2.38593707e+02,  2.58650218e+02,  2.18291687e+02,
       -5.53942998e+02,  2.74867268e+04,  3.14956095e+02,  1.48928462e+02,
        8.59723904e+02, -

In [None]:
%%time
mc_logit_pen.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 42.5 s, sys: 11.2 ms, total: 42.5 s
Wall time: 6.2 s


array([  445.32352216,   774.36484701,   423.87962075,   141.11418411,
         635.46756758,   644.34383339,   294.02737088,   616.31253123,
         217.66893085,   236.79952166,   487.92761072,   153.9430304 ,
         178.58382774,   249.64590107,   220.84982905,   152.78403658,
         208.5287406 ,   233.25229627,   167.67560815,   217.31282088,
         216.14384497,   149.69884798,   448.15810861,   316.27194736,
         216.96742086,   170.02941577,   432.71502898,  4622.3925948 ,
         350.4826602 ,   223.88309221,   179.10404752,   231.76930042,
         158.31746156,   158.58787223,   151.66206207,   154.51667271,
         406.65913765,   218.36473678,   164.55101957,   172.18489522,
         568.77413756,   131.23281427,   363.36314425,   599.50366936,
         305.90931856,   323.22161613,   133.6316683 ,  1033.45902849,
         347.70653486, 19737.95871048,   152.01913985,   240.85503705,
        1644.00315675,   248.36640116,   180.81448961,   246.78852307,
      

In [None]:
%%time
mc_rr_pen.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 4min 45s, sys: 77.3 ms, total: 4min 45s
Wall time: 40.8 s


array([-1.08268954e+03, -7.80527512e+02,  2.64764047e+01, -1.07332626e+03,
        2.16036702e+02,  2.15997201e-01, -3.82836507e+02, -5.65800939e+01,
        5.16961843e+02, -2.63872796e+02,  6.36132654e+02, -3.99558648e+01,
        1.46378496e+02,  3.41822212e+01, -5.09823997e-02, -6.62669137e+02,
       -1.98610768e+00, -2.27617863e+02, -1.13456462e+03, -1.27542484e+03,
       -1.46652457e+03,  1.08056833e-01,  9.53299680e+01, -7.73869257e+01,
        2.11496110e+02,  1.11802085e+00,  4.83505711e-01,  7.17574384e+03,
        6.11798743e+02,  3.43511398e+02,  5.52770119e-02, -1.19920632e+03,
        4.81566873e+02,  2.21085913e-02,  8.98535538e+01, -4.50051771e+02,
       -1.11649188e+02,  1.59837796e+03, -2.81973891e+01,  4.04707947e+02,
       -1.62746829e+02, -1.48464866e+02,  9.46982380e-02,  7.16276510e+02,
       -8.56599110e+01,  1.76313542e+02,  5.13205130e+01, -1.23001545e+02,
       -1.01619425e+02,  3.18775167e+04,  9.15673493e+02,  3.34417780e+02,
        4.24570949e+02, -

In [None]:
%%time
mc_lt_pen.calibration(sampling_probabilities, calibration_matrix, calibration_target).x

CPU times: user 9.46 s, sys: 8.56 ms, total: 9.47 s
Wall time: 1.36 s


array([  445.32352216,   774.36484701,   423.87962075,   141.11418411,
         635.46756758,   644.34383339,   294.02737088,   616.31253123,
         217.66893085,   236.79952166,   487.92761072,   153.9430304 ,
         178.58382774,   249.64590107,   220.84982905,   152.78403658,
         208.5287406 ,   233.25229627,   167.67560815,   217.31282088,
         216.14384497,   149.69884798,   448.15810861,   316.27194736,
         216.96742086,   170.02941577,   432.71502898,  4622.3925948 ,
         350.4826602 ,   223.88309221,   179.10404752,   231.76930042,
         158.31746156,   158.58787223,   151.66206207,   154.51667271,
         406.65913765,   218.36473678,   164.55101957,   172.18489522,
         568.77413756,   131.23281427,   363.36314425,   599.50366936,
         305.90931856,   323.22161613,   133.6316683 ,  1033.45902849,
         347.70653486, 19737.95871048,   152.01913985,   240.85503705,
        1644.00315675,   248.36640116,   180.81448961,   246.78852307,
      