In [1]:
# dev of adjust_sigma.py

In [2]:
from __future__ import print_function

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

import numpy as np

In [4]:
import sys
sys.path.append('/home/felix/Research/Adversarial Research/FGN---Research/')
import Finite_Gaussian_Network_lib as fgnl
import Finite_Gaussian_Network_lib.fgn_helper_lib as fgnh

In [5]:
# MNIST dataset and dataloader declaration
# transforms does both the conversion from 0-255 to 0-1
# and normalizes by the precomputed mean and std

batch_size = 2048

mnist_train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../../MNIST-dataset', train=True, download=False, 
                   transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))
            ])), 
        batch_size=batch_size, shuffle=True)

In [6]:
def adjust_sigma(fgn_model, dataloader=None, loss_func=None, pred_func=None, verbose=False):
    
    ###
    # will attempt to adjust the sigmas in the FGN model
    # if no data is given, will adjust based on mean 0 var 1 input to get as close to mean 0 var 1 output
    # if data is given, will adjust based on it to get as close to mean 0 var 1 output
    # if data+loss function given, will minimize loss
    # if data+pred_func given, will maximize accuracy
    ###
    
    ### type checks
    # model: a Pytorch module
    if not isinstance(model, fgnl.Feedforward_FGN_net):
        raise TypeError("model is not pytorch  FGN module")
    # dataset: a pytorch data loader
    if (dataloader is not None) and (not isinstance(dataloader, torch.utils.data.dataloader.DataLoader)):
        raise TypeError("test_loader is not pytorch dataloader")
    # loss_func: a pytorch loss function (can be any function)
    if (loss_func is not None) and (not callable(loss_func)):
        raise TypeError("loss_func is not a function")
    # loss_func: a pred function (can be any function)
    if (pred_func is not None) and (not callable(pred_func)):
        raise TypeError("pred_func is not a function")
    # verbose: bool
    if not isinstance(verbose, bool):
        raise TypeError("verbose is not a boolean")
    
    # if the dataloader is empty, adjust sigmas to get mean=0 variance=0
    if dataloader==None:
        pass
        
    # if a dataset (and either loss function or pred_func) is provided, adjust to fit the data best
#     if :
#         pass
    
    
    # return nothing
    return None

In [7]:
def adjust_sigma_pred_func(fgn_model, dataloader, pred_func, verbose):
    
    ###
    # adjusts the sigmas of the given fgn model so that the pred accuracy over the dataset is max
    ###
    
    # best pred acc yet
    fgn_test_res = fgnh.test(fgn_model, dataloader, 
                             (lambda model, output, target:torch.tensor(0)), verbose=verbose, 
                             pred_func=pred_func)
    best_pred = fgn_test_res['test_accuracy']
    # best sigma multiplier yet
    best_sig_mult = 1.0
    # lower bound for sigma mult
    lower_bound = 0.0
    # uper bound for sigma
    upper_bound = float('Inf')
    
    # max number of values to test
    max_iter = 25
    
    # first double sigmas until performance decreases
    for ite in range(max_iter):
        # new val to test
        cur_sig_mult = 2.0*best_sig_mult
        if verbose: print(ite, "testing", cur_sig_mult)
        
        # apply multiplier
        # given an fgn model, multiplies all the sigmas by a value
        for p in fgn_model.modules():
            if isinstance(p, fgnl.FGN_layer):
                p.sigmas = torch.nn.Parameter(p.sigmas*cur_sig_mult)
    
        # test
        fgn_test_res = fgnh.test(fgn_model, dataloader, 
                             (lambda model, output, target:torch.tensor(0)), verbose=verbose, 
                             pred_func=pred_func)
        cur_pred = fgn_test_res['test_accuracy']
        
        # reset sigmas 
        for p in fgn_model.modules():
            if isinstance(p, fgnl.FGN_layer):
                p.sigmas = torch.nn.Parameter(p.sigmas/cur_sig_mult)
                
        if cur_pred > best_pred:
            if verbose: print("new best during doubling")
            # new best
            best_pred = cur_pred
            best_sig_mult = cur_sig_mult
            # increase lower bound
            lower_bound = cur_sig_mult
        else:
            # new upper bound
            upper_bound = cur_sig_mult
            # and exit loop
            break
            
            
    # next half sigmas until performance decreases
    for ite in range(max_iter):
        # new val to test
        cur_sig_mult = 0.5*best_sig_mult
        if verbose: print(ite, "testing", cur_sig_mult)
        
        # apply multiplier
        # given an fgn model, multiplies all the sigmas by a value
        for p in fgn_model.modules():
            if isinstance(p, fgnl.FGN_layer):
                p.sigmas = torch.nn.Parameter(p.sigmas*cur_sig_mult)
    
        # test
        fgn_test_res = fgnh.test(fgn_model, dataloader, 
                             (lambda model, output, target:torch.tensor(0)), verbose=verbose, 
                             pred_func=pred_func)
        cur_pred = fgn_test_res['test_accuracy']
        
        # reset sigmas 
        for p in fgn_model.modules():
            if isinstance(p, fgnl.FGN_layer):
                p.sigmas = torch.nn.Parameter(p.sigmas/cur_sig_mult)
                
        if cur_pred >= (1.0-1e-3)*best_pred:
            if verbose: print("new best during halfing")
            # new best
            best_pred = cur_pred
            best_sig_mult = cur_sig_mult
            # new upper bound
            upper_bound = cur_sig_mult
        else:
            # increase lower bound
            lower_bound = cur_sig_mult
            # and exit loop
            break
    
    # now that we have a real bounds, search by dichotomie
    for ite in range(max_iter):

        # new val to test
        cur_sig_mult = 0.5*(upper_bound+lower_bound)
        if verbose: print(ite, "testing", cur_sig_mult)
        
        # apply multiplier
        # given an fgn model, multiplies all the sigmas by a value
        for p in fgn_model.modules():
            if isinstance(p, fgnl.FGN_layer):
                p.sigmas = torch.nn.Parameter(p.sigmas*cur_sig_mult)
    
        # test
        fgn_test_res = fgnh.test(fgn_model, dataloader, 
                             (lambda model, output, target:torch.tensor(0)), verbose=verbose, 
                             pred_func=pred_func)
        cur_pred = fgn_test_res['test_accuracy']
        
        # reset sigmas 
        for p in fgn_model.modules():
            if isinstance(p, fgnl.FGN_layer):
                p.sigmas = torch.nn.Parameter(p.sigmas/cur_sig_mult)
                
        if cur_pred >= (1.0-1e-3)*best_pred:
            if verbose: print("new best during dicho")
            # new low bound
            if cur_sig_mult > best_sig_mult:
                lower_bound = cur_sig_mult
            # new upper bound
            else:
                upper_bound = cur_sig_mult
            # new best
            best_pred = cur_pred
            best_sig_mult = cur_sig_mult
                
        else:
            # new low bound
            if cur_sig_mult < best_sig_mult:
                lower_bound = cur_sig_mult
            # new upper bound
            else:
                upper_bound = cur_sig_mult
            
    # apply best mult
    if verbose: print("best multiplier:", best_sig_mult)
    for p in fgn_model.modules():
        if isinstance(p, fgnl.FGN_layer):
            p.sigmas = torch.nn.Parameter(p.sigmas*best_sig_mult)
            
    return None

In [8]:
# create model to test:
# Initialize the classic network
hidden_l_nums = [16,16,16]
drop_p= 0.0
fgn_model = fgnl.Feedforward_FGN_net(in_feats=28*28, out_feats=10, hidden_l_nums=hidden_l_nums, drop_p=drop_p)

In [9]:
# train it
# loss functions for the classic net
lmbda_l2 = (4.0*0.1/len(mnist_train_loader.dataset))
print(lmbda_l2)
lmbda_sigs = 1.01*lmbda_l2
print(lmbda_sigs)
loss_func = fgnl.def_fgn_cross_ent_loss(lmbda_l2,lmbda_sigs)

fgn_optimizer = optim.RMSprop(filter(lambda p: p.requires_grad, fgn_model.parameters()),momentum=0.5)

epochs=3

fgn_train_res = fgnh.train(fgn_model, mnist_train_loader,
                           loss_func, fgn_optimizer, epochs, 
                           pred_func=fgnh.cross_ent_pred_accuracy, save_hist=2, verbose=True)

6.66666666667e-06
6.73333333333e-06
Epoch 0 Train set - Average loss: 1.0402, Accuracy: 41953/60000 (70%)
Epoch 1 Train set - Average loss: 0.3339, Accuracy: 54701/60000 (91%)
Epoch 2 Train set - Average loss: 0.2543, Accuracy: 55954/60000 (93%)


In [10]:
# make sure performance before isnt crap
fgn_test_res_pre = fgnh.test(fgn_model, mnist_train_loader, 
                     loss_func, verbose=True, 
                     pred_func=fgnh.cross_ent_pred_accuracy)

Test set - Average loss: 0.3423, Accuracy: 54457/60000 (91%)


In [11]:
adjust_sigma_pred_func(fgn_model, mnist_train_loader, fgnh.cross_ent_pred_accuracy, True)

Test set - Average loss: 0.0000, Accuracy: 54457/60000 (91%)
0 testing 2.0
Test set - Average loss: 0.0000, Accuracy: 5959/60000 (10%)
0 testing 0.5
Test set - Average loss: 0.0000, Accuracy: 5851/60000 (10%)
0 testing 1.25
Test set - Average loss: 0.0000, Accuracy: 5933/60000 (10%)
1 testing 0.875
Test set - Average loss: 0.0000, Accuracy: 5958/60000 (10%)
2 testing 1.0625
Test set - Average loss: 0.0000, Accuracy: 9350/60000 (16%)
3 testing 0.96875
Test set - Average loss: 0.0000, Accuracy: 39387/60000 (66%)
4 testing 1.015625
Test set - Average loss: 0.0000, Accuracy: 49501/60000 (83%)
5 testing 0.9921875
Test set - Average loss: 0.0000, Accuracy: 53030/60000 (88%)
6 testing 1.00390625
Test set - Average loss: 0.0000, Accuracy: 54407/60000 (91%)
new best during dicho
7 testing 1.009765625
Test set - Average loss: 0.0000, Accuracy: 53109/60000 (89%)
8 testing 1.0068359375
Test set - Average loss: 0.0000, Accuracy: 53970/60000 (90%)
9 testing 1.00537109375
Test set - Average loss: 0.0

In [12]:
# make sure performance after isnt crap
fgn_test_res_post = fgnh.test(fgn_model, mnist_train_loader, 
                     loss_func, verbose=True, 
                     pred_func=fgnh.cross_ent_pred_accuracy)

Test set - Average loss: 0.3577, Accuracy: 54286/60000 (90%)
