In [None]:
from utils import load_data, get_grid_edges, Publisher
from code_p_graph_fast import Pairwise_Graph

from ERM import *
from Ensemble import *
from GibbsNet import *

import random

### Entry point for our code

The code below combines all the other modules in this folder, to reproduce our experiments. To keep default experiment settings, only change the following:
* `mode` (choose between 'AGM', 'EGM' or 'GibbsNet')
* `data_name` (can be one of: 'MNIST', 'CALTECH', pumsb_star', 'accidents', 'adult', 'bnetflix', 'connect4', 'jester', 'mushrooms', 'nltcs', 'voting', 'c20ng')

Also, if you would like to save a model, change the `model_save_path` and `model_save_name` from `None` to your desired path and label for the saved file.  

**This should allow reproducing experiments I and II. 
See text after code for experiment III instructions.**

In [None]:
### Change which experiment to run here
mode = 'EGM' # can be 'AGM', 'EGM' or 'GibbsNet'
data_name='MNIST' # can be one of: ['MNIST', 'CALTECH', pumsb_star', 'accidents', 'adult', 'bnetflix', 'connect4', 'jester', 'mushrooms', 'nltcs', 'voting', 'c20ng']

### If you would like to save the model, add a path and name for the model
model_save_path = None
model_save_name = None

### Instead of train or test from data_name, if you would like to load custom data from a saved PyTorch tensor
custom_train_data_path = None
custom_test_data_path = None

#####
# Leave unchanged from here for default experiment settings
cap_train=None
cap_test=1000
train_batch_size=128
test_batch_size=128 if mode != 'AGM' else 16
print(f'data: {data_name}')

train_loader, test_loader, variables = load_data(name=data_name, custom_train_data_path=custom_train_data_path, custom_test_data_path=custom_test_data_path, cap_train=cap_train, cap_test=cap_test, train_bs=train_batch_size, test_bs=test_batch_size)
train_batch_size = min(len(train_loader.dataset), train_batch_size)
test_batch_size = min(len(test_loader.dataset), test_batch_size)
train_loader, test_loader, variables = load_data(name=data_name, custom_train_data_path=custom_train_data_path, custom_test_data_path=custom_test_data_path, cap_train=cap_train, cap_test=cap_test, train_bs=train_batch_size, test_bs=test_batch_size)

alphabet = [0,1]
n_vars = len(variables)
publisher = Publisher()

if data_name in ['MNIST', 'CALTECH']:
    chosen_edges = get_grid_edges(width=28)
    mode_images = True
else:
    from itertools import combinations
    chosen_edges = random.sample(list(combinations(range(n_vars), 2)), min(len(list(combinations(range(n_vars), 2))), n_vars*5))
    mode_images = False

G = Pairwise_Graph(edges=chosen_edges, dct_node_idx_to_alphabet={v: [0,1] for v in variables})
G.alphabet = [0,1]

torch.random.manual_seed(0)
np.random.seed(0)
random.seed(0)

# args set (2) in here for training procedure

if mode == 'AGM':
    main_AGM(
        data_name=data_name,
        train_loader=train_loader,
        test_loader=test_loader,
        G = G,
        publisher=publisher,
        mode_images=mode_images,
        z_dimension=None, 
        lamb=10, 
        M = 1000,
        ratio_D_to_G=10, 
        lr=1e-4, 
        n_bp_steps=5, 
        n_steps=10000,
        device='cuda:0',
        test_frac=True,
        test_squares=mode_images,
        test_quads=mode_images,
        test_corrupt=mode_images,
        test_every=5000,
        model_save_path=model_save_path,
        model_save_name=model_save_name,
        sample_save_n=None,
        sample_save_path=None
    )

elif mode == 'EGM':
    main_EGM(
        data_name=data_name,
        train_loader=train_loader,
        test_loader=test_loader,
        G = G,
        publisher=publisher,
        mode_images=mode_images,  
        lr=1e-2,
        n_bp_steps=25, 
        n_steps=1000, 
        device='cuda:0', 
        test_frac=True,
        test_squares=mode_images,
        test_quads=mode_images,
        test_corrupt=mode_images,
        test_every=500,
        model_save_path=model_save_path,
        model_save_name=model_save_name,
        sample_save_n=None,
        sample_save_burnin=None,
        sample_save_path=None)

elif mode == 'GibbsNet':
    main_GibbsNet(
        data_name=data_name,
        train_loader=train_loader,
        test_loader=test_loader,
        G=G,
        publisher=publisher,
        mode_images=mode_images,
        z_dimension=None, 
        lamb=10, 
        sampling_count=5,
        ratio_D_to_G=10, 
        lr=5e-5, 
        n_steps=10000,
        device='cuda:0',
        test_frac=True,
        test_squares=mode_images,
        test_quads=mode_images,
        test_corrupt=mode_images,
        test_every=5000,
        model_save_path=model_save_path,
        model_save_name=model_save_name)
        
else: raise NotImplementedError
#####

**To reproduce Experiment III**, one needs to first train EGM or AGM on some data from `data_name`, store samples from this model, then train an EGM on this saved data, and test on `data_name`.

To save samples from EGM or AGM, you need to provide a:
* `sample_save_path` for path to save samples as a pytorch tensor
* `sample_save_n` for number of samples, e.g. 1000
* `sample_save_burnin` in case EGM is the model being used as it uses a Gibbs sampler

Hence for experiment III, one needs to run twice. First time, run with some `data_name` of choice, and save samples by providing a `sample_save_path`, `sample_save_n=1000`. And if EGM, also give a `sample_save_burnin`.

Then, for the second run, again provide `data_name` to allow fetching the data for testing, but to train on your sampled data, provide a `custom_train_data_path` (which equals the `sample_save_path` used in the first run).

Side note: The data loader always loads data from some `data_name` (train and test), then discards train or test respectively, whenever `custom_train_data_path` or `custom_test_data_path` is provided.