In [None]:
import torch
import pickle
import matplotlib.pyplot as plt
from data_loader import unsupervised_dataloaders
from models.mlp import MLP, BernoulliMLP
from models.train_models import MSELoss
from models.spinn import BernoulliSPINN, make_schedule, train_model_sequence

In [None]:
# Load data
device = torch.device('cuda:0')
train_loader, val_loader, test_loader, mean, std, total_variance = unsupervised_dataloaders(
    mean_adjustment=True,
    normalization=True,
    device=device)

_, original_dim = train_loader.dataset.get_shape()
print('Total variance = {:.4f}'.format(total_variance))

In [None]:
# Identify subsets of the following sizes
num_variables_list = (5, 10, 20, 30, 40, 50, 60, 80, 100)

# Perform feature selection (ranking)

In [None]:
# Number of trials
num_trials = 1
trial_results = []

for trial in range(num_trials):
    # Create model
    model = BernoulliMLP(original_dim,
                         original_dim,
                         hidden=[100, 100, 100, 100],
                         activation='elu',
                         p=0.01,
                         reference=0,
                         penalty='log').to(device=device)
    
    # Learn SPINN
    spinn = BernoulliSPINN(model)
    spinn.train_ranking(num_variables_list,
                        train_loader,
                        val_loader,
                        lr=1e-3,
                        mbsize=256,
                        nepochs=250,
                        lam=1.0,
                        check_every=250)

    # Record subsets
    trial_results.append([{'inds': subset} for subset in spinn.subsets])
    print('Done with trial {}'.format(trial))

# Train debiased models

In [None]:
m = max(num_variables_list)
for results in trial_results:
    model = MLP(m,
                original_dim - m,
                hidden=[100, 100, 100, 100],
                activation='elu').to(device=device)

    train_model_sequence(results,
                         model,
                         train_loader,
                         val_loader,
                         test_loader,
                         lr=1e-3,
                         mbsize=256,
                         nepochs=250,
                         check_every=250,
                         lookback=20,
                         task_name='reconstruction')

In [None]:
# Plot results
fig, axarr = plt.subplots(1, 2, figsize=(16, 6), sharey=True)

for results in trial_results:
    axarr[0].plot(num_variables_list,
                  [result['reconstruction']['train'] / total_variance for result in results],
                  color='C0', alpha=0.5)
axarr[0].set_title('Train')

for results in trial_results:
    axarr[1].plot(num_variables_list,
                  [result['reconstruction']['val'] / total_variance for result in results],
                  color='C0', alpha=0.5)
axarr[1].set_title('Val')

plt.ylim(0.62, 0.75)

plt.tight_layout()
plt.show()

In [None]:
save_dict = {
    'total_variance': total_variance,
    'num_variables': num_variables_list,
    'trial_results': trial_results
}

with open('bernoulli spinn ranking results.pkl', 'wb') as f:
    pickle.dump(save_dict, f)