In [None]:
import os
import sys
from utils.data import create_dataloader, merge_dataloaders
from tree.tree import Node, grow_tree_from_root
import torch

class Args():
    def __init__(self):
        self = self
args = Args()


#main config
args.dataset_path = 'data' #Path for folder containing the dataset root folder
args.logs_path = 'experiment_logs_mnist_notebook' #Folder for saving all logs (replaces previous logs in the folder if any)
args.root_node_name = 'Z' #Name for the root node of the tree
args.device=0 #change to select another device
args.amp_enable = False #enables automatic mixed precision if available (executes a lot faster)


#architecture/model parameters
args.nf_g = 128 #Number of feature maps for generator
args.nf_d = 128 #Number of feature maps for discriminator/classifier
args.kernel_size_g = 4 #Size of kernel for generators
args.kernel_size_d = 5 #Size of kernel for discriminator/classifier
args.normalization_d = 'layer_norm' #Type of normalization layer used for discriminator/classifier
args.normalization_g = 'no_norm' #Type of normalization layer used for generator
args.architecture_d = 'cnn' #Specific architecture choice for for discriminator/classifier
args.architecture_g = 'cnn' #Specific architecture choice for for generator
args.img_channels = 1 #Number of channels used for intended types of images
args.latent_dim = 100 #Dimension of generator's latent space
args.batch_size_real = 100 #Minibatch size for real images
args.batch_size_gen = 100 #Minibatch size for generated images 
args.img_dim = 28 #Image dimensions
args.shared_features_across_ref = False #Shares encoder features among parallel refinement groups 


#training parameters
args.lr_d = 0.0001 #Learning rate for discriminator
args.lr_c = 0.00002 #Learning rate for classifier
args.lr_g = 0.0002 #Learning rate for generator
args.b1 = 0.5 # Learning rate for generator
args.b2 = 0.999 #Adam optimizer beta 2 parameter
args.noise_start = 1.0 #Start image noise intensity linearly decaying throughout each GAN/MGAN training
args.epochs_raw_split = 100 #Number of epochs for raw split training
args.epochs_refinement = 100 #Number of epochs for refinement training
args.diversity_parameter_g = 1.0 #Hyperparameter for weighting generators' classification loss component
args.no_refinements = 8 #Number of refinements in each split
args.no_splits = 9 #Number of splits during tree growth
args.collapse_check_epoch = 40 #Epoch after which to check for generation collapse
args.sample_interval = 10 #No. of epochs between printring/saving training logs
args.min_prob_mass_variation = 150 #If the total prob mass variation between two consecutive refinements is less than this number, to save up time, the next refinements are skipped for that node


torch.cuda.set_device(args.device)      

dataloader_train = create_dataloader(dataset='mnist', test=False, batch_size=args.batch_size_real,  path=args.dataset_path)
dataloader_test = create_dataloader(dataset='mnist', test=True, batch_size=args.batch_size_real,  path=args.dataset_path)
dataloader_train = merge_dataloaders(dataloader_train, dataloader_test)

In [None]:
root_node = Node(args.root_node_name, dataloader_train.sampler.weights, args.logs_path)
grow_tree_from_root(root_node, dataloader_train, args)
                    
