In [1]:
import torch as t, torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision as tv, torchvision.transforms as tr
import os
import sys
import argparse
import numpy as np
import wideresnet
import pdb
import json
from matplotlib import pyplot as plt
from numpy import genfromtxt
import yaml
from pathlib import Path
from zntrack import ZnTrackProject, Node, config, dvc, zn
from tqdm import tqdm
import pandas as pd
from zntrack.metadata import TimeIt
#from jemsharedclasses import Base, JEMUtils


config.nb_name = "ZnJEMProject.ipynb"
project = ZnTrackProject()

In [2]:

# Setup parameters
# defaults for paper
# --lr .0001 --dataset cifar10 --optimizer adam --p_x_weight 1.0 --p_y_given_x_weight 1.0 
# --p_x_y_weight 0.0 --sigma .03 --width 10 --depth 28 --save_dir /YOUR/SAVE/DIR 
# --plot_uncond --warmup_iters 1000
#

#This class can be reused by passing a name attribute to its constructor, and then referencing
# that name in the @Node class dependencies
@Node()
class train_args():
    # define params
    # this will write them to params.yaml
    experiment = dvc.params()
    dataset = dvc.params()
    n_classes = dvc.params()
    n_steps = dvc.params()
    width = dvc.params()
    depth = dvc.params()
    sigma = dvc.params()
    data_root = dvc.params()
    seed = dvc.params()
    lr = dvc.params()
    clf_only = dvc.params()
    labels_per_class = dvc.params()
    batch_size = dvc.params()
    n_epochs = dvc.params()
    dropout_rate = dvc.params()
    weight_decay = dvc.params()
    norm = dvc.params()
    save_dir = dvc.params()
    ckpt_every = dvc.params()
    eval_every = dvc.params()
    print_every = dvc.params()
    load_path = dvc.params()
    print_to_log = dvc.params()
    n_valid = dvc.params()
    
    result = zn.metrics()
    
    def __call__(self, param_dict):
        # set defaults
        self.experiment = "energy_model"
        self.dataset = "cifar10"
        self.n_classes = 10
        self.n_steps = 20
        self.width = 10 # wide-resnet widen_factor
        self.depth = 28  # wide-resnet depth
        self.sigma = .03 # image transformation
        self.data_root = "./dataset" 
        self.seed = JEMUtils.get_parameter("seed", 1)
        # optimization
        self.lr = 1e-4
        self.clf_only = False #action="store_true", help="If set, then only train the classifier")
        self.labels_per_class = -1# help="number of labeled examples per class, if zero then use all labels")
        self.batch_size = 64
        self.n_epochs = JEMUtils.get_parameter("epochs", 10)
        # regularization
        self.dropout_rate = 0.0
        self.sigma = 3e-2 # help="stddev of gaussian noise to add to input, .03 works but .1 is more stable")
        self.weight_decay = 0.0
        # network
        self.norm = None # choices=[None, "norm", "batch", "instance", "layer", "act"], help="norm to add to weights, none works fine")
        # logging + evaluation
        self.save_dir = './experiment'
        self.ckpt_every = 1 # help="Epochs between checkpoint save")
        self.eval_every = 1 # help="Epochs between evaluation")
        self.print_every = 100 # help="Iterations between print")
        self.load_path = None # path for checkpoint to load
        self.print_to_log = False #", action="store_true", help="If true, directs std-out to log file")
        self.n_valid = 5000 # number of validation images
        
        # set from inline dict
        for key in param_dict:
            #print(key, '->', param_dict[key])
            setattr(self, key, param_dict[key])
            
    def run(self):
        self.result = self.experiment
    

Submit issues to https://github.com/zincware/ZnTrack.


[NbConvertApp] Converting notebook ZnJEMProject.ipynb to script
[NbConvertApp] Writing 40669 bytes to ZnJEMProject.py


In [3]:
#this is a base for the Node compute functions, to split off the actual work from the dvc control flow
class Base:
    def compute(self, inp):
        raise NotImplementedError

In [4]:
# get random subset of data
class DataSubset(Dataset):
    def __init__(self, base_dataset, inds=None, size=-1):
        self.base_dataset = base_dataset
        if inds is None:
            inds = np.random.choice(list(range(len(base_dataset))), size, replace=False)
        self.inds = inds

    def __getitem__(self, index):
        base_ind = self.inds[index]
        return self.base_dataset[base_ind]

    def __len__(self):
        return len(self.inds)

In [5]:
# setup Wide_ResNet
# Uses The Google Research Authors, file wideresnet.py
class FTrain(nn.Module):
    def __init__(self, depth=28, width=2, norm=None, dropout_rate=0.0, n_classes=10):
        super(FTrain, self).__init__()
        self.f = wideresnet.Wide_ResNet(depth, width, norm=norm, dropout_rate=dropout_rate)
        self.energy_output = nn.Linear(self.f.last_dim, 1)
        self.class_output = nn.Linear(self.f.last_dim, n_classes)

    def classify(self, x):
        penult_z = self.f(x)
        return self.class_output(penult_z).squeeze()

In [6]:
class JEMUtils:
    
    
    # various utilities
    @staticmethod
    def cycle(loader):
        while True:
            for data in loader:
                yield data
                
                
    # calculate loss and accuracy for periodic printout
    
    @staticmethod
    def eval_classification(f, dload, device):
        corrects, losses = [], []
        for x_p_d, y_p_d in dload:
            x_p_d, y_p_d = x_p_d.to(device), y_p_d.to(device)
            logits = f.classify(x_p_d)
            loss = nn.CrossEntropyLoss(reduce=False)(logits, y_p_d).cpu().numpy()
            losses.extend(loss)
            correct = (logits.max(1)[1] == y_p_d).float().cpu().numpy()
            corrects.extend(correct)
        loss = np.mean(losses)
        correct = np.mean(corrects)
        return correct, loss
    
    
    # save checkpoint data
    
    @staticmethod
    def checkpoint(f, opt, epoch_no, tag, args, device):
        f.cpu()
        ckpt_dict = {
            "model_state_dict": f.state_dict(),
            'optimizer_state_dict': opt.state_dict(),
            'epoch': epoch_no,
            #"replay_buffer": buffer
        }
        t.save(ckpt_dict, os.path.join(os.path.join(args.save_dir, args.experiment), tag))
        f.to(device)
        

    @staticmethod
    def get_parameter(name, default):
        params = yaml.safe_load(open("params.yaml"))
        to_search = params
        for part in name.split("."):
            result = to_search.get(part)
            if result == None:
                return default
            to_search = result
        return to_search
    
    
    @staticmethod
    def get_data(args):
        im_sz = 32
        
        
        #global transform_train
        # the GaussianBlur is roughly equivalent to the lambda functions here
        # but the lambda functions aren't serializable for multi-processing
        # torchvision.transforms documentation state to not use lambda functions as well
        transform_train = tr.Compose(
            [tr.Pad(4, padding_mode="reflect"),
             tr.RandomCrop(im_sz),
             tr.RandomHorizontalFlip(),
             tr.ToTensor(),
             tr.Normalize((.5, .5, .5), (.5, .5, .5)),
             tr.GaussianBlur(kernel_size=(5, 5), sigma=(args.sigma, args.sigma * 2))]
             #lambda x: x + args.sigma * t.randn_like(x)]
        )
        #global transform_test
        transform_test = tr.Compose(
            [tr.ToTensor(),
             tr.Normalize((.5, .5, .5), (.5, .5, .5)),
             tr.GaussianBlur(kernel_size=(5, 5), sigma=(args.sigma, args.sigma * 2))]
             #lambda x: x + args.sigma * t.randn_like(x)]
        )
        
        
        def dataset_fn(train, transform):
            return tv.datasets.CIFAR10(root='./dataset', transform=transform, download=True, train=train)

        # get all training inds
        #global full_train
        full_train = dataset_fn(train=True, transform=transform_train)
        all_inds = list(range(len(full_train)))
        # set seed
        np.random.seed(args.seed)
        # shuffle
        np.random.shuffle(all_inds)
        # seperate out validation set
        if args.n_valid is not None:
            valid_inds, train_inds = all_inds[:args.n_valid], all_inds[args.n_valid:]
        else:
            valid_inds, train_inds = [], all_inds
        train_inds = np.array(train_inds)
        train_labeled_inds = []
        other_inds = []
        train_labels = np.array([full_train[ind][1] for ind in train_inds])
        if args.labels_per_class > 0:
            for i in range(args.n_classes):
                print(i)
                train_labeled_inds.extend(train_inds[train_labels == i][:args.labels_per_class])
                other_inds.extend(train_inds[train_labels == i][args.labels_per_class:])
        else:
            train_labeled_inds = train_inds

        dset_train = DataSubset(
            dataset_fn(train=True, transform=transform_train),
            inds=train_inds)
        dset_train_labeled = DataSubset(
            dataset_fn(train=True, transform=transform_train),
            inds=train_labeled_inds)
        dset_valid = DataSubset(
            dataset_fn(train=True, transform=transform_test),
            inds=valid_inds)
        
        dload_train = DataLoader(dset_train, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True)
        dload_train_labeled = DataLoader(dset_train_labeled, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True)
        dload_train_labeled = JEMUtils.cycle(dload_train_labeled)
        dset_test = dataset_fn(train=False, transform=transform_test)
        dload_valid = DataLoader(dset_valid, batch_size=100, shuffle=False, num_workers=4, drop_last=False)
        dload_test = DataLoader(dset_test, batch_size=100, shuffle=False, num_workers=4, drop_last=False)
        return dload_train, dload_train_labeled, dload_valid, dload_test
    

In [7]:
# basic training from train.ipynb
class Trainer(Base):
    
    def compute(self, inp):
        args = inp

        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)
        
        if not os.path.exists(os.path.join(args.save_dir, args.experiment)):
            os.makedirs(os.path.join(args.save_dir, args.experiment))

        if args.print_to_log:
            sys.stdout = open(f'{os.path.join(args.save_dir, args.experiment)}/log.txt', 'w')

        t.manual_seed(args.seed)
        if t.cuda.is_available():
            t.cuda.manual_seed_all(args.seed)

        # datasets
        
        dload_train, dload_train_labeled, dload_valid, dload_test = JEMUtils.get_data(args)

        device = t.device('cuda' if t.cuda.is_available() else 'cpu')

        # setup Wide_ResNet
        f = FTrain(args.depth, args.width, args.norm, dropout_rate=args.dropout_rate, n_classes=args.n_classes)
    
        # push to GPU
        f = f.to(device)

        # optimizer
        params = f.class_output.parameters() if args.clf_only else f.parameters()
        optim = t.optim.Adam(params, lr=args.lr, betas=[.9, .999], weight_decay=args.weight_decay)

        # epoch_start
        epoch_start = 0
    
        # load checkpoint?
        if args.load_path and os.path.exists(os.path.join(os.path.join(args.load_path, args.experiment), 'ckpt_9.pt')):
            print(f"loading model from {os.path.join(args.load_path, args.experiment)}")
            #ckpt_dict = t.load(os.path.join(args.load_path, args.experiment))
            ckpt_dict = t.load(os.path.join(os.path.join(args.load_path, args.experiment), 'ckpt_9.pt'))
            f.load_state_dict(ckpt_dict["model_state_dict"])
            optim.load_state_dict(ckpt_dict['optimizer_state_dict'])
            epoch_start = ckpt_dict['epoch']

        # push to GPU
        f = f.to(device)
    
        # Show train set loss/accuracy after reload
        f.eval()
        with t.no_grad():
            correct, loss = JEMUtils.eval_classification(f, dload_train, device)
            print("Epoch {}: Train Loss {}, Train Acc {}".format(epoch_start, loss, correct))
        f.train()

        best_valid_acc = 0.0
        cur_iter = 0
    
        # loop over epochs
        scores = {}
        for epoch in range(epoch_start, epoch_start + args.n_epochs):
            # loop over data in batches
            # x_p_d sample from dataset
            for i, (x_p_d, _) in enumerate(dload_train): #tqdm(enumerate(dload_train)):

                #print("x_p_d_shape",x_p_d.shape)
                x_p_d = x_p_d.to(device)
                x_lab, y_lab = dload_train_labeled.__next__()
                x_lab, y_lab = x_lab.to(device), y_lab.to(device)

                # initialize loss
                L = 0.
            
                # normal cross entropy loss function
                # maximize log p(y | x)
                logits = f.classify(x_lab)
                l_p_y_given_x = nn.CrossEntropyLoss()(logits, y_lab)
                if cur_iter % args.print_every == 0:
                    acc = (logits.max(1)[1] == y_lab).float().mean()
                    print('P(y|x) {}:{:>d} loss={:>14.9f}, acc={:>14.9f}'.format(epoch,
                                                                             cur_iter,
                                                                             l_p_y_given_x.item(),
                                                                             acc.item()))
                # add to loss
                L += l_p_y_given_x

                # break if the loss diverged
                if L.abs().item() > 1e8:
                    print("Divergwence error")
                    1/0

                # Optimize network using our loss function L
                optim.zero_grad()
                L.backward()
                optim.step()
                cur_iter += 1

            # do checkpointing
            if epoch % args.ckpt_every == 0:
                JEMUtils.checkpoint(f, optim, epoch, f'ckpt_{epoch}.pt', args, device)

            # Print performance assesment 
            if epoch % args.eval_every == 0:
                f.eval()
                with t.no_grad():
                    # train set
                    correct, loss = JEMUtils.eval_classification(f, dload_train, device)
                    scores["train"] = {"acc:": float(correct), "loss": float(loss)}
                    print("Epoch {}: Train Loss {}, Train Acc {}".format(epoch, loss, correct))

                    # test set
                    correct, loss = JEMUtils.eval_classification(f, dload_test, device)
                    scores["test"] = {"acc:": float(correct), "loss": float(loss)}
                    print("Epoch {}: Test Loss {}, Test Acc {}".format(epoch, loss, correct))

                    # validation set
                    correct, loss = JEMUtils.eval_classification(f, dload_valid, device)
                    scores["validation"] = {"acc:": float(correct), "loss": float(loss)}
                    print("Epoch {}: Valid Loss {}, Valid Acc {}".format(epoch, loss, correct))

                f.train()

            # do "last" checkpoint
            JEMUtils.checkpoint(f, optim, epoch, "last_ckpt.pt", args, device)

        # write stats
        with open(os.path.join(args.save_dir, args.experiment) + '_scores.json', 'w') as outfile:
            json.dump(scores, outfile)
            
        return scores

In [8]:
#Do the operations from train.ipynb and track in dvc
#dependency is train_args stage with default name
#outs is the path to the last_ckpt.pt model file, which serves as a dependency to the evaluation stage

@Node()
class XEntropyAugmented:
    
    args: train_args = dvc.deps(train_args(load=True))
    trainer: Base = zn.Method()
    result = zn.metrics()
    model: Path = dvc.outs()  # is making the model file an outs causing it to delete the file?
    
            
    def __call__(self, operation):
        self.trainer = operation
        self.model = Path(os.path.join(os.path.join(self.args.save_dir, self.args.experiment), "last_ckpt.pt"))
    
    @TimeIt
    def run(self):
        #self.result = self.args.result
        self.result = self.trainer.compute(self.args)
        
    
        

Submit issues to https://github.com/zincware/ZnTrack.


[NbConvertApp] Converting notebook ZnJEMProject.ipynb to script
[NbConvertApp] Writing 40669 bytes to ZnJEMProject.py


In [9]:
# add/change parameters for this stage
inline_parms = {"lr": .0001, "experiment": 'x-entropy_augmented', "load_path": './experiment'}

#declare the train_args stage and pass the modified/new params
params = train_args()
params(param_dict=inline_parms)

#declare the compute class for the XEntropyAugmented stage
trainer = Trainer()

#declare stage and pass the compute class
runner = XEntropyAugmented()
runner(operation=trainer)

2022-01-10 13:18:57,782 (INFO): Modifying stage 'train_args' in 'dvc.yaml'

To track the changes with git, run:

    git add dvc.yaml

To enable auto staging, run:

	dvc config core.autostage true

2022-01-10 13:18:57,793 (ERROR): Can not convert args!
2022-01-10 13:18:57,794 (ERROR): Can not convert kwargs!
2022-01-10 13:18:58,663 (INFO): Modifying stage 'XEntropyAugmented' in 'dvc.yaml'

To track the changes with git, run:

    git add dvc.yaml

To enable auto staging, run:

	dvc config core.autostage true



In [10]:
#stage MaxEntropyL1, originally from train_max_entropy_L1.ipynb
# dependency is train_args, named "train_argsL1"
#outs is the path to the last_ckpt.pt model file, which serves as a dependency to the evaluation stage

@Node()
class MaxEntropyL1:
    args: train_args = dvc.deps(train_args(load=True, name="train_argsL1"))
    trainer: Base = zn.Method()
    result = zn.metrics()
    model: Path = dvc.outs()
            
    def __call__(self, operation):
        self.trainer = operation
        self.model = Path(os.path.join(os.path.join(self.args.save_dir, self.args.experiment), "last_ckpt.pt"))
    
    @TimeIt
    def run(self):
        #self.result = self.args.result
        self.result = self.trainer.compute(self.args)
        

Submit issues to https://github.com/zincware/ZnTrack.


[NbConvertApp] Converting notebook ZnJEMProject.ipynb to script
[NbConvertApp] Writing 40669 bytes to ZnJEMProject.py


In [11]:
#trainer class for MaxEntropyL1 stage's compute function

class TrainerL1(Base):
    
    def compute(self, inp):
        args = inp
        
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)
        
        if not os.path.exists(os.path.join(args.save_dir, args.experiment)):
            os.makedirs(os.path.join(args.save_dir, args.experiment))

        if args.print_to_log:
            sys.stdout = open(f'{os.path.join(args.save_dir, args.experiment)}/log.txt', 'w')

        t.manual_seed(args.seed)
        if t.cuda.is_available():
            t.cuda.manual_seed_all(args.seed)

        # datasets
        dload_train, dload_train_labeled, dload_valid, dload_test = JEMUtils.get_data(args)

        device = t.device('cuda' if t.cuda.is_available() else 'cpu')

        # setup Wide_ResNet
        f = FTrain(args.depth, args.width, args.norm, dropout_rate=args.dropout_rate, n_classes=args.n_classes)
    
        # push to GPU
        f = f.to(device)

        # optimizer
        params = f.class_output.parameters() if args.clf_only else f.parameters()
        optim = t.optim.Adam(params, lr=args.lr, betas=[.9, .999], weight_decay=args.weight_decay)

        # epoch_start
        epoch_start = 0
    
        # load checkpoint?
        if args.load_path and os.path.exists(os.path.join(os.path.join(args.load_path, args.experiment), 'ckpt_9.pt')):
            print(f"loading model from {os.path.join(args.load_path, args.experiment)}")
            #ckpt_dict = t.load(os.path.join(args.load_path, args.experiment))
            ckpt_dict = t.load(os.path.join(os.path.join(args.load_path, args.experiment), 'ckpt_9.pt'))
            f.load_state_dict(ckpt_dict["model_state_dict"])
            optim.load_state_dict(ckpt_dict['optimizer_state_dict'])
            epoch_start = ckpt_dict['epoch']

        # push to GPU
        f = f.to(device)
    
        # Show train set loss/accuracy after reload
        f.eval()
        with t.no_grad():
            correct, loss = JEMUtils.eval_classification(f, dload_train, device)
            print("Epoch {}: Train Loss {}, Train Acc {}".format(epoch_start, loss, correct))
        f.train()

        best_valid_acc = 0.0
        cur_iter = 0
        # loop over epochs
        scores = {}
        for epoch in range(epoch_start, epoch_start + args.n_epochs):
            # loop over data in batches
            # x_p_d sample from dataset
            for i, (x_p_d, _) in enumerate(dload_train): #tqdm(enumerate(dload_train)):

                #print("x_p_d_shape",x_p_d.shape)
                x_p_d = x_p_d.to(device)
                x_lab, y_lab = dload_train_labeled.__next__()
                x_lab, y_lab = x_lab.to(device), y_lab.to(device)

                # initialize loss
                L = 0.
            
                # get logits for calculations
                logits = f.classify(x_lab)

                ####################################################
                # Maximize entropy by assuming equal probabilities #
                ####################################################
                energy = logits.logsumexp(dim=1, keepdim=False)
            
                e_mean = t.mean(energy)
                #print('Energy shape',energy.size())
            
                energy_loss = t.sum(t.abs(e_mean - energy))
            
                L += energy_loss
            
                ######################################
                # normal cross entropy loss function #
                ######################################
                # maximize log p(y | x)
                l_p_y_given_x = nn.CrossEntropyLoss()(logits, y_lab)
                if cur_iter % args.print_every == 0:
                    acc = (logits.max(1)[1] == y_lab).float().mean()
                    print('P(y|x) {}:{:>d} loss={:>14.9f}, acc={:>14.9f}'.format(epoch,
                                                                             cur_iter,
                                                                             l_p_y_given_x.item(),
                                                                             acc.item()))
                # add to loss
                L += l_p_y_given_x

                # break if the loss diverged
                if L.abs().item() > 1e8:
                    print("Divergwence error")
                    1/0

                # Optimize network using our loss function L
                optim.zero_grad()
                L.backward()
                optim.step()
                cur_iter += 1

            # do checkpointing
            if epoch % args.ckpt_every == 0:
                JEMUtils.checkpoint(f, optim, epoch, f'ckpt_{epoch}.pt', args, device)
            
            # Print performance assesment 
            if epoch % args.eval_every == 0:
                f.eval()
                with t.no_grad():
                    # train set
                    correct, loss = JEMUtils.eval_classification(f, dload_train, device)
                    scores["train"] = {"acc:": float(correct), "loss": float(loss)}
                    print("Epoch {}: Train Loss {}, Train Acc {}".format(epoch, loss, correct))

                    # test set
                    correct, loss = JEMUtils.eval_classification(f, dload_test, device)
                    scores["test"] = {"acc:": float(correct), "loss": float(loss)}
                    print("Epoch {}: Test Loss {}, Test Acc {}".format(epoch, loss, correct))

                    # validation set
                    correct, loss = JEMUtils.eval_classification(f, dload_valid, device)
                    scores["validation"] = {"acc:": float(correct), "loss": float(loss)}
                    print("Epoch {}: Valid Loss {}, Valid Acc {}".format(epoch, loss, correct))

                f.train()

            # do "last" checkpoint
            JEMUtils.checkpoint(f, optim, epoch, "last_ckpt.pt", args, device)

        # write stats
        with open(os.path.join(args.save_dir, args.experiment) + '_scores.json', 'w') as outfile:
            json.dump(scores, outfile)
            
        return scores

In [12]:
inline_parms1 = {"lr": .0001, "experiment": 'max-entropy-L1_augmented', "load_path": './experiment'} 

#declare the train_args with name train_argsL1, which becomes its stage name in dvc.yaml
params1 = train_args(name="train_argsL1")
params1(param_dict=inline_parms1)

trainer1 = TrainerL1()

runner1 = MaxEntropyL1()
runner1(operation=trainer1)

2022-01-10 13:19:09,742 (INFO): Modifying stage 'train_argsL1' in 'dvc.yaml'

To track the changes with git, run:

    git add dvc.yaml

To enable auto staging, run:

	dvc config core.autostage true

2022-01-10 13:19:09,754 (ERROR): Can not convert args!
2022-01-10 13:19:09,755 (ERROR): Can not convert kwargs!
2022-01-10 13:19:10,572 (INFO): Modifying stage 'MaxEntropyL1' in 'dvc.yaml'

To track the changes with git, run:

    git add dvc.yaml

To enable auto staging, run:

	dvc config core.autostage true



In [13]:
#stage MaxEntropyL2 originally from train_max_entropy_L2.ipynb
#dependency is train_args with name train_argsL2

@Node()
class MaxEntropyL2:
    args: train_args = dvc.deps(train_args(load=True, name="train_argsL2"))
    trainer: Base = zn.Method()
    result = zn.metrics()
    model: Path = dvc.outs()
            
    def __call__(self, operation):
        self.trainer = operation
        self.model = Path(os.path.join(os.path.join(self.args.save_dir, self.args.experiment), "last_ckpt.pt"))
    
    @TimeIt
    def run(self):
        #self.result = self.args.result
        self.result = self.trainer.compute(self.args)
        

Submit issues to https://github.com/zincware/ZnTrack.


[NbConvertApp] Converting notebook ZnJEMProject.ipynb to script
[NbConvertApp] Writing 40669 bytes to ZnJEMProject.py


In [14]:
#compute class for the above stage

class TrainerL2(Base):
    
    def compute(self, inp):
        args = inp
        
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)
        
        if not os.path.exists(os.path.join(args.save_dir, args.experiment)):
            os.makedirs(os.path.join(args.save_dir, args.experiment))

        if args.print_to_log:
            sys.stdout = open(f'{os.path.join(args.save_dir, args.experiment)}/log.txt', 'w')

        t.manual_seed(args.seed)
        if t.cuda.is_available():
            t.cuda.manual_seed_all(args.seed)

        # datasets
        dload_train, dload_train_labeled, dload_valid, dload_test = JEMUtils.get_data(args)

        device = t.device('cuda' if t.cuda.is_available() else 'cpu')

        # setup Wide_ResNet
        f = FTrain(args.depth, args.width, args.norm, dropout_rate=args.dropout_rate, n_classes=args.n_classes)
    
        # push to GPU
        f = f.to(device)

        # optimizer
        params = f.class_output.parameters() if args.clf_only else f.parameters()
        optim = t.optim.Adam(params, lr=args.lr, betas=[.9, .999], weight_decay=args.weight_decay)

        # epoch_start
        epoch_start = 0
    
        # load checkpoint?
        if args.load_path and os.path.exists(os.path.join(os.path.join(args.load_path, args.experiment), 'ckpt_9.pt')):
            print(f"loading model from {os.path.join(args.load_path, args.experiment)}")
            #ckpt_dict = t.load(os.path.join(args.load_path, args.experiment))
            ckpt_dict = t.load(os.path.join(os.path.join(args.load_path, args.experiment), 'ckpt_9.pt'))
            f.load_state_dict(ckpt_dict["model_state_dict"])
            optim.load_state_dict(ckpt_dict['optimizer_state_dict'])
            epoch_start = ckpt_dict['epoch']

        # push to GPU
        f = f.to(device)
    
        # Show train set loss/accuracy after reload
        f.eval()
        with t.no_grad():
            correct, loss = JEMUtils.eval_classification(f, dload_train, device)
            print("Epoch {}: Train Loss {}, Train Acc {}".format(epoch_start, loss, correct))
        f.train()

        best_valid_acc = 0.0
        cur_iter = 0
    
        # loop over epochs
        scores = {}
        for epoch in range(epoch_start, epoch_start + args.n_epochs):
            # loop over data in batches
            # x_p_d sample from dataset
            for i, (x_p_d, _) in enumerate(dload_train): #tqdm(enumerate(dload_train)):

                #print("x_p_d_shape",x_p_d.shape)
                x_p_d = x_p_d.to(device)
                x_lab, y_lab = dload_train_labeled.__next__()
                x_lab, y_lab = x_lab.to(device), y_lab.to(device)

                # initialize loss
                L = 0.
            
                # get logits for calculations
                logits = f.classify(x_lab)

                ####################################################
                # Maximize entropy by assuming equal probabilities #
                ####################################################
                energy = logits.logsumexp(dim=1, keepdim=False)
            
                e_mean = t.mean(energy)
                #print('Energy shape',energy.size())
            
                energy_loss = t.sum((e_mean - energy)**2)
            
                L += energy_loss
            
                ######################################
                # normal cross entropy loss function #
                ######################################
                # maximize log p(y | x)
                l_p_y_given_x = nn.CrossEntropyLoss()(logits, y_lab)
                if cur_iter % args.print_every == 0:
                    acc = (logits.max(1)[1] == y_lab).float().mean()
                    print('P(y|x) {}:{:>d} loss={:>14.9f}, acc={:>14.9f}'.format(epoch,
                                                                             cur_iter,
                                                                             l_p_y_given_x.item(),
                                                                             acc.item()))
                # add to loss
                L += l_p_y_given_x

                # break if the loss diverged
                if L.abs().item() > 1e8:
                    print("Divergwence error")
                    1/0

                # Optimize network using our loss function L
                optim.zero_grad()
                L.backward()
                optim.step()
                cur_iter += 1

            # do checkpointing
            if epoch % args.ckpt_every == 0:
                JEMUtils.checkpoint(f, optim, epoch, f'ckpt_{epoch}.pt', args, device)

            
            # Print performance assesment 
            if epoch % args.eval_every == 0:
                f.eval()
                with t.no_grad():
                    # train set
                    correct, loss = JEMUtils.eval_classification(f, dload_train, device)
                    scores["train"] = {"acc:": float(correct), "loss": float(loss)}
                    print("Epoch {}: Train Loss {}, Train Acc {}".format(epoch, loss, correct))

                    # test set
                    correct, loss = JEMUtils.eval_classification(f, dload_test, device)
                    scores["test"] = {"acc:": float(correct), "loss": float(loss)}
                    print("Epoch {}: Test Loss {}, Test Acc {}".format(epoch, loss, correct))

                    # validation set
                    correct, loss = JEMUtils.eval_classification(f, dload_valid, device)
                    scores["validation"] = {"acc:": float(correct), "loss": float(loss)}
                    print("Epoch {}: Valid Loss {}, Valid Acc {}".format(epoch, loss, correct))

                f.train()

            # do "last" checkpoint
            JEMUtils.checkpoint(f, optim, epoch, "last_ckpt.pt", args, device)

        # write stats
        with open(os.path.join(args.save_dir, args.experiment) + '_scores.json', 'w') as outfile:
            json.dump(scores, outfile)
            
        return scores

In [15]:
inline_parms2 = {"lr": .0001, "experiment": 'max-entropy-L2_augmented', "load_path": './experiment'} 

#declare the train_args with new name
params2 = train_args(name="train_argsL2")
params2(param_dict=inline_parms2)
trainer2 = TrainerL2()

runner2 = MaxEntropyL2()
runner2(operation=trainer2)

2022-01-10 13:19:23,256 (INFO): Modifying stage 'train_argsL2' in 'dvc.yaml'

To track the changes with git, run:

    git add dvc.yaml

To enable auto staging, run:

	dvc config core.autostage true

2022-01-10 13:19:23,267 (ERROR): Can not convert args!
2022-01-10 13:19:23,268 (ERROR): Can not convert kwargs!
2022-01-10 13:19:24,088 (INFO): Modifying stage 'MaxEntropyL2' in 'dvc.yaml'

To track the changes with git, run:

    git add dvc.yaml

To enable auto staging, run:

	dvc config core.autostage true



In [16]:
class F(nn.Module):
    def __init__(self, depth=28, width=2, norm=None):
        super(F, self).__init__()
        self.f = wideresnet.Wide_ResNet(depth, width, norm=norm)
        self.energy_output = nn.Linear(self.f.last_dim, 1)
        self.class_output = nn.Linear(self.f.last_dim, 10)

    def forward(self, x, y=None):
        penult_z = self.f(x)
        return self.energy_output(penult_z).squeeze()

    def classify(self, x):
        penult_z = self.f(x)
        return self.class_output(penult_z)

In [17]:
class CCF(F):
    def __init__(self, depth=28, width=2, norm=None):
        super(CCF, self).__init__(depth, width, norm=norm)

    def forward(self, x, y=None):
        logits = self.classify(x)
        if y is None:
            return logits.logsumexp(1)
        else:
            return t.gather(logits, 1, y[:, None])

In [18]:
#class to hold the parameters for the evaluate calibration stage

@Node()

# Setup parameters
class eval_args():
    
    experiment = dvc.params()
    dataset = dvc.params()
    n_steps = dvc.params()
    width = dvc.params()
    depth = dvc.params()
    sigma = dvc.params()
    data_root = dvc.params()
    seed = dvc.params()    
    norm = dvc.params()
    save_dir = dvc.params()
    print_to_log = dvc.params()
    uncond = dvc.params()
    load_path = dvc.params()
    
    result = zn.metrics()
    
    #src = dvc.deps(Path("src", self.experiment))

    
    def __call__(self, param_dict):
        self.experiment = "energy_model"
        self.data_root = "./dataset" 
        self.dataset = "cifar_test" #, type=str, choices=["cifar_train", "cifar_test", "svhn_test", "svhn_train"], help="Dataset to use when running test_clf for classification accuracy")
        self.seed = JEMUtils.get_parameter("seed", 1)
        # regularization
        self.sigma = 3e-2
        # network
        self.norm = None #, choices=[None, "norm", "batch", "instance", "layer", "act"])
        # EBM specific
        self.n_steps = 20 # help="number of steps of SGLD per iteration, 100 works for short-run, 20 works for PCD")
        self.width = 10 # help="WRN width parameter")
        self.depth = 28 # help="WRN depth parameter")        
        self.uncond = False # "store_true" # help="If set, then the EBM is unconditional")
        # logging + evaluation
        self.save_dir = './experiment'
        self.print_to_log = False
        
        # set from inline dict
        for key in param_dict:
            #print(key, '->', param_dict[key])
            setattr(self, key, param_dict[key])
            
    def run(self):
        self.result = {"experiment": self.experiment}

Submit issues to https://github.com/zincware/ZnTrack.


[NbConvertApp] Converting notebook ZnJEMProject.ipynb to script
[NbConvertApp] Writing 40669 bytes to ZnJEMProject.py


In [19]:
# compute class for the evaluation stage

class Calibration(Base):
    
    def calibration(self, f, args, device):
        transform_test = tr.Compose(
            [tr.ToTensor(),
             tr.Normalize((.5, .5, .5), (.5, .5, .5)),
             lambda x: x + t.randn_like(x) * args.sigma]
        )

        def sample(x, n_steps=args.n_steps):
            x_k = t.autograd.Variable(x.clone(), requires_grad=True)
            # sgld
            for k in range(n_steps):
                f_prime = t.autograd.grad(f(x_k).sum(), [x_k], retain_graph=True)[0]
                x_k.data += f_prime + 1e-2 * t.randn_like(x_k)
            final_samples = x_k.detach()
            return final_samples

        if args.dataset == "cifar_train":
            dset = tv.datasets.CIFAR10(root=args.data_root, transform=transform_test, download=True, train=True)
        elif args.dataset == "cifar_test":
            dset = tv.datasets.CIFAR10(root=args.data_root, transform=transform_test, download=True, train=False)
        elif args.dataset == "svhn_train":
            dset = tv.datasets.SVHN(root=args.data_root, transform=transform_test, download=True, split="train")
        else:  # args.dataset == "svhn_test":
            dset = tv.datasets.SVHN(root=args.data_root, transform=transform_test, download=True, split="test")

        dload = DataLoader(dset, batch_size=1, shuffle=False, num_workers=4, drop_last=False)

        start=0.05
        step=.05
        num=20

        bins=np.arange(0,num)*step+start+ 1e-10
        bin_total = np.zeros(20)+1e-5
        bin_correct = np.zeros(20)

        #energies, corrects, losses, pys, preds = [], [], [], [], []
    
        for x_p_d, y_p_d in tqdm(dload):
            x_p_d, y_p_d = x_p_d.to(device), y_p_d.to(device)

            logits = f.classify(x_p_d).detach().cpu()#.numpy()

            py = nn.Softmax()(logits)[0].numpy()#(f.classify(x_p_d)).max(1)[0].detach().cpu().numpy()
        
            expected = y_p_d[0].detach().cpu().numpy()
        
            actual = logits.max(1)[1][0].numpy()
        
            #print(py[expected],expected,actual)
        
            inds = np.digitize(py[actual], bins)
            bin_total[inds] += 1
            if actual == expected:
                bin_correct[inds] += 1
            
        #
        accu = np.divide(bin_correct,bin_total)
        print("Bin data",np.sum(bin_total),accu,bins,bin_total)
    
        # calc ECE
        ECE = 0.0
        for i in range(20):
            #print("accu",accu[i],(i/20.0 + 0.025),bin_total[i])
            ECE += (float(bin_total[i]) / float(np.sum(bin_total))) * abs(accu[i] - (i/20.0 + 0.025))
        
        print("ECE", ECE)
    
        # save calibration  in a text file
            
        pd.DataFrame({'accuracy': accu, 'ECE': ECE}).to_csv(path_or_buf=os.path.join(args.save_dir, args.experiment) + "_calibration.csv", index_label="index")
        outputcsv = os.path.join(args.save_dir, args.experiment) + "_calibration.csv"
        return outputcsv
        
        
    def compute(self, inp):
        args = inp
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)
        
        if args.print_to_log:
            sys.stdout = open(f'{os.path.join(args.save_dir, args.experiment)}/log.txt', 'w')

        if not os.path.exists(os.path.join(args.save_dir, args.experiment)):
            os.makedirs(os.path.join(args.save_dir, args.experiment))

        t.manual_seed(args.seed)
        if t.cuda.is_available():
            t.cuda.manual_seed_all(args.seed)

        device = t.device('cuda' if t.cuda.is_available() else 'cpu')

        model_cls = F if args.uncond else CCF
        f = model_cls(args.depth, args.width, args.norm)
        print(f"loading model from {os.path.join(os.path.join(args.load_path, args.experiment), 'last_ckpt.pt')}")

        # load em up
        ckpt_dict = t.load(os.path.join(os.path.join(args.load_path, args.experiment), 'last_ckpt.pt'))
        f.load_state_dict(ckpt_dict["model_state_dict"])
        #replay_buffer = ckpt_dict["replay_buffer"]

        f = f.to(device)

        # do calibration
        resultfile = self.calibration(f, args, device)
        return resultfile

In [20]:
#stage EvaluateX
#multiple dependencies of eval_args with appropriate names for each stage they are testing
# and the model files from the other stages so the dependency graph flows in the correct order

@Node()
class EvaluateX:
    
    #from the DVC docs:  "Stage dependencies can be any file or directory"
    # so the eval_args stages have to output something in order to be used as deps here
    # so we use the metrics files like:  nodes/x-entropy_augmented/metrics_no_cache.json
    args = dvc.deps([eval_args(load=True, name="x-entropy_augmented"), 
                     eval_args(load=True, name="max-entropy-L1_augmented"), 
                     eval_args(load=True, name="max-entropy-L2_augmented")])
    #arg0: eval_args = dvc.deps(eval_args(name="x-entropy_augmented", load=True))
    
    models = dvc.deps([XEntropyAugmented(load=True), MaxEntropyL1(load=True), MaxEntropyL2(load=True)])
    
    calibration: Base = zn.Method()
   
    result: Path = dvc.outs()
    
    # add plots to dvc tracking
    # this would be better if the paths could be defined by the passed args, but can't see how to 
    plot0: Path = dvc.plots("./experiment/x-entropy_augmented_calibration.csv")
    plot1: Path = dvc.plots("./experiment/max-entropy-L1_augmented_calibration.csv")
    plot2: Path = dvc.plots("./experiment/max-entropy-L2_augmented_calibration.csv")
    
    #def __init__(self):
    #    self.result = Path('./experiment/joint_energy_models_scores.json')
        
            
    def __call__(self, operation):
        self.calibration = operation
        
    
    @TimeIt
    def run(self):
        #scores = {}
        for arg in self.args:
            self.calibration.compute(arg)
            #with open('./experiment/joint_energy_models_scores.json', 'a') as outfile:
            #    json.dump(scores, outfile)
            
            

Submit issues to https://github.com/zincware/ZnTrack.
Submit issues to https://github.com/zincware/ZnTrack.


[NbConvertApp] Converting notebook ZnJEMProject.ipynb to script


Submit issues to https://github.com/zincware/ZnTrack.


[NbConvertApp] Writing 40669 bytes to ZnJEMProject.py
[NbConvertApp] Converting notebook ZnJEMProject.ipynb to script


Submit issues to https://github.com/zincware/ZnTrack.


[NbConvertApp] Writing 40669 bytes to ZnJEMProject.py
[NbConvertApp] Converting notebook ZnJEMProject.ipynb to script


Submit issues to https://github.com/zincware/ZnTrack.


[NbConvertApp] Writing 40669 bytes to ZnJEMProject.py
[NbConvertApp] Converting notebook ZnJEMProject.ipynb to script


Submit issues to https://github.com/zincware/ZnTrack.


[NbConvertApp] Writing 40669 bytes to ZnJEMProject.py
[NbConvertApp] Converting notebook ZnJEMProject.ipynb to script


Submit issues to https://github.com/zincware/ZnTrack.


[NbConvertApp] Writing 40669 bytes to ZnJEMProject.py
[NbConvertApp] Converting notebook ZnJEMProject.ipynb to script
[NbConvertApp] Writing 40669 bytes to ZnJEMProject.py
[NbConvertApp] Converting notebook ZnJEMProject.ipynb to script
[NbConvertApp] Writing 40669 bytes to ZnJEMProject.py


In [21]:
#declare all the args for evaluation stage

inline_train_args = {"load_path": "./experiment", "experiment": "x-entropy_augmented"}
args_train = eval_args(name="x-entropy_augmented")
args_train(inline_train_args)

inline_L1_args = {"load_path": "./experiment", "experiment": "max-entropy-L1_augmented"}
args_L1 = eval_args(name="max-entropy-L1_augmented")
args_L1(inline_L1_args)

inline_L2_args = {"load_path": "./experiment", "experiment": "max-entropy-L2_augmented"}
args_L2 = eval_args(name="max-entropy-L2_augmented")
args_L2(inline_L2_args)

2022-01-10 13:19:53,787 (INFO): Modifying stage 'x-entropy_augmented' in 'dvc.yaml'

To track the changes with git, run:

    git add dvc.yaml

To enable auto staging, run:

	dvc config core.autostage true

2022-01-10 13:19:54,612 (INFO): Modifying stage 'max-entropy-L1_augmented' in 'dvc.yaml'

To track the changes with git, run:

    git add dvc.yaml

To enable auto staging, run:

	dvc config core.autostage true

2022-01-10 13:19:55,443 (INFO): Modifying stage 'max-entropy-L2_augmented' in 'dvc.yaml'

To track the changes with git, run:

    git add dvc.yaml

To enable auto staging, run:

	dvc config core.autostage true



In [22]:
#declare the calibration compute class and pass to the evaluation stage

cal = Calibration()
eva = EvaluateX()
eva(cal)

2022-01-10 13:20:00,391 (ERROR): Can not convert args!
2022-01-10 13:20:00,392 (ERROR): Can not convert kwargs!
2022-01-10 13:20:01,202 (INFO): Modifying stage 'EvaluateX' in 'dvc.yaml'

To track the changes with git, run:

    git add dvc.yaml

To enable auto staging, run:

	dvc config core.autostage true



In [117]:
#this will run the project in the notebook kernel
#see cluster-script.sh for alternately enqueueing to the CRC cluster.

project.repro()

Running stage 'max-entropy-L1_augmented':
> python3 -c "from src.eval_args import eval_args; eval_args(load=True, name='max-entropy-L1_augmented').run()" 
Updating lock file 'dvc.lock'

Running stage 'train_argsL1':
> python3 -c "from src.train_args import train_args; train_args(load=True, name='train_argsL1').run()" 
Updating lock file 'dvc.lock'

Running stage 'MaxEntropyL1':
> python3 -c "from src.MaxEntropyL1 import MaxEntropyL1; MaxEntropyL1(load=True, name='MaxEntropyL1').run()" 
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
| Wide-Resnet 28x10




KeyboardInterrupt: 