In [1]:
from pathlib import Path
import os
import sys
import shutil
import numpy as np
import time
import datetime
import torch
import torch.nn as nn
import torchvision
import cfg
from datetime import datetime

from torchvision.utils import save_image, make_grid
# from tensorboardX import SummaryWriter

from types import SimpleNamespace
import json

from models.ResNet import Res12_Quadratic, Res18_Quadratic, Res34_Quadratic

In [2]:
args = SimpleNamespace()
args.n_iter = 25000
args.num_sub = 50 # number of subjects per contrast
args.res = 64
args.classes = [0]
args.batch_size = 128
args.in_chan = 1
args.n_chan = 128
args.n_gpus = 1
args.max_lr = 1e-5
args.min_noise = 0.1
args.max_noise = 3.0
args.noise_distribution = 'exp'
args.save_every = 1000
args.dataset = 'fastmri'
args.log = 'fastmri_EBM'

# resume from previous checkpoint
args.cont = False 
# timestamp to resume training
# args.time = '2024_May15_08_38'
args.time = ''
args.net_indx = 100000

args.lr_schedule = 'cosine'
args.rand_seed = 42
args.file_name = 'mdsm_ebm'
args.net = 'res34'

In [3]:
torch.cuda.empty_cache()
torch.cuda.manual_seed(args.rand_seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# DATA LOADER
from data.fastmri_brain import inf_train_gen, inf_train_gen_downsampled
# itr = inf_train_gen(args.batch_size, num_sub=args.num_sub)
itr = inf_train_gen_downsampled(args.batch_size, num_sub=args.num_sub, device=device, res=args.res, complex_in=False, classes=args.classes)

if args.net == 'res18':
    netE = Res18_Quadratic(args.in_chan,args.n_chan,32,normalize=False,AF=nn.ELU())
elif args.net == 'res34':
    netE = Res34_Quadratic(args.in_chan,args.n_chan,32,normalize=False,AF=nn.ELU())

netE = netE.to(device)

In [4]:
x_real = itr.__next__()

FLAIR_big_pickle - loading 50 of 108 subjects
min/max: tensor(0., dtype=torch.float64) / tensor(1.0000, dtype=torch.float64)
ksp: torch.Size([10, 12, 320, 320]) 	csm: torch.Size([10, 12, 320, 320])
min/max: tensor(0., dtype=torch.float64) / tensor(1.0000, dtype=torch.float64)
ksp: torch.Size([10, 12, 320, 320]) 	csm: torch.Size([10, 12, 320, 320])
min/max: tensor(0., dtype=torch.float64) / tensor(1., dtype=torch.float64)
ksp: torch.Size([10, 12, 320, 320]) 	csm: torch.Size([10, 12, 320, 320])
min/max: tensor(0., dtype=torch.float64) / tensor(1., dtype=torch.float64)
ksp: torch.Size([10, 12, 320, 320]) 	csm: torch.Size([10, 12, 320, 320])
min/max: tensor(0., dtype=torch.float64) / tensor(1., dtype=torch.float64)
ksp: torch.Size([10, 12, 320, 320]) 	csm: torch.Size([10, 12, 320, 320])
min/max: tensor(0., dtype=torch.float64) / tensor(1., dtype=torch.float64)
ksp: torch.Size([10, 12, 320, 320]) 	csm: torch.Size([10, 12, 320, 320])
min/max: tensor(0., dtype=torch.float64) / tensor(1., dtyp

In [5]:
x_real.shape

torch.Size([128, 1, 64, 64])

In [6]:
now = datetime.now()
timestamp = now.strftime('%Y_%b%d_%H_%M')
print(str(args.cont))
# print(str(args.time))
if args.cont==True:
    root = 'logs/' + args.log + '_'+ args.time #compose string for loading
    #load network
    file_name = args.file_name + str(args.net_indx) + '.pt'
    print("Preloading from ",root+ '/models/' +file_name)
    netE.load_state_dict(torch.load(root + '/models/' +file_name))
    print("Preloading successful ")
else: # start new will create logging folder
    root = 'logs/'+ args.log + '_' + timestamp #add timestemp
    #over write if folder already exist, not likely to happen as timestamp is used
    if os.path.isdir(root):
        shutil.rmtree(root)
    os.makedirs(root)
    os.makedirs(root+'/models')
    os.makedirs(root+'/samples')

log_fname = root + "/log.txt"
def print_log(string, mute=False):
    with open(log_fname, "a", encoding='utf8') as log: # append mode 
        log.write(string+"\n")
    if not mute: print(string)
    
settings_fname = root + f"/settings_{timestamp}.json"
with open(settings_fname, "w", encoding='utf8') as sf:
    json.dump(args.__dict__, sf, indent=4) 

True
Preloading from  logs/fastmri_EBM_2024_May29_16_13/models/mdsm_ebm100000.pt
Preloading successful 


In [None]:
# writer = SummaryWriter(root)

# setup optimizer and lr scheduler
params = {'lr':args.max_lr,'betas':(0.9,0.95)}
optimizerE = torch.optim.Adam(netE.parameters(),**params)
if args.lr_schedule == 'exp':
    scheduler = torch.optim.lr_scheduler.StepLR(optimizerE,int(args.n_iter/6))

elif args.lr_schedule == 'cosine':
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerE,args.n_iter,eta_min=1e-6,last_epoch=-1)

elif args.lr_schedule == 'const':
    scheduler = torch.optim.lr_scheduler.StepLR(optimizerE,int(args.n_iter))

#train
print_interval = args.save_every // 4
max_iter = args.n_iter+args.net_indx
batchSize = args.batch_size
sigma0 = 0.1
sigma02 = sigma0**2

if args.noise_distribution == 'exp':
    sigmas_np = np.logspace(np.log10(args.min_noise),np.log10(args.max_noise),batchSize)
elif args.noise_distribution == 'lin':
    sigmas_np = np.linspace(args.min_noise,args.max_noise,batchSize)

sigmas = torch.Tensor(sigmas_np).view((batchSize,1,1,1)).to(device)

start_time = time.time()

print_log(f'Training started -- {root}')
for i in range(args.net_indx,args.net_indx + args.n_iter):
    x_real = itr.__next__().to(device)
    x_noisy = x_real + sigmas*torch.randn_like(x_real)

    x_noisy = x_noisy.requires_grad_()
    E = netE(x_noisy).sum()
    grad_x = torch.autograd.grad(E,x_noisy,create_graph=True)[0]
    x_noisy.detach()

    optimizerE.zero_grad()

    LS_loss = ((((x_real-x_noisy)/sigmas/sigma02+grad_x/sigmas)**2)/batchSize).sum()

    LS_loss.backward()
    optimizerE.step()
    scheduler.step()

    if (i+1)%print_interval == 0:
        time_spent = time.time() - start_time
        start_time = time.time()
        netE.eval()
        E_real = netE(x_real).mean()
        E_noise = netE(torch.rand_like(x_real)).mean()
        netE.train()

        print_log('Iteration {}/{} ({:.0f}%), E_real {:e}, E_noise {:e}, Normalized Loss {:e}, time {:4.1f}'.format(i+1,max_iter,100*((i+1)/max_iter),E_real.item(),E_noise.item(),(sigma02**2)*(LS_loss.item()),time_spent))

        # writer.add_scalar('E_real',E_real.item(),i+1)
        # writer.add_scalar('E_noise',E_noise.item(),i+1)
        # writer.add_scalar('loss',(sigma02**2)*LS_loss.item(),i+1)
        del E_real, E_noise, x_real, x_noisy

    if (i+1)%args.save_every == 0:
        print_log("-"*50)
        file_name = args.file_name+str(i+1)+'.pt'
        torch.save(netE.state_dict(),root+'/models/'+file_name)

Training started -- logs/fastmri_EBM_2024_May29_16_13
Iteration 100250/125000 (80%), E_real -8.387317e+04, E_noise -7.389418e+04, Normalized Loss 8.024816e+00, time 308.8
Iteration 100500/125000 (80%), E_real -8.281145e+04, E_noise -7.282845e+04, Normalized Loss 7.606600e+00, time 308.7
Iteration 100750/125000 (81%), E_real -8.196445e+04, E_noise -7.203902e+04, Normalized Loss 8.008731e+00, time 308.7
Iteration 101000/125000 (81%), E_real -8.116862e+04, E_noise -7.104477e+04, Normalized Loss 7.690745e+00, time 308.7
--------------------------------------------------
Iteration 101250/125000 (81%), E_real -8.069720e+04, E_noise -7.068544e+04, Normalized Loss 8.024630e+00, time 313.6
Iteration 101500/125000 (81%), E_real -8.008139e+04, E_noise -7.016610e+04, Normalized Loss 7.410288e+00, time 308.7
Iteration 101750/125000 (81%), E_real -7.981743e+04, E_noise -6.987573e+04, Normalized Loss 8.181366e+00, time 308.7
Iteration 102000/125000 (82%), E_real -7.916834e+04, E_noise -6.924633e+04, 