In [1]:
# Case study 
# python version == 3.10.14
# torch.__version__ == 2.2.2
# numpy.__version__ == 1.26.4

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import pickle
import os
import argparse
import torch
from torch import distributions as td

In [2]:
def example1(e, m1, m2, N=10000): 
    #independent causes 
    beta = np.array([3.0, 2, 0])
    x2_mean = m2
    x2_e = torch.normal(x2_mean, e,[N,1])   
    x1_mean = m1
    x1_e = torch.normal(x1_mean, e,[N,1])
    y_e = beta[0] * x1_e + beta[1] * x2_e + torch.randn(N,1)
    z = e*y_e +  torch.normal(0, e, [N,1])
    return (torch.cat((x1_e, x2_e, z), 1), y_e, beta)


def IRMv1(environments, args, lmbd, init_random = False):
    estimate_r = []
    error_r = []
    penalty_r = []
    if init_random:
        phi = torch.nn.Parameter(torch.normal(1,0.2,[environments[0][0].shape[1],1]))
    else:
        phi = torch.nn.Parameter(torch.Tensor([[3.0, 2.0, 0.0]]).T)
    dummy_w = torch.nn.Parameter(torch.Tensor([1.0])) 
    opt1 = torch.optim.Adam([phi], lr=args.lrs) 
    phi_old = 0
    for iteration in range(args.max_iter):
        error = 0
        penalty = 0
        for i in range(len(environments)):
            x_e, y_e, beta = environments[i]
            error_e = 0.5*mse(x_e @ phi * dummy_w, y_e).mean()   
            error += error_e
            
            phi_grad_out = torch.autograd.grad(error_e, dummy_w, create_graph=True)
            penalty += torch.square(phi_grad_out[0]) 
  
        opt1.zero_grad()
        total_loss =  (error + lmbd * penalty)
        total_loss.backward()     
        opt1.step()
        
        estimate = phi.view(-1).detach().numpy()
        estimate_r.append(estimate)
        error_r.append(error.item())
        penalty_r.append(penalty.item())
        
        if iteration % 2000 == 0: 
            phi_new = np.mean(estimate_r[-100:],axis=0)
            print(phi_new)
            if ((np.sum(np.abs(phi_new - phi_old))<0.001) & (iteration>=10000)):
                break
            else:
                phi_old = phi_new
    
    return [np.mean(estimate_r[-100:],axis=0), np.mean(error_r[-100:]), np.mean(penalty_r[-100:])]



def IRM_test(environments, coef):
    phi = torch.nn.Parameter(torch.Tensor([coef])).T
    dummy_w = torch.nn.Parameter(torch.Tensor([1.0])) 
    error = 0
    penalty = 0
    for i in range(len(environments)):
        x_e, y_e, beta = environments[i]
        error_e = 0.5*mse(x_e @ phi * dummy_w, y_e).mean()   
        error += error_e
        
        phi_grad_out = torch.autograd.grad(error_e, dummy_w, create_graph=True)
        pe1 = torch.square(phi_grad_out[0]) 
        penalty += pe1
        
        # phi_grad_out = torch.autograd.grad(error_e, phi,create_graph=True)
        # pe2 = torch.square(torch.sum(phi_grad_out[0]*phi)) 
        # penalty += pe2
     
    return [error.item(), penalty.item()]


def CoCo(environments, args):
    estimate_r = []
    phi = torch.nn.Parameter(torch.normal(1,0.2,[environments[0][0].shape[1],1]))
    opt1 = torch.optim.Adam([phi], lr=args.lrs)  
    scheduler = torch.optim.lr_scheduler.StepLR(opt1, step_size=2000, gamma=0.8)

    phi_old = 0
    for iteration in range(args.max_iter):
        error = 0
        penalty = 0
        for i in range(len(environments)):
            x_e, y_e, beta = environments[i]           
            error_e = 0.5*mse(x_e @ phi, y_e).mean()  
            error += error_e

            phi_grad_out = torch.autograd.grad(error_e, phi,create_graph=True)
            penalty += torch.square(phi_grad_out[0][0]) + \
                torch.sum(torch.square(phi_grad_out[0][1:]*phi[1:])) 
     
        opt1.zero_grad()
        total_loss =  torch.sqrt(penalty)
        total_loss.backward()     
        opt1.step()
        scheduler.step()
        
        estimate = phi.view(-1).detach().numpy()
        estimate_r.append(estimate)
        if iteration % 2000 == 0: 
            phi_new = np.mean(estimate_r[-100:],axis=0)
            print(phi_new)
            if ((np.sum(np.abs(phi_new - phi_old))<0.001) & (iteration>=10000)):
                break
            else:
                phi_old = phi_new  
                          
    return np.mean(estimate_r[-100:],axis=0)

def ERM(environments, args):
    estimate_r = []
    phi = torch.nn.Parameter(torch.normal(1,0.2,[environments[0][0].shape[1],1]))
    opt1 = torch.optim.SGD([phi], lr=0.002) 
    phi_old = 0
    for iteration in range(args.max_iter):
        error = 0
        for i in range(len(environments)):
            x_e, y_e, beta = environments[i]              
            error_e = 0.5*mse(x_e @ phi , y_e).mean()  
            error += error_e           
        opt1.zero_grad()
        error.backward()     
        opt1.step()   
        
        estimate = phi.view(-1).detach().numpy()
        estimate_r.append(estimate)
        
        if iteration % 2000 == 0:
            phi_new = np.mean(estimate_r[-100:],axis=0)
            print(phi_new)
            if ((np.sum(np.abs(phi_new - phi_old))<0.001) & (iteration>=10000)):
                break
            else:
                phi_old = phi_new                
    return np.mean(estimate_r[-100:],axis=0)


In [5]:
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=2, help='Random seed')
parser.add_argument('--max_iter', type=int, default=100000, help='max iteration.')
parser.add_argument('--N', type=int, default=100000, help='number of data per env.')
parser.add_argument('--path', default='results/', help='The path results to be saved.')
args = parser.parse_args(args=[])


np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

mse = torch.nn.MSELoss(reduction="none")  
environments = [example1(0.5, 3, -1, N=args.N), 
            example1(2, 2, 0.5, N=args.N)]  

args.lrs = 0.3
print('##############')
print('CoCo', CoCo(environments, args))
print('##############')
args.lrs = 0.01
result = IRMv1(environments, args, lmbd=1, init_random=True)
print('IRMv1', result)
result = IRMv1(environments, args, lmbd=1, init_random=False)
print('IRMv1', result)
print('##############')
print('IRM_test', IRM_test(environments, [3,2,0]))
print('##############')
print('ERM', ERM(environments, args))


##############


  from .autonotebook import tqdm as notebook_tqdm


[0.6405374 0.7824335 0.5727433]
[2.9818997  1.9819907  0.01841489]
[2.990405   1.98872    0.01360961]
[2.995853   1.994091   0.01003939]
[2.9989142  1.9980304  0.00747031]
[3.0002408  2.0010564  0.00559988]
[3.0009122  2.0028532  0.00422782]
[3.0018182  2.002843   0.00310581]
[3.0012109e+00 2.0040286e+00 2.1859666e-03]
[3.001315e+00 2.005207e+00 2.194642e-03]
[3.0011315e+00 2.0037997e+00 9.3487598e-04]
[3.0017123e+00 2.0044672e+00 1.4314966e-03]
[3.0010374e+00 2.0038605e+00 5.4875284e-04]
[3.0011694e+00 2.0038195e+00 5.3218758e-04]
CoCo [3.0011694e+00 2.0038195e+00 5.3218758e-04]
##############
[1.0656513 1.088206  0.9634695]
[-0.03507781  1.0945516   0.41847622]
[-0.11566354  1.1363164   0.4240916 ]
[-0.15687944  1.0552808   0.4357452 ]
[-0.2579613   0.85396653  0.46408054]
[-0.4652736   0.42912257  0.5210888 ]
[-0.7320835  -0.14375661  0.592129  ]
[-0.83901703 -0.38267514  0.6197817 ]
[-0.84236217 -0.390255    0.62064666]
[-0.8424558  -0.39035785  0.62056047]
IRMv1 [array([-0.8424558