In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

from keras.callbacks import ModelCheckpoint, LearningRateScheduler # type: ignore
from sklearn.model_selection import train_test_split

from mcvd_transformer.dataset.parser import DataParser
from mcvd_transformer.dataset.dataloader import BatchGenerator
from mcvd_transformer.utils.objects import CoordinateSystem
from mcvd_transformer.model.model import create_model
from mcvd_transformer.model.callbacks import lr_scheduler, AdaptiveLossWeight
from mcvd_transformer.utils.postprocessing import PostProcessing
from mcvd_transformer.utils.evaluator import PerformanceEvaluator

EXPERIMENT_NAME = 'MCvD_Transformer'

Read MCvD Simulations

In [None]:
data_parser = DataParser("data", unwanted_folders=[".git"], include_prism=True)
data_set = data_parser.parse_data()

Training Parameters

In [None]:
batch_size = 64
num_epochs = 400

data_seed = 1
numpy_seed = 10
dataloader_seed = 50

Split Dataset into Train, Validation and Test Parts

In [None]:
np.random.seed(numpy_seed)
np.random.shuffle(data_set)

train_set,test_val_set = train_test_split(data_set, test_size=0.3, random_state=data_seed)
val_set,test_set = train_test_split(test_val_set, test_size=0.33, random_state=data_seed)

del test_val_set

print(f"Size of Training Set = {len(train_set)}")
print(f"Size of Validation Set = {len(val_set)}")
print(f"Size of Test Set = {len(test_set)}")

Create Batch Generators

In [None]:
training_batch_generator = BatchGenerator(
    train_set, 
    batch_size = batch_size, 
    coordinate_system = CoordinateSystem.BOTH, 
    random_rotate = True, 
    entity_order = "shuffle", 
    zero_padding = 10, 
    max_shape = True, 
    shuffle = True, 
    max_spherical_entity = 15,
    flatten = False,
    one_absorber_points = 0,
    random_seed=dataloader_seed
)

validation_batch_generator = BatchGenerator(
    val_set, 
    batch_size = batch_size, 
    coordinate_system = CoordinateSystem.BOTH, 
    random_rotate = True, 
    entity_order = "shuffle", 
    zero_padding = 10, 
    max_shape = True, 
    shuffle = True, 
    max_spherical_entity = 15,
    flatten = False,
    one_absorber_points = 0,
    random_seed=dataloader_seed
)

test_batch_generator = BatchGenerator(
    val_set, 
    batch_size = batch_size, 
    coordinate_system = CoordinateSystem.BOTH, 
    random_rotate = True, 
    entity_order = "shuffle", 
    zero_padding = 10, 
    max_shape = True, 
    shuffle = True, 
    max_spherical_entity = 15,
    flatten = False,
    one_absorber_points = 0,
    random_seed=dataloader_seed
)

Create Model

In [None]:
model, alpha, beta = create_model(training_batch_generator[0][0][0].shape[1:], training_batch_generator[0][0][1].shape[1])
model.summary()

Create Training Callbacks

In [None]:
# Create Experiment Log Folder
os.makedirs('experiments' + os.path.sep + EXPERIMENT_NAME, exist_ok=True)

filepath = 'experiments' + os.path.sep + EXPERIMENT_NAME + os.path.sep + 'model.keras'
checkpoint = ModelCheckpoint(filepath, monitor='val_cir_max_loss', verbose=1, save_best_only=True, mode='min')
callbacks_list = [checkpoint, LearningRateScheduler(lr_scheduler, verbose=0), AdaptiveLossWeight(alpha,beta)]

Run Training

In [None]:
history = model.fit(
    training_batch_generator,
    epochs = num_epochs,
    verbose = 1,
    validation_data = validation_batch_generator,
    callbacks=callbacks_list
)

In [None]:
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])

plt.title('Model Loss Curve')
plt.ylabel('loss')
plt.xlabel('epoch')

plt.legend(['train', 'val'], loc='upper left')

Calculate Performance Metrics

In [None]:
model.load_weights(filepath)

train_error_values = PerformanceEvaluator.create_evaluation_results(
    model,
    data_set = train_set,
    coordinate_system = CoordinateSystem.BOTH,
    max_number_of_spherical = 15,
    order = "shuffle",
    path = 'experiments' + os.path.sep + EXPERIMENT_NAME + os.path.sep + "train_error.json"
)

val_error_values = PerformanceEvaluator.create_evaluation_results(
    model,
    data_set = val_set,
    coordinate_system = CoordinateSystem.BOTH,
    max_number_of_spherical = 15,
    order = "shuffle",
    path = 'experiments' + os.path.sep + EXPERIMENT_NAME + os.path.sep + "val_error.json"
)

test_error_values = PerformanceEvaluator.create_evaluation_results(
    model,
    data_set = test_set,
    coordinate_system = CoordinateSystem.BOTH,
    max_number_of_spherical = 15,
    order = "shuffle",
    path = 'experiments' + os.path.sep + EXPERIMENT_NAME + os.path.sep + "test_error.json"
)

Plot Raw Estimations

In [None]:
index = 277
input_topology = val_set[index].rotate(10,45)

input_top,input_num = input_topology.convert_numpy(CoordinateSystem.BOTH,15,"shuffle",False)
shape,max_value,_ = model.predict([np.expand_dims(input_top,axis=0),np.expand_dims(input_num,axis=0)])

prediction = PostProcessing.postprocessing_separate(shape[0],max_value[0])

PerformanceEvaluator.time_graph(
    time_output_actual = input_topology.time_output,
    time_output_predicted_list = [prediction],
    legend=["Ground Truth", "Prediction"],
    image_loc=None,
    time_res=1,
    expension_ratio=1,
    path=None
)

Visualize Topology

In [None]:
input_topology.visualize()