# IRM on churn data

The news popularity dataset seems hard to model. Let's try our old favourite churn dataset.

For IRM to work, we need to be able to define several _environments_ on which to train. An environment is the result of an intervention - something that changed the data generating process. The environments need to be sufficiently different (and sufficiently similar). Then IRM will return to us an invariant representation - one that has learned the correlations that hold true across environments, but ignored spurious correlations specific to an environment.

For the churn dataset, it's not clear what an environment could be. Let's construct a plausible business story: our telco company has lots of data for single people with no dependents, (it has previously marketed to that demographic), but is launching a family oriented brand. It needs to do churn modeling on the family brand, but has little data. As such, we'll construct four environments from the features "Partner" and "Dependents", reserving the case where both Partner and Dependents are true as the test environment.

In [1]:
import pandas as pd

import torch

from torch import nn, optim, autograd
from torch.nn import functional as F

In [2]:
df = pd.read_csv('../../data/churn.csv').drop(['customerID','TotalCharges'], axis='columns')

Convert to a machine learnable dataset. It'll give us some weird column names, but this is exploration, we'll deal with it.

In [3]:
df_ = pd.get_dummies(df, drop_first=True)

Since during IRM we'll be training on several environments, we wrap each in a dict for easy management.

In [4]:
def construct_env(df):
    return {
        'features': torch.Tensor(
            df.drop(['Churn_Yes', 'Partner_Yes', 'Dependents_Yes'],
                    axis='columns').to_numpy()),
        'target': torch.Tensor(df['Churn_Yes'].to_numpy()).unsqueeze(dim=1)
    }

Define our neural net architecture. We're starting with a straightforward MLP with ReLU nonlinearities and a sigmoid output, since it's a classification problem.

In [5]:
class NN(nn.Module):
    def __init__(self, n_features, hidden_dim):
        super(NN, self).__init__()
        self.layer1 = nn.Linear(n_features, hidden_dim)
        self.layer2 = nn.Linear(hidden_dim, hidden_dim)
        self.layer3 = nn.Linear(hidden_dim, 1)
        
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = torch.sigmoid(self.layer3(x))
        return x

Define a bunch of utility functions for calculating errors and such to report during training.

In [6]:
def error(logits, target):
    loss = nn.functional.binary_cross_entropy(logits, target)
    return loss

In [7]:
def penalty(logits, target):
    dummy = torch.tensor(1., requires_grad=True)
    loss = error(logits*dummy, target)
    grad = autograd.grad(loss, [dummy], create_graph=True)[0]
    squared_grad_norm = (grad**2).sum()
    return squared_grad_norm

In [8]:
def accuracy(predictions, target):
    n_preds = torch.tensor(len(predictions)).float()
    acc = ((predictions == target).sum() / n_preds)
    return acc

In [9]:
def precision(predictions, target):
    n_preds = torch.tensor(len(predictions)).float()
    tp = ((predictions == 1) & (target == 1)).sum().float()
    fp = ((predictions == 1) & (target == 0)).sum().float()
    prec = tp / (tp + fp)
    return prec

In [10]:
def recall(predictions, target):
    n_preds = torch.tensor(len(predictions)).float()
    tp = ((predictions == 1) & (target == 1)).sum().float()
    fn = ((predictions == 0) & (target == 1)).sum().float()
    rec = tp / (tp + fn)
    return rec

Construct environments. We hold out a final test set (the naming of validation and test is arguably the wrong way around here) of customers with dependents and partners. We train on two environments, both are customers without partners, and the two envs are defined by whether they have dependents or not. The validation set is the remaining combination (with partner, without dependents). We can use this to choose an early stopping time.

To start a new training procedure, we need to run all the code below here, since the environments are mutable dictionaries that pick up entries during training.

In [11]:
env_test = construct_env(df_[(df_.Partner_Yes == 1) & (df_.Dependents_Yes == 1)])
env_valid = construct_env(df_[(df_.Partner_Yes == 1) & (df_.Dependents_Yes == 0)])
env_1 = construct_env(df_[(df_.Partner_Yes == 0) & (df_.Dependents_Yes == 1)])
env_2 = construct_env(df_[(df_.Partner_Yes == 0) & (df_.Dependents_Yes == 0)])

In [12]:
N_FEATURES = env_1['features'].shape[1]
HIDDEN_DIM = 10

In [13]:
net = NN(N_FEATURES, HIDDEN_DIM)

In [14]:
opt = optim.Adam(net.parameters(), lr=1e-3)

In [15]:
for iteration in range(20001):
    for env in [env_1, env_2]:
        logits = net(env['features'])
        env['error'] = error(logits, env['target'])
        env['penalty'] = penalty(logits, env['target'])
    
    train_error = torch.stack([env_1['error'], env_2['error']]).mean()
    train_penalty = torch.stack([env_1['penalty'], env_2['penalty']]).mean()
    
    # deactivate IRM to begin
    total_loss = train_error #(train_error + 1e6 * train_penalty) / 1e6
        
    opt.zero_grad()
    total_loss.backward()
    opt.step()
    
    valid_preds = net(env_valid['features']) > 0.5
    env_1_preds = net(env_1['features']) > 0.5
    env_2_preds = net(env_2['features']) > 0.5
    test_preds = net(env_test['features']) > 0.5
    
    # ## train environment metrics
    env_1['accuracy'] = accuracy(env_1_preds, env_1['target'])
    env_1['precision'] = precision(env_1_preds, env_1['target'])
    env_1['recall'] = recall(env_1_preds, env_1['target'])
    
    env_2['accuracy'] = accuracy(env_2_preds, env_2['target'])
    env_2['precision'] = precision(env_2_preds, env_2['target'])
    env_2['recall'] = recall(env_2_preds, env_2['target'])
    
    # ## validation set metrics
    env_valid['accuracy'] = accuracy(valid_preds, env_valid['target'])
    env_valid['precision'] = precision(valid_preds, env_valid['target'])
    env_valid['recall'] = recall(valid_preds, env_valid['target'])
    
    if iteration % 1000 == 0:
        print('---')
        print('iteration: {}, train_loss: {:.8f}'.format(iteration, total_loss))
        print('env_1 accuracy: {:.3f}, precision: {:.3f}, and recall: {:.3f}'''.format(
            env_1['accuracy'], env_1['precision'], env_1['recall']
        ))
        print('env_2 accuracy: {:.3f}, precision: {:.3f}, and recall: {:.3f}'''.format(
            env_2['accuracy'], env_2['precision'], env_2['recall']
        ))
        print('validation accuracy: {:.3f}, precision: {:.3f}, and recall: {:.3f}'''.format(
            env_valid['accuracy'], env_valid['precision'], env_valid['recall']
        ))

---
iteration: 0, train_loss: 0.64038801
env_1 accuracy: 0.729, precision: 0.182, and recall: 0.078
env_2 accuracy: 0.628, precision: 0.303, and recall: 0.066
validation accuracy: 0.742, precision: 0.364, and recall: 0.019
---
iteration: 1000, train_loss: 0.42862311
env_1 accuracy: 0.834, precision: 0.660, and recall: 0.455
env_2 accuracy: 0.753, precision: 0.655, and recall: 0.589
validation accuracy: 0.806, precision: 0.653, and recall: 0.502
---
iteration: 2000, train_loss: 0.41983685
env_1 accuracy: 0.848, precision: 0.704, and recall: 0.494
env_2 accuracy: 0.757, precision: 0.657, and recall: 0.608
validation accuracy: 0.804, precision: 0.651, and recall: 0.493
---
iteration: 3000, train_loss: 0.41280931
env_1 accuracy: 0.859, precision: 0.717, and recall: 0.558
env_2 accuracy: 0.755, precision: 0.654, and recall: 0.602
validation accuracy: 0.805, precision: 0.654, and recall: 0.490
---
iteration: 4000, train_loss: 0.40702993
env_1 accuracy: 0.861, precision: 0.729, and recall: 0.

In [17]:
# baseline accuracy (majority class predictor)
(1-env_valid['target']).sum() / len(env_valid['target'])

tensor(0.7459)