In [1]:
import numpy as np
import os
import sys
import PIL
import time
import copy
import scipy
import sklearn
import math

import torchvision
import torchvision.transforms as transforms
import torch

import torch.nn as nn
from torch.autograd import Variable

from utils.data_io import *

from opts import *

In [2]:
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        if m.weight is not None:
            nn.init.kaiming_uniform_(m.weight)
        if m.bias is not None:
            init.constant(m.bias, 0.0)
    if isinstance(m, nn.BatchNorm2d):
        if m.weight is not None:
            init.constant(m.weight, 1.0)
        if m.bias is not None:
            init.constant(m.bias, 0.0)

In [3]:
class swish(nn.Module):
    def __init__(self):
        super(swish, self).__init__()
    
    def forward(self, z):
        return z*nn.Relu(z)

In [4]:
class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-5):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(features))
        self.beta = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        # mean = x.mean(-1, keepdim=True)
        # std = x.std(-1, keepdim=True)
        # pdb.set_trace()
        # return self.gamma * (x - mean) / (std + self.eps) + self.beta
        shape = [-1] + [1] * (x.dim() - 1)
        mean = x.view(x.size(0), -1).mean(1).view(*shape)
        std = x.view(x.size(0), -1).std(1).view(*shape)
        # print ("to x {}".format(x.data.numpy().shape))
        # print ("to gamma {}".format(self.gamma.shape))
        # print ("to beta {}".format(self.beta.shape))
        # print ("to mean {}".format(mean.data.numpy().shape))
        # print ("to std {}".format(std.data.numpy().shape))

        y = (x - mean) / (std + self.eps)
        shape = [1, -1] + [1] * (x.dim() - 2)
        y = self.gamma.view(*shape) * y + self.beta.view(*shape)
        return y

In [5]:
#"same padding" padding = int((kernel_size-1)/2)
class ConvPadding(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
        super(ConvPadding, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.padding = int((kernel_size - 1)/2)
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.bias = bias
        
    def forward(self, x):
        out = nn.Conv2d(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, self.dilation, 
                        self.groups, self.bias)(x)
        
        return out

In [6]:
class ConvMeanPool(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(ConvMeanPool, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.bias = bias
        self.conv = ConvPadding(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.dilation, 
                        self.groups, self.bias)
        
    def forward(self, x):
        conv = ConvPadding(x)
        
        #::k ervery k element, s=range(20), s[::3]=[0, 3, 6, 9, 12, 15, 18]
        #L[x::y] means a slice of L where the x is the index to start from and y is the step size.
        #here we 
        out = ([conv[:,:,::2,::2] + conv[:,:,1::2,::2] + conv[:,:,::2,1::2] + conv[:,:,1::2,1::2]])/4.
        
        return out

In [7]:
class MeanPoolConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(MeanPoolConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.bias = bias
        self.conv = ConvPadding(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.dilation, 
                        self.groups, self.bias)
        
    def forward(self, x):
        
        x = ([x[:,:,::2,::2] + x[:,:,1::2,::2] + x[:,:,::2,1::2] + x[:,:,1::2,1::2]])/4.
        out = ConvPadding(x)
        
        return out

In [8]:
#Rearranges data from depth into blocks of spatial data. This is the reverse transformation of SpaceToDepth
#This operation is useful for resizing the activations between convolutions (but keeping all data), 
#e.g. instead of pooling. It is also useful for training purely convolutional models.
#N H W C
#Chunks of data of size block_size * block_size from depth 
#are rearranged into non-overlapping blocks of size block_size x block_size
#The width the output tensor is input_depth * block_size, whereas the height is input_height * block_siz
class DepthToSpace(nn.Module):
    def __init__(self, block_size):
        super(DepthToSpace, self).__init__()
        self.block_size = block_size
        
    def forward(self, x):
        x = x.permute(0, 2, 3, 1) #N H W C
        (batch_size, in_height, in_width, in_channels) = x.size()
        out_channels = int(in_channels / self.block_size / self.block_size)
        out_width = int(in_width * self.block_size)
        out_height = int(in_height * self.block_size)
        out = x.reshape(batch_size, input_height, input_width, self.block_size*self.block_size, out_channels)
        #N H W BLOCK*BLOCK C/BLOCK/BLOCK
        
        splits = out.split(self.block_size, dim=3)
        #BLOCK (N H W BLOCK C/BLOCK/BLOCK) list
        
        #If split_size_or_sections is an integer type, then tensor will be split into equally sized chunks (if possible). 
        #Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size.

        #If split_size_or_sections is a list, then tensor will be split into len(split_size_or_sections) chunks 
        #with sizes in dim according to split_size_or_sections.
        
        #split -> (N H W BLOCK C/BLOCK/BLOCK) -> reshape (N H W*BLOCK C/BLOCK/BLOCK)
        #stacks -> BLOCK (N H W*BLOCK C/BLOCK/BLOCK)
        stacks = [split.reshape(batch_size, in_height, out_width, out_channels) for split in splits]
        #stacks -> BLOCK N H W*BLOCK C/BLOCK/BLOCK
        stacks = torch.stack(stacks, 0)
        #stacks -> N BLOCK H W*BLOCK C/BLOCK/BLOCK
        stacks = stacks.transpose(0, 1)
        #stacks -> N H*BLOCK W*BLOCK C/BLOCK/BLOCK
        stacks = stacks.reshape(batch_size, out_height, out_width, out_channels)
        #out -> N C/BLOCK/BLOCK H*BLOCK W*BLOCK
        out = stacks.permute(0, 3, 1, 2)
        
        return out       

In [9]:
class UpsamplingConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
        super(UpsamplingConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.bias = bias
        self.depth_to_space = DepthToSpace(2)
        self.conv = ConvPadding(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.dilation, 
                        self.groups, self.bias)
        
    def forward(self, x):
        
        x = torch.cat((x, x, x, x), 1)
        #x -> N H W C*4
        x = x.permute(0, 2, 3, 1)
        #x -> N C H*2 W*2
        x = self.depth_to_space(x)
        out = ConvPadding(x)
        
        return out

In [10]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, resample=None):
        super(ResidualBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.bias = bias
        self.resample = resample
        self.bn1 = None
        self.bn2 = None
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        
        if resample == 'down':
            self.bn1 = nn.LayerNorm(in_channels)
            self.bn2 = nn.LayerNorm(in_channels)
        elif resample == 'up':
            self.bn1 = nn.BatchNorm2d(in_channels)
            self.bn2 = nn.BatchNorm2d(out_channels)
        elif resample == None:
            self.bn1 = nn.BatchNorm2d(out_channels)
            self.bn2 = nn.LayerNorm(out_channels)
        else:
            raise Exception('invalid resample value')

        if resample == 'down':
            
            self.conv_shortcut = MeanPoolConv(self.in_channels, self.out_channels, 1, self.stride, 
                                             self.padding, self.dilation, self.groups, self.bias)
            
            self.conv_1 = ConvPadding(self.in_channels, self.in_channels, self.kernel_size, self.stride, self.dilation, 
                        self.groups, bias=False)
            
            self.conv_2 = ConvMeanPool(self.in_channels, self.out_channels, self.kernel_size, self.stride, 
                                             self.padding, self.dilation, self.groups, self.bias)
        
        elif resample == 'up':
            
            self.conv_shortcut = UpSampleConv(self.in_channels, self.out_channels, 1, self.stride, 
                                             self.padding, self.dilation, self.groups, self.bias)
            
            self.conv_1 = UpSampleConv(self.in_channels, self.out_channels, self.kernel_size, self.stride, 
                                             self.padding, self.dilation, self.groups, bias=False)
            
            self.conv_2 = MyConvo2d(self.out_channels, self.out_channels, self.kernel_size, self.stride, self.dilation, 
                        self.groups, self.bias)
        
        elif resample == None:
            
            self.conv_shortcut = MyConvo2d(self.in_channels, self.out_channels, 1, self.stride, self.dilation, 
                        self.groups, self.bias)
            
            self.conv_1 = MyConvo2d(self.in_channels, self.in_channels, self.kernel_size, self.stride, self.dilation, 
                        self.groups, bias=False)
            
            self.conv_2 = MyConvo2d(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.dilation, 
                        self.groups, self.bias)
            
        else:
            
            raise Exception('invalid resample value')

    def forward(self, input):
        if self.input_dim == self.output_dim and self.resample == None:
            shortcut = input
        else:
            shortcut = self.conv_shortcut(input)

        output = input
        output = self.bn1(output)
        output = self.relu1(output)
        output = self.conv_1(output)
        output = self.bn2(output)
        output = self.relu2(output)
        output = self.conv_2(output)

        return shortcut + output

In [11]:
class Good_Discriminator(nn.Module):
    def __init__(self, in_channels=3, dim=64):
        super(Good_Discriminator, self).__init__()
        self.in_channels = in_channels
        self.dim = dim
        
        self.conv1 = ConvPadding(self.in_channels, dim, 3, 1)
        self.resblock1 = ResidualBlock(self.dim, 2*self.dim, 3, 'down')
        self.resblock2 = ResidualBlock(2*self.dim, 4*self.dim, 3, 'down')
        self.resblock3 = ResidualBlock(4*self.dim, 8*self.dim, 3, 'down')
        self.resblock4 = ResidualBlock(8*self.dim, 8*self.dim, 3, 'down')
        
        self.linear = nn.Linear(4*4*8*self.dim, 1)
        
    def forward(self, x, img_size=64):
        
        x = x.contiguous()
        #N H W C -> N C H W
        x = x.permute(0, 3, 1, 2)
       
        x = x.view(-1, self.in_channels, img_size, img_size)
        
        conv1 = self.conv1(x)
        #no activation, why???
        
        res1 = self.resblock1(conv1)
        res2 = self.resblock2(res1)
        res3 = self.resblock3(res2)
        res4 = self.resblock4(res4)
        
        out = res.view(-1, 4*4*8*self.dim)
        out = self.linear(out)
        out = out.view(-1)
        
        return out

In [12]:
class INN_Discriminator(nn.Module):
    def __init__(self, in_channels=3, dim=64, batch_norm=True):
        super(INN_Discriminator, self).__init__()
        self.in_channels = in_channels
        self.dim = dim
        self.batch_norm = batch_norm
        
        self.conv1 = ConvPadding(self.in_channels, dim/2, 3, 1)
        self.swish = swish()
        
        self.conv2 = ConvPadding(dim/2, dim, 3, 1)
        self.layernorm1 = nn.LayerNorm(dim)
        
        self.meanpoolconv1 = MeanPoolConv(dim, dim, 3, 1)
        self.layernorm2 = nn.LayerNorm(dim)
        
        self.conv3 = ConvPadding(dim, dim*2, 3, 1)
        self.layernorm3 = LayerNorm(dim*2)
        
        self.meanpoolconv2 = MeanPoolConv(dim*2, dim*2, 3, 1)
        self.layernorm4 = nn.LayerNorm(dim*2)
        
        self.conv4 = ConvPadding(dim*2, dim*4, 3, 1)
        self.layernorm5 = nn.LayerNorm(dim*4)
        
        self.meanpoolconv3 = MeanPoolConv(dim*4, dim*4, 3, 1)
        self.layernorm6 = nn.LayerNorm(dim*4)
        
        self.conv5 = ConvPadding(dim*4, dim*8, 3, 1)
        self.layernorm7 = nn.LayerNorm(dim*8)
        self.swish8 = swish()
        
        self.linear = nn.Linear(4*4*8*dim, 1)
        
        
    def forward(self, x, img_size=64):
        
        x = x.contiguous()
        #N H W C -> N C H W
        x = x.permute(0, 3, 1, 2)
       
        x = x.view(-1, self.in_channels, img_size, img_size)
        
        conv1 = self.swish(self.conv1(x))
        #N self.dim/2 img_size img_size
        
        if(self.batch_norm):
            conv2 = self.swish(self.layernorm1(self.conv2(x)))
            #N self.dim img_size img_size
        else:
            conv2 = self.swish(self.conv2(x))
            #N self.dim img_size img_size
            
        if(self.batch_norm):
            meanpoolconv1 = self.swish(self.layernorm2(self.meanpoolconv1(conv2)))
            #N self.dim img_size/2 img_size/2
        else:
            meanpoolconv1 = self.swish(self.meanpoolconv1(conv2))
            #N self.dim img_size/2 img_size/2
        
        if(self.batch_norm):
            conv3 = self.swish(self.layernorm3(self.conv3(meanpoolconv1)))
            #N self.dim*2 img_size/2 img_size/2
        else:
            conv3 = self.swish(self.conv3(meanpoolconv1))
            #N self.dim*2 img_size/2 img_size/2
        
        if(self.batch_norm):
            meanpoolconv2 = self.swish(self.layernorm4(self.meanpoolconv2(conv3)))
            #N self.dim*2 img_size/4 img_size/4
        else:
            meanpoolconv2 = self.swish(self.meanpoolconv2(conv3))
            #N self.dim*2 img_size/4 img_size/4
        
        if(self.batch_norm):
            conv4 = self.swish(self.layernorm5(self.conv4(meanpoolconv2)))
            #N self.dim*4 img_size/4 img_size/4
        else:
            conv4 = self.swish(self.conv4(meanpoolconv2))
            #N self.dim*4 img_size/4 img_size/4
        
        if(self.batch_norm):
            meanpoolconv3 = self.swish(self.layernorm6(self.meanpoolconv3(conv4)))
            #N self.dim*4 img_size/8 img_size/8
        else:
            meanpoolconv3 = self.swish(self.meanpoolconv3(conv4))
            #N self.dim*4 img_size/8 img_size/8
        
        if(self.batch_norm):
            conv5 = self.swish(self.layernorm7(self.conv5(meanpoolconv3)))
            #N self.dim*8 img_size/8 img_size/8 
        else:
            conv5 = self.swish(self.conv5(meanpoolconv3))
            #N self.dim*8 img_size/8 img_size/8 
        
        out = (conv5[:, :, ::2, ::2] + conv5[:, :, 1::2, ::2] + conv5[:, :, ::2, 1::2] + conv5[:, :, 1::2, 1::2])/4
        #N self.dim*8 img_size/16 img_size/16 
        
        out = out.view(-1, (img_size/16)*(img_size/16)*8*self.dim)
        
        out = self.linear(out)
        #N 1
        
        out = out.view(-1)
        #N
        
        return out

In [13]:
#Noise samples, initial psudo negatives
class Noise(nn.Module):
    def __init__(self, n_samples, dim = 64):
        super(Noise, self).__init__()
        
        self.n_samples = n_samples
        self.dim = dim
        self.conv1 = ConvPaddings(8*self.dim, 4*self.dim, 5, 1)
        self.upsample1 = nn.Upsample(size=8, scale_factor=None, mode='nearest', align_corners=None)
        self.layernorm1 = nn.LayerNorm(4*self.dim)
        
        self.conv2 = ConvPaddings(4*self.dim, 2*self.dim, 5, 1)
        self.upsample2 = nn.Upsample(size=16, scale_factor=None, mode='nearest', align_corners=None)
        self.layernorm2 = nn.LayerNorm(2*self.dim)
        
        self.conv3 = ConvPaddings(2*self.dim, self.dim, 5, 1)
        self.upsample3 = nn.Upsample(size=32, scale_factor=None, mode='nearest', align_corners=None)
        self.layernorm3 = nn.LayerNorm(self.dim)
        
        self.conv4 = ConvPaddings(self.dim, 3, 5, 1)
        self.upsample4 = nn.Upsample(size=64, scale_factor=None, mode='nearest', align_corners=None)
        self.layernorm4 = nn.LayerNorm(3)
        
    def forward(self, x):
        
        conv1 = self.conv1(x)
        upsample1 = self.upsample1(conv1)
        upsample1 = self.layernorm1(upsample1)
        #N 4*dim 8 8
        
        conv2 = self.conv2(x)
        upsample2 = self.upsample2(conv2)
        upsample2 = self.layernorm2(upsample2)
        #N 2*dim 16 16
        
        conv3 = self.conv3(x)
        upsample3 = self.upsample3(conv3)
        upsample3 = self.layernorm3(upsample3)
        #N dim 32 32
        
        conv4 = self.conv4(x)
        upsample4 = self.upsample4(conv4)
        upsample4 = self.layernorm4(upsample4)
        #N dim/2 64 64
        
        return upsample4
        

In [15]:
class WINN(nn.Module):
    def __init__(self, in_channels=3, dim=64):
        super(WINN, self).__init__()
        self.batch_size = opts.batch_size
        self.in_channels = in_channels
        self.img_size = img_size
        self.dim = dim
        self.img_size = opts.img_size
        self.num_chain = opts.nRow*opts.nCol #each image in final result
        self.opts = opts
        
        if(opts.with_noise):
            print("Langevin Dynamics with noise")
        else:
            print("Langevin Dynamics without noise")
        
        if(opts.set=='cifar'):
            opts.img_size = 32
            print("training on cifar with image size: %i" %(img_size))
            
        
    def load_Discriminator(self, file):
        self.discriminator = torch.load(file).train()
        print('Loading Descriptor from ' + file + '...')
        
        
    def langevin_dynamics_discriminator(self, x):
        
        #run langevin_step_num_gen steps langevin dynamics
        for i in range(self.opts.langevin_step_num_des):
            
            #dimension of x is 3
            noise = Variable(torch.randn(self.num_chain, 3, self.opts.img_size, self.opts.img_size).cuda())
            #"However, .data can be unsafe in some cases. 
            #Any changes on x.data wouldn’t be tracked by autograd, 
            #and the computed gradients would be incorrect if x is needed in a backward pass. 
            #A safer alternative is to use x.detach(), 
            #which also returns a Tensor that shares data with requires_grad=False, 
            #but will have its in-place changes reported by autograd if x is needed in backward."
            
            # clone it and turn x into a leaf variable so the grad won't be thrown away
            x = Variable(x.data, requires_grad=True)
            
            #gradient is torch.ones(self.num_chain, self.opts.z_size).cuda()
            
            x_feature = self.discriminator(x)
            #x_feature is f(x;\theta) which is \ln(p(y=1|x,\theta)/p(y=0|x,\theta))
            
            #torch.autograd.backward(tensors, grad_tensors=None, retain_graph=None, 
            #create_graph=False, grad_variables=None)
            
            # do backward for all element of x_feature
            x_feature.backward(torch.ones(self.num_chain, self.opts.z_size).cuda())
            
            #grad = \frac{\partial f(x;\theta)}{\partial x}
            grad = x.grad
            
            # print ('x is : '+str(x[0]))
            # print ('x_grad is : '+str(grad[0]))
            
            x = x + 0.5 * self.opts.langevin_step_size_dis * self.opts.langevin_step_size_dis * grad
            
            #+ step_size*U_{\tau}
            if self.opts.with_noise:
                x += self.opts.langevin_step_size_dis * noise
                
        return x 
        
    def train(self, discriminator_model=None, LAMBDA=10.0):
        if(discriminator_model!=None):
            self.discriminator = torch.load(file).train()
            print('Loading Discriminator from ' + discriminator_model + '...')
        else:
            self.discriminator = INN_Discriminator().cuda().train()
            print('Loading Discriminator without initialization...')
            
            
        if self.opts.set == 'scene' or self.opts.set == 'cifar':
            train_data = DataSet(os.path.join(self.opts.data_path, self.opts.category), 
                                 image_size=self.opts.img_size)
        else:
            train_data = torchvision.datasets.LSUN(root=self.opts.data_path,
                                                   classes=['bedroom_train'],
                                                   transform=transforms.Compose([transforms.Resize(self.img_size),
                                                   transforms.ToTensor(), ]))
            
        num_batches = int(math.ceil(len(train_data) / batch_size))
        
        if not os.path.exists(self.opts.ckpt_dir):
            os.makedirs(self.opts.ckpt_dir)
        if not os.path.exists(self.opts.output_dir):
            os.makedirs(self.opts.output_dir)
        logfile = open(self.opts.ckpt_dir + '/log', 'w+')
        
        # Prepare for root directory of intermediate image.
        intermediate_image_root = os.path.join(self.opts.output_dir, "intermediate")
        mkdir_if_not_exists(intermediate_image_root)
        # Prepare for root directory of negative images.
        neg_image_root = os.path.join(self.opts.output_dir, "negative")
        mkdir_if_not_exists(neg_image_root)
        
        ######################################################################
        # Training stage 1: Load positive images.
        ######################################################################
        log(log_file_path,
            "Training stage 1: Load positive images...")


        # sample_results = np.random.randn(self.num_chain * num_batches, self.opts.img_size, self.opts.img_size, 3)
        dis_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.opts.lr_des,
                                         betas=[self.opts.beta1_des, 0.999])
        
        D_pos_logits = self.discriminator(D_pos_images)
        D_neg_logits = self.discriminator(D_neg_images)
        
        D_loss = torch.mean(D_neg_logits - D_pos_logits)
        D_pos_loss = torch.mean(D_pos_logits)

        uniform_dist = torch.distributions.uniform.Uniform(torch.Tensor([0.0]),torch.Tensor([1.0]))
                            
        epsilon = uniform_dist.sample(torch.Size([half_b_size, 1, 1, 1]))
        # Dirty hack to tile the tensor
        epsilon = epsilon + torch.zeros(D_pos_images.shape, dtype=epsilon.dtype)
        x_hat = epsilon * D_pos_images + (1 - epsilon) * D_neg_images
                                      
        d_hat = self.descriptor(x_hat)

        ddx = tf.gradients(d_hat, x_hat)[0]
        ddx = torch.sqrt(torch.sum(ddx**2, (1, 2, 3)))
        ddx = torch.sum((ddx - 1.0).pow(2) * LAMBDA)
        D_loss += ddx
            
        
        
        

usage: ipykernel_launcher.py [-h] [-num_epoch NUM_EPOCH]
                             [-batch_size BATCH_SIZE] [-nRow NROW]
                             [-nCol NCOL] [-img_size IMG_SIZE]
                             [-test_size TEST_SIZE] [-test] [-score]
                             [-z_size Z_SIZE] [-category CATEGORY]
                             [-data_path DATA_PATH] [-output_dir OUTPUT_DIR]
                             [-log_dir LOG_DIR] [-ckpt_dir CKPT_DIR]
                             [-log_epoch LOG_EPOCH] [-set SET]
                             [-with_noise WITH_NOISE]
                             [-incep_interval INCEP_INTERVAL]
                             [-ckpt_des CKPT_DES] [-sigma_gen SIGMA_GEN]
                             [-langevin_step_num_gen LANGEVIN_STEP_NUM_GEN]
                             [-langevin_step_size_gen LANGEVIN_STEP_SIZE_GEN]
                             [-lr_gen LR_GEN] [-beta1_gen BETA1_GEN]
                             [-ckpt_gen CKPT_GEN] [-sigm

SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
