In [1]:
# Built-in imports:
from typing import List
from itertools import islice

# Import GravyFlow:
import gravyflow as gf

# Dependency imports: 
from bokeh.io import show, output_notebook
from bokeh.layouts import gridplot, column

## Obtaining Transient Events

To acquire data from specific gravitational wave events (transients), use `gf.TransientObtainer`. This works similarly to `gf.NoiseObtainer` but is specifically designed for acquiring data around known event times.

### TransientObtainer

**Parameters:**

- `ifo_data_obtainer` : `gf.IFODataObtainer` (**required**):
  > The IFODataObtainer configured for transient acquisition. Unlike `NoiseObtainer`, this parameter is mandatory. The `data_labels` should include `gf.DataLabel.EVENTS` or `gf.DataLabel.GLITCHES` (not `gf.DataLabel.NOISE`).

- `ifos` : Union[`gf.IFO`, List[`gf.IFO`]] = `[gf.IFO.L1]`:
  > List of interferometers to acquire data from.

- `event_names` : Union[str, List[str]] = None:
  > Optional event name(s) to fetch (e.g., `"GW150914"` or `["GW150914", "GW170817"]`). If set, only data for these specific events will be returned, superseding the default behavior of returning all events. Event names must match those in GWTC catalogs.

- `event_types` : `List[gf.EventType]` = `[gf.EventType.CONFIDENT]`:
  > Filter by event confidence.
  > **Options:**
  > - `gf.EventType.CONFIDENT`: Confirmed detections (Default).
  > - `gf.EventType.MARGINAL`: Marginal triggers/candidates.

- `data_labels` : List[`gf.DataLabel`] = `[gf.DataLabel.EVENTS]`:
  > Specifies which transient types to include. Must NOT include `gf.DataLabel.NOISE` (raises `ValueError`). For noise acquisition, use `gf.NoiseObtainer` instead.

- `groups` : dict = `{"all": 1.0}`:
  > Group splits for data partitioning. Defaults to a single "all" group (no train/val/test split), which is typical for transient evaluation.

In [None]:
import gravyflow as gf
print(gf.__file__)
import inspect
print(inspect.signature(gf.IFODataObtainer))


In [None]:
# The new unified IFODataObtainer factory handles everything.
# It returns a TransientDataObtainer when DataLabel.NOISE is not present.
transient_obtainer = gf.IFODataObtainer(
    observing_runs=gf.ObservingRun.O3,
    data_quality=gf.DataQuality.BEST,
    data_labels=[gf.DataLabel.EVENTS],
    force_acquisition=True,               # Force the acquisition of new data.
    cache_segments=False,                 # Choose not to cache the segments.
    event_names=["GW150914", "GW170817"]  # Now passed directly here
)

## Searching for Events

GravyFlow provides a powerful `search_events` function to filter gravitational wave events from GWTC catalogs based on astrophysical properties, observing runs, and more.

**Parameters:**

- `source_type` : `Union[gf.SourceType, str]` = `None`:
  > Filter by astrophysical source type. 
  > **Enums (Recommended):**
  > - `gf.SourceType.BBH`: Binary Black Hole (both masses ≥ 3 M☉)
  > - `gf.SourceType.BNS`: Binary Neutron Star (both masses < 3 M☉)
  > - `gf.SourceType.NSBH`: Neutron Star - Black Hole (one < 3 M☉, one ≥ 3 M☉)
  >
  > **Strings (Supported):** `"BBH"`, `"BNS"`, `"NSBH"` (case-insensitive).

- `observing_runs` : `List[gf.ObservingRun]` = `None`:
  > Filter by specific observing runs (e.g., `[gf.ObservingRun.O3]`).

- `mass1_range` : `tuple` = `None`:
  > (min, max) range for primary mass in solar masses. Use `None` for unbounded limits. 
  > *Example:* `(30, None)` finds events with m1 > 30 M☉.

- `mass2_range` : `tuple` = `None`:
  > (min, max) range for secondary mass in solar masses.

- `total_mass_range` : `tuple` = `None`:
  > (min, max) range for total system mass (m1 + m2).

- `distance_range` : `tuple` = `None`:
  > (min, max) range for luminosity distance in Mpc.
  > *Example:* `(None, 500)` finds events closer than 500 Mpc.

- `name_contains` : `str` = `None`:
  > Substring to search for in the event name (case-insensitive).
  > *Example:* `"GW17"` matches all 2017 events.

**Returns:**
- `List[str]`: A list of event names matching all specified conditions.

In [None]:
### Examples

#### 1. Filter by Source Type (Using Enums)

# Find all Binary Neutron Star events
bns_events = gf.search_events(source_type=gf.SourceType.BNS)
print(bns_events)
# Output: ['GW170817', 'GW190425']

#### 2. Filter by Observing Run
# Find all Binary Black Holes in O3
o3_bbh = gf.search_events(
    source_type=gf.SourceType.BBH,
    observing_runs=[gf.ObservingRun.O3]
)

#### 3. Complex Physical Queries
# Find heavy BBHs (Total Mass > 80 M☉) that are relatively close (< 1000 Mpc)
heavy_nearby = gf.search_events(
    source_type=gf.SourceType.BBH,
    total_mass_range=(80, None),
    distance_range=(None, 1000)
)

#### 4. Search by Name
# Find all events from 2017
events_2017 = gf.search_events(name_contains="GW17")


In [None]:
print(len(gf.search_events(observing_runs=[gf.ObservingRun.O1])))
print(len(gf.search_events(observing_runs=[gf.ObservingRun.O2])))
print(len(gf.search_events(observing_runs=[gf.ObservingRun.O3])))
print(len(gf.search_events(observing_runs=[gf.ObservingRun.O4])))

In [None]:
# The unified factory now handles specific event names directly.
# and ifos is passed when calling the obtainer instance.
batch = next(gf.IFODataObtainer(
    observing_runs=None,
    data_quality=gf.DataQuality.BEST,
    data_labels=[gf.DataLabel.EVENTS],
    force_acquisition=True,
    cache_segments=False,
    event_names=["GW150914", "GW170817"]  # Now passed directly to the factory
)(
    ifos=[gf.IFO.H1, gf.IFO.L1], # IFO selection moved here
    scale_factor=1, 
    whiten=True, 
    crop=True
))

# Extract from dict (same as before)
onsource = batch[gf.ReturnVariables.ONSOURCE]
offsource = batch[gf.ReturnVariables.OFFSOURCE]
gps_times = batch[gf.ReturnVariables.TRANSIENT_GPS_TIME]

In [None]:
gw150914_plot = gf.generate_strain_plot(
    {"Onsource Noise": onsource[0]},
    title=[
        f"L1 Onsource GW150914",
        f"H1 Onsource GW150914",
    ]
)

gw170817_plot = gf.generate_strain_plot(
    {"Onsource Noise": onsource[1]},
    title=[
        f"L1 Onsource GW170817",
        f"H1 Onsource GW170817",
    ]
)


grid = gridplot([[gw150914_plot], [gw170817_plot]])
output_notebook()
show(grid)


# Glitch Acquisition

In [None]:
print("Available Glitch Types:")
glitch_types = list(gf.GlitchType)
for glitch_type in glitch_types:
    print(f"  - {glitch_type.name}: '{glitch_type.value}'")

num_glitch_types = len(glitch_types)
print(f"\nTotal: {num_glitch_types} glitch types")

In [None]:
# Configure for glitch acquisition
# IFODataObtainer acts as a factory and returns a TransientDataObtainer 
# because valid data_labels (GLITCHES) are provided.
glitch_obtainer = gf.IFODataObtainer(
    data_quality=gf.DataQuality.BEST,
    data_labels=[gf.DataLabel.GLITCHES],
    observing_runs=[gf.ObservingRun.O3],
    saturation=1.0
)

# Create generator - get enough samples for all glitch types
glitch_generator = glitch_obtainer(
    sample_rate_hertz=2048.0,
    onsource_duration_seconds=1.0,
    offsource_duration_seconds=16.0,
    num_examples_per_batch=num_glitch_types,  # One for each type
    ifos=[gf.IFO.L1],     # IFOs must be passed here
    scale_factor=1.0,     # No pre-scaling needed, whitening handles it
    seed=42,
    crop=True,            # Remove padding from onsource
    whiten=True           # Apply whitening
)

# Get batch of glitches
print("\nAcquiring glitches...")
try:
    # Glitch generator returns dict
    batch = next(glitch_generator)
    onsource = batch[gf.ReturnVariables.ONSOURCE]
    offsource = batch[gf.ReturnVariables.OFFSOURCE]
    gps_times = batch[gf.ReturnVariables.TRANSIENT_GPS_TIME]
    label = batch.get(gf.ReturnVariables.GLITCH_TYPE)
    print(f"Acquired {onsource.shape[0]} glitch samples")
    print(f"Onsource shape: {onsource.shape}")
    print(f"Offsource shape: {offsource.shape}")
except Exception as e:
    print(f"Error acquiring glitches: {e}")
    # Optional: Print traceback to see details
    import traceback
    traceback.print_exc()
    onsource = None

Generating One-of-Each Glitch Plots (Optimized)...
DEBUG: Found 6609 Air_Compressor glitches for L1
DEBUG: Found 25013 Blip glitches for L1
DEBUG: Found 7291 Extremely_Loud glitches for L1
DEBUG: Found 758 Helix glitches for L1
DEBUG: Found 14052 Koi_Fish glitches for L1
DEBUG: Found 905 Light_Modulation glitches for L1
DEBUG: Found 19829 Low_Frequency_Burst glitches for L1
DEBUG: Found 14931 Low_Frequency_Lines glitches for L1
DEBUG: Found 26778 None_of_the_Above glitches for L1
DEBUG: Found 5584 Paired_Doves glitches for L1
DEBUG: Found 2669 Power_Line glitches for L1
DEBUG: Found 2362 Repeating_Blips glitches for L1
DEBUG: Found 89715 Scattered_Light glitches for L1
DEBUG: Found 294 Scratchy glitches for L1
DEBUG: Found 28412 Tomte glitches for L1
DEBUG: Found 2171 Violin_Mode glitches for L1
DEBUG: Found 6596 Whistle glitches for L1


KeyboardInterrupt: 

In [None]:
import numpy as np
import gravyflow as gf
from bokeh.layouts import column
from bokeh.io import output_notebook, show

print("Generating One-of-Each Glitch Plots (Optimized)...")

# 1. Initialize ONE obtainer for ALL types (Metadata query & Cache load happens ONCE)
obtainer = gf.IFODataObtainer(
    data_quality=gf.DataQuality.BEST,
    data_labels=[gf.DataLabel.GLITCHES], 
    observing_runs=[gf.ObservingRun.O3], 
    saturation=1.0
)

# Initialize internal structures manually to avoid full iteration overhead
# CORRECTED METHOD NAME: build_feature_index
index = obtainer.build_feature_index(ifos=[gf.IFO.L1], seed=42)
cache, _ = obtainer._initialize_transient_cache([gf.IFO.L1], 2048.0, 1.0, 16.0)

# 2. Select one example of each target type
target_types = [
    gt for gt in gf.GlitchType 
    if gt not in [gf.GlitchType.CHIRP, gf.GlitchType.NO_GLITCH, gf.GlitchType.WANDERING_LINE]
]
examples = {}

for segment in index.iter(shuffle=True, seed=42):
    if segment.kind in target_types and segment.kind not in examples:
        examples[segment.kind] = segment
    if len(examples) == len(target_types):
        break

plots = []

# 3. Fetch and plot
for glitch_type, segment in examples.items():
    print(f"Processing {glitch_type.name}...")
    
    # Fast internal cache lookup
    onsource, _, source = obtainer._get_sample_from_cache(
        cache,
        segment.transient_gps_time,
        2048.0,
        1.0, 
        16.0,
        gps_key=segment.gps_key,
        target_ifos=["L1"]
    )
    
    if onsource is not None:
        plot = gf.generate_strain_plot(
            strain={"L1": np.asarray(onsource[0])},
            sample_rate_hertz=2048.0,
            title=f"{glitch_type.name} (GPS: {segment.transient_gps_time:.1f})",
            has_legend=False,
            height=150, width=800
        )
        plots.append(plot)

if plots:
    layout = column(*plots)
    output_notebook()
    show(layout)

Generating One-of-Each Glitch Plots (Optimized)...


AttributeError: 'TransientDataObtainer' object has no attribute '_build_transient_index'