In [None]:

%matplotlib inline

import os
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import random_split
from torch.autograd import Variable
from torchvision import transforms
import warnings
warnings.filterwarnings('ignore')

from dataset import DatasetFromFolderPy
from NDMnet import Generator



In [None]:
DIR_RESULTS = 'results/'
if not os.path.isdir(DIR_RESULTS):
    os.mkdir(DIR_RESULTS)
    
N_sources = 171 #Number of sources in dataset
DIR_TRAINING_DATASET_X = 'Data/Model5M/mar5z_5/'
DIR_TRAINING_DATASET_Y = 'Data/Model1M/mar1z_1/'

lrG = 0.0002
ngf = 64
beta1 = 0.5
beta2 = 0.999
num_epochs = 750
batch_size = 5

In [None]:
torch.cuda.is_available()
torch.cuda.device_count()
device = torch.device("cuda:1" if torch.cuda.is_available() else 'cpu:0')
print(device)
torch.cuda.set_device(device)

In [None]:
#Create training dataset (equdistant)

percent = 60
proportion = int(100/percent)
train_idx = range(0, N_sources, proportion)
np.save(DIR_RESULTS+'jdx', train_idx)
train_idx_dir = DIR_RESULTS+'jdx.npy'

alpha = 10e-9 #Normalization coefficient
transTorch = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.0),(alpha))
])

trans_target = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.0),(alpha))
    
])

train_data = DatasetFromFolderPy(DIR_TRAINING_DATASET_X, DIR_TRAINING_DATASET_Y, DIR_RESULTS, transform = transTorch, transform_target= trans_target,direction='AtoB')

val_percent = 0.1
n_val = int(len(train_data)*val_percent)
n_train = len(train_data) - n_val
train_set, val_set = random_split(train_data, [n_train, n_val], generator=torch.Generator().manual_seed(0))

train_data_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
val_data_loader = torch.utils.data.DataLoader(dataset=val_set, batch_size=batch_size, shuffle=True)


In [None]:

G=Generator(2,ngf,2) 
G.cuda()
G.normal_weight_init(mean=0.0, std=0.02)
BCE_loss = torch.nn.BCELoss().cuda()
L1_loss = torch.nn.L1Loss().cuda()
L2_loss = torch.nn.MSELoss().cuda()

G_optimizer = torch.optim.Adam(G.parameters(), lr = lrG, betas = (beta1, beta2))
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(G_optimizer, factor = 0.5, patience = 10, verbose = True)

load = False
if load:
    G.load_state_dict(torch.load(DIR_RESULTS+"G_740.pkl"))

TrainingLoss = []
ValidationLoss = []
for epoch in range(0,num_epochs):
            losses = []
            avg_time = 0.
            G.train()
            for i, (input, target) in enumerate(train_data_loader):
                x_ = Variable(input.cuda())
                y_ = Variable(target.cuda())
                
                G_optimizer.zero_grad()
                gen_image = G(x_)
                loss = L1_loss(gen_image, y_)
                losses.append(loss.item())
                
                loss.backward()
                G_optimizer.step()
            
            mean_loss = sum(losses) / len(losses)     
            scheduler.step(mean_loss)
            TrainingLoss.append(mean_loss)
            #validate
            G.eval()
            losses_val = []
            for i, (input, target) in enumerate(val_data_loader):
                x_ = Variable(input.cuda())
                y_ = Variable(target.cuda())
                    
                gen_image = G(x_)
                loss = L1_loss(gen_image, y_)
                losses_val.append(loss.item())
                loss.backward()
            mean_loss_val = sum(losses_val)/len(losses_val)
            ValidationLoss.append(mean_loss_val)
            if epoch!=0 and epoch % 20 == 0:
                    torch.save(G.state_dict(), DIR_RESULTS + 'G_' + str(epoch) + '.pkl')
                    print(epoch, " save.")
torch.save(G.state_dict(), DIR_RESULTS + 'G_' + str(epoch) + '.pkl')

In [None]:
plt.figure(figsize=(25,10))
epochs = np.arange(0,750,1)
plt.plot(epochs,TrainingLoss, label = 'Training loss')
epochs = np.arange(0,750,1)
plt.plot(epochs,ValidationLoss, label = 'Validation loss')
plt.legend(fontsize="25")
plt.grid(True)
plt.ylabel('Loss', size= 25)
plt.xlabel('Epochs', size= 25)
plt.savefig(DIR_RESULTS+'TrValLoss.eps', format = 'eps')
plt.savefig(DIR_RESULTS+'TrValLoss.png')
