In [None]:
"""
Notebook to perform 3D trans-dimensional inversion as well as exploration of model space with 
the introduction of user-made perturbations to petrophysical properties or rock unit geometry.
"""
# Standard libs. 
import time
import numpy as np
from argparse import ArgumentParser
from importlib import reload
from pathlib import Path

# Home-made libs. 
import src.transd_solver as ts
import src.transd_runner as tr
import src.forward_solver.forward_calculation as fcu
import src.input_output.output_manager as om
import src.utils.plot_utils as ptu
import src.input_output.output_manager as om
import src.input_output.input_params as input_params

# %load_ext line_profiler
sensit = None

In [None]:
reload(ts)
reload(tr)
reload(om)
reload(ptu)
reload(fcu)
reload(input_params)

# Define expert parameter configuration.
run_config = input_params.RunBaseParams() 

# Create log. No logging created if return_none=True. 
log_run = om.LogRun('log_file.log', verbose=True, return_none=False)  # TODO create log name name dynamically using e.g. the run ID. 

# Set a fixed seed for reproducibility
rng_main = np.random.default_rng(seed=123)   

# Read command line arguments.
parser = ArgumentParser()
parser.add_argument("-p", "--parfile", dest="parfile_path",
                    # help="path to the parameters file", default="parfiles/Parfile_example_pyr_force_geom.txt")
                    help="path to the parameters file", default="parfiles/parfile_transd_synth1.txt")
# Get the information from the parameter file.
args, unknown = parser.parse_known_args()

# # Read input parameters.
par = input_params.read_input_parameters(args.parfile_path, log_run)

# Remove files from previous runs with same output name folder. 
om.remove_vtk_files(folder_path=par.path_output + '/', target_string="*", target_extension=".vts", verbose=True, log_run=log_run)

# Record the start time and run the modelling. 
start_time = time.time()
petrovals, birth_params, mvars, metrics, shpars, gpars, phiper, geophy_data, sensit, spars, par = tr.run_transd(par, log_run, rng_main, run_config, sensit)

# Save acceptance tracking.
om.save_acceptance_history(metrics, run_config, 
                           filename=spars.path_output + '/accepted_changes',
                           save=spars.save_plots, 
                           log_run=log_run)
# Save the data for the last model. 
om.save_data_to_vtk(geophy_data, datatype_to_save='data_calc', filename=par.path_output+'/data_calc', save=True, log_run=log_run)
# Save the metrics to file. 
om.save_metrics_txt(metrics, filename=par.path_output+'/metrics', save=True, log_run=log_run)

# Calculate and print the elapsed time
log_run.info(f'RUN TIME: {time.time() - start_time} sec')
log_run.close() 

In [12]:
# Visualisation: metrics recorded during inversion.
reload(ptu)
reload(ts)
plot_path = Path(spars.path_output) / 'metrics_plot.png'
ptu.plot_metrics(metrics, run_config, data_misfit_lims=np.array([0, 6]), print_stats=True, save=True, plot_path=plot_path)


Acceptance rate: 0.660
Mean accept ratio: 0.641

Accepted changes by type:
  geometrical: 28 (42.4%)
  petrophysical: 26 (39.4%)
  Birth of a unit: 6 (9.1%)
  Death of a unit: 6 (9.1%)
Figure saved in output_example_synth1\metrics_plot.png


In [None]:
np.eye(3)

In [None]:
reload(om)
om.save_metrics_summary(metrics, f"{spars.path_output}/metrics_summary.txt")
om.save_metrics_data(metrics, f"{spars.path_output}/metrics_data.csv")

In [None]:
# Visualisation: GIF of 1 slice per saved model.  
from IPython.display import Image
reload(ptu)

# Create GIF for a selected slice. 
gif_path = ptu.create_inversion_animation(
    metrics=metrics,
    gpars=gpars,
    output_dir=par.path_output,
    filename='custom_animation.gif',  # Here. 
    axis='y',
    slice_index=gpars.dim[0] // 2,  # Middle slice
    fps=10,  # Here. 
    clim=np.array((-110,50)), 
    padding_start=1,
    padding_end=-1,
    decimate=4,  # Here. 
    dpi=150,  # Here. 
    filename_pattern='m_curr_',
    max_iteration=np.max(metrics.it_accepted_model)
)

# Show GIF.
Image(filename=gif_path)

In [None]:
# Visualisation: GIF of moving slice through the last model. 
reload(ptu)
from IPython.display import Image

gif_path = ptu.create_model_flythrough_gif(
    mvars.m_curr, gpars,
    par.path_output + '/flythrough_y.gif',
    axis='y',
    decimate=2,
    fps=8,
    dpi=150,
    title_prefix='Last model, slice'
)

display(Image(filename=gif_path))

In [None]:
# TODO plot last model! 
import src.utils.transd_utils as tu
reload(tu)
reload(ts)

# reload(ns)
# %load_ext line_profiler
# %lprun -f solve -f ts.pertubate_scal_fields -f tu.calc_signed_distances_opti -f om.save_model_to_vtk solve(par, log_run, sensit)# pydeps transd_solver.py --noshow -o graph.svg

In [None]:
# reload(ptu)
# from IPython.display import Image

# pth_output = par.path_output
# # pth_output = "C:\\Users\\00090846\\OneDrive - UWA\\Documents\\GitHub\\transd\\_output_example_test_workshop"
# file_name_gif = '_test_petro_force_decrease.gif'
# filename_rt = 'm_curr_'
# dim = gpars.dim
# # index = dim[1]//2,
# index = 20  # in theory the same as in the Pyr paper. 
# fps = 200.  # seconds between images. 
# padding_start = 9
# padding_end = -10
# decimate_files = 4

# x_misfit = metrics.it_accepted_model
# y_misfit = metrics.data_misfit[metrics.it_accepted_model]
# max_x_misfit = np.max(x_misfit)
# max_x_misfit = 450
# # max_x_misfit = np.max(len(x_misfit))
# accept_ratio = metrics.accept_ratio[metrics.it_accepted_model]  # TODO make it for all models! Not just the accepted ones!

# ptu.create_section_gif_from_vts(pth_output, file_name_gif, padding_start, padding_end, x_misfit, y_misfit, max_x_misfit, accept_ratio, gpars, axis='y',
#                             index=index, fps=fps, required_strings=filename_rt, files=decimate_files, dpi=150)

# Image(filename=file_name_gif)

In [None]:
# # Simple overview in one visualization
# def quick_project_view(root_dir='.'):
#     """Quick project dependency overview."""
    
#     # Scan
#     deps = {}
#     for f in Path(root_dir).rglob('*.py'):
#         if '__pycache__' not in str(f):
#             try:
#                 with open(f) as file:
#                     tree = ast.parse(file.read())
#                 imports = []
#                 for node in ast.walk(tree):
#                     if isinstance(node, ast.Import):
#                         imports.extend([n.name.split('.')[0] for n in node.names])
#                     elif isinstance(node, ast.ImportFrom) and node.module:
#                         imports.append(node.module.split('.')[0])
#                 deps[f.stem] = set(imports)
#             except:
#                 pass
    
#     # Graph
#     modules = set(deps.keys())
#     G = nx.DiGraph()
#     for m in modules:
#         for imp in deps[m]:
#             if imp in modules:
#                 G.add_edge(m, imp)
    
#     # Visualize
#     plt.figure(figsize=(16, 12))
#     pos = nx.spring_layout(G, k=2, iterations=50)
    
#     # Size by connections
#     sizes = [500 + 100 * (G.in_degree(n) + G.out_degree(n)) for n in G.nodes()]
    
#     nx.draw(G, pos, with_labels=True, node_size=sizes,
#             node_color='lightblue', font_size=10, font_weight='bold',
#             arrows=True, edge_color='gray', alpha=0.7, arrowsize=20)
    
#     plt.title("Project Dependency Graph", fontsize=16, fontweight='bold')
#     plt.tight_layout()
#     plt.show()
    
#     # Stats
#     print(f"üìä Modules: {G.number_of_nodes()}")
#     print(f"üìä Dependencies: {G.number_of_edges()}")
    
#     in_deg = dict(G.in_degree())
#     print(f"\nüèÜ Most imported: {max(in_deg, key=in_deg.get)} ({max(in_deg.values())} imports)")

# # Run
# quick_project_view('.')




# # Cell 1: Complete Project Analyzer
# import matplotlib.pyplot as plt
# import networkx as nx
# import ast
# from pathlib import Path
# from collections import defaultdict
# import numpy as np

# class ProjectDependencyAnalyzer:
#     """Analyze and visualize entire project structure."""
    
#     def __init__(self, root_dir='.'):
#         self.root_dir = root_dir
#         self.deps = {}
#         self.all_modules = set()
#         self.module_categories = {}
#         self._scan_files()
#         self._categorize_modules()
    
#     def _scan_files(self):
#         """Scan all Python files."""
#         class ImportVisitor(ast.NodeVisitor):
#             def __init__(self):
#                 self.imports = []
            
#             def visit_Import(self, node):
#                 for alias in node.names:
#                     self.imports.append(alias.name.split('.')[0])
            
#             def visit_ImportFrom(self, node):
#                 if node.module:
#                     self.imports.append(node.module.split('.')[0])
        
#         print(f"üîç Scanning {self.root_dir}...")
        
#         for py_file in Path(self.root_dir).rglob('*.py'):
#             if '__pycache__' in str(py_file) or '.ipynb_checkpoints' in str(py_file):
#                 continue
            
#             try:
#                 with open(py_file, 'r', encoding='utf-8') as f:
#                     tree = ast.parse(f.read())
                
#                 visitor = ImportVisitor()
#                 visitor.visit(tree)
                
#                 module_name = py_file.stem
#                 self.deps[module_name] = {
#                     'imports': set(visitor.imports),
#                     'path': py_file,
#                     'folder': py_file.parent.name
#                 }
                
#             except Exception as e:
#                 print(f"  ‚ö†Ô∏è  Error: {py_file.name}")
        
#         self.all_modules = set(self.deps.keys())
#         print(f"‚úì Found {len(self.all_modules)} modules\n")
    
#     def _categorize_modules(self):
#         """Categorize modules by folder or naming pattern."""
#         for module in self.all_modules:
#             folder = self.deps[module]['folder']
            
#             # Categorize by folder or file pattern
#             if 'core' in folder or 'inversion' in folder:
#                 category = 'Core Algorithm'
#             elif 'forward' in folder:
#                 category = 'Forward Modeling'
#             elif 'sampling' in folder or 'noise' in module:
#                 category = 'Sampling'
#             elif 'utils' in folder or 'quality' in module or 'plot' in module:
#                 category = 'Utilities'
#             elif 'test' in folder:
#                 category = 'Tests'
#             else:
#                 # Categorize by content
#                 if 'solver' in module:
#                     category = 'Core Algorithm'
#                 elif 'state' in module or 'metrics' in module:
#                     category = 'Core Algorithm'
#                 elif 'forward' in module or 'sensit' in module:
#                     category = 'Forward Modeling'
#                 elif 'sampling' in module or 'noise' in module:
#                     category = 'Sampling'
#                 else:
#                     category = 'Utilities'
            
#             self.module_categories[module] = category
    
#     def show_summary(self):
#         """Show project summary."""
#         print("="*70)
#         print("üìä PROJECT STRUCTURE SUMMARY")
#         print("="*70)
        
#         # Count by category
#         category_counts = defaultdict(int)
#         for cat in self.module_categories.values():
#             category_counts[cat] += 1
        
#         print(f"\nTotal modules: {len(self.all_modules)}")
#         print("\nBy category:")
#         for cat in sorted(category_counts.keys()):
#             print(f"  ‚Ä¢ {cat}: {category_counts[cat]} modules")
        
#         # Show internal dependencies
#         internal_edges = 0
#         external_imports = set()
        
#         for module in self.all_modules:
#             imports = self.deps[module]['imports']
#             for imp in imports:
#                 if imp in self.all_modules:
#                     internal_edges += 1
#                 else:
#                     external_imports.add(imp)
        
#         print(f"\nInternal dependencies: {internal_edges}")
#         print(f"External libraries used: {len(external_imports)}")
#         print(f"\nTop external libraries:")
        
#         # Count external imports
#         ext_counts = defaultdict(int)
#         for module in self.all_modules:
#             for imp in self.deps[module]['imports']:
#                 if imp not in self.all_modules:
#                     ext_counts[imp] += 1
        
#         for lib, count in sorted(ext_counts.items(), key=lambda x: x[1], reverse=True)[:10]:
#             print(f"  ‚Ä¢ {lib}: {count} imports")
    
#     def visualize_full_project(self, figsize=(20, 16), show_external=False):
#         """Visualize entire project structure."""
        
#         # Build graph
#         G = nx.DiGraph()
        
#         for module in self.all_modules:
#             G.add_node(module, category=self.module_categories[module])
            
#             imports = self.deps[module]['imports']
#             for imp in imports:
#                 if imp in self.all_modules:  # Internal only
#                     G.add_edge(module, imp)
#                 elif show_external:
#                     G.add_node(imp, category='External')
#                     G.add_edge(module, imp)
        
#         print("\n" + "="*70)
#         print("üé® GENERATING VISUALIZATION")
#         print("="*70)
#         print(f"Nodes: {G.number_of_nodes()}")
#         print(f"Edges: {G.number_of_edges()}")
        
#         # Create figure with multiple views
#         fig = plt.figure(figsize=figsize)
#         gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3)
        
#         # Color map by category
#         categories = list(set(self.module_categories.values()))
#         colors = plt.cm.Set3(np.linspace(0, 1, len(categories)))
#         color_map = dict(zip(categories, colors))
#         if show_external:
#             color_map['External'] = [0.9, 0.9, 0.9, 0.5]
        
#         node_colors = [color_map[G.nodes[node].get('category', 'Utilities')] 
#                       for node in G.nodes()]
        
#         # Node sizes based on connections
#         node_sizes = [300 + 50 * (G.in_degree(node) + G.out_degree(node)) 
#                      for node in G.nodes()]
        
#         # ========== PLOT 1: Spring Layout ==========
#         ax1 = fig.add_subplot(gs[0, 0])
#         ax1.set_title("Spring Layout (Force-Directed)", fontsize=12, fontweight='bold')
        
#         pos1 = nx.spring_layout(G, k=2, iterations=50, seed=42)
#         nx.draw_networkx_nodes(G, pos1, node_color=node_colors, 
#                               node_size=node_sizes, alpha=0.8, ax=ax1)
#         nx.draw_networkx_edges(G, pos1, edge_color='gray', 
#                               arrows=True, arrowsize=10, 
#                               alpha=0.3, width=1, ax=ax1)
#         nx.draw_networkx_labels(G, pos1, font_size=7, ax=ax1)
#         ax1.axis('off')
        
#         # ========== PLOT 2: Circular Layout ==========
#         ax2 = fig.add_subplot(gs[0, 1])
#         ax2.set_title("Circular Layout (By Category)", fontsize=12, fontweight='bold')
        
#         # Group by category
#         category_groups = defaultdict(list)
#         for node in G.nodes():
#             cat = G.nodes[node].get('category', 'Utilities')
#             category_groups[cat].append(node)
        
#         pos2 = nx.shell_layout(G, nlist=list(category_groups.values()))
#         nx.draw_networkx_nodes(G, pos2, node_color=node_colors, 
#                               node_size=node_sizes, alpha=0.8, ax=ax2)
#         nx.draw_networkx_edges(G, pos2, edge_color='gray', 
#                               arrows=True, arrowsize=10, 
#                               alpha=0.3, width=1, ax=ax2)
#         nx.draw_networkx_labels(G, pos2, font_size=7, ax=ax2)
#         ax2.axis('off')
        
#         # ========== PLOT 3: Hierarchical Layout ==========
#         ax3 = fig.add_subplot(gs[1, 0])
#         ax3.set_title("Hierarchical Layout (Flow)", fontsize=12, fontweight='bold')
        
#         try:
#             # Try hierarchical layout
#             pos3 = nx.nx_agraph.graphviz_layout(G, prog='dot')
#         except:
#             # Fallback to kamada_kawai
#             try:
#                 pos3 = nx.kamada_kawai_layout(G)
#             except:
#                 pos3 = nx.spring_layout(G, seed=42)
        
#         nx.draw_networkx_nodes(G, pos3, node_color=node_colors, 
#                               node_size=node_sizes, alpha=0.8, ax=ax3)
#         nx.draw_networkx_edges(G, pos3, edge_color='gray', 
#                               arrows=True, arrowsize=10, 
#                               alpha=0.3, width=1, ax=ax3)
#         nx.draw_networkx_labels(G, pos3, font_size=7, ax=ax3)
#         ax3.axis('off')
        
#         # ========== PLOT 4: Category Subgraphs ==========
#         ax4 = fig.add_subplot(gs[1, 1])
#         ax4.set_title("By Category (Grouped)", fontsize=12, fontweight='bold')
        
#         # Create layout with categories as subgraphs
#         pos4 = {}
#         y_offset = 0
#         for i, (cat, nodes) in enumerate(sorted(category_groups.items())):
#             subG = G.subgraph(nodes)
#             if len(nodes) > 0:
#                 sub_pos = nx.spring_layout(subG, k=1, iterations=30, seed=42)
#                 # Offset each category
#                 for node, (x, y) in sub_pos.items():
#                     pos4[node] = (x, y + y_offset)
#                 y_offset += 2
        
#         nx.draw_networkx_nodes(G, pos4, node_color=node_colors, 
#                               node_size=node_sizes, alpha=0.8, ax=ax4)
#         nx.draw_networkx_edges(G, pos4, edge_color='gray', 
#                               arrows=True, arrowsize=10, 
#                               alpha=0.3, width=1, ax=ax4)
#         nx.draw_networkx_labels(G, pos4, font_size=7, ax=ax4)
#         ax4.axis('off')
        
#         # ========== LEGEND ==========
#         from matplotlib.patches import Patch
#         legend_elements = [Patch(facecolor=color_map[cat], label=cat) 
#                           for cat in sorted(color_map.keys())]
#         fig.legend(handles=legend_elements, loc='upper center', 
#                   bbox_to_anchor=(0.5, 0.98), ncol=len(categories), 
#                   frameon=True, fontsize=10)
        
#         plt.suptitle("Project Dependency Architecture", 
#                     fontsize=16, fontweight='bold', y=0.99)
        
#         plt.tight_layout(rect=[0, 0, 1, 0.97])
#         plt.show()
        
#         return G
    
#     def show_dependency_matrix(self):
#         """Show dependency matrix as heatmap."""
#         import pandas as pd
        
#         # Create matrix
#         modules = sorted(self.all_modules)
#         matrix = np.zeros((len(modules), len(modules)))
        
#         module_to_idx = {m: i for i, m in enumerate(modules)}
        
#         for i, module in enumerate(modules):
#             imports = self.deps[module]['imports']
#             for imp in imports:
#                 if imp in module_to_idx:
#                     j = module_to_idx[imp]
#                     matrix[i, j] = 1
        
#         # Plot
#         fig, ax = plt.subplots(figsize=(16, 14))
        
#         im = ax.imshow(matrix, cmap='Blues', aspect='auto')
        
#         ax.set_xticks(range(len(modules)))
#         ax.set_yticks(range(len(modules)))
#         ax.set_xticklabels(modules, rotation=90, ha='right', fontsize=8)
#         ax.set_yticklabels(modules, fontsize=8)
        
#         ax.set_xlabel('Imported Module', fontsize=12, fontweight='bold')
#         ax.set_ylabel('Importing Module', fontsize=12, fontweight='bold')
#         ax.set_title('Dependency Matrix\n(Row imports Column)', 
#                     fontsize=14, fontweight='bold', pad=20)
        
#         # Add colorbar
#         cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
#         cbar.set_label('Import Relationship', rotation=270, labelpad=20)
        
#         # Add grid
#         ax.set_xticks(np.arange(len(modules))-0.5, minor=True)
#         ax.set_yticks(np.arange(len(modules))-0.5, minor=True)
#         ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)
        
#         plt.tight_layout()
#         plt.show()
    
#     def analyze_module_importance(self):
#         """Analyze which modules are most important."""
        
#         # Build graph
#         G = nx.DiGraph()
#         for module in self.all_modules:
#             for imp in self.deps[module]['imports']:
#                 if imp in self.all_modules:
#                     G.add_edge(module, imp)
        
#         print("\n" + "="*70)
#         print("üìà MODULE IMPORTANCE ANALYSIS")
#         print("="*70)
        
#         # PageRank (importance)
#         try:
#             pagerank = nx.pagerank(G)
#             print("\nüèÜ Most Important Modules (PageRank):")
#             for module, score in sorted(pagerank.items(), key=lambda x: x[1], reverse=True)[:10]:
#                 print(f"  {module}: {score:.4f}")
#         except:
#             print("\n‚ö†Ô∏è  Could not calculate PageRank")
        
#         # In-degree (most imported)
#         in_deg = dict(G.in_degree())
#         print("\nüì• Most Imported Modules:")
#         for module, count in sorted(in_deg.items(), key=lambda x: x[1], reverse=True)[:10]:
#             if count > 0:
#                 print(f"  {module}: imported by {count} modules")
        
#         # Out-degree (imports most)
#         out_deg = dict(G.out_degree())
#         print("\nüì§ Modules with Most Dependencies:")
#         for module, count in sorted(out_deg.items(), key=lambda x: x[1], reverse=True)[:10]:
#             if count > 0:
#                 print(f"  {module}: imports {count} modules")
        
#         # Betweenness centrality (bridges)
#         try:
#             betweenness = nx.betweenness_centrality(G)
#             print("\nüåâ Bridge Modules (High Betweenness):")
#             for module, score in sorted(betweenness.items(), key=lambda x: x[1], reverse=True)[:10]:
#                 if score > 0:
#                     print(f"  {module}: {score:.4f}")
#         except:
#             print("\n‚ö†Ô∏è  Could not calculate betweenness")
    
#     def find_circular_dependencies(self):
#         """Find circular dependencies."""
#         G = nx.DiGraph()
#         for module in self.all_modules:
#             for imp in self.deps[module]['imports']:
#                 if imp in self.all_modules:
#                     G.add_edge(module, imp)
        
#         print("\n" + "="*70)
#         print("üîÑ CIRCULAR DEPENDENCIES")
#         print("="*70)
        
#         try:
#             cycles = list(nx.simple_cycles(G))
#             if cycles:
#                 print(f"\n‚ö†Ô∏è  Found {len(cycles)} circular dependencies:\n")
#                 for i, cycle in enumerate(cycles[:10], 1):
#                     cycle_str = ' ‚Üí '.join(cycle + [cycle[0]])
#                     print(f"  {i}. {cycle_str}")
#                 if len(cycles) > 10:
#                     print(f"\n  ... and {len(cycles)-10} more")
#             else:
#                 print("\n‚úÖ No circular dependencies found!")
#         except Exception as e:
#             print(f"\n‚ö†Ô∏è  Error checking cycles: {e}")

# # Cell 2: Create analyzer
# analyzer = ProjectDependencyAnalyzer('.')

# # Cell 3: Show summary
# analyzer.show_summary()

# # Cell 4: Visualize full project
# G = analyzer.visualize_full_project(figsize=(20, 16))

# # Cell 5: Show dependency matrix
# analyzer.show_dependency_matrix()

# # Cell 6: Analyze importance
# analyzer.analyze_module_importance()

# # Cell 7: Check for circular dependencies
# analyzer.find_circular_dependencies()