In [1]:
import sys

sys.path.append('C:/Users/nilso/Documents/EPFL/PDM/PDM_PINN/SciANN/DNN_TEST/sys/')

from loss import *
from unet import UNet
from dataloader import *
from BaseModel import BaseModel

import torch.optim as optim
import logging

if torch.cuda.is_available():
    device = 'cuda:0'
else:
    device = 'cpu'

##################
# # # Config # # # 
##################

epochs = 500
batch_size = 2

# Data
data_dir = '../Moseley_EARTH/'
data_csv = '../Moseley_Earth_Event0000_Continuous.csv'
event = 'Event0000'
velocity_field = '../Velocity_Field_1.npy'

In [2]:
training_data = dataset(data_dir,data_csv,event,velocity_field=velocity_field)

In [3]:
inputs = training_data.__getitem__(150)['wave_input'].transpose(1,0).detach().cpu()

In [5]:
# Paths
save_dir = '../results/'
save_pt_best = f'Best_L2_GDL_MAE_E{epochs}.pt'
save_pt = f'L2_GDL_MAE_E{epochs}.pt'
save_txt = f'L2_GDL_MAE_E{epochs}.yml'

checkpoint_path= f'checkpoint_L2_GDL_MAE_E{epochs}.pt'

# # # Data
training_data = dataset(data_dir,data_csv,event,velocity_field=velocity_field)
train_loader = DataLoader(training_data, batch_size=batch_size, shuffle=True)

net = UNet(in_channels=5,out_channels=1)

# Optimizer & Scheduler
optimizer = optim.Adam(net.parameters(), lr=0.0001, weight_decay=1e-6)

# Logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
try:
    logger.handlers[1].stream.close()
    logger.removeHandler(logger.handlers[1])
except IndexError:
    pass
logger.setLevel(logging.INFO)
file_handler = logging.FileHandler("log.txt")
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter('%(asctime)s | %(levelname)s | %(message)s'))
logger.addHandler(file_handler)

class UNetModel(BaseModel):
    def __init__(self, net, opt=None, sched=None, logger=None, print_progress=True, device='cuda:0'):
        """

        """
        super().__init__(net, opt, sched, logger, print_progress, device)

        self.loss_fn = MSLoss(*[nn.MSELoss(),GDLLoss(),nn.L1Loss()]) 

    def forward_loss(self, data):
        """

        """
        input, label = data['wave_input'].transpose(2, 1) , data['wave_output'].transpose(2, 1)
        input = input.to(self.device)
        label = label.to(self.device)

        output = self.net(input)

        loss = self.loss_fn(output, label)

        return loss[0], {'Loss':loss[0], 'Loss MSE':loss[1], 'Loss GDL':loss[2], 'Loss MAE':loss[3]} # Elements in the dict : only for printing

# Create the model
model = UNetModel(net, opt=optimizer, sched=None, logger=None, print_progress=True, device=device)

# Train the model
model.train(epochs, train_loader, checkpoint_path=checkpoint_path, checkpoint_freq=5, save_best='Loss')

# Save
model.save_best(export_path=save_dir + save_pt_best)
model.save(export_path=save_dir + save_pt)
model.save_outputs(export_path=save_dir + save_txt)

No Checkpoint found. Training from beginning.
Epoch 0001/0500 | Time 0:00:50.336703 | Loss 60.96793 | Loss MSE 52.31500 | Loss GDL 6.42697 | Loss MAE 2.22596 | 
