<a href="https://colab.research.google.com/gist/lampinen-dm/b6541019ef4cf2988669ab44aa82460b/easy_vs_hard_feature_bias_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Copyright 2024 Google LLC.

SPDX-License-Identifier: Apache-2.0

# Representations are biased by feature complexity

This colab is intended to provide a basic demonstration of the feature complexity bias we describe in "Learned feature representations are biased by complexity, learning order, position, and more" (https://arxiv.org/abs/2405.05847). It contains a simple implementation that captures most key features from the first binary feature + MLP experiments. Note that this colab does not attempt to reproduce those experiments exactly (e.g., this colab uses the Adam optimizer for quick experimentation, while the first experiments in the paper used SGD + early stopping), but rather it is intended to provide a simple illustrative example and starting point for further investigation.

Some notes:

*   The Sum(W,X,Y,Z) % 2 function we describe in the paper is equivalent to the xor_xor_xor = XOR(XOR(W,X), XOR(Y,Z)) function used here.
*   To plot the representation variance curves comparably to Fig. 2 in the paper, multiply the R^2 for each feature at step *t* by the ratio total_variance_*t* / total_variance_*final* to normalize by the final variance explained.





In [None]:
import flax.linen as nn
from flax.training import train_state
import jax, jax.numpy as jnp
import numpy as np
import optax
from sklearn import linear_model

In [None]:
NUM_SEEDS = 5
FEATURES = ('linear', 'xor_xor_xor')  # features can be 'linear' or combine exactly three logical ops from ('and', 'or', 'xor'), e.g. 'and_and_xor' or 'xor_or_or'
TRAIN_DATASET_SIZE = 8192             # in which case the order is top_left_right for the function top(left(A,B), right(C,D)) where A-D are boolean inputs
TEST_DATASET_SIZE = 1024
BATCH_SIZE = 1024

# datasets

In [None]:
OP_FNS = {'and': np.logical_and, 'or': np.logical_or, 'xor': np.logical_xor}

def make_easy_hard_multi_feature_dataset(features=('linear', 'xor_xor_xor'), input_units_per=16, num_examples=128, seed=123):
  np.random.seed(seed)
  inputs = np.random.binomial(1, 0.5, (num_examples, len(features) * input_units_per))
  outputs = []
  for feature_i, feature_type in enumerate(features):
    these_inputs = inputs[:, feature_i * input_units_per:(feature_i + 1) * input_units_per,]
    if feature_type == 'linear':
      these_inputs[:, :4] = these_inputs[:, :1] # copy easy feature across 4 input units to match number of relevant inputs
      these_outputs = these_inputs[:, :1]
    else:
      these_outputs = np.zeros_like(these_inputs[:, :1])
      def label_fn(inputs):
        top_op, left_op, right_op = [OP_FNS[op] for op in feature_type.split('_')]
        return top_op(left_op(inputs[0], inputs[1]), right_op(inputs[2], inputs[3]))
      for ex_number, ex_index in enumerate(np.random.permutation(num_examples)):
        while label_fn(these_inputs[ex_index]) != ex_number % 2:  # rejection sample
          these_inputs[ex_index] = np.random.binomial(1, 0.5, input_units_per)
        these_outputs[ex_index] = label_fn(these_inputs[ex_index])
    outputs.append(these_outputs)
  return {'inputs': inputs, 'labels': np.concatenate(outputs, axis=-1)}

# network

In [None]:
class MLP(nn.Module):
  layer_sizes: list[int]

  @nn.compact
  def __call__(self, x):
    for layer_size in self.layer_sizes[:-1]:
      x = nn.leaky_relu(nn.Dense(layer_size, kernel_init=nn.initializers.variance_scaling(scale=1., mode='fan_avg', distribution='truncated_normal'))(x))
    output = nn.Dense(self.layer_sizes[-1], kernel_init=nn.initializers.variance_scaling(scale=1., mode='fan_avg', distribution='truncated_normal'))(x)
    return output, x

In [None]:
@jax.jit
def apply_model(state, inputs, labels):
  def loss_fn(params):
    logits, penultimate_reps = state.apply_fn({'params': params}, inputs)
    loss_array = np.mean(optax.sigmoid_binary_cross_entropy(logits=logits, labels=labels), axis=0)
    return jnp.mean(loss_array), (logits, penultimate_reps, loss_array)
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, (logits, penultimate_reps, feature_losses)), grads = grad_fn(state.params)
  accuracies = jnp.mean((logits > 0) * 1. == labels, axis=0)
  return grads, loss, accuracies, penultimate_reps, feature_losses


@jax.jit
def update_model(state, grads):
  return state.apply_gradients(grads=grads)


def train_epoch(state, train_ds):
  this_data_order = np.random.permutation(train_ds['inputs'].shape[0])
  train_ds = {k: v[this_data_order, :] for k, v in train_ds.items()}
  for batch_index in range(0, TRAIN_DATASET_SIZE, BATCH_SIZE):
    grads, *_ = apply_model(state, train_ds['inputs'][batch_index:batch_index + BATCH_SIZE], train_ds['labels'][batch_index:batch_index + BATCH_SIZE])
    state = update_model(state, grads)
  return state


def create_train_state(rng, num_features, num_input_units_per_feature=16):
  network = MLP(layer_sizes=[256, 128, 64, 64, num_features])
  params = network.init(rng, jnp.ones([1, num_input_units_per_feature * num_features]))['params']
  return train_state.TrainState.create(apply_fn=network.apply, params=params, tx=optax.adam(1e-3))

# analyze

In [None]:
def analyze_rep_var_explained(fit_reps, fit_labels, test_reps, test_labels):
  scores = []
  total_variance = np.sum(np.var(test_reps, axis=0))   # useful to normalize
  for feat in range(fit_labels.shape[-1]):
    regr = linear_model.LinearRegression()
    regr.fit(fit_labels[:, feat:feat+1], fit_reps)
    scores.append(regr.score(test_labels[:, feat:feat+1], test_reps))
  return scores, total_variance

# run

In [None]:
features_str = '(' + ', '.join([str(x) for x in FEATURES]) + ')'
with open('./results.csv', 'w') as outfile:
  output_labels = ['seed', 'epoch', 'features','train-loss', 'test-loss', 'total-variance'] + ['%s_feature%i-%s' % (stat, i, str(f)) for stat in ('train-acc', 'test-acc', 'train-loss', 'test-loss', 'rep-R2') for i,f in enumerate(FEATURES)]
  outfile.write(', '.join(output_labels) + '\n')
  output_formats = ['%d', '%d', '%s'] + ['%.8f'] * (len(output_labels) - 3)
  print_format = ', '.join([x + ': ' + y for x, y in zip(output_labels, output_formats)]).replace('%.8f', '%.4f')
  output_format = ', '.join(output_formats) + '\n'
  for seed in range(NUM_SEEDS):
    state = create_train_state(jax.random.key(123 + seed), num_features=len(FEATURES))
    rng = np.random.default_rng(123 + seed)
    train_ds = make_easy_hard_multi_feature_dataset(features=FEATURES, num_examples=TRAIN_DATASET_SIZE, seed=123 + seed)
    val_rep_ds = make_easy_hard_multi_feature_dataset(features=FEATURES, num_examples=TEST_DATASET_SIZE, seed=1234 + seed)
    test_ds = make_easy_hard_multi_feature_dataset(features=FEATURES, num_examples=TEST_DATASET_SIZE, seed=12345 + seed)
    def _do_eval():
      _, train_loss, train_accuracies, train_reps, train_feature_losses = apply_model(state, train_ds['inputs'], train_ds['labels'])
      _, _, _, extra_reps, _ = apply_model(state, val_rep_ds['inputs'], val_rep_ds['labels'])
      _, test_loss, test_accuracies, test_reps, test_feature_losses = apply_model(state, test_ds['inputs'], test_ds['labels'])
      variance_scores, total_variance = analyze_rep_var_explained(extra_reps, val_rep_ds['labels'], test_reps, test_ds['labels'])
      these_results = (seed, epoch, features_str, train_loss, test_loss, total_variance, *train_accuracies, *test_accuracies, *train_feature_losses, *test_feature_losses, *variance_scores)
      print(print_format % these_results, flush=True)
      outfile.write(output_format % these_results)
      return variance_scores
    for epoch in range(0, 101):
      variance_scores = _do_eval()
      state = train_epoch(state, train_ds)
    print(f"Seed {seed} representation variance at end of training:\n" + "\n".join([f'{fe}: {sc}' for fe, sc in zip(FEATURES, variance_scores)]))
