In [1]:
from pathlib import Path
import os
import sys
import shutil
import numpy as np
import time
import datetime
import torch
import torch.nn as nn
import torchvision
import cfg
from datetime import datetime

from torchvision.utils import save_image, make_grid
# from tensorboardX import SummaryWriter

from types import SimpleNamespace

from models.ResNet import Res12_Quadratic, Res18_Quadratic, Res34_Quadratic

In [2]:
args = SimpleNamespace()
args.n_iter = 1000
args.batch_size = 256
args.n_chan = 64
args.n_gpus = 1
args.max_lr = 1e-4
args.min_noise = 0.1
args.max_noise = 3
args.noise_distribution = 'exp'
args.save_every = 500
args.dataset = 'fmnist'
args.cont = True
args.log = 'fmnist_EBM'
args.time = '2024_Apr28_12_10' # timestamp to resume training
args.lr_schedule = 'cosine'
args.rand_seed = 42
args.net_indx = 1000
args.file_name = 'mdsm_ebm'

# args = SimpleNamespace()
# args.n_iter = 10000
# args.batch_size = 128
# args.n_chan = 128
# args.n_gpus = 1
# args.max_lr = 5e-5
# args.min_noise = 0.05
# args.max_noise = 1.2
# args.noise_distribution = 'lin'
# args.save_every = 5000
# args.dataset = 'celeba'
# args.log = 'celeba_EBM'
# args.lr_schedule = 'cosine'
# args.rand_seed = 42
# args.cont = False
# args.net_indx = 0
# args.file_name = 'mdsm_ebm'

In [3]:
torch.cuda.empty_cache()
torch.cuda.manual_seed(args.rand_seed)

if args.dataset == 'cifar':
    from data.cifar import inf_train_gen
    itr = inf_train_gen(args.batch_size,flip=False)
    netE = Res18_Quadratic(3,args.n_chan,32,normalize=False,AF=nn.ELU())
    
elif args.dataset == 'mnist':
    from data.mnist_32 import inf_train_gen
    itr = inf_train_gen(args.batch_size)
    netE = Res12_Quadratic(1,args.n_chan,32,normalize=False,AF=nn.ELU())
    
elif args.dataset == 'fmnist':
    #print(dataset+str(args.n_chan))
    from data.fashion_mnist_32 import inf_train_gen
    itr = inf_train_gen(args.batch_size)
    netE = Res12_Quadratic(1,args.n_chan,32,normalize=False,AF=nn.ELU())
    
elif args.dataset == 'celeba':
    #print(dataset+str(args.n_chan))
    from data.celeba import inf_train_gen
    itr = inf_train_gen(args.batch_size)
    netE = Res18_Quadratic(3,args.n_chan,32,normalize=False,AF=nn.ELU())

else:
    NotImplementedError('{} unknown dataset'.format(args.dataset))

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
netE = netE.to(device)

In [5]:
now = datetime.now()
timestamp = now.strftime('%Y_%b%d_%H_%M')
print(str(args.cont))
# print(str(args.time))
if args.cont==True:
    root = 'logs/' + args.log + '_'+ args.time #compose string for loading
    #load network
    file_name = args.file_name + str(args.net_indx) + '.pt'
    print("Preloading from ",root+ '/models/' +file_name)
    netE.load_state_dict(torch.load(root + '/models/' +file_name))
    print("Preloading successful ")
else: # start new will create logging folder
    root = 'logs/'+ args.log + '_' + timestamp #add timestemp
    #over write if folder already exist, not likely to happen as timestamp is used
    if os.path.isdir(root):
        shutil.rmtree(root)
    os.makedirs(root)
    os.makedirs(root+'/models')
    os.makedirs(root+'/samples')

log_fname = root + "/log.txt"
def print_log(string, mute=False):
    with open(log_fname, "a", encoding='utf8') as log: # append mode 
        log.write(string+"\n")
    if not mute: print(string)

True
Preloading from  logs/fmnist_EBM_2024_Apr28_12_10/models/mdsm_ebm1000.pt
Preloading successful 


In [6]:
# writer = SummaryWriter(root)

# setup optimizer and lr scheduler
params = {'lr':args.max_lr,'betas':(0.9,0.95)}
optimizerE = torch.optim.Adam(netE.parameters(),**params)
if args.lr_schedule == 'exp':
    scheduler = torch.optim.lr_scheduler.StepLR(optimizerE,int(args.n_iter/6))

elif args.lr_schedule == 'cosine':
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerE,args.n_iter,eta_min=1e-6,last_epoch=-1)

elif args.lr_schedule == 'const':
    scheduler = torch.optim.lr_scheduler.StepLR(optimizerE,int(args.n_iter))

#train
print_interval = args.save_every // 4
max_iter = args.n_iter+args.net_indx
batchSize = args.batch_size
sigma0 = 0.1
sigma02 = sigma0**2

if args.noise_distribution == 'exp':
    sigmas_np = np.logspace(np.log10(args.min_noise),np.log10(args.max_noise),batchSize)
elif args.noise_distribution == 'lin':
    sigmas_np = np.linspace(args.min_noise,args.max_noise,batchSize)

sigmas = torch.Tensor(sigmas_np).view((batchSize,1,1,1)).to(device)

start_time = time.time()

print_log('Training started')
for i in range(args.net_indx,args.net_indx + args.n_iter):
    x_real = itr.__next__().to(device)
    x_noisy = x_real + sigmas*torch.randn_like(x_real)

    x_noisy = x_noisy.requires_grad_()
    E = netE(x_noisy).sum()
    grad_x = torch.autograd.grad(E,x_noisy,create_graph=True)[0]
    x_noisy.detach()

    optimizerE.zero_grad()

    LS_loss = ((((x_real-x_noisy)/sigmas/sigma02+grad_x/sigmas)**2)/batchSize).sum()

    LS_loss.backward()
    optimizerE.step()
    scheduler.step()

    if (i+1)%print_interval == 0:
        time_spent = time.time() - start_time
        start_time = time.time()
        netE.eval()
        E_real = netE(x_real).mean()
        E_noise = netE(torch.rand_like(x_real)).mean()
        netE.train()

        print_log('Iteration {}/{} ({:.0f}%), E_real {:e}, E_noise {:e}, Normalized Loss {:e}, time {:4.1f}'.format(i+1,max_iter,100*((i+1)/max_iter),E_real.item(),E_noise.item(),(sigma02**2)*(LS_loss.item()),time_spent))

        # writer.add_scalar('E_real',E_real.item(),i+1)
        # writer.add_scalar('E_noise',E_noise.item(),i+1)
        # writer.add_scalar('loss',(sigma02**2)*LS_loss.item(),i+1)
        del E_real, E_noise, x_real, x_noisy

    if (i+1)%args.save_every == 0:
        print_log("-"*50)
        file_name = args.file_name+str(i+1)+'.pt'
        torch.save(netE.state_dict(),root+'/models/'+file_name)

Training started
Iteration 1125/2000 (56%), E_real -3.851346e+05, E_noise -3.820644e+05, Normalized Loss 3.239212e+02, time 20.8
Iteration 1250/2000 (62%), E_real -7.013970e+05, E_noise -6.975259e+05, Normalized Loss 1.382604e+02, time 19.2
Iteration 1375/2000 (69%), E_real -7.155488e+05, E_noise -7.107834e+05, Normalized Loss 1.162074e+02, time 19.0
Iteration 1500/2000 (75%), E_real -7.408106e+05, E_noise -7.355999e+05, Normalized Loss 1.071651e+02, time 19.3
--------------------------------------------------
Iteration 1625/2000 (81%), E_real -7.759752e+05, E_noise -7.707541e+05, Normalized Loss 1.016989e+02, time 19.2
Iteration 1750/2000 (88%), E_real -8.069706e+05, E_noise -8.018905e+05, Normalized Loss 9.919466e+01, time 19.3
Iteration 1875/2000 (94%), E_real -8.315058e+05, E_noise -8.260849e+05, Normalized Loss 9.319047e+01, time 19.0
Iteration 2000/2000 (100%), E_real -8.403926e+05, E_noise -8.349966e+05, Normalized Loss 9.585524e+01, time 19.3
-----------------------------------