# Invariant risk minimization

Minimal pytorch implementation from appendix D of the paper [Invariant Risk Minimization](https://arxiv.org/abs/1907.02893)

In [1]:
import torch
from torch.autograd import grad

In [2]:
# split into two random mini-batches of size b (random happens in the iteration)
# an unbiased estimate of the squared gradient norm

def compute_penalty(losses, dummy_w):
    g1 = grad(losses[0::2].mean(), dummy_w, create_graph=True)[0]
    g2 = grad(losses[1::2].mean(), dummy_w, create_graph=True)[0]
    return (g1 * g2).sum()

In [3]:
def example(n=10000, d=2, env=1):
    x = torch.randn(n, d) * env
    y = x + torch.randn(n, d) * env
    z = y + torch.randn(n, d)
    return torch.cat((x, z), 1), y.sum(1, keepdim=True)

In [4]:
phi = torch.nn.Parameter(torch.ones(4, 1))
dummy_w = torch.nn.Parameter(torch.Tensor([1.0])) # a dummy predictor

In [5]:
opt = torch.optim.SGD([phi], lr=1e-3)
mse = torch.nn.MSELoss(reduction="none")

In [6]:
environments = [
    example(env=0.1),
    example(env=1.0)
]

In [7]:
for iteration in range(50000):
    error = 0
    penalty = 0
    for x_e, y_e in environments:
        # here we permute data so that we get random minibatches when we compute the penalty.
        # this gives us an unbiased estimate for the squared gradient norm.
        p = torch.randperm(len(x_e))
        error_e = mse(x_e[p] @ phi * dummy_w, y_e[p])
        penalty += compute_penalty(error_e, dummy_w)
        error += error_e.mean()
        
    opt.zero_grad()
    (1e-5 * error + penalty).backward()
    opt.step()
    
    if iteration % 1000 == 0:
        print(phi)

Parameter containing:
tensor([[0.8598],
        [0.8562],
        [0.6880],
        [0.6795]], requires_grad=True)
Parameter containing:
tensor([[0.9306],
        [0.9261],
        [0.1654],
        [0.1612]], requires_grad=True)
Parameter containing:
tensor([[0.9655],
        [0.9609],
        [0.1134],
        [0.1104]], requires_grad=True)
Parameter containing:
tensor([[0.9778],
        [0.9729],
        [0.0900],
        [0.0871]], requires_grad=True)
Parameter containing:
tensor([[0.9836],
        [0.9787],
        [0.0764],
        [0.0737]], requires_grad=True)
Parameter containing:
tensor([[0.9869],
        [0.9823],
        [0.0672],
        [0.0648]], requires_grad=True)
Parameter containing:
tensor([[0.9891],
        [0.9845],
        [0.0609],
        [0.0586]], requires_grad=True)
Parameter containing:
tensor([[0.9906],
        [0.9860],
        [0.0558],
        [0.0540]], requires_grad=True)
Parameter containing:
tensor([[0.9917],
        [0.9873],
        [0.0517],
    

In [8]:
phi

Parameter containing:
tensor([[0.9982],
        [0.9936],
        [0.0207],
        [0.0214]], requires_grad=True)