In [1]:
from typing import Tuple, List, Union, Any, Optional, Dict, Literal, Callable, Type
import abc

from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torch.utils.data
from torch import Tensor
from sklearn.linear_model import RidgeClassifierCV
import xgboost as xgb

from models.ridge_ALOOCV import fit_ridge_ALOOCV
from models.sandwiched_least_squares import sandwiched_LS_dense, sandwiched_LS_diag, sandwiched_LS_scalar

In [3]:
import cProfile
import pstats
from models.random_feature_representation_boosting import GreedyRFRBoostRegressor
import torch
import torch.nn.functional as F

# Make regression data X, y
device = "cpu"
N = 5000
N_test = 5000
D = 10
d = 1
X = torch.randn(N, D).to(device)
X_test = torch.randn(N_test, D).to(device)
w_true = torch.randn(D, d).to(device)
y = (X @ w_true)**2 + torch.randn(N, d).to(device) * 0.1  # Adding some noise
y_test = (X_test @ w_true)**2 + torch.randn(N_test, d).to(device) * 0.1  # Adding some noise


def train_and_evaluate():
    #dense      
    model = GreedyRFRBoostRegressor(
        in_dim = D,
        out_dim = d,
         hidden_dim=128, 
         randfeat_x0_dim=512, 
         randfeat_xt_dim=512,
         n_layers=10,
         l2_reg = 0.01,
         l2_ghat=0.1, 
         feature_type="SWIM", 
         upscale_type="iid", 
         sandwich_solver="dense",
         ).to(device)

    results = []
    for i in range(10):
        model.fit(X, y)
        out = model(X)
        out_test = model(X_test)
        rmse = torch.sqrt(F.mse_loss(out, y))
        rmse_test = torch.sqrt(F.mse_loss(out_test, y_test))
        results.append(torch.tensor([rmse, rmse_test]))
    results = torch.stack(results)
    print("train rmse", results[:, 0].mean(), "std", results[:, 0].std())
    print("test rmse", results[:, 1].mean(), "std", results[:, 1].std())
    print("train", results[:, 0])
    print("test", results[:, 1])

# Profile the training and evaluation
profiler = cProfile.Profile()
profiler.enable()
train_and_evaluate()
profiler.disable()

# Print profiling results
stats = pstats.Stats(profiler).sort_stats('tottime')
stats.print_stats()

Exception ignored When destroying _lsprof profiler:
Traceback (most recent call last):
  File "/tmp/ipykernel_15303/1612922258.py", line 52, in <module>
RuntimeError: Cannot install a profile function while another profile function is being installed


train rmse tensor(1.06544589996337890625) std tensor(0.01993291825056076050)
test rmse tensor(1.49411332607269287109) std tensor(0.02958944998681545258)
train tensor([1.05934667587280273438, 1.04710435867309570312, 1.04993116855621337891,
        1.11542403697967529297, 1.07662832736968994141, 1.06558179855346679688,
        1.05304908752441406250, 1.05845415592193603516, 1.05627787113189697266,
        1.07266211509704589844])
test tensor([1.48107147216796875000, 1.44302821159362792969, 1.49441707134246826172,
        1.54382896423339843750, 1.48942172527313232422, 1.54012680053710937500,
        1.50104653835296630859, 1.48752248287200927734, 1.48332262039184570312,
        1.47734713554382324219])
         65050 function calls (57802 primitive calls) in 19.827 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      100    4.470    0.045    8.983    0.090 /home/nikita/Code/random-feature-boosting/models/sandwiched_least_squ

<pstats.Stats at 0x7f9ffe6ead40>

In [3]:
import cProfile
import pstats
from models.random_feature_representation_boosting import GradientRFRBoostRegressor
import torch
import torch.nn.functional as F

# Make regression data X, y
device = "cpu"
N = 5000
N_test = 5000
D = 10
d = 1
X = torch.randn(N, D).to(device)
X_test = torch.randn(N_test, D).to(device)
w_true = torch.randn(D, d).to(device)
y = (X @ w_true)**2 + torch.randn(N, d).to(device) * 0.1  # Adding some noise
y_test = (X_test @ w_true)**2 + torch.randn(N_test, d).to(device) * 0.1  # Adding some noise


def train_and_evaluate():
    #dense      
    model = GradientRFRBoostRegressor(
        in_dim = D,
        out_dim = d,
         hidden_dim=128, 
         randfeat_x0_dim=512, 
         randfeat_xt_dim=512,
         n_layers=10,
         l2_reg = 0.01,
         l2_ghat=0.001, 
         feature_type="SWIM", 
         upscale_type="iid", 
         ).to(device)

    results = []
    for i in range(10):
        model.fit(X, y)
        out = model(X)
        out_test = model(X_test)
        rmse = torch.sqrt(F.mse_loss(out, y))
        rmse_test = torch.sqrt(F.mse_loss(out_test, y_test))
        results.append(torch.tensor([rmse, rmse_test]))
    results = torch.stack(results)
    print("train rmse", results[:, 0].mean(), "std", results[:, 0].std())
    print("test rmse", results[:, 1].mean(), "std", results[:, 1].std())
    print("train", results[:, 0])
    print("test", results[:, 1])

# Profile the training and evaluation
profiler = cProfile.Profile()
profiler.enable()
train_and_evaluate()
profiler.disable()

# Print profiling results
stats = pstats.Stats(profiler).sort_stats('tottime')
stats.print_stats()

train rmse tensor(1.21921765804290771484) std tensor(0.03444970399141311646)
test rmse tensor(1.46565508842468261719) std tensor(0.04728671163320541382)
train tensor([1.17456102371215820312, 1.22131717205047607422, 1.22359132766723632812,
        1.25396025180816650391, 1.23235726356506347656, 1.27447116374969482422,
        1.19557857513427734375, 1.18070495128631591797, 1.18476688861846923828,
        1.25086808204650878906])
test tensor([1.45425438880920410156, 1.45770776271820068359, 1.42601478099822998047,
        1.41118264198303222656, 1.51764202117919921875, 1.56326651573181152344,
        1.47702682018280029297, 1.43248617649078369141, 1.42823863029479980469,
        1.48872983455657958984])
         69228 function calls (61354 primitive calls) in 17.473 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      210    5.730    0.027    6.909    0.033 /home/nikita/Code/random-feature-boosting/models/base.py:114(fit)
   

<pstats.Stats at 0x7f7b5ccbaec0>