In [None]:
import time
import warnings

from sklearn.metrics import balanced_accuracy_score
from sklearn.model_selection import train_test_split

from src.data_loader.DataLoader import DataLoader
from src.model.LogisticRegression import LogisticRegression
from src.optim.ADAM import ADAM
from src.optim.IWLS import IWLS
from src.optim.SGD import SGD
from src.optim.conditions import NoLogLikOrMaxIterCondition

warnings.filterwarnings('ignore')

# Example of usage
dl = DataLoader(product=True)
sd = dl.get_supported_datasets()

rows = []

for d in sd:
    print(d)
    x, y = dl[d]

    for t in range(3):
        train_x, test_x, train_y, test_y = train_test_split(x, y, test_size=0.3, stratify=y, random_state=42)
        print("ADAM")
        start = time.time()
        model = LogisticRegression()
        optim = ADAM(model, NoLogLikOrMaxIterCondition(50, 1e-3))
        model = optim.optimize(train_x, train_y)
        end = time.time()
        rows.append(
            {"optim": "ADAM", "time": start - end, "accuracy": balanced_accuracy_score(test_y, model.predict(test_x)),
             "t": t, "iters": optim.stop_condition.epoch})

        print("IWLS")
        start = time.time()
        model = LogisticRegression()
        optim = IWLS(model, NoLogLikOrMaxIterCondition(50, 1e-3))
        model = optim.optimize(train_x, train_y)
        end = time.time()
        rows.append(
            {"optim": "IWLS", "time": start - end, "accuracy": balanced_accuracy_score(test_y, model.predict(test_x)),
             "t": t, "iters": optim.stop_condition.epoch})

        print("SGD")
        start = time.time()
        model = LogisticRegression()
        optim = SGD(model, NoLogLikOrMaxIterCondition(50, 1e-3))
        model = optim.optimize(train_x, train_y)
        end = time.time()
        rows.append(
            {"optim": "SGD", "time": start - end, "accuracy": balanced_accuracy_score(test_y, model.predict(test_x)),
             "t": t, "iters": optim.stop_condition.epoch})


banknote
(1372, 2)
ADAM
IWLS
SGD
ADAM
IWLS
SGD
ADAM
IWLS
SGD
kin8nm
(8192, 8)
ADAM
IWLS
SGD
ADAM
IWLS
SGD
ADAM
IWLS


In [None]:
import pandas as pd

df = pd.DataFrame(rows)
df