In [5]:
import pandas as pd
import numpy as np
from collections import Counter




In [6]:
# Simulate a skewed "transactions" dataset
np.random.seed(42)
user_ids = np.random.choice(
    [1, 2, 3, 4, 5, 999],  # user 999 will be our "skewed" key
    size=10000,
    p=[0.05, 0.05, 0.05, 0.05, 0.05, 0.75]  # 75% of rows are user 999
)


In [7]:
transactions = pd.DataFrame({
    'user_id': user_ids,
    'amount': np.random.randint(10, 500, size=10000)
})

In [11]:
df_counts = transactions.groupby('user_id').size().reset_index(name='n_rows').sort_values('n_rows', ascending=False)
df_counts

Unnamed: 0,user_id,n_rows
5,999,7453
1,2,538
3,4,509
4,5,504
2,3,503
0,1,493


In [15]:
# correct way to sum the 'n_rows' column
df_counts['n_rows'].sum()
df_counts['skew'] =( df_counts['n_rows'] / df_counts['n_rows'].sum())*100
df_counts

Unnamed: 0,user_id,n_rows,skew
5,999,7453,74.53
1,2,538,5.38
3,4,509,5.09
4,5,504,5.04
2,3,503,5.03
0,1,493,4.93


In [10]:


# Small users dataset
users = pd.DataFrame({
    'user_id': [1, 2, 3, 4, 5, 999],
    'country': ['US', 'CA', 'FR', 'DE', 'UK', 'IN']
})

print("Transaction distribution by user:")
print(Counter(transactions['user_id']))

Transaction distribution by user:
Counter({999: 7453, 2: 538, 4: 509, 5: 504, 3: 503, 1: 493})


In [17]:
users.head()

Unnamed: 0,user_id,country
0,1,US
1,2,CA
2,3,FR
3,4,DE
4,5,UK


In [18]:
transactions.head()

Unnamed: 0,user_id,amount
0,999,192
1,999,361
2,999,45
3,999,379
4,4,46


In [20]:
joined = transactions.merge(users, on = 'user_id', how = 'left')

In [23]:
partition_load =  joined['country'].value_counts()

In [24]:
print("\nSimulated partition load:")
print(partition_load)


Simulated partition load:
country
IN    7453
CA     538
DE     509
UK     504
FR     503
US     493
Name: count, dtype: int64


In [26]:
# Step 3 â€” Apply Salting

k = 5  # number of salt buckets

# Add random salt to large (skewed) table
transactions['salt'] = np.where(
    transactions['user_id'] == 999,  # only salt the heavy key
    np.random.randint(0, k, size=len(transactions)),  # assign random salt 0..k-1
    0  # others stay at 0
)

In [27]:
transactions.head()

Unnamed: 0,user_id,amount,salt
0,999,192,3
1,999,361,4
2,999,45,2
3,999,379,2
4,4,46,0


In [28]:
# Duplicate the small-side (users) table for those salt keys
users_salted = (
    users.assign(key=1)
    .merge(pd.DataFrame({'salt': range(k)}).assign(key=1), on='key')
    .drop('key', axis=1)
)

In [30]:
users_salted.loc[users_salted['user_id'] != 999, 'salt'] = 0  # only skewed key gets multiple salts
users_salted

Unnamed: 0,user_id,country,salt
0,1,US,0
1,1,US,0
2,1,US,0
3,1,US,0
4,1,US,0
5,2,CA,0
6,2,CA,0
7,2,CA,0
8,2,CA,0
9,2,CA,0


In [31]:
joined_salted = transactions.merge(users_salted, on=['user_id', 'salt'], how='left')

In [32]:
joined_salted.head()

Unnamed: 0,user_id,amount,salt,country
0,999,192,3,IN
1,999,361,4,IN
2,999,45,2,IN
3,999,379,2,IN
4,4,46,0,DE


In [33]:
print("Partition load after salting:")
print(joined_salted['country'].value_counts())

Partition load after salting:
country
IN    7453
CA    2690
DE    2545
UK    2520
FR    2515
US    2465
Name: count, dtype: int64
