# Reproduction
The following notebook aims at reproducing the __empirical results__ from the paper ["Feature learning in neural networks and kernel machines that recursively learn features"](https://arxiv.org/abs/2212.13881). The experiments aim at showing that RFMs are able to learn extremely similar features to those learned by neural networks. The reproduction focuses on 3 experiments:
1. __Key result__: RFMs and neural networks learn similar features
2. __Tabular data__: RFMs outperforms most models on tabular data
3. __Special phenomena__: Both neural networks and RFMs exhibit grokking and simplicity biases behaviours

In [2]:
import torch
import numpy as np

from rfm import rfm, eval_rfm
from nn import MLP

## 1. Key result: RFMs and neural networks learn similar features
The original paper shows that RFMs and neural networks learn similar features on the [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) dataset. The following code just focuses on that, as it's the main empirical results that motivates the usefulness of RFMs in the study of neural networks.

In [None]:
SEED = 5636
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.cuda.manual_seed(SEED)
cudnn.benchmark = True

SIZE = 96
transform = transforms.Compose([transforms.Resize([SIZE, SIZE]), transforms.ToTensor()])

celeba_path = "~/datasets/"
trainset = torchvision.datasets.CelebA(
    root=celeba_path, split="train", transform=transform, download=True
)

trainset = get_balanced_data(trainset)
trainset, valset = split(trainset, p=0.8)

print("Train Size: ", len(trainset), "Val Size: ", len(valset))

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=1
)

valloader = torch.utils.data.DataLoader(
    valset, batch_size=100, shuffle=False, num_workers=1
)

testset = torchvision.datasets.CelebA(
    root=celeba_path, split="test", transform=transform, download=True
)

testset = get_balanced_data(testset)
print("Test Size: ", len(testset))

testloader = torch.utils.data.DataLoader(
    testset, batch_size=512, shuffle=False, num_workers=1
)

# Optional name for saving model
name = "glasses"

# Code for training rfm
rfm.rfm(trainloader, valloader, testloader, name=name, reg=1e-3, iters=1)

# Code for training neural network
t.train_network(trainloader, valloader, testloader, name=name)

## 2. Tabular data: RFMs outperforms most models on tabular data
Reportedly RFMs get better accuracy on tabular benchmark datasets than other models, such as XGBoost and MLPs. This sections reproduces those results on a couple of benchmarks.

## 3. Special phenomena: Both neural networks and RFMs exhibit grokking and simplicity biases behaviours
Finally, it is shown that RFMs exhibit grokking and simplicity biases, as neural networks do. Briefly:
- __Grokking__ is a dramatic increase in test accuracy when training past the point of overfitting. Arguably, it happens when the model learns a latent algirith to solve the task (e.g. if the data is algorighmically generated)
- __Simplicity bias__ is the fact that networks rely on simpler features to solve the task, even if more complex features are available. This could happen because neural networks tend to be as linear as possible.
This last section reproduces the results for RFMs and neural networks on grokking and simplicity biases.