In [1]:
import numpy as np

from sklearn.metrics import r2_score, mean_squared_error
from sklearn.preprocessing import StandardScaler
import torch

from nldg.new.archive.utils import gen_data_v3, max_mse
from nldg.new.archive.train_nn import train_model, train_model_GDRO, predict_GDRO

from scipy.optimize import minimize


In [21]:
dtr, dts = gen_data_v3(n_train=1000, n_test=500, train_setting=2, test_setting=1, random_state=42)
Xtr, Xts = np.array(dtr.drop(columns=['E', 'Y'])), np.array(dts.drop(columns=['E', 'Y']))
Ytr, Yts = np.array(dtr['Y']), np.array(dts['Y'])
Etr = np.array(dtr['E'])

In [22]:
scaler = StandardScaler()
X_train = scaler.fit_transform(Xtr)
X_test = scaler.transform(Xts)
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)

### Maximin

In [23]:
model = train_model(X_train, Ytr, Etr, verbose=False)
model.eval()
with torch.no_grad():
    mpreds = model(X_test_tensor).numpy()
    mfitted = model(X_train_tensor).numpy()

In [24]:
mean_squared_error(Ytr, mfitted), r2_score(Ytr, mfitted)

(33.76689425277214, 0.33103179360802615)

In [25]:
mean_squared_error(Yts, mpreds), r2_score(Yts, mpreds)

(11.338295102360032, 0.07453013900632988)

In [26]:
max_mse(Ytr, mfitted, Etr)

np.float64(82.89055054669906)

### Default

In [27]:
model = train_model(X_train, Ytr, Etr, verbose=False, default=True)
model.eval()
with torch.no_grad():
    preds = model(X_test_tensor).numpy()
    fitted = model(X_train_tensor).numpy()

In [28]:
mean_squared_error(Ytr, fitted), r2_score(Ytr, fitted)

(29.58461025066445, 0.41388854100587535)

In [29]:
mean_squared_error(Yts, preds), r2_score(Yts, preds)

(9.615159876381016, 0.21517824382838224)

In [30]:
max_mse(Ytr, fitted, Etr)

np.float64(90.02221080560959)

### Group DRO

In [31]:
model, bweights = train_model_GDRO(X_train, Ytr, Etr, lr_model=0.01)
preds_gdro = predict_GDRO(model, X_test)
fitted_gdro = predict_GDRO(model, X_train)

In [32]:
mean_squared_error(Ytr, fitted_gdro), r2_score(Ytr, fitted_gdro)

(39.764844164019216, 0.21220422942283468)

In [33]:
mean_squared_error(Yts, preds_gdro), r2_score(Yts, preds_gdro)

(6.087953233822545, 0.5030807381376967)

In [34]:
max_mse(Ytr, fitted_gdro, Etr)

np.float64(60.78069394108514)

### Magging

In [35]:
def objective(w: np.ndarray, F: np.ndarray) -> float:
    return np.dot(w.T, np.dot(F.T, F).dot(w))


n_envs = len(np.unique(Etr))
winit = np.array([1 / n_envs] * n_envs)
constraints = {"type": "eq", "fun": lambda w: np.sum(w) - 1}
bounds = [[0, 1] for _ in range(n_envs)]

preds_envs = []
fitted_envs = []
for env in np.unique(Etr):
    Xtr_e = X_train[Etr == env]
    Ytr_e = Ytr[Etr == env]
    model = train_model(Xtr_e, Ytr_e, Etr[Etr == env], verbose=False, default=True)
    with torch.no_grad():
        preds_envs.append(model(X_test_tensor).numpy())
        fitted_envs.append(model(X_train_tensor).numpy())
preds_envs = np.column_stack(preds_envs)
fitted_envs = np.column_stack(fitted_envs)

wmag = minimize(objective, winit, args=(fitted_envs,), bounds=bounds, constraints=constraints, ).x
wpreds = np.dot(wmag, preds_envs.T)
wfitted = np.dot(wmag, fitted_envs.T)

In [36]:
mean_squared_error(Ytr, wfitted), r2_score(Ytr, wfitted)

(41.71906049748486, 0.17348853985677137)

In [37]:
mean_squared_error(Yts, wpreds), r2_score(Yts, wpreds)

(1.1304455392662764, 0.9077292250165624)

In [38]:
max_mse(Ytr, wfitted, Etr)

np.float64(74.87545876102443)

In [39]:
wmag

array([0.47740634, 0.52105749, 0.00153617])