In [19]:
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, mean_absolute_error
import torch
from torch.autograd import grad

import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, mean_absolute_error
import torch
from torch.autograd import grad

In [20]:
def generate_sem_data(n, mu_x, sigma_x, sigma_y, sigma_x2, rho=0):
    X1 = np.random.normal(mu_x, sigma_x, n)
    # If rho is not 0, we introduce correlation between X1 and noise in Y
    noise_y = np.random.normal(0, sigma_y * np.sqrt(1 - rho**2), n) + rho * X1
    Y = X1 + noise_y
    X2 = Y + np.random.normal(0, sigma_x2, n)
    return np.column_stack((Y, X1, X2))

class InvariantRiskMinimization(object):
    def __init__(self, environments, args):
        self.best_reg = 0
        self.best_err = float('inf')

        # Assumes the last environment is the validation set
        x_val, y_val = environments[-1]

        for reg in [0, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1]:
            self.train(environments[:-1], args, reg=reg)
            err = torch.mean((x_val @ self.solution() - y_val) ** 2).item()

            if args["verbose"]:
                print(f"IRM (reg={reg:.3f}) has {err:.3f} validation error.")

            if err < self.best_err:
                self.best_err = err
                self.best_reg = reg
                self.best_phi = self.phi.clone()

        self.phi = self.best_phi

    def train(self, environments, args, reg=0):
        dim_x = environments[0][0].shape[1]

        self.phi = torch.nn.Parameter(torch.eye(dim_x, dim_x))
        self.w = torch.ones((dim_x, 1), requires_grad=True)

        opt = torch.optim.Adam([self.phi], lr=args["lr"])
        mse_loss = torch.nn.MSELoss()

        for iteration in range(args["n_iterations"]):
            penalty = 0
            error = 0
            for x_e, y_e in environments:
                preds = x_e @ self.phi @ self.w
                error_e = mse_loss(preds, y_e)
                penalty += grad(error_e, self.w, create_graph=True)[0].pow(2).mean()
                error += error_e.item()

            opt.zero_grad()
            loss = reg * error + (1 - reg) * penalty
            loss.backward()
            opt.step()

            if args["verbose"] and iteration % 1000 == 0:
                w_str = ' '.join(f'{w:.2f}' for w in self.solution().view(-1))
                print(f"{iteration:05d} | {reg:.5f} | {error:.5f} | {penalty:.5f} | {w_str}")

    def solution(self):
        return self.phi @ self.w

def to_tensor(data):
    return torch.tensor(data, dtype=torch.float32)

In [21]:
# Training dataset with variances as described in training environments
n_train = 1000  # Number of samples in training

train_env1 = generate_sem_data(n_train, 0, np.sqrt(10), np.sqrt(10), 1)
train_env2 = generate_sem_data(n_train, 0, np.sqrt(20), np.sqrt(20), 1)
train_data = np.vstack((train_env1, train_env2))

train_env3 = generate_sem_data(n_train, 0, np.sqrt(10), np.sqrt(10), 1)

# Test dataset with variance shift
test_var_shift = generate_sem_data(n_train, 0, np.sqrt(50), np.sqrt(50), 1)
test_data_var_shift = np.vstack(test_var_shift)

# Test dataset with mean shift
test_mean_shift = generate_sem_data(n_train, 15, np.sqrt(20), np.sqrt(20), 1)
test_data_mean_shift = np.vstack(test_mean_shift)

# Test dataset with correlation shift
test_corr_shift = generate_sem_data(n_train, 0, np.sqrt(20), np.sqrt(20), 1, rho=0.7)
test_data_corr_shift = np.vstack(test_corr_shift)

# Combine into dataframes for easier handling
columns = [ 'X1', 'X2', 'Y']
train_df = pd.DataFrame(train_data, columns=columns)
train_df2 = pd.DataFrame(train_env3, columns=columns)
test_var_shift_df = pd.DataFrame(test_data_var_shift, columns=columns)
test_mean_shift_df = pd.DataFrame(test_data_mean_shift, columns=columns)
test_corr_shift_df = pd.DataFrame(test_data_corr_shift, columns=columns)

In [25]:
# Initialize the Linear Regression model
model = LinearRegression()

# Fit the model on the training dataset
model.fit(train_df[['X1', 'X2']], train_df['Y'])

# Create a dictionary to store the MSE and MAE for each test set
metrics = {}

# Predict and evaluate on a similar training data
predictions_standard = model.predict(train_df2[['X1', 'X2']])
metrics['standard'] = {
    'MSE': mean_squared_error(train_df2['Y'], predictions_standard),
    'MAE': mean_absolute_error(train_df2['Y'], predictions_standard)
}

# Predict and evaluate on the test dataset with variance shift
predictions_var_shift = model.predict(test_var_shift_df[['X1', 'X2']])
metrics['variance_shift'] = {
    'MSE': mean_squared_error(test_var_shift_df['Y'], predictions_var_shift),
    'MAE': mean_absolute_error(test_var_shift_df['Y'], predictions_var_shift)
}

# Predict and evaluate on the test dataset with mean shift
predictions_mean_shift = model.predict(test_mean_shift_df[['X1', 'X2']])
metrics['mean_shift'] = {
    'MSE': mean_squared_error(test_mean_shift_df['Y'], predictions_mean_shift),
    'MAE': mean_absolute_error(test_mean_shift_df['Y'], predictions_mean_shift)
}

# Predict and evaluate on the test dataset with correlation shift
predictions_corr_shift = model.predict(test_corr_shift_df[['X1', 'X2']])
metrics['correlation_shift'] = {
    'MSE': mean_squared_error(test_corr_shift_df['Y'], predictions_corr_shift),
    'MAE': mean_absolute_error(test_corr_shift_df['Y'], predictions_corr_shift)
}

In [23]:
# Convert the datasets to PyTorch tensors
train_tensors1 = [[train_env1[:, 0:2], train_env1[:, 2:]], [train_env2[:, 0:2], train_env2[:, 2:]]]
train_tensors2 = [train_env3[:, 0:2], train_env3[:, 2:]]
test_var_shift_tensors = [test_var_shift[:, 0:2], test_var_shift[:, 2:]]
test_mean_shift_tensors = [test_mean_shift[:, 0:2], test_mean_shift[:, 2:]]
test_corr_shift_tensors = [test_corr_shift[:, 0:2], test_corr_shift[:, 2:]]

train_tensors1 = [(to_tensor(env[0]), to_tensor(env[1])) for env in train_tensors1]
train_tensors2 = [to_tensor(tensor_obj) for tensor_obj in train_tensors2]
test_var_shift_tensors = [to_tensor(tensor_obj) for tensor_obj in test_var_shift_tensors]
test_mean_shift_tensors = [to_tensor(tensor_obj) for tensor_obj in test_mean_shift_tensors]
test_corr_shift_tensors = [to_tensor(tensor_obj) for tensor_obj in test_corr_shift_tensors]


# Define arguments for training
args = {
    "lr": 1e-3,
    "n_iterations": 500000,
    "verbose": False
}

# Initialize the IRM model and train on the training environments
irm_model = InvariantRiskMinimization(train_tensors1, args)

# Evaluate on the test datasets
names = ["standard", 'variance_shift', 'mean_shift', 'correlation_shift']
test_datasets = [train_tensors2, test_var_shift_tensors, test_mean_shift_tensors, test_corr_shift_tensors]
metrics2 = {}
for i, (x_test, y_test) in enumerate(test_datasets):
    y_pred = x_test @ irm_model.solution()
    mse =  mean_squared_error(y_test, y_pred.detach())
    mae = mean_absolute_error(y_test, y_pred.detach())

    print(f"Test {names[i]}, MSE: {mse.item()}, MAE: {mae.item()}")

    metrics2[names[i]] = {
        'MSE': mse.item(),
        'MAE': mae.item()
    }

Test standard, MSE: 0.9973065257072449, MAE: 0.7931180596351624
Test variance_shift, MSE: 0.9421562552452087, MAE: 0.776894748210907
Test mean_shift, MSE: 1.0360878705978394, MAE: 0.8153470754623413
Test correlation_shift, MSE: 0.9364076852798462, MAE: 0.7630943655967712
