# Invariant risk minimization

Minimal pytorch implementation from appendix D of the paper [Invariant Risk Minimization](https://arxiv.org/abs/1907.02893)

In [3]:
import torch
from torch.autograd import grad

In [4]:
# split into two random mini-batches of size b (random happens in the iteration)
# an unbiased estimate of the squared gradient norm

def compute_penalty(losses, dummy_w):
    g1 = grad(losses[0::2].mean(), dummy_w, create_graph=True)[0]
    g2 = grad(losses[1::2].mean(), dummy_w, create_graph=True)[0]
    return (g1 * g2).sum()

In [5]:
def example(n=10000, d=2, env=1):
    x = torch.randn(n, d) * env
    y = x + torch.randn(n, d) * env
    z = y + torch.randn(n, d)
    return torch.cat((x, z), 1), y.sum(1, keepdim=True)

In [6]:
phi = torch.nn.Parameter(torch.ones(4, 1))
dummy_w = torch.nn.Parameter(torch.Tensor([1.0])) # a dummy predictor

In [7]:
opt = torch.optim.SGD([phi], lr=1e-3)
mse = torch.nn.MSELoss(reduction="none")

In [8]:
environments = [
    example(env=0.1),
    example(env=1.0)
]

In [9]:
x = torch.randn(10000, 2) * 0.1
y = x + torch.randn(10000, 2) * 0.1
z = y + torch.randn(10000, 2)

In [10]:
torch.cat((x, z), 1)

tensor([[ 0.1351,  0.0705, -0.1392, -0.2783],
        [-0.1587, -0.3021, -1.4708,  1.3177],
        [-0.0370,  0.0737,  1.0947,  1.0362],
        ...,
        [-0.0604,  0.0324,  0.1619, -0.4181],
        [-0.1433,  0.0794, -0.2072,  1.3918],
        [ 0.0451,  0.0160,  0.7566, -0.5391]])

In [11]:
y

tensor([[ 0.0324,  0.2193],
        [-0.1265, -0.2113],
        [ 0.0225,  0.3116],
        ...,
        [-0.0636,  0.1220],
        [-0.1166,  0.1059],
        [ 0.0513, -0.1420]])

In [12]:
phi.size()

torch.Size([4, 1])

In [19]:
losses = mse(torch.cat((x, z), 1) @ phi * dummy_w, y.sum(1, keepdim=True))
losses.size()

torch.Size([10000, 1])

In [15]:
dummy_w.size()

torch.Size([1])

In [20]:
losses[0::2].size()

torch.Size([5000, 1])

In [21]:
grad(losses[0::2].mean(), dummy_w, create_graph=True)[0]

tensor([4.0774], grad_fn=<SumBackward1>)

In [7]:
for iteration in range(50000):
    error = 0
    penalty = 0
    for x_e, y_e in environments:
        # here we permute data so that we get random minibatches when we compute the penalty.
        # this gives us an unbiased estimate for the squared gradient norm.
        p = torch.randperm(len(x_e))
        error_e = mse(x_e[p] @ phi * dummy_w, y_e[p])
        penalty += compute_penalty(error_e, dummy_w)
        error += error_e.mean()
        
    opt.zero_grad()
    (1e-5 * error + penalty).backward()
    opt.step()
    
    if iteration % 1000 == 0:
        print(phi)

Parameter containing:
tensor([[0.8582],
        [0.8600],
        [0.6869],
        [0.6908]], requires_grad=True)
Parameter containing:
tensor([[0.9308],
        [0.9278],
        [0.1568],
        [0.1556]], requires_grad=True)
Parameter containing:
tensor([[0.9609],
        [0.9567],
        [0.1072],
        [0.1056]], requires_grad=True)
Parameter containing:
tensor([[0.9707],
        [0.9659],
        [0.0853],
        [0.0835]], requires_grad=True)
Parameter containing:
tensor([[0.9753],
        [0.9703],
        [0.0725],
        [0.0706]], requires_grad=True)
Parameter containing:
tensor([[0.9780],
        [0.9728],
        [0.0642],
        [0.0619]], requires_grad=True)
Parameter containing:
tensor([[0.9797],
        [0.9743],
        [0.0582],
        [0.0560]], requires_grad=True)
Parameter containing:
tensor([[0.9808],
        [0.9754],
        [0.0534],
        [0.0516]], requires_grad=True)
Parameter containing:
tensor([[0.9817],
        [0.9762],
        [0.0499],
    

In [8]:
phi

Parameter containing:
tensor([[0.9858],
        [0.9799],
        [0.0231],
        [0.0195]], requires_grad=True)