In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pickle
import seaborn as sns
from tabulate import tabulate
from astropy.time import Time
from tqdm import tqdm
from scipy.spatial.distance import pdist, squareform
import numpy as np

In [3]:
df_vlass = pd.read_pickle('/datax/scratch/ellambishop/new_hits_organized/vlass_other.pkl')
#df_vlass.groupby('tstart').size()
df_vlass.columns


Index(['id', 'beam_id', 'observation_id', 'tuning', 'subband_offset',
       'file_uri', 'file_local_enumeration', 'signal_frequency',
       'signal_index', 'signal_drift_steps', 'signal_drift_rate', 'signal_snr',
       'signal_coarse_channel', 'signal_beam', 'signal_num_timesteps',
       'signal_power', 'signal_incoherent_power', 'source_name', 'fch1_mhz',
       'foff_mhz', 'tstart', 'tsamp', 'ra_hours', 'dec_degrees',
       'telescope_id', 'num_timesteps', 'num_channels', 'coarse_channel',
       'start_channel'],
      dtype='object')

In [4]:
#clean out initial drift rates=0 and snr > 16
def set_filter(file, drift_max=0):
    # Load and slice dataset
    global df_new
    df = file
   # print(df.columns)
    #small_df = df[start:stop]

    # Select relevant columns
    columns = ['file_uri', 'observation_id','source_name', 'beam_id', 'ra_hours', 'dec_degrees', 'tstart',
               'signal_frequency', 'signal_beam', 'signal_drift_rate', 'signal_snr',
               'signal_power', 'signal_incoherent_power', 'signal_num_timesteps']
    df_new = df[columns]

    # Apply filtering thresholds
    df_new = df_new[(df_new['signal_drift_rate'] != drift_max)]
    df_new = df_new[(df_new['signal_snr'])<=16]
    df_new['signal_frequency'] = df_new['signal_frequency'].round(3)    
    
    print(df_new.groupby('signal_frequency').size())
    return df_new

#set_filter(df_vlass)


In [5]:
#rfi filtering based on time step- snr ranges and frequency overlap flagging
def processing(df):
    """
    Processes filtered DataFrame by grouping on (file_uri, observation_id),
    flags RFI based on conditions, and returns the flagged DataFrame.
    """
    # Apply initial filtering
    df_new = set_filter(df)

    # Group by file_uri and observation_id
    grouped = df_new.groupby(['file_uri', 'observation_id'])
    flagged_dfs = []

    for (file_uri, obs_id), group_df in grouped:
        group_df = group_df.copy()
        group_df['rfi_flag'] = False
        group_df['flag_strength'] = 0

        # Conditions to drop hits
        cond1 = (
            (group_df['signal_num_timesteps'] >= 16) &
            (group_df['signal_num_timesteps'] <= 64) &
            (group_df['signal_snr'] > 10)
        )
        cond2 = (
            (group_df['signal_num_timesteps'] < 16) &
            (group_df['signal_snr'] > 15)
        )

        drop_mask = cond1 | cond2
        group_df = group_df.loc[~drop_mask]

        # Flag overlapping frequencies at same tstart
        overlap_indices = []
        for t_value, t_df in group_df.groupby('tstart'):
            freq_counts = t_df['signal_frequency'].value_counts()
            overlapping_freqs = freq_counts[freq_counts > 1].index
            for freq in overlapping_freqs:
                hits = t_df[t_df['signal_frequency'] == freq]
                overlap_indices.extend(hits.index.tolist())

        group_df.loc[overlap_indices, 'rfi_flag'] = True
        group_df.loc[overlap_indices, 'flag_strength'] += 1

        
        flagged_dfs.append(group_df)

    full_df = pd.concat(flagged_dfs, ignore_index=True)

    return full_df
#processing()


In [6]:
full_df= processing(df_vlass)

signal_frequency
2290.822     3
2290.876     1
2290.877     2
2290.878     8
2290.880     3
            ..
3699.913     1
3699.915    30
3699.917     4
3699.919    49
3699.963    24
Length: 21234, dtype: int64


In [7]:
#convert time to dates and flag the different NRAO configurations 
from astropy.time import Time
import datetime

# Extended VLA configuration date ranges to fill gaps
vla_config_ranges = [
    (datetime.date(2023, 1, 19), datetime.date(2023, 5, 29), 'B'),
    (datetime.date(2023, 6, 2),  datetime.date(2023, 6, 19), 'BnA'),
    (datetime.date(2023, 6, 30), datetime.date(2023, 10, 2), 'A'),
    (datetime.date(2023, 10, 20), datetime.date(2024, 1, 15), 'D'),
    (datetime.date(2024, 1, 25), datetime.date(2024, 5, 7), 'C'),   # Gap filled
    (datetime.date(2024, 5, 8),  datetime.date(2024, 9, 16), 'B'),
    (datetime.date(2024, 9, 17), datetime.date(2024, 10, 7), 'BnA'), # Gap filled
    (datetime.date(2024, 10, 8), datetime.date(2025, 2, 3), 'A'),   # Gap filled
]

dates = []
configs = []

for t in full_df['tstart']:
    try:
        time = Time(t, format='mjd')
        date_str = time.to_value('iso', subfmt='date')
        date_obj = datetime.datetime.strptime(date_str, "%Y-%m-%d").date()
        dates.append(date_str)

        matched_config = 'Unknown'
        for start, end, config in vla_config_ranges:
            if start <= date_obj <= end:
                matched_config = config
                break
        configs.append(matched_config)

    except Exception as e:
        print(f"[!] Error parsing tstart={t}: {e}")
        dates.append('Invalid')
        configs.append('Unknown')

full_df['date'] = dates
full_df['config'] = configs

print(full_df['config'].value_counts())

config
B    163298
C     13754
Name: count, dtype: int64


In [8]:
# calculate beam size for each hit to use in beam overlap calculations

band_ranges_mhz = {
    'S':  (2000, 4000),   # 3.0 GHz
    'C':  (4000, 8000),   # 6.0 GHz
    'X':  (8000, 12000),  # 10 GHz
}

# Beam size per band and configuration (arcsec)
beam_resolutions = {
    'S':  {'D': 23,   'C': 7.0,  'B': 2.1,  'A': 0.65},
    'C':  {'D': 12,   'C': 3.5,  'B': 1.0,  'A': 0.33},
    'X':  {'D': 7.2,  'C': 2.1,  'B': 0.6,  'A': 0.20},
}

def get_band_mhz(freq_mhz):
    for band, (low, high) in band_ranges_mhz.items():
        if low <= freq_mhz < high:
            return band
    return None

def lookup_beam_size_mhz(freq_mhz, config):
    band = get_band_mhz(freq_mhz)
    if band and config in beam_resolutions[band]:
        return beam_resolutions[band][config]
    return None  # or np.nan if you're using pandas

full_df['beam_size_arcsec'] = full_df.apply(
    lambda row: lookup_beam_size_mhz(row['signal_frequency'], row['config']), axis=1
)



In [9]:
#flag overlapping signals in spacially separated beams as rfi, same file uri/fov 
def flag_local_rfi(df):
    from scipy.spatial import ConvexHull
    import numpy as np
    from astropy.coordinates import SkyCoord
    import astropy.units as u

    local_flag_count = 0
    unique_freqs = sorted(df['signal_frequency'].unique())

    for freq in unique_freqs:
        freq_subset = df[df['signal_frequency'] == freq]
        for fov in freq_subset['file_uri'].unique():
            subset = freq_subset[freq_subset['file_uri'] == fov]

            if len(subset['signal_beam'].unique()) <= 1 or len(subset) < 3:
                continue

            ra0, dec0 = np.mean(subset['ra_hours']), np.mean(subset['dec_degrees'])
            ra_offsets = (subset['ra_hours'] - ra0) * 3600 * np.cos(np.deg2rad(dec0))
            dec_offsets = (subset['dec_degrees'] - dec0) * 3600
            points = np.vstack([ra_offsets, dec_offsets]).T

            try:
                hull = ConvexHull(points)
                hit_area = hull.area
            except:
                hit_area = 0

            avg_beam_size = subset['beam_size_arcsec'].mean()
            beam_area = np.pi * (avg_beam_size / 2)**2
            threshold_area = 3 * beam_area

            if hit_area > threshold_area:
                df.loc[subset.index, 'rfi_flag_local'] = True
                local_flag_count += len(subset)

    return df, local_flag_count

full_df['rfi_flag_local'] = False
full_df, local_count = flag_local_rfi(full_df)


In [11]:
#global overlap beam frequency rfi 
from scipy.spatial import cKDTree
import numpy as np
import pandas as pd

from scipy.spatial import cKDTree
import numpy as np
import pandas as pd

def flag_temporal_persistence_rfi(df,
                                   frequency_tol=1.0,
                                   drift_tol=1.0,
                                   position_match_radius_arcsec=10,
                                   min_time_gap=10,  # seconds
                                   min_repeat_count=2,
                                   debug=False):
    """
    Flags signals as RFI if they appear at the same sky location (within arcsec),
    with same freq/drift (within tolerance), at different times (with ≥ min_time_gap),
    and occur at least min_repeat_count times.
    """
    df = df.copy()
    df['rfi_flag_global'] = False

    # Bin frequency and drift
    df['freq_bin'] = (df['signal_frequency'] / frequency_tol).round().astype(int)
    df['drift_bin'] = (df['signal_drift_rate'] / drift_tol).round().astype(int)

    # Convert time to datetime and positions to arcsec
    df['tstart'] = pd.to_datetime(df['tstart'])
    dec_mean = df['dec_degrees'].mean()
    df['ra_arcsec'] = df['ra_hours'] * 3600 * np.cos(np.deg2rad(dec_mean))
    df['dec_arcsec'] = df['dec_degrees'] * 3600

    flagged_indices = set()

    grouped = df.groupby(['freq_bin', 'drift_bin'])
    total_groups = len(grouped)
    if debug:
        print(f"Total freq/drift groups: {total_groups}")

    for count, ((f_bin, d_bin), group) in enumerate(grouped, 1):
        group_size = len(group)
        if debug and count % 10 == 0:
            print(f"[{count}/{total_groups}] Processing group freq_bin={f_bin}, drift_bin={d_bin}, size={group_size}")

        if group_size < min_repeat_count:
            continue

        coords = group[['ra_arcsec', 'dec_arcsec']].values
        times = group['tstart'].dt.to_pydatetime()
        tree = cKDTree(coords)

        for i in range(group_size):
            time_i = times[i]
            idx_neighbors = tree.query_ball_point(coords[i], r=position_match_radius_arcsec)

            repeat_times = []
            for j in idx_neighbors:
                if i == j:
                    continue
                time_j = times[j]
                time_diff = abs((time_i - time_j).total_seconds())
                if time_diff >= min_time_gap:
                    repeat_times.append(group.index[j])

            if len(set(repeat_times)) >= (min_repeat_count - 1):
                flagged_indices.add(group.index[i])
                flagged_indices.update(repeat_times)
                if debug:
                    print(f"  Flagging group idx {group.index[i]} with {len(set(repeat_times))} repeats (time gap >= {min_time_gap}s)")

    df.loc[list(flagged_indices), 'rfi_flag_global'] = True
    if debug:
        print(f"Total flagged hits: {len(flagged_indices)}")

    return df, len(flagged_indices)



full_df['rfi_flag_global'] = False


full_df, global_count = flag_temporal_persistence_rfi(full_df)


  times = group['tstart'].dt.to_pydatetime()
  times = group['tstart'].dt.to_pydatetime()
  times = group['tstart'].dt.to_pydatetime()
  times = group['tstart'].dt.to_pydatetime()
  times = group['tstart'].dt.to_pydatetime()
  times = group['tstart'].dt.to_pydatetime()
  times = group['tstart'].dt.to_pydatetime()
  times = group['tstart'].dt.to_pydatetime()
  times = group['tstart'].dt.to_pydatetime()
  times = group['tstart'].dt.to_pydatetime()
  times = group['tstart'].dt.to_pydatetime()
  times = group['tstart'].dt.to_pydatetime()
  times = group['tstart'].dt.to_pydatetime()
  times = group['tstart'].dt.to_pydatetime()
  times = group['tstart'].dt.to_pydatetime()
  times = group['tstart'].dt.to_pydatetime()
  times = group['tstart'].dt.to_pydatetime()
  times = group['tstart'].dt.to_pydatetime()
  times = group['tstart'].dt.to_pydatetime()
  times = group['tstart'].dt.to_pydatetime()
  times = group['tstart'].dt.to_pydatetime()
  times = group['tstart'].dt.to_pydatetime()
  times = 

In [12]:

# Combine flags into one final flag column:
full_df['rfi_flaging'] = full_df['rfi_flag_local'] | full_df['rfi_flag_global']

print(f"Local RFI flagged: {local_count}")
print(f"Global RFI flagged: {global_count}")
print(f"Total flagged: {full_df['rfi_flaging'].sum()}")


Local RFI flagged: 53914
Global RFI flagged: 0
Total flagged: 53914


In [14]:
# dividing up df based on flags 
full_df.groupby('flag_strength').size()
clean_df = full_df[full_df['flag_strength'] == 0]         # High-confidence real
maybe_rfi = full_df[full_df['flag_strength'] == 1]        # Weak RFI flag
strong_rfi = full_df[full_df['flag_strength'] >= 2]       # Remove from science set

clean_df.to_pickle('data/clean_df.pkl')
maybe_rfi.to_pickle('data/maybe_rfi.pkl')
strong_rfi.to_pickle('data/strong_rfi.pkl')