In [None]:
batch_size = 4  # Increased batch size for faster training
train_dataset = create_dataset(train_indices, batch_size)
test_dataset = create_dataset(test_indices, batch_size)

# Calculate steps per epoch
steps_per_epoch = len(train_indices) // batch_size
validation_steps = len(test_indices) // batch_size

# Create and compile the model
model = SWETransUNet()
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5, clipnorm=1.0)
model.compile(optimizer=optimizer, loss='mse', metrics=['mae'])

# Train the model with callbacks for early stopping and reducing learning rate
early_stopping = tf.keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(factor=0.2, patience=2)

history = model.fit(
    train_dataset,
    epochs=5,  # Reduced to 5 epochs
    validation_data=test_dataset,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
    callbacks=[early_stopping, reduce_lr],
    verbose=1
)

# Print final metrics
print(f"Final Training Loss: {history.history['loss'][-1]:.4f}")
print(f"Final Training MAE: {history.history['mae'][-1]:.4f}")
print(f"Final Validation Loss: {history.history['val_loss'][-1]:.4f}")
print(f"Final Validation MAE: {history.history['val_mae'][-1]:.4f}")

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['mae'], label='Training MAE')
plt.plot(history.history['val_mae'], label='Validation MAE')
plt.title('Model MAE')
plt.xlabel('Epoch')
plt.ylabel('MAE')
plt.legend()

plt.tight_layout()
plt.savefig('/content/drive/MyDrive/Independent Environmental Science Project/Plots/training_history.png')
plt.show()

# Print final metrics
print(f"Final Training Loss: {history.history['loss'][-1]:.4f}")
print(f"Final Training MAE: {history.history['mae'][-1]:.4f}")
print(f"Final Validation Loss: {history.history['val_loss'][-1]:.4f}")
print(f"Final Validation MAE: {history.history['val_mae'][-1]:.4f}")

# Save the model
model.save('/content/drive/MyDrive/Independent Environmental Science Project/Model/SWETransUNet_model.keras')

# Save the training history
import pickle
with open('/content/drive/MyDrive/Independent Environmental Science Project/Model/training_history.pkl', 'wb') as f:
    pickle.dump(history.history, f)

# Visualize model architecture
tf.keras.utils.plot_model(model, to_file='/content/drive/MyDrive/Independent Environmental Science Project/Plots/model_architecture.png', show_shapes=True, show_layer_names=True)

# Generate and plot sample predictions
def plot_sample_predictions(model, dataset, num_samples=5):
    plt.figure(figsize=(20, 4*num_samples))
    for i, (X, y_true) in enumerate(dataset.take(num_samples)):
        y_pred = model.predict(X)

        for j in range(3):  # Plot each channel
            plt.subplot(num_samples, 9, i*9 + j*3 + 1)
            plt.imshow(X[0, -1, 0, :, :, j], cmap='viridis')
            plt.title(f'Input Ch{j}')
            plt.axis('off')

            plt.subplot(num_samples, 9, i*9 + j*3 + 2)
            plt.imshow(y_true[0, :, :, j], cmap='viridis')
            plt.title(f'True Ch{j}')
            plt.axis('off')

            plt.subplot(num_samples, 9, i*9 + j*3 + 3)
            plt.imshow(y_pred[0, :, :, j], cmap='viridis')
            plt.title(f'Pred Ch{j}')
            plt.axis('off')

    plt.tight_layout()
    plt.savefig('/content/drive/MyDrive/Independent Environmental Science Project/Plots/sample_predictions.png')
    plt.show()

plot_sample_predictions(model, test_dataset)

print("All visualizations have been saved in the 'Plots' folder.")