# Linear models - scalability

In this notebook, we will make a quick note to show the `partial_fit` functionality of some estimator that could be used to train a model.

In [10]:
import pandas as pd

data = pd.read_csv("../datasets/adult-census-numeric-all.csv")
data.head()

Unnamed: 0,age,education-num,capital-gain,capital-loss,hours-per-week,class
0,25,7,0,0,40,<=50K
1,38,9,0,0,50,<=50K
2,28,12,0,0,40,>50K
3,44,10,7688,0,40,>50K
4,18,10,0,0,30,<=50K


In [11]:
target_name = "class"
X = data.drop(columns=target_name)
y = data[target_name]

In [12]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    X, y, random_state=0,
)

In [4]:
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()

In [5]:
batch_size = 100
start = 0
while start < y_train.size:
    stop = start + batch_size
    scaler.partial_fit(X_train[start:stop])
    start = stop

In [6]:
scaler.mean_, scaler.var_

(array([  38.68545767,   10.07327127, 1063.20692856,   86.77983129,
          40.42248369]),
 array([1.88511311e+02, 6.61175889e+00, 5.43824675e+07, 1.61245256e+05,
        1.54386985e+02]))

In [7]:
scaler = StandardScaler().fit(X_train)
scaler.mean_, scaler.var_

(array([  38.68545767,   10.07327127, 1063.20692856,   86.77983129,
          40.42248369]),
 array([1.88511311e+02, 6.61175889e+00, 5.43824675e+07, 1.61245256e+05,
        1.54386985e+02]))

In [8]:
from sklearn.linear_model import SGDClassifier

model = SGDClassifier(
    loss="hinge", alpha=0.01, max_iter=200
)

In [9]:
import numpy as np

batch_size = 4_000
start = 0
iteration = 1
while start < y_train.size:
    stop = start + batch_size
    X_scaled = scaler.transform(X_train[start:stop])
    if not start:
        params = {"classes": np.unique(y)}
    else:
        params = {}
    model.partial_fit(X_scaled, y_train[start:stop], **params)
    print(
        f"Iteration #{iteration}: Weights:\n"
        f"{model.coef_}"
    )
    iteration += 1
    start = stop

Iteration #1: Weights:
[[0.09907963 0.26380117 1.52122195 0.35878544 0.01237512]]
Iteration #2: Weights:
[[0.01640582 0.07287119 1.16940378 0.27487187 0.09471367]]
Iteration #3: Weights:
[[0.15245868 0.25970974 1.31388277 0.36729879 0.08375345]]
Iteration #4: Weights:
[[0.15590466 0.2335446  1.22708766 0.3044234  0.21978365]]
Iteration #5: Weights:
[[0.13269746 0.19178599 1.11354771 0.27223249 0.08386686]]
Iteration #6: Weights:
[[0.1194259  0.17126778 1.13674605 0.32037163 0.07275619]]
Iteration #7: Weights:
[[0.06842697 0.08706685 1.18157248 0.40481137 0.06110402]]
Iteration #8: Weights:
[[0.06268194 0.09917399 1.20545059 0.23858605 0.09608195]]
Iteration #9: Weights:
[[0.13817491 0.21181688 1.22910588 0.25993019 0.18156781]]
Iteration #10: Weights:
[[0.13333257 0.18698275 1.23286096 0.16778199 0.12090864]]
