In [1]:
import itertools
import functools

import torch

import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim

from torchvision import datasets, transforms
import torchvision.transforms as transforms

import os
import torch
import numpy as np
import matplotlib.pyplot as plt

import random

import copy
import math
import pickle
import argparse

from torch.optim.lr_scheduler import StepLR
import torchvision
import imageio.v2 as imageio


from mup_nets import ReLUResNetMUP_manual
from train_utils import get_random_data_unif_binary, train_mse, test_mse, train_class_erm, test_class, test_class_acc, ERMDatasetFromFuncBinary
from decision_tree_utils import DecisionTree, AndOfLiterals, comp_and_list

from tqdm import tqdm

# Training/testing code

## Decision tree generation

In [2]:
device = torch.device("cuda")

In [7]:
suffixes = [''] # , '_a', '_b', '_c', '_d'

print('Generating random decision trees without overwriting')
for suffix in tqdm(suffixes):
    for d in [100,1000]:
        for r in [2,3,4,5]:
            dtreefile = f'dectrees/dectrees_{d}_{r}{suffix}.pkl'
            if os.path.exists(dtreefile):
                # print('exists')
                pass
            else:
                print('creating')
                dtree = DecisionTree(d=d,r=r)
                pickle.dump(dtree, open(dtreefile, 'wb'))

Generating random decision trees without overwriting


100%|██████████| 1/1 [00:00<00:00, 3141.80it/s]


## Main code

In [8]:
treefilenames = []
for suffix in suffixes:
    for d in [100,1000]:
        for r in [2,3,4,5]:
            treefilenames.append(f'dectrees_{d}_{r}{suffix}')

# Train with MSE or CLASS loss
loss_types = ['mse', 'class']

# Train with varying numbers of ERM samples for 10 epochs
sample_nums = [100000] # , 1000000

# Experiment configurations consist of one tree, plus MSE or CLASS loss trained for a certain number of ERM samples
exp_configs = list(itertools.product(treefilenames, loss_types, sample_nums))

for suffix in suffixes:
    exp_configs.append((f'dectrees_100_5{suffix}', 'class', 1000000))
    
print(len(exp_configs))

17


In [9]:

for exp_config in tqdm(exp_configs):
    
    print('-'*10)
    print(*exp_config)
    
    dtreefile = exp_config[0]
    loss_type = exp_config[1]
    erm_num_samples = exp_config[2]
    
    
    
    # Training settings
    dataset_type = 'eval_fn_binary'
    optimizer_type = 'sgd'


    sgd_type = 'erm'
    # sgd_type = 'online'

    test_num_samples = 10000

    log_first_weight_rank = False
    weight_vis_fps = 20
    log_interval=10
    batch_size=100
    test_batch_size = 1000
    epochs = 10
    lr = 0.05
    gamma = 1.0
    
    print(dataset_type)
    print(optimizer_type,sgd_type)
    print('loss',loss_type)


    num_layers = 5
    width = 1000
    weight_std = 0.001
    bias_std = 0.001
    net_type = ReLUResNetMUP_manual
    
    save_model_name = 'saved_models/' + dtreefile + '_' + loss_type + '_erm' + str(erm_num_samples) + '_' + optimizer_type + str(lr) + '_epochs' + str(epochs)
    print(dtreefile)
    print(save_model_name)
    if os.path.exists(save_model_name + '.pt'):
        print(save_model_name + '.pt', 'EXISTS; skipping')
        continue
    
    dtree = pickle.load(open(f'dectrees/{dtreefile}.pkl', 'rb'))
    d = len(dtree.ands[0].get_tup())
    eval_fn = lambda x : dtree.compute(x)
    gen_batch = 1024


    if dataset_type[:7] == 'eval_fn':
     
        input_length = d
        output_width = 1

        if sgd_type == 'erm':
            if loss_type == 'mse':
                train_fn = train_mse
                test_fn = test_mse
            elif loss_type == 'class':
                train_fn = train_class_erm
                test_fn = test_class
            else:
                assert(False)
        elif sgd_type == 'online':
            if loss_type == 'mse':
                train_fn = train_mse_online
                test_fn = test_mse
            elif loss_type == 'class':
                train_fn = train_class_online
                test_fn = test_class
        else:
            assert(False)
    else:
        train_fn = train
        test_fn = test


    no_cuda = False
    # seed=2
    load_model_name=None
    train_model=True

    dry_run=False # 'quickly check a single pass (NOT SURE IF IMPLEMENTED)


    ########### BOILERPLATE
    use_cuda = not no_cuda and torch.cuda.is_available()

    args = {'batch_size' : batch_size, 'test_batch_size': test_batch_size, 'epochs' : epochs, \
            'lr' : lr, 'gamma': gamma, 'no_cuda' : no_cuda, 'dry_run' : dry_run, \
            'log_interval' : log_interval, 'save_model' : save_model_name, 'use_cuda' : use_cuda} #  'seed' : seed,
    print(args)

    # torch.manual_seed(seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    train_kwargs = {'batch_size': batch_size}
    test_kwargs = {'batch_size': test_batch_size}
    if use_cuda:
        cuda_kwargs = {'num_workers': 0,
                       'pin_memory': True}
        if sgd_type == 'erm':
            cuda_kwargs['shuffle'] = True
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)



    ########### DATASET LOADING

    if dataset_type == 'eval_fn_binary':

        print('Loading datasets')
        if sgd_type == 'erm':
            ## eval_fn datset
            train_ds = ERMDatasetFromFuncBinary(d,eval_fn,erm_num_samples,batch=gen_batch)
            test_ds = ERMDatasetFromFuncBinary(d,eval_fn,test_num_samples,batch=gen_batch)
        elif sgd_type == 'online':
            train_ds = DatasetFromFuncBinary(d,eval_fn)
            test_ds = ERMDatasetFromFuncBinary(d,eval_fn,test_num_samples,batch=gen_batch)
        else:
            assert(False)

        print('Loaded datasets')

    elif dataset_type == 'eval_fn_gaussian':
        print('Loading datasets')
        if sgd_type == 'erm':
            ## eval_fn datset
            train_ds = ERMDatasetFromFuncGaussian(d,eval_fn,erm_num_samples)
            test_ds = ERMDatasetFromFuncGaussian(d,eval_fn,test_num_samples)
        elif sgd_type == 'online':
            train_ds = DatasetFromFuncGaussian(d,eval_fn)
            test_ds = ERMDatasetFromFuncGaussian(d,eval_fn,test_num_samples)
        else:
            assert(False)
        print('Loaded datasets')

    else:
        print(dataset_type,'dataset not implemented')
        assert(False)


    ########### DATA LOADING

    train_loader = torch.utils.data.DataLoader(train_ds,**train_kwargs)
    test_loader = torch.utils.data.DataLoader(test_ds, **test_kwargs)


    if load_model_name is not None:
        assert(False)

    else:
        print('Training')
        model_args = [input_length, num_layers, width]
        model_kwargs = {'output_width' : output_width, 'weight_std' : weight_std, 'bias_std' : bias_std}
        model = net_type(*model_args, **model_kwargs).to(device)
        model.to(device)

        test_loss_list = []

        if train_model:
            ps = [(n,p) for (n,p) in model.named_parameters()]
            biases = [p for (n,p) in ps if 'bias' in n]
            other_params = [p for (n,p) in ps if 'bias' not in n]
            assert(optimizer_type == 'sgd')
            optimizer = optim.SGD([{'params' : other_params, 'lr' : lr, 'weight_decay' : 0},
                                   {'params' : biases, 'lr' : lr, 'weight_decay' : 0}])


            if log_first_weight_rank:
                weight_vis_writer = imageio.get_writer('weight_vis.mp4', fps=weight_vis_fps)

            scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
            for epoch in range(1, epochs + 2):
                curr_test_loss = test_fn(model, device, test_loader)
                test_loss_list.append(curr_test_loss)
                if log_first_weight_rank:
                    visualize_first_weight_rank(model,test_loss_list,epochs,log_interval*batch_size)
                    # assert(False)
                    weight_vis_img = imageio.imread('curr_weight_vis.jpg')
                    weight_vis_writer.append_data(weight_vis_img)
                train_fn(args, model, device, train_loader, optimizer, epoch)

                scheduler.step()

            if log_first_weight_rank:
                weight_vis_writer.close()

    if save_model_name is not None:
        torch.save(model, save_model_name + '.pt')

  0%|          | 0/17 [00:00<?, ?it/s]

----------
dectrees_100_2 mse 100000
eval_fn_binary
sgd erm
loss mse
dectrees_100_2
saved_models/dectrees_100_2_mse_erm100000_sgd0.05_epochs10
saved_models/dectrees_100_2_mse_erm100000_sgd0.05_epochs10.pt EXISTS; skipping
----------
dectrees_100_2 class 100000
eval_fn_binary
sgd erm
loss class
dectrees_100_2
saved_models/dectrees_100_2_class_erm100000_sgd0.05_epochs10
saved_models/dectrees_100_2_class_erm100000_sgd0.05_epochs10.pt EXISTS; skipping
----------
dectrees_100_3 mse 100000
eval_fn_binary
sgd erm
loss mse
dectrees_100_3
saved_models/dectrees_100_3_mse_erm100000_sgd0.05_epochs10
saved_models/dectrees_100_3_mse_erm100000_sgd0.05_epochs10.pt EXISTS; skipping
----------
dectrees_100_3 class 100000
eval_fn_binary
sgd erm
loss class
dectrees_100_3
saved_models/dectrees_100_3_class_erm100000_sgd0.05_epochs10
saved_models/dectrees_100_3_class_erm100000_sgd0.05_epochs10.pt EXISTS; skipping
----------
dectrees_100_4 mse 100000
eval_fn_binary
sgd erm
loss mse
dectrees_100_4
saved_models

 65%|██████▍   | 11/17 [00:32<00:17,  2.93s/it]

----------
dectrees_1000_3 class 100000
eval_fn_binary
sgd erm
loss class
dectrees_1000_3
saved_models/dectrees_1000_3_class_erm100000_sgd0.05_epochs10
{'batch_size': 100, 'test_batch_size': 1000, 'epochs': 10, 'lr': 0.05, 'gamma': 1.0, 'no_cuda': False, 'dry_run': False, 'log_interval': 10, 'save_model': 'saved_models/dectrees_1000_3_class_erm100000_sgd0.05_epochs10', 'use_cuda': True}
Loading datasets
Loaded datasets
Training

Test set: Accuracy 50.3% (5029/10000), Average loss: 0.6931


Test set: Accuracy 84.1% (8414/10000), Average loss: 0.1921


Test set: Accuracy 100.0% (9999/10000), Average loss: 0.0016


Test set: Accuracy 100.0% (10000/10000), Average loss: 0.0003


Test set: Accuracy 100.0% (10000/10000), Average loss: 0.0002


Test set: Accuracy 100.0% (10000/10000), Average loss: 0.0001


Test set: Accuracy 100.0% (10000/10000), Average loss: 0.0001


Test set: Accuracy 100.0% (10000/10000), Average loss: 0.0001


Test set: Accuracy 100.0% (10000/10000), Average loss: 0.000

 71%|███████   | 12/17 [01:05<00:31,  6.39s/it]

----------
dectrees_1000_4 mse 100000
eval_fn_binary
sgd erm
loss mse
dectrees_1000_4
saved_models/dectrees_1000_4_mse_erm100000_sgd0.05_epochs10
{'batch_size': 100, 'test_batch_size': 1000, 'epochs': 10, 'lr': 0.05, 'gamma': 1.0, 'no_cuda': False, 'dry_run': False, 'log_interval': 10, 'save_model': 'saved_models/dectrees_1000_4_mse_erm100000_sgd0.05_epochs10', 'use_cuda': True}
Loading datasets
Loaded datasets
Training

Test set: Average loss: 1.0002


Test set: Average loss: 0.7280


Test set: Average loss: 0.6637


Test set: Average loss: 0.3919


Test set: Average loss: 0.0627


Test set: Average loss: 0.0099


Test set: Average loss: 0.0075


Test set: Average loss: 0.0053


Test set: Average loss: 0.0052


Test set: Average loss: 0.0043


Test set: Average loss: 0.0040



 76%|███████▋  | 13/17 [01:36<00:39,  9.87s/it]

----------
dectrees_1000_4 class 100000
eval_fn_binary
sgd erm
loss class
dectrees_1000_4
saved_models/dectrees_1000_4_class_erm100000_sgd0.05_epochs10
{'batch_size': 100, 'test_batch_size': 1000, 'epochs': 10, 'lr': 0.05, 'gamma': 1.0, 'no_cuda': False, 'dry_run': False, 'log_interval': 10, 'save_model': 'saved_models/dectrees_1000_4_class_erm100000_sgd0.05_epochs10', 'use_cuda': True}
Loading datasets
Loaded datasets
Training

Test set: Accuracy 63.1% (6308/10000), Average loss: 0.6930


Test set: Accuracy 63.1% (6308/10000), Average loss: 0.6067


Test set: Accuracy 70.7% (7074/10000), Average loss: 0.5340


Test set: Accuracy 71.7% (7168/10000), Average loss: 0.5163


Test set: Accuracy 79.7% (7965/10000), Average loss: 0.4200


Test set: Accuracy 87.9% (8793/10000), Average loss: 0.2477


Test set: Accuracy 96.5% (9649/10000), Average loss: 0.0871


Test set: Accuracy 97.9% (9793/10000), Average loss: 0.0576


Test set: Accuracy 99.1% (9909/10000), Average loss: 0.0268


Test set:

 82%|████████▏ | 14/17 [02:08<00:41, 13.69s/it]

----------
dectrees_1000_5 mse 100000
eval_fn_binary
sgd erm
loss mse
dectrees_1000_5
saved_models/dectrees_1000_5_mse_erm100000_sgd0.05_epochs10
{'batch_size': 100, 'test_batch_size': 1000, 'epochs': 10, 'lr': 0.05, 'gamma': 1.0, 'no_cuda': False, 'dry_run': False, 'log_interval': 10, 'save_model': 'saved_models/dectrees_1000_5_mse_erm100000_sgd0.05_epochs10', 'use_cuda': True}
Loading datasets
Loaded datasets
Training

Test set: Average loss: 1.0001


Test set: Average loss: 0.7382


Test set: Average loss: 0.5675


Test set: Average loss: 0.4999


Test set: Average loss: 0.3678


Test set: Average loss: 0.3187


Test set: Average loss: 0.2451


Test set: Average loss: 0.1740


Test set: Average loss: 0.1635


Test set: Average loss: 0.1398


Test set: Average loss: 0.1380



 88%|████████▊ | 15/17 [02:41<00:34, 17.28s/it]

----------
dectrees_1000_5 class 100000
eval_fn_binary
sgd erm
loss class
dectrees_1000_5
saved_models/dectrees_1000_5_class_erm100000_sgd0.05_epochs10
{'batch_size': 100, 'test_batch_size': 1000, 'epochs': 10, 'lr': 0.05, 'gamma': 1.0, 'no_cuda': False, 'dry_run': False, 'log_interval': 10, 'save_model': 'saved_models/dectrees_1000_5_class_erm100000_sgd0.05_epochs10', 'use_cuda': True}
Loading datasets
Loaded datasets
Training

Test set: Accuracy 56.2% (5623/10000), Average loss: 0.6931


Test set: Accuracy 67.1% (6713/10000), Average loss: 0.6384


Test set: Accuracy 71.8% (7180/10000), Average loss: 0.5393


Test set: Accuracy 78.2% (7823/10000), Average loss: 0.4465


Test set: Accuracy 80.8% (8081/10000), Average loss: 0.3814


Test set: Accuracy 84.7% (8468/10000), Average loss: 0.3068


Test set: Accuracy 87.2% (8721/10000), Average loss: 0.2567


Test set: Accuracy 88.1% (8815/10000), Average loss: 0.2646


Test set: Accuracy 89.3% (8931/10000), Average loss: 0.2608


Test set:

 94%|█████████▍| 16/17 [03:15<00:20, 20.87s/it]

----------
dectrees_100_5 class 1000000
eval_fn_binary
sgd erm
loss class
dectrees_100_5
saved_models/dectrees_100_5_class_erm1000000_sgd0.05_epochs10
{'batch_size': 100, 'test_batch_size': 1000, 'epochs': 10, 'lr': 0.05, 'gamma': 1.0, 'no_cuda': False, 'dry_run': False, 'log_interval': 10, 'save_model': 'saved_models/dectrees_100_5_class_erm1000000_sgd0.05_epochs10', 'use_cuda': True}
Loading datasets
Loaded datasets
Training

Test set: Accuracy 46.7% (4671/10000), Average loss: 0.6932


Test set: Accuracy 100.0% (10000/10000), Average loss: 0.0002


Test set: Accuracy 100.0% (10000/10000), Average loss: 0.0000


Test set: Accuracy 100.0% (10000/10000), Average loss: 0.0000


Test set: Accuracy 100.0% (10000/10000), Average loss: 0.0000


Test set: Accuracy 100.0% (10000/10000), Average loss: 0.0000


Test set: Accuracy 100.0% (10000/10000), Average loss: 0.0000


Test set: Accuracy 100.0% (10000/10000), Average loss: 0.0000


Test set: Accuracy 100.0% (10000/10000), Average loss: 0.0

100%|██████████| 17/17 [07:15<00:00, 25.61s/it]


In [10]:
# Compute the accuracy of each of the models and save in testaccs.csv

csvlines = ['model,treefile,loss,acc\n']
for exp_config in tqdm(exp_configs):
    
    # print('-'*10)
    # print(*exp_config)
    
    dtreefile = exp_config[0]
    loss_type = exp_config[1]
    erm_num_samples = exp_config[2]
    
    
    
    # Training settings
    dataset_type = 'eval_fn_binary'
    optimizer_type = 'sgd'


    sgd_type = 'erm'
    # sgd_type = 'online'

    test_num_samples = 10000

    log_first_weight_rank = False
    weight_vis_fps = 20
    log_interval=10
    batch_size=100
    test_batch_size = 1000
    epochs = 10
    lr = 0.05
    gamma = 1.0
    
    # print(dataset_type)
    # print(optimizer_type,sgd_type)
    # print('loss',loss_type)


    num_layers = 5
    width = 1000
    weight_std = 0.001
    bias_std = 0.001
    net_type = ReLUResNetMUP_manual
    
    save_model_name = 'saved_models/' + dtreefile + '_' + loss_type + '_erm' + str(erm_num_samples) + '_' + optimizer_type + str(lr) + '_epochs' + str(epochs)
    
    model = torch.load(save_model_name + '.pt')
    
    # print(save_model_name)
    # print(dtreefile)
    
    dtree = pickle.load(open(f'dectrees/{dtreefile}.pkl', 'rb'))
    d = len(dtree.ands[0].get_tup())
    eval_fn = lambda x : dtree.compute(x)
    gen_batch = 1024
    
    test_ds = ERMDatasetFromFuncBinary(d,eval_fn,test_num_samples,batch=gen_batch)
    if loss_type == 'mse':
        test_fn = test_mse
    elif loss_type == 'class':
        test_fn = test_class_acc
    else:
        assert(False)
        
    model.to(device)
    test_kwargs = {'batch_size': test_batch_size}
    test_loader = torch.utils.data.DataLoader(test_ds, **test_kwargs)
    if loss_type == 'mse':
        curr_test_loss = test_mse(model, device, test_loader)
        curr_test_acc = None
    elif loss_type == 'class':
        curr_test_loss, curr_test_acc = test_class_acc(model, device, test_loader)
    else:
        assert(False)
    print(curr_test_loss)
    csvlines.append(save_model_name + ',' + dtreefile + ',' + str(curr_test_loss) + ',' + str(curr_test_acc) + '\n')

with open('testaccs.csv', 'w') as f:
    f.writelines(csvlines)

 12%|█▏        | 2/17 [00:00<00:01,  9.67it/s]


Test set: Average loss: 0.0000

1.4176428919654428e-13

Test set: Accuracy 100.0% (10000/10000), Average loss: 0.0000

2.6895394176244734e-05


 24%|██▎       | 4/17 [00:00<00:01,  9.46it/s]


Test set: Average loss: 0.0000

2.2343339631333946e-06

Test set: Accuracy 100.0% (10000/10000), Average loss: 0.0000

4.413553886115551e-05


 35%|███▌      | 6/17 [00:00<00:01,  8.80it/s]


Test set: Average loss: 0.0000

3.169366512447596e-05

Test set: Accuracy 100.0% (10000/10000), Average loss: 0.0001

7.858491763472558e-05


 41%|████      | 7/17 [00:00<00:01,  8.07it/s]


Test set: Average loss: 0.0002

0.00015133364349603654


 47%|████▋     | 8/17 [00:01<00:01,  5.97it/s]


Test set: Accuracy 100.0% (10000/10000), Average loss: 0.0001

0.000134759996086359


 53%|█████▎    | 9/17 [00:01<00:01,  5.54it/s]


Test set: Average loss: 0.0000

4.213514057482826e-09


 65%|██████▍   | 11/17 [00:01<00:01,  5.05it/s]


Test set: Accuracy 100.0% (10000/10000), Average loss: 0.0000

2.2176150791347025e-05

Test set: Average loss: 0.0000

2.1884818375110625e-05


 71%|███████   | 12/17 [00:01<00:00,  5.10it/s]


Test set: Accuracy 100.0% (10000/10000), Average loss: 0.0000

3.7036661617457864e-05


 76%|███████▋  | 13/17 [00:02<00:00,  5.06it/s]


Test set: Average loss: 0.0040

0.004032966470718384


 82%|████████▏ | 14/17 [00:02<00:00,  4.41it/s]


Test set: Accuracy 99.9% (9990/10000), Average loss: 0.0037

0.003674157166481018


 88%|████████▊ | 15/17 [00:02<00:00,  4.43it/s]


Test set: Average loss: 0.1491

0.14910233306884765


100%|██████████| 17/17 [00:02<00:00,  5.68it/s]


Test set: Accuracy 90.2% (9016/10000), Average loss: 0.3362

0.3361600311279297

Test set: Accuracy 100.0% (10000/10000), Average loss: 0.0000

4.633553768508136e-06





In [None]:
# Hyperparameters
# 10 epochs
# SGD with learning rate 0.05
# ResNet with inner dimension 1000, depth 5, init std 0.0001 of biases and weights, in MuP?

# For the random dtrees:
# d = 100, 2 <= r <= 5, trained with 100,000 ERM samples
# d = 1000, 2 <= r <= 3, also 100,000 ERM samples
# d = 1000, r = 4, 5, trained with 1,000,000 ERM samples (1e5 did not work)