In [None]:
import tensorflow as tf
from train_model import train_model
from Generator_v1 import Patient_data_generator
import lung_extraction_funcs_13_09 as le
import os

# Initialize paths
data_path = '../Software for qualitative assesment/train_data'  # Path to training data
validation_data_path = '../Software for qualitative assesment/validation_data'  # Path to validation data
model_path = './model_files/'  # Path to model files
best_model_path = 'best_model.h5'  # Path to the best model weights from previous training

# Initialize data generators
train_patient_dict = le.parse_dataset(data_path, img_only=True)
validation_patient_dict = le.parse_dataset(validation_data_path, img_only=True)

train_generator = Patient_data_generator(
    train_patient_dict, predict=False, batch_size=8, image_size=512, shuffle=True,
    use_window=True, window_params=[1500, -600], resample_int_val=True, resampling_step=25,
    extract_lungs=True, size_eval=False, verbosity=True, reshape=True, img_only=True
)

validation_generator = Patient_data_generator(
    validation_patient_dict, predict=False, batch_size=8, image_size=512, shuffle=True,
    use_window=True, window_params=[1500, -600], resample_int_val=True, resampling_step=25,
    extract_lungs=True, size_eval=False, verbosity=True, reshape=True, img_only=True
)

# Load or initialize the model
json_file = open(os.path.join(model_path, 'model_v7.json'), 'r')
loaded_model_json = json_file.read()
json_file.close()
model = tf.keras.models.model_from_json(loaded_model_json)

# Load the best model weights from previous training (if available)
if os.path.exists(best_model_path):
    print(f"Loading model weights from {best_model_path}...")
    model.load_weights(best_model_path)
else:
    print("No previous model weights found. Starting training from scratch.")
    model.load_weights(os.path.join(model_path, 'weights_v7.hdf5'))  # Load initial weights if no previous training

# Train the model
history = train_model(model, train_generator, validation_generator)

# Save the final model (optional)
model.save_weights('final_model_weights.h5')  # Save final weights as .hdf5