# Iterative Proportional Fitting, Higher dimensions

- race
  - white: 58%
  - other: 42%
- age
  - minor: 28%
  - adult: 72%
- gender
  - male: 49%
  - female: 51%


In [1]:
import pandas as pd
import numpy as np
import random
import itertools

np.random.seed(37)
random.seed(37)

height = [
    np.random.normal(5.5, 1.0, 100),
    np.random.normal(5.3, 1.0, 200),
    np.random.normal(5.9, 1.0, 300),
    np.random.normal(5.7, 1.0, 200),
    np.random.normal(5.3, 1.0, 400),
    np.random.normal(5.2, 1.0, 500),
    np.random.normal(5.8, 1.0, 300),
    np.random.normal(5.5, 1.0, 200)
]

demographic = [
    ['white', 'minor', 'male'],
    ['white', 'minor', 'female'],
    ['white', 'adult', 'male'],
    ['white', 'adult', 'female'],
    ['other', 'minor', 'male'],
    ['other', 'minor', 'female'],
    ['other', 'adult', 'male'],
    ['other', 'adult', 'female']
]

data = [[{'race': d[0], 'age': d[1], 'gender': d[2], 'height': h} for h in s] for d, s in zip(demographic, height)]
data = list(itertools.chain(*data))

df = pd.DataFrame(data)
df.head()

Unnamed: 0,race,age,gender,height
0,white,minor,male,5.445536
1,white,minor,male,6.174308
2,white,minor,male,5.846647
3,white,minor,male,4.199654
4,white,minor,male,7.018512


In [2]:
def get_target_marginals(d):
    factors = list(d.keys())
    targets = [sorted([(k2, v2) for k2, v2 in v.items()]) for k, v in d.items()]
    targets = np.array([[v for _, v in item] for item in targets])
    return factors, targets

def get_table(df, targets):
    factors, target_marginals = get_target_marginals(targets)
    
    cross_tab = pd.crosstab(df[factors[0]], [df[c] for c in factors[1:]])
    shape = tuple([df[c].unique().shape[0] for c in factors])
    table = cross_tab.values.reshape(shape)
    
    return factors, target_marginals, table

f, u, X = get_table(df, {
    'race': {'white': 5800, 'other': 4200},
    'age': {'minor': 2800, 'adult': 7200},
    'gender': {'male': 4900, 'female': 5100}
})

In [3]:
def get_coordinates(M):
    return list(itertools.product(*[list(range(n)) for n in M.shape]))

def get_marginals(M, i):
    coordinates = get_coordinates(M)
    
    key = lambda tup: tup[0]
    counts = [(c[i], M[c]) for c in coordinates]
    counts = sorted(counts, key=key)
    counts = itertools.groupby(counts, key=key)
    counts = {k: sum([v[1] for v in g]) for k, g in counts}
    
    return counts

def get_all_marginals(M):
    return np.array([[v for _, v in get_marginals(M, i).items()] 
                     for i in range(len(M.shape))])

def get_counts(M, i):
    coordinates = get_coordinates(M)
    
    key = lambda tup: tup[0]
    counts = [(c[i], M[c], c) for c in coordinates]
    counts = sorted(counts, key=key)
    counts = itertools.groupby(counts, key=key)
    counts = {k: [(tup[1], tup[2]) for tup in g] for k, g in counts}
    
    return counts

def update_values(M, i, u):
    marg = get_marginals(M, i)
    vals = get_counts(M, i)
    
    d = [[(c, n * u[k] / marg[k]) for n, c in v] for k, v in vals.items()]
    d = itertools.chain(*d)
    d = list(d)
    
    return d

def ipf_update(M, u):
    for i in range(len(M.shape)):
        values = update_values(M, i, u[i])
        for idx, v in values:
            M[idx] = v
    
    o = get_all_marginals(M)
    d = get_deltas(o, u)
    
    return M, d

def get_deltas(o, t):
    return np.array([np.linalg.norm(o[r] - t[r], 2) for r in range(o.shape[0])])

def get_weights(X, max_iters=50, zero_threshold=0.0001, convergence_threshold=3, debug=True):
    M = X.copy()
    
    d_prev = np.zeros(len(M.shape))
    count_zero = 0

    for _ in range(max_iters):
        M, d_next = ipf_update(M, u)
        d = np.linalg.norm(d_prev - d_next, 2)

        if d < zero_threshold:
            count_zero += 1

        if debug:
            print(','.join([f'{v:.5f}' for v in d_next]), d)
        d_prev = d_next

        if count_zero >= convergence_threshold:
            break

    w = M / M.sum()
    return w

In [5]:
w = get_weights(X)
w

758.02375,123.06909,3.16228 767.9557278906121
75.74299,7.28011,3.60555 692.0363565714724
6.70820,2.23607,2.23607 69.23235546955618
2.23607,2.23607,2.23607 4.47213595499958
2.23607,2.23607,2.23607 0.0
2.23607,2.23607,2.23607 0.0
2.23607,2.23607,2.23607 0.0


array([[[0.113334  , 0.13604081],
        [0.10423127, 0.0663199 ]],

       [[0.21406422, 0.256677  ],
        [0.0783235 , 0.0310093 ]]])

In [14]:
{k: v for k, v in zip(list(itertools.product(*[sorted(df[c].unique()) for c in f])), np.ravel(w))}

{('other', 'adult', 'female'): 0.11333400020006001,
 ('other', 'adult', 'male'): 0.1360408122436731,
 ('other', 'minor', 'female'): 0.10423126938081424,
 ('other', 'minor', 'male'): 0.06631989596879063,
 ('white', 'adult', 'female'): 0.21406421926577973,
 ('white', 'adult', 'male'): 0.2566770031009303,
 ('white', 'minor', 'female'): 0.07832349704911473,
 ('white', 'minor', 'male'): 0.031009302790837252}