## Imports

In [15]:
%load_ext autoreload
%autoreload 2
import torch
from torch import Tensor

from Datasets.Dataset_Loaders.breast_cancer_dataset_loader import prepare_wisc_breast_cancer_dataset
from Datasets.Dataset_Loaders.kbit_parity_datset_loader import generate_k_bit_parity_dataset
from Utils.neural_network import initialize_FC_neural_net, add_neuron_to_network, compute_accuracy, forward_pass
from Utils.newton_update import compute_minimizer

# Globals
USE_KBIT_DATASET = True
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Dataset Loading Examples

In [3]:
# k-bit parity dataset generation
if USE_KBIT_DATASET:
    X, y = generate_k_bit_parity_dataset(k=4)
    print(f"X shape: {X.shape}")
    print(f"Y shape: {y.shape}")

X shape: torch.Size([16, 4])
Y shape: torch.Size([16, 1])


In [4]:
# Breast cancer dataset loading
if not USE_KBIT_DATASET:
    X, y = prepare_wisc_breast_cancer_dataset(convert_to_tensor=True)
    print(f"X shape: {X.shape}")
    print(f"Y shape: {y.shape}")

## Feedforward Neural Network Construction Algorithm

In [36]:
def run_FNNCA(X: Tensor, y_true: Tensor, 
              acc_threshold: float, max_hidden_units: int, device):
    """ Description: Performs Feedforward Neural Network Construction 
                    Algorithm with BFGS/SR1 update method.
        Args:
            X (Tensor): Input tensor of feature vectors
            y_true (Tensor): Input tensor of true labels
            acc_threshold (float): Accuray threshold for the given
                                classification problem
            max_hidden_units (int): Max number of hidden units that can be added
                                    before termination if acc_threshold not reached
    """
    # Initialize neural network with 2 initial hidden units
    n_hidden = 2
    n_feats = X.shape[1]
    W_init = initialize_FC_neural_net(n_feats, n_hidden, device)
    W = W_init
    accuracy = 0
    while True:
        W = compute_minimizer(W, X, y_true, n_hidden, device)
        _, _, _, y_hat = forward_pass(X, W, n_feats, n_hidden)
        accuracy = compute_accuracy(y_hat, y_true)
        print(f"Number of hidden units: {n_hidden} | Accuracy after optimization: {accuracy}")
        # If accuracy is sufficient or max network size reached, terminate
        if (accuracy > acc_threshold or n_hidden == max_hidden_units):
            break
        # Otherwise, add another neuron and reoptimize
        else: 
            W, n_hidden = add_neuron_to_network(W, n_feats, 
                                    n_output_neurons=1, n_hidden=n_hidden, device=device)
    return W

## Reproduce Study Results

In [41]:
## kbit parity problem
X, y_true = generate_k_bit_parity_dataset(k=4)
X, y_true = X.to(DEVICE), y_true.to(DEVICE)
print(f"device that we are running on: {DEVICE}")
W_final = run_FNNCA(X, y_true, acc_threshold=80, max_hidden_units=10, device=DEVICE)


device that we are running on: cuda


Iteration 9999 | grad norm: 1.74e-04: 100%|██████████| 10000/10000 [01:09<00:00, 144.71it/s]


Number of hidden units: 2 | Accuracy after optimization: 62.5


Iteration 9999 | grad norm: 1.13e-07: 100%|██████████| 10000/10000 [01:09<00:00, 144.38it/s]


Number of hidden units: 3 | Accuracy after optimization: 75.0


Iteration 9999 | grad norm: 4.27e-05: 100%|██████████| 10000/10000 [00:53<00:00, 187.28it/s]

Number of hidden units: 4 | Accuracy after optimization: 81.25





In [49]:
## kbit parity problem
X, y_true = generate_k_bit_parity_dataset(k=7)
X, y_true = X.to(DEVICE), y_true.to(DEVICE)
print(f"device that we are running on: {DEVICE}")
W_final = run_FNNCA(X, y_true, acc_threshold=100, max_hidden_units=10, device=DEVICE)

device that we are running on: cuda


Iteration 9 | grad norm: 1.83e-03:   0%|          | 0/10000 [00:00<?, ?it/s]

Iteration 9999 | grad norm: 6.07e-06: 100%|██████████| 10000/10000 [01:10<00:00, 141.19it/s]


Number of hidden units: 2 | Accuracy after optimization: 51.5625


Iteration 9999 | grad norm: 1.92e-04: 100%|██████████| 10000/10000 [01:10<00:00, 140.90it/s]


Number of hidden units: 3 | Accuracy after optimization: 50.0


Iteration 9999 | grad norm: 8.44e-06: 100%|██████████| 10000/10000 [01:11<00:00, 140.77it/s]


Number of hidden units: 4 | Accuracy after optimization: 41.40625


Iteration 9999 | grad norm: 5.38e-06: 100%|██████████| 10000/10000 [01:11<00:00, 140.51it/s]


Number of hidden units: 5 | Accuracy after optimization: 49.21875


Iteration 9999 | grad norm: 6.10e-02: 100%|██████████| 10000/10000 [01:09<00:00, 143.85it/s]


Number of hidden units: 6 | Accuracy after optimization: 82.03125


Iteration 9999 | grad norm: 5.27e-02: 100%|██████████| 10000/10000 [01:08<00:00, 145.69it/s]


Number of hidden units: 7 | Accuracy after optimization: 87.5


Iteration 9999 | grad norm: 9.42e-02: 100%|██████████| 10000/10000 [01:08<00:00, 145.36it/s]


Number of hidden units: 8 | Accuracy after optimization: 92.96875


Iteration 168 | grad norm: 4.64e-03:   2%|▏         | 168/10000 [00:00<00:16, 582.56it/s]


Number of hidden units: 9 | Accuracy after optimization: 97.65625


Iteration 11 | grad norm: 4.84e-03:   0%|          | 11/10000 [00:00<00:19, 503.27it/s]


Number of hidden units: 10 | Accuracy after optimization: 97.65625
