In [None]:
import os, time
import numpy as np
import torch
import supportingFunctions as sf
import model_torch as mm
from datetime import datetime
from tqdm import tqdm

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

# Set these parameters carefully
nLayers = 5
epochs = 50
batchSize = 1
gradientMethod = 'AG'
K = 1
sigma = 0.01
restoreWeights = False

# To train the model with higher K values (K > 1), such as K = 5 or 10, it is better
# to initialize with a pre-trained model with K = 1.
if K > 1:
    restoreWeights = True
    restoreFromModel = '04Jun_0243pm_5L_1K_100E_AG'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#%% Generate a meaningful filename to save the trainined models for testing
print('*************************************************')
start_time = time.time()
saveDir = 'savedModels/'
directory = saveDir + datetime.now().strftime("%d%b_%I%M%P_") + \
             str(nLayers) + 'L_' + str(K) + 'K_' + str(epochs) + 'E_' + gradientMethod

if not os.path.exists(directory):
    os.makedirs(directory)
sessFileName= directory+'/model'

#%% Save test model
csmT = torch.randn((1, 12, 256, 232), dtype=torch.complex64, device=device, requires_grad=False)
maskT = torch.randn((1, 256, 232), dtype=torch.complex64, device=device, requires_grad=False)
atbT = torch.randn((1, 256, 232, 2), dtype=torch.float32, device=device, requires_grad=False)

out = mm.makeModel(atbT, csmT, maskT, False, nLayers, K, gradientMethod)
predTst = out['dc'+str(K)]
predTst = predTst.clone().detach().requires_grad_(False)
sessFileNameTst = directory+'/modelTst'

saver = torch.save(predTst, sessFileNameTst)
print('Testing model saved')

#%% Read multi-channel dataset
trnOrg, trnAtb, trnCsm, trnMask = sf.getData('training')
trnOrg, trnAtb = sf.c2r(trnOrg), sf.c2r(trnAtb)

#%%
csmP = torch.zeros((batchSize, 12, 256, 232), dtype=torch.complex64, device=device, requires_grad=False)
maskP = torch.zeros((batchSize, 256, 232), dtype=torch.complex64, device=device, requires_grad=False)
atbP = torch.zeros((batchSize, 256, 232, 2), dtype=torch.float32, device=device, requires_grad=False)
orgP = torch.zeros((batchSize, 256, 232, 2), dtype=torch.float32, device=device, requires_grad=False)

#%% Creating the dataset
nTrn = trnOrg.shape[0]
nBatch = int(np.floor(np.float32(nTrn) / batchSize))
nSteps = nBatch * epochs

trnData = torch.utils.data.TensorDataset(orgP, atbP, csmP, maskP)
trnData = torch.utils.data.DataLoader(trnData, batch_size=batchSize, shuffle=True)
iterator = iter(trnData)

#%% Make training model
out = mm.makeModel(atbT, csmT, maskT, True, nLayers, K, gradientMethod)
predT = out['dc' + str(K)]
predT = predT.clone().detach().requires_grad_(True)
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam([predT])

print('training started at', datetime.now().strftime("%d-%b-%Y %I:%M %P"))
print('parameters are: Epochs:', epochs, ' BS:', batchSize, 'nSteps:', nSteps, 'nSamples:', nTrn)

model_params = list(out.values())
saver = torch.save(model_params, sessFileName) # PyTorch saves the weights, not the meta graph
totalLoss, ep = [], 0
writer = SummaryWriter(directory)
for step in tqdm(range(nSteps)):
    try:
        optimizer.zero_grad()
        loss = loss_fn(predT, target) # target should be defined
        loss.backward()
        optimizer.step()
        totalLoss.append(loss.item())
        if np.remainder(step + 1, nBatch) == 0:
            ep = ep + 1
            avgTrnLoss = np.mean(totalLoss)
            writer.add_scalar("TrnLoss", avgTrnLoss, ep)
            totalLoss = []  # after each epoch empty the list of total loss
    except:
        break
torch.save(model_params, sessFileName) # PyTorch saves the weights, not the meta graph
writer.close()

end_time = time.time()
print('Training completed in minutes ', ((end_time - start_time) / 60))
print('training completed at', datetime.now().strftime("%d-%b-%Y %I:%M %P"))
print('*************************************************')

#%%