# Overview of the CrystFEL-based Processing Workflow

This notebook implements a complete workflow for processing crystallography data using CrystFEL tools alongside custom scripts. It is designed to:
  
1. **Run Indexamajig** (`gandalf_iterator`)  
   - Execute peak finding, indexing, and integration for each HDF5 file.
   - Vary beam center coordinates on a radial grid (using defined maximum radius and step size).

2. **Visualize Indexing Results**  
   - Generate 3D histograms and 2D heatmaps to assess indexing performance.

3. **Evaluate Index Metrics** (`automate_evaluation`)  
   - Parse stream files to compute indexing quality metrics (IQMs) for each frame.
   - Analyze metrics such as weighted RMSD, fraction of outliers, length and angle deviations, peak ratio, and percentage indexed.

4. **Interactive IQM Metrics Dashboard**  
   - Use interactive sliders and weight inputs to filter and combine metrics.
   - Create/update a combined metric and display histograms for filtered results.

5. **CSV-to-Stream Conversion**  
   - Convert the filtered metrics CSV into a stream file using custom conversion scripts.

6. **Merge Best Results** (`merge`)  
   - Merge the best result stream file to refine cell parameters and symmetry.

7. **Convert to Shelx-Compatible .hkl**  
   - Transform the merged output to a Shelx-compatible format for further crystallographic analysis.

8. **Convert to .mtz Format**  
   - Prepare data for downstream crystallographic software by converting .hkl to .mtz.

> **Pre-requisites:**  
> Ensure that all preprocessing steps (peak finding, center refinement, etc.) have been completed and that the required tools and Python packages (CrystFEL, ipywidgets, matplotlib, etc.) are installed and properly configured before running the notebook.




# ==============================================
# Run indexamajig Iterations on a Circular Grid
### Options for Peakfinding, Indexing and Integration
# ==============================================


In [None]:
from gandalf_radial_iterator import gandalf_iterator

geomfile_path = "/home/bubl3932/files/UOX1/UOX.geom"       # .geom file
cellfile_path = "/home/bubl3932/files/UOX1/UOX.cell"          # .cell file

input_path =   "/home/bubl3932/files/UOX1/UOX1_min_15_peak/icf/uox1_min15p_icf_centershifted"      # .h5 folder will also be output folder

output_file_base = "UOX"    # output files will be named output_file_base_xcoord_ycoord.h5

num_threads = 24             # number of CPU threads to use

"""Define the grid and maximum radius in pixels for iterations.
As example max_radius = 1, step = 0.2 will give 81 iterations.
Iterations will start at the center and move radially outwards.
"""
max_radius = 1             # maximum radius in pixels
step = 0.5                 # grid granularity in pixels

extra_flags=[
# PEAKFINDING
"--no-revalidate",
"--no-half-pixel-shift",
"--peaks=cxi", 
"--min-peaks=15",
# INDEXING
"--indexing=xgandalf",
"--tolerance=10,10,10,5",
"--no-refine",
"--xgandalf-sampling-pitch=5",
"--xgandalf-grad-desc-iterations=1",
"--xgandalf-tolerance=0.02",
# INTEGRATION
"--integration=rings",
"--int-radius=2,5,10",
# "--fix-profile-radius=70000000",
# OUTPUT
"--no-non-hits-in-stream",
]

"""Examples of extra flags(see crystfel documentation https://www.desy.de/~twhite/crystfel/manual-indexamajig.html):
    
    Peakfinding
    "--peaks=cxi",
    "--peak-radius=inner,middle,outer",
    "--min-peaks=n",
    "--median-filter=n",
    "--filter-noise",
    "--no-revalidate",
    "--no-half-pixel-shift",

    "--peaks=peakfinder9",
    "--min-snr=1",
    "--min-snr-peak-pix=6",
    "--min-snr-biggest-pix=1",
    "--min-sig=9",
    "--min-peak-over-neighbour=5",
    "--local-bg-radius=5",

    "--peaks=peakfinder8",
    "--threshold=45",
    "--min-snr=3",
    "--min-pix-count=3",
    "--max-pix-count=500",
    "--local-bg-radius=9",
    "--min-res=30",
    "--max-res=500",
    
    Indexing
    "--indexing=xgandalf",

    "--tolerance=tol"
    "--no-check-cell",
    "--no-check-peaks",
    "--multi",
    "--no-retry",
    "--no-refine",

    "--xgandalf-sampling-pitch=n"
    "--xgandalf-grad-desc-iterations=n"
    "--xgandalf-tolerance=n"
    "--xgandalf-no-deviation-from-provided-cell"
    "--xgandalf-max-lattice-vector-length=n"
    "--xgandalf-min-lattice-vector-length=n"
    "--xgandalf-max-peaks=n"

    Integration
    "--fix-profile-radius=n",
    "--integration=rings",
    "--int-radius=4,5,10",
    "--push-res=n",
    "--overpredict",

    Output
    "--no-non-hits-in-stream",
    "--no-peaks-in-stream",
    "--no-refls-in-stream",
"""

gandalf_iterator(geomfile_path, cellfile_path, input_path, output_file_base, num_threads, max_radius=max_radius, step=step, extra_flags=extra_flags)


# Visualize Indexing Results: 3D Histogram & 2D Heatmap

In [None]:
%matplotlib qt
from indexing_3d_histogram import plot3d_indexing_rate
from indexing_center import indexing_heatmap

output_folder = "/home/bubl3932/files/UOX1/UOX1_min_15_peak/icf/uox1_min15p_icf_centershifted/xgandalf_iterations_max_radius_1_step_0.5"
plot3d_indexing_rate(output_folder)
indexing_heatmap(output_folder)


# ==============================================
# Process Indexing Metrics Across All Stream Files
# ==============================================

In [None]:
from process_indexing_metrics import process_indexing_metrics

# Enter folder with stream file results from indexamajig. 
# Note that ALL stream files in the folder will be processed.

output_folder = "/home/bubl3932/files/UOX1/UOX1_min_15_peak/icf/uox1_min15p_icf_centershifted/xgandalf_iterations_max_radius_1_step_0.5"
stream_file_folder = output_folder
wrmsd_tolerance = 2.0
indexing_tolerance = 4.0

"""
wrmsd_tolerance :
The number of standard deviations away from the mean weighted RMSD for a chunk to be considered an outlier. Default factor is 2.0.

indexing_tolerance :
The maximum deviation in pixels between observed and predicted peak positions for a peak to be considered indexed. Default is 1.0 pixel.

The following metrics will be evaluated for analysis in the next step:

- 'weighted_rmsd'
- 'fraction_outliers'
- 'length_deviation'
- 'angle_deviation'
- 'peak_ratio'
- 'percentage_indexed'

"""

process_indexing_metrics(stream_file_folder, wrmsd_tolerance=wrmsd_tolerance, indexing_tolerance=indexing_tolerance)

# ==============================================
# Interactive Metrics Analysis and CSV-to-Stream Conversion
# ==============================================

In [None]:
# Interactive metric analysis tool

%matplotlib qt

import ipywidgets as widgets
from IPython.display import display, Markdown
import matplotlib.pyplot as plt
import os
import time  # Only if you want to show progress bar updates slowly

# 1) Attempt to import your custom modules. If there's an error,
#    we'll catch it and display a message in an output widget.
import_failed = False
import_error_msg = ""

module_import_out = widgets.Output(layout={'border': '1px solid black', 'padding': '5px'})
with module_import_out:
    try:
        import csv_to_stream
        from interactive_iqm import (
            read_metric_csv,
            select_best_results_by_event,
            get_metric_ranges,
            create_combined_metric,
            filter_rows,
            write_filtered_csv
        )
        print("Successfully imported 'csv_to_stream' and 'interactive_iqm' modules.")
    except Exception as e:
        import_failed = True
        import_error_msg = str(e)
        print("Error importing modules:", e)

# This Output widget will display all feedback (printed messages, progress bars, etc.)
feedback_out = widgets.Output(layout={
    'border': '1px solid black',
    'height': '300px',
    'overflow_y': 'auto',
    'padding': '5px'
})

if not import_failed:
    ########################################################
    # 2) PATHS & Data Loading
    ########################################################
    CSV_PATH = "/home/bubl3932/files/UOX1/UOX1_min_15_peak/icf/uox1_min15p_icf_centershifted/xgandalf_iterations_max_radius_1_step_0.5/normalized_metrics.csv"
    FILTERED_CSV_PATH = os.path.join(os.path.dirname(CSV_PATH), 'filtered_metrics.csv')

    grouped_data = read_metric_csv(CSV_PATH, group_by_event=True)
    all_rows = [row for rows in grouped_data.values() for row in rows]

    metrics_in_order = [
        'weighted_rmsd',
        'fraction_outliers',
        'length_deviation',
        'angle_deviation',
        'peak_ratio',
        'percentage_unindexed'
    ]

    ########################################################
    # 3) SECTION 1: Separate Metrics Filtering
    ########################################################
    ranges_dict = get_metric_ranges(all_rows, metrics=metrics_in_order)
    metric_sliders = {}

    def create_slider(metric_name, min_val, max_val):
        default_val = max_val  # default is max to "include all"
        step = (max_val - min_val) / 100.0 if max_val != min_val else 0.01
        slider = widgets.FloatSlider(
            value=default_val,
            min=min_val,
            max=max_val,
            step=step,
            description=f"{metric_name} ≤",
            layout=widgets.Layout(width='95%')  # wide enough for 2 columns
        )
        return slider

    for metric in metrics_in_order:
        mn, mx = ranges_dict[metric]
        metric_sliders[metric] = create_slider(metric, mn, mx)

    # Arrange the sliders in a 2-column grid
    slider_box = widgets.GridBox(
        children=[metric_sliders[m] for m in metrics_in_order],
        layout=widgets.Layout(
            grid_template_columns="repeat(2, 300px)",
            grid_gap="10px 20px"
        )
    )

    filter_separate_button = widgets.Button(
        description="Apply Separate Metrics Thresholds",
        button_style='info'  
    )

    @feedback_out.capture(clear_output=False)  # Capture prints in feedback_out
    def on_filter_separate_clicked(_):
        print("\n" + "="*50)
        print("SEPARATE METRICS FILTERING")
        print("="*50)
        
        thresholds = {m: metric_sliders[m].value for m in metrics_in_order}
        filtered_separate = filter_rows(all_rows, thresholds)

        print(f"Filtering: {len(all_rows)} total rows -> {len(filtered_separate)} pass thresholds.")
        if not filtered_separate:
            print("No rows passed the thresholds.")
            return

        # 2 columns x 3 rows => 6 subplots for 6 metrics
        fig, axes = plt.subplots(3, 2, figsize=(12, 12))
        axes = axes.flatten()
        for i, metric in enumerate(metrics_in_order):
            values = [r[metric] for r in filtered_separate if metric in r]
            axes[i].hist(values, bins=20)
            axes[i].set_title(f"Histogram of {metric}")
            axes[i].set_xlabel(metric)
            axes[i].set_ylabel("Count")
        plt.tight_layout()
        plt.show()

    filter_separate_button.on_click(on_filter_separate_clicked)

    ########################################################
    # 4) SECTION 2: Combined Metric Creation & Filtering (Best Rows)
    ########################################################
    weight_text_fields = {}
    for metric in metrics_in_order:
        weight_text_fields[metric] = widgets.FloatText(
            value=0.0,
            description=f"{metric}",
            style={"description_width": "60px"},
            layout=widgets.Layout(width='150px')
        )

    # Arrange weight fields in a 2-column grid
    weights_box = widgets.GridBox(
        children=[weight_text_fields[m] for m in metrics_in_order],
        layout=widgets.Layout(
            grid_template_columns="repeat(2, 200px)",
            grid_gap="10px 20px"
        )
    )

    combined_metric_slider = widgets.FloatSlider(
        value=0.0,
        min=0.0,
        max=1.0,
        step=0.01,
        description="threshold ≤",
        layout=widgets.Layout(width='300px')  # narrower slider
    )

    create_combined_button = widgets.Button(
        description="Create Combined Metric",
        button_style='primary'
    )

    @feedback_out.capture(clear_output=False)
    def create_or_update_combined_metric(_):
        print("\n" + "="*50)
        print("COMBINED METRIC CREATION")
        print("="*50)

        selected_metrics = []
        weights_list = []
        for m in metrics_in_order:
            w = weight_text_fields[m].value
            selected_metrics.append(m)
            weights_list.append(w)

        # Compute combined_metric for all_rows
        create_combined_metric(
            rows=all_rows,
            metrics_to_combine=selected_metrics,
            weights=weights_list,
            new_metric_name="combined_metric"
        )

        combined_vals = [r["combined_metric"] for r in all_rows if "combined_metric" in r]
        if combined_vals:
            cmin, cmax = min(combined_vals), max(combined_vals)
            current_val = combined_metric_slider.value
            if current_val < cmin or current_val > cmax:
                current_val = cmax

            with combined_metric_slider.hold_trait_notifications():
                combined_metric_slider.min = cmin
                combined_metric_slider.max = cmax
                combined_metric_slider.value = current_val

            print("Combined metric created successfully!")
            print(f"  * Min value: {cmin:.3f}")
            print(f"  * Max value: {cmax:.3f}")
            print("Adjust the slider below and click 'Apply Combined Metric Threshold (Best Rows)' to filter.")
        else:
            print("Failed to create combined metric. Check your weights.")

    create_combined_button.on_click(create_or_update_combined_metric)

    filter_combined_button = widgets.Button(
        description="Apply Combined Metric Threshold (Best Rows)",
        button_style='info'
    )
    
    @feedback_out.capture(clear_output=False)
    def on_filter_combined_clicked(_):
        print("\n" + "="*50)
        print("COMBINED METRIC FILTERING")
        print("="*50)

        threshold = combined_metric_slider.value
        filtered_combined = [r for r in all_rows if "combined_metric" in r and r["combined_metric"] <= threshold]

        print(f"Filtering rows by combined_metric ≤ {threshold:.3f}")
        if not filtered_combined:
            print("No rows passed the combined metric threshold.")
            return

        # Group the filtered rows by event
        grouped_filtered = {}
        for r in filtered_combined:
            event = r.get("event_number")
            if event not in grouped_filtered:
                grouped_filtered[event] = []
            grouped_filtered[event].append(r)

        best_filtered = select_best_results_by_event(grouped_filtered, sort_metric="combined_metric")

        print(f"{len(filtered_combined)} rows passed threshold, {len(best_filtered)} best rows selected per event.")

        # WRITE THE FILTERED CSV HERE:
        write_filtered_csv(best_filtered, FILTERED_CSV_PATH)
        print(f"Wrote {len(best_filtered)} best-filtered rows to {FILTERED_CSV_PATH}")

        plt.figure(figsize=(8, 6))
        values = [r["combined_metric"] for r in best_filtered]
        plt.hist(values, bins=20)
        plt.title("Histogram of Best Rows (combined_metric)")
        plt.xlabel("combined_metric")
        plt.ylabel("Count")
        plt.tight_layout()
        plt.show()

    filter_combined_button.on_click(on_filter_combined_clicked)

    ########################################################
    # 5) Convert to Stream (with progress bar & multi-line feedback)
    ########################################################
    convert_button = widgets.Button(
        description="Convert to Stream",
        button_style='success'
    )

    @feedback_out.capture(clear_output=False)
    def on_convert_clicked(_):
        print("\n" + "="*50)
        print("CONVERT TO STREAM")
        print("="*50)

        # Create a simple progress bar with 5 steps, plus final completion
        pb = widgets.IntProgress(
            value=0,
            min=0,
            max=5,
            step=1,
            description='Converting...',
            bar_style=''
        )
        display(pb)  # Show the progress bar inside feedback_out

        OUTPUT_STREAM_PATH = os.path.join(os.path.dirname(FILTERED_CSV_PATH), 'filtered_metrics.stream')

        print("\nStarting conversion...\n")
        time.sleep(0.2)  # short delay so you can see the progress step

        # Step 1: read filtered CSV
        pb.value = 1
        print("  * Step 1/5: Reading filtered CSV file...")
        filtered_grouped_data = read_metric_csv(FILTERED_CSV_PATH, group_by_event=True)
        time.sleep(0.2)

        # Step 2: check if combined metric exists
        pb.value = 2
        print("  * Step 2/5: Checking for combined metric, if present, picking best rows...")
        first_event = next(iter(filtered_grouped_data.values()))
        time.sleep(0.2)

        # Step 3: if combined_metric present, select best row and re-write CSV
        pb.value = 3
        if "combined_metric" in first_event[0]:
            best_filtered = select_best_results_by_event(filtered_grouped_data, sort_metric="combined_metric")
            write_filtered_csv(best_filtered, FILTERED_CSV_PATH)
            print("      - Best rows selected & CSV overwritten.")
        else:
            print("      - No combined_metric found, skipping best-row selection.")
        time.sleep(0.2)

        # Step 4: write the .stream file
        pb.value = 4
        print("  * Step 4/5: Writing the .stream file...")
        csv_to_stream.write_stream_from_filtered_csv(
            filtered_csv_path=FILTERED_CSV_PATH,
            output_stream_path=OUTPUT_STREAM_PATH,
            event_col="event_number",
            streamfile_col="stream_file"
        )
        time.sleep(0.2)

        # Step 5: done
        pb.value = 5
        print("  * Step 5/5: Conversion complete!\n")
        print(f"CSV has been successfully converted to:\n  {OUTPUT_STREAM_PATH}")

    convert_button.on_click(on_convert_clicked)

    ########################################################
    # 6) Lay Out All Widgets
    ########################################################

    # --- SECTION 1: Separate metrics filtering ---
    separate_control_panel = widgets.VBox([
        widgets.HTML("<h3>Separate Metrics Filtering</h3>"),
        slider_box,             # The 2-column grid of sliders
        filter_separate_button
    ])

    # --- SECTION 2: Combined metrics creation & filtering ---
    combined_control_panel = widgets.VBox([
        widgets.HTML("<h3>Combined Metric Creation & Filtering</h3>"),
        widgets.HTML("Enter weights for each metric:"),
        weights_box,            # The 2-column grid of weight text fields
        widgets.HTML("<hr style='margin:10px 0;'>"),
        create_combined_button,
        combined_metric_slider,
        widgets.HTML("<hr style='margin:10px 0;'>"),
        widgets.HBox([filter_combined_button, convert_button])
    ])

    final_layout = widgets.VBox([
        module_import_out,
        separate_control_panel,
        combined_control_panel,
        widgets.HTML("<h3>Feedback & Logs</h3>"),
        feedback_out  # The big output area for all text messages
    ])
    display(final_layout)

else:
    # If the imports failed, just display an error message.
    display(module_import_out)
    display(Markdown(f"**Could not load your custom modules.**\n```\n{import_error_msg}\n```"))
    display(Markdown("Please fix the import error above, then re-run the cell."))


VBox(children=(Output(layout=Layout(border_bottom='1px solid black', border_left='1px solid black', border_rig…

# ==============================================
# Interactive Merging, SHELX Conversion, and MTZ Conversion
# ==============================================

In [None]:
# Interactive Merging and Conversion tool

import ipywidgets as widgets
from IPython.display import display
import os
import time  # Only if you want to simulate progress bar delays

# Try to import file chooser widget
try:
    from ipyfilechooser import FileChooser
except ImportError:
    print("ipyfilechooser not found. Install via: pip install ipyfilechooser")

# Try to import your custom modules
import_failed = False
import_error_msg = ""
modules_import_out = widgets.Output(layout={'border': '1px solid black', 'padding': '5px'})
with modules_import_out:
    try:
        from merge import merge
        from convert_hkl_crystfel_to_shelx import convert_hkl_crystfel_to_shelx 
        from convert_hkl_to_mtz import convert_hkl_to_mtz
        print("Successfully imported merge, convert_hkl_crystfel_to_shelx, and convert_hkl_to_mtz modules.")
    except Exception as e:
        import_failed = True
        import_error_msg = str(e)
        print("Error importing modules:", e)

# A single output widget to capture feedback from all operations
feedback_out = widgets.Output(
    layout={
        'border': '1px solid black',
        'height': '300px',
        'overflow_y': 'auto',
        'padding': '5px'
    }
)

global_output_dir = None  # Global variable to store merged output directory

if not import_failed:
    #################################
    # 1) Merging Section
    #################################
    # File chooser for selecting the stream file
    stream_file_chooser = FileChooser(os.getcwd())
    stream_file_chooser.title = 'Select .stream File'
    stream_file_chooser.filter_pattern = '*.stream'  # Only show .stream files

    pointgroup_widget = widgets.Text(
        value="",
        description="Pointgroup:",
        style={"description_width": "150px"}
    )
    num_threads_widget = widgets.IntText(
        value=24,
        description="Num Threads:",
        style={"description_width": "150px"}
    )
    iterations_widget = widgets.IntText(
        value=5,
        description="Iterations:",
        style={"description_width": "150px"}
    )

    merge_button = widgets.Button(
        description="Merge",
        button_style='warning'
    )

    @feedback_out.capture(clear_output=False)
    def on_merge_clicked(b):
        """Handles merging with Partialator, prints feedback & uses a simple progress bar."""
        global global_output_dir

        print("\n" + "="*50)
        print("MERGING SECTION")
        print("="*50)

        stream_file = stream_file_chooser.selected
        pointgroup = pointgroup_widget.value
        num_threads = num_threads_widget.value
        iterations = iterations_widget.value

        if not stream_file:
            print("Please select a .stream file first.")
            return

        # Optional progress bar for merging
        pb_merge = widgets.IntProgress(
            value=0,
            min=0,
            max=3,
            step=1,
            description='Merging...',
            bar_style=''
        )
        display(pb_merge)

        print("Merging in progress...")
        pb_merge.value = 1
        time.sleep(0.2)  # Delay so you can see the progress bar

        output_dir = merge(
            stream_file,
            pointgroup=pointgroup,
            num_threads=num_threads,
            iterations=iterations,
        )
        pb_merge.value = 2
        time.sleep(0.2)

        if output_dir is not None:
            print("Merging done. Results are in:", output_dir)
            global_output_dir = output_dir
        else:
            print("Merging failed. Please check the parameters and try again.")

        pb_merge.value = 3
        print("Done merging.")

    merge_button.on_click(on_merge_clicked)

    merge_controls = widgets.VBox([
        widgets.HTML("<h3>Merging Parameters</h3>"),
        stream_file_chooser,
        pointgroup_widget,
        num_threads_widget,
        iterations_widget,
        merge_button
    ])

    #################################
    # 2) SHELX Conversion Section
    #################################
    shelx_button = widgets.Button(
        description="Convert to SHELX",
        button_style='primary'
    )

    @feedback_out.capture(clear_output=False)
    def on_shelx_clicked(b):
        print("\n" + "="*50)
        print("SHELX CONVERSION")
        print("="*50)

        if global_output_dir is None:
            print("No merged output available. Please run the merge step first.")
            return

        print("Converting to SHELX...")
        convert_hkl_crystfel_to_shelx(global_output_dir)
        print("Conversion to SHELX completed.")

    shelx_button.on_click(on_shelx_clicked)

    shelx_controls = widgets.VBox([
        widgets.HTML("<h3>SHELX Conversion</h3>"),
        shelx_button
    ])

    #################################
    # 3) MTZ Conversion Section
    #################################
    cell_file_chooser = FileChooser(os.getcwd())
    cell_file_chooser.title = 'Select Cell File'

    mtz_button = widgets.Button(
        description="Convert to MTZ",
        button_style='success'
    )

    @feedback_out.capture(clear_output=False)
    def on_mtz_clicked(b):
        print("\n" + "="*50)
        print("MTZ CONVERSION")
        print("="*50)

        if global_output_dir is None:
            print("No merged output available. Please run the merge step first.")
            return

        cellfile_path = cell_file_chooser.selected
        if not cellfile_path:
            print("Please select a cell file first.")
            return

        print("Converting to MTZ...")
        convert_hkl_to_mtz(global_output_dir, cellfile_path=cellfile_path)
        print("Conversion to MTZ completed.")

    mtz_button.on_click(on_mtz_clicked)

    mtz_controls = widgets.VBox([
        widgets.HTML("<h3>MTZ Conversion</h3>"),
        cell_file_chooser,
        mtz_button
    ])

    #################################
    # 4) Display All Controls
    #################################
    # We'll show:
    #  - Our modules import output
    #  - The three sections of controls
    #  - The feedback_out area for logs
    controls_layout = widgets.VBox([
        modules_import_out,
        widgets.HTML("<h2>Interactive Merging & Conversion Tool</h2>"),
        merge_controls,
        shelx_controls,
        mtz_controls,
        widgets.HTML("<h3>Feedback & Logs</h3>"),
        feedback_out
    ])

    display(controls_layout)

else:
    # If the imports failed, just display an error message.
    display(modules_import_out)
    print("Could not load your modules:")
    print(import_error_msg)


VBox(children=(Output(layout=Layout(border_bottom='1px solid black', border_left='1px solid black', border_rig…

# ==============================================
# Refine Merging Results Using REFMAC5
# ==============================================

In [3]:
# Cell: Interactive Refmac5 Refinement + Parsing & Plotting the Rf_used Table
%matplotlib qt

import ipywidgets as widgets
from IPython.display import display, Markdown
import matplotlib.pyplot as plt
import numpy as np
import os
import re
import time

# Try to import file chooser widget
try:
    from ipyfilechooser import FileChooser
except ImportError:
    print("ipyfilechooser is required. Install with: pip install ipyfilechooser")

# Try to import your refinement function
import_failed = False
import_error_msg = ""
refmac_import_out = widgets.Output(layout={'border': '1px solid black', 'padding': '5px'})
with refmac_import_out:
    try:
        from ctruncate_freerflag_refmac5 import ctruncate_freerflag_refmac5  # <-- update module name/path if needed
        print("Successfully imported ctruncate_freerflag_refmac5.")
    except Exception as e:
        import_failed = True
        import_error_msg = str(e)
        print("Error importing ctruncate_freerflag_refmac5:", e)

# Define a function to parse the last table from refmac5.log
def parse_refmac_log_for_table(log_path):
    """
    Opens refmac5.log at log_path, finds the last table that contains the header
    "M(4SSQ/LL)" and "Rf_used", and returns two lists:
      - resolution_list (in Å) computed as 1/(first column)
      - rf_used_list (from the 6th column)
    """
    resolution_list = []
    rf_used_list = []
    if not os.path.isfile(log_path):
        return resolution_list, rf_used_list

    with open(log_path, 'r') as f:
        lines = f.readlines()

    # Find the last occurrence of the header line
    header_indices = []
    for i, line in enumerate(lines):
        if "M(4SSQ/LL)" in line and "Rf_used" in line:
            header_indices.append(i)
    if not header_indices:
        return resolution_list, rf_used_list
    start_index = header_indices[-1]
    
    # Find the next line that is exactly "$$" which marks the end of header block
    for j in range(start_index, len(lines)):
        if lines[j].strip() == "$$":
            start_index = j + 1
            break

    end_index = None
    for j in range(start_index, len(lines)):
        if lines[j].strip() == "$$":
            end_index = j
            break
    if end_index is None:
        return resolution_list, rf_used_list

    raw_table_lines = lines[start_index:end_index]
    for line in raw_table_lines:
        parts = re.split(r"\s+", line.strip())
        if len(parts) < 6:
            continue
        try:
            col1_val = float(parts[0])
            col6_val = float(parts[5])
            # Avoid division by zero
            if col1_val != 0:
                res = np.sqrt(1.0 / col1_val)
            else:
                res = None
            if res is not None:
                resolution_list.append(res)
                rf_used_list.append(col6_val)
        except ValueError:
            continue

    return resolution_list, rf_used_list

# Create an output widget for all feedback and logs
refmac_feedback_out = widgets.Output(layout={
    'border': '1px solid black',
    'height': '350px',
    'overflow_y': 'auto',
    'padding': '5px'
})

if not import_failed:
    # File choosers for input files
    mtz_file_chooser = FileChooser(os.getcwd())
    mtz_file_chooser.title = 'Select .mtz File'
    mtz_file_chooser.filter_pattern = '*.mtz'
    
    pdb_file_chooser = FileChooser(os.getcwd())
    pdb_file_chooser.title = 'Select .pdb File'
    pdb_file_chooser.filter_pattern = '*.pdb'
    
    # Extra parameter widgets
    max_res_widget = widgets.FloatText(
        value=20.0,
        description="max_res:",
        style={"description_width": "80px"},
        layout=widgets.Layout(width='200px')
    )
    min_res_widget = widgets.FloatText(
        value=1.5,
        description="min_res:",
        style={"description_width": "80px"},
        layout=widgets.Layout(width='200px')
    )
    ncycles_widget = widgets.IntText(
        value=30,
        description="ncycles:",
        style={"description_width": "80px"},
        layout=widgets.Layout(width='200px')
    )
    bins_widget = widgets.IntText(
        value=10,
        description="bins:",
        style={"description_width": "80px"},
        layout=widgets.Layout(width='200px')
    )
    
    extra_params_box = widgets.HBox([max_res_widget, min_res_widget, ncycles_widget, bins_widget])
    
    # Refine button
    refine_button = widgets.Button(
        description="Refine with Refmac5 (and Plot)",
        button_style='info'
    )
    
    @refmac_feedback_out.capture(clear_output=False)
    def on_refine_clicked(b):
        print("\n" + "="*50)
        print("REFMAC5 REFINEMENT + TABLE PARSING & PLOTTING")
        print("="*50)
        
        mtz_file = mtz_file_chooser.selected
        pdb_file = pdb_file_chooser.selected
        
        if not mtz_file:
            print("Please select an MTZ file first.")
            return
        if not pdb_file:
            print("Please select a PDB file first.")
            return
        
        max_res = max_res_widget.value
        min_res = min_res_widget.value
        ncycles = ncycles_widget.value
        bins_ = bins_widget.value
        
        print(f"Running refinement with parameters:\n  MTZ: {mtz_file}\n  PDB: {pdb_file}\n  max_res: {max_res}\n  min_res: {min_res}\n  ncycles: {ncycles}\n  bins: {bins_}")
        
        # Run the refinement function; it should return a string output directory
        output_dir = ctruncate_freerflag_refmac5(mtz_file, pdb_file, max_res=max_res, min_res=min_res, ncycles=ncycles, bins=bins_)
        
        print("Refinement completed.")
        if output_dir is None:
            print("No output directory returned. Cannot locate refmac5.log for plotting.")
            return
        
        log_file_path = os.path.join(output_dir, "refmac5.log")
        if not os.path.isfile(log_file_path):
            print(f"refmac5.log not found at {log_file_path}. Skipping plot.")
            return
        
        resolution_list, rf_used_list = parse_refmac_log_for_table(log_file_path)
        if not resolution_list:
            print("No valid table found in refmac5.log, or columns didn't parse. Skipping plot.")
            return
        
        # Sort the data by resolution (ascending)
        sorted_pairs = sorted(zip(resolution_list, rf_used_list), key=lambda x: x[0])
        sorted_res, sorted_rf = zip(*sorted_pairs)
        
        plt.figure(figsize=(6, 4))
        plt.plot(sorted_res, sorted_rf, marker='o', linestyle='-')
        plt.xlabel("Resolution (Å)")
        plt.ylabel("Rf_used")
        plt.title("Rf_used vs. Resolution")
        plt.grid(True)
        plt.gca().invert_xaxis()  # Flip the x-axis so high resolutions appear on the left
        plt.tight_layout()
        plt.show()
    
    refine_button.on_click(on_refine_clicked)
    
    refine_controls = widgets.VBox([
        widgets.HTML("<h3>Refmac5 Refinement (with Table Parsing & Plot)</h3>"),
        mtz_file_chooser,
        pdb_file_chooser,
        widgets.HTML("<h4>Optional Parameters</h4>"),
        extra_params_box,
        refine_button
    ])
    
    final_layout = widgets.VBox([
        refmac_import_out,
        widgets.HTML("<h2>Refmac5 Refinement & Plot Tool</h2>"),
        refine_controls,
        widgets.HTML("<h3>Logs & Feedback</h3>"),
        refmac_feedback_out
    ])
    
    display(final_layout)
else:
    display(refmac_import_out)
    print("Could not load ctruncate_freerflag_refmac5:")
    print(import_error_msg)


VBox(children=(Output(layout=Layout(border_bottom='1px solid black', border_left='1px solid black', border_rig…

# Merging Filtered Stream File Using Partialator (Non-Interactive)

In [None]:
from merge import merge

stream_file = "/home/bubl3932/files/UOX1/UOX1_original_IQM/xgandalf_iterations_max_radius_1_step_0.5/filtered_metrics.stream"
pointgroup = "mmm"
num_threads = 24
iterations = 5

output_dir = merge(
    stream_file,
    pointgroup=pointgroup,
    num_threads=num_threads,
    iterations=iterations,
)

if output_dir is not None:
    print("Merging done. Results are in:", output_dir)

# Convert to SHELX Compatible .hkl (Non-Interactive)

In [None]:
from convert_hkl_crystfel_to_shelx import convert_hkl_crystfel_to_shelx 
# output_dir = "" # If defined above comment out this line
convert_hkl_crystfel_to_shelx(output_dir)

# Convert to mtz (Non-Interactive)

In [None]:
from convert_hkl_to_mtz import convert_hkl_to_mtz
# output_dir = "" # If defined above comment out this line
cellfile_path = "/home/bubl3932/files/UOX1/UOX1_original/UOX.cell"  # If defined above comment out this line
convert_hkl_to_mtz(output_dir, cellfile_path=cellfile_path)