# jane_Array-MTR
written by Dr. Joachim Wassermann

In [None]:
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
from obspy import *
from obspy.core.inventory.inventory import Inventory
from obspy.core import AttribDict

import matplotlib.dates as mdates
import matplotlib.cm as cm


from obspy.clients.fdsn import Client
from obspy.signal.invsim import cosine_taper
import obspy_arraytools as AA
import os
import sys



client = Client("http://tarzan")
arraystats = ["XG.UP1..GLZ","XG.UP2..GLZ","XG.UP3..GLZ","XG.UP4..GLZ","XG.UP5..GLZ","XG.UP6..GLZ"]


ts= t = UTCDateTime("2024-03-22T05:00:00")
e = UTCDateTime("2024-03-22T08:00:00")
output_path = "./Grenzgletscher_fk"
figure_path = "./Grenzgletscher_fk/figure"
fl=1
fh=20.00
win_len=2.0 
win_frac=0.1
sll_x=-0.5
slm_x=0.5
sll_y=-0.5
slm_y=0.5
sl_s=0.025
thres_rel = 0.5

while (ts+3600) < e:
    start = (ts)
    end = (ts+3600)
    ts += 3600
    try:
        sz = Stream()
        inv= Inventory()
        i = 0
        for station in arraystats:
            net,stat,loc,chan=station.split('.')
            tr = client.get_waveforms(network=net,station=stat,location=loc,channel=chan, starttime=start, endtime=end)
            ii = client.get_stations(network=net,station=stat,location='',channel=chan, starttime=start, endtime=end,level="response")
            print(tr)
            sz += tr
            inv += ii
        sz.merge()
        sz.detrend("linear")
        sz.attach_response(inv)
        vc = sz.select(component="Z")
        array = AA.SeismicArray("",inv)
        array.inventory_cull(vc)
        print(array.center_of_gravity)
        outray = 0. 
        outray = array.fk_analysis(vc, frqlow=fl, frqhigh=fh, prefilter=True,\
                         static3d=False, array_response=False,vel_corr=4.8, wlen=win_len,\
                         wfrac=win_frac,sec_km=True,
                         slx=(sll_x,slm_x),sly=(sll_y,slm_y),
                         sls=sl_s)

        trace1 = Trace(data=outray.max_rel_power)
        trace1.stats.channel = 'REL'
        out = outray.max_rel_power

        trace2 = Trace(data=outray.max_abs_power)
        trace2.stats.channel = 'ABS'
        out = np.vstack([out,outray.max_abs_power])

        trace3 = Trace(data=outray.max_pow_baz)
        trace3.stats.channel = 'BACK'
        out = np.vstack([out,outray.max_pow_baz])

        trace4 = Trace(data=outray.max_pow_slow)
        trace4.stats.channel = 'SLOW'
        out = np.vstack([out,outray.max_pow_slow])

        #saving f-k analysis results into mseed file
        fk = Stream()
        tr = Trace()

        delta = outray.timestep

        tr.stats.network = outray.inventory.networks[0].code
        tr.stats.station = outray.inventory.networks[0][0].code
        tr.stats.channel = "ZGC"
        tr.stats.location = ""
        tr.data = outray.max_rel_power
        tr.stats.starttime = outray.starttime
        tr.stats.delta = delta

        fk += tr

        tr = Trace()
        tr.stats.network = outray.inventory.networks[0].code
        tr.stats.station = outray.inventory.networks[0][0].code
        tr.stats.channel = "ZGI"
        tr.stats.location = ""
        tr.stats.starttime = outray.starttime
        tr.data = outray.max_abs_power
        tr.stats.delta = delta

        fk += tr

        tr = Trace()
        tr.stats.network = outray.inventory.networks[0].code
        tr.stats.station = outray.inventory.networks[0][0].code
        tr.stats.channel = "ZGS"
        tr.stats.location = ""
        tr.stats.starttime = outray.starttime
        tr.data = outray.max_pow_baz
        tr.stats.delta = delta

        fk += tr

        tr = Trace()
        tr.stats.network = outray.inventory.networks[0].code
        tr.stats.station = outray.inventory.networks[0][0].code
        tr.stats.channel = "ZGA"
        tr.stats.location = ""
        tr.stats.starttime = outray.starttime
        tr.data = outray.max_pow_slow
        tr.stats.delta = delta

        fk += tr

        myday = "%03d"%fk[0].stats.starttime.julday

        pathyear = str(fk[0].stats.starttime.year)
        # open catalog file in read and write mode in case we are continuing d/l,
        # so we can append to the file
        mydatapath = os.path.join(output_path, pathyear)

        # create datapath 
        if not os.path.exists(mydatapath):
            os.mkdir(mydatapath)

        mydatapath = os.path.join(mydatapath, fk[0].stats.network)
        if not os.path.exists(mydatapath):
            os.mkdir(mydatapath)

        mydatapath = os.path.join(mydatapath, fk[0].stats.station)

        # create datapath 
        if not os.path.exists(mydatapath):
                os.mkdir(mydatapath)


        for tr in fk:
            print("saving to " + mydatapath)
            print(tr)
            mydatapathchannel = os.path.join(mydatapath,tr.stats.channel + ".D")

            if not os.path.exists(mydatapathchannel):
                os.mkdir(mydatapathchannel)

            netFile = tr.stats.network + "." + tr.stats.station +  "." + tr.stats.location + "." + tr.stats.channel+ ".D." + pathyear + "." + myday
            netFileout = os.path.join(mydatapathchannel, netFile)

            # try to open File
            print(netFileout)
            try:
                netFileout = open(netFileout, 'ab')
            except:
                netFileout = open(netFileout, 'w')
            tr.write(netFileout , format='MSEED',encoding="FLOAT64")
            netFileout.close()

        #print(outray)

        # Plot FK
        labels = ['ref','rel.power', 'abs.power', 'baz', 'slow']
        xlocator = mdates.AutoDateLocator()
        fig = plt.figure()
        alphas = out[0,:]
        condition1 = (out[0,:] < thres_rel)
        condition2 = (out[3,:] > 0.4) 
        tt = np.ma.masked_array(fk[0].times("matplotlib"),mask=condition1)
        tt = np.ma.masked_array(tt,mask=condition2)
        axis = []

        for i, lab in enumerate(labels):
            try:
                if i == 0:
                    ax = fig.add_subplot(5, 1, i + 1,sharex=None)
                    ax.plot(vc[0].times("matplotlib"),vc[0].data)
                else:
                    ax = fig.add_subplot(5, 1, i + 1,sharex=axis[0])
                    mask_v = np.ma.masked_array(out[i-1,:],mask=condition1)
                    mask_v = np.ma.masked_array(mask_v,mask=condition2)
                    ax.scatter(tt,mask_v, c=out[0,:], alpha=alphas,
                       edgecolors='none', cmap=cm.viridis_r)
                    ax.set_ylabel(lab)
                    ax.set_ylim(mask_v.min()-0.1, mask_v.max()+0.1)
                    ax.xaxis.set_major_locator(xlocator)
                    ax.xaxis.set_major_formatter(mdates.AutoDateFormatter(xlocator))
                axis.append(ax)
            except Exception as er:
                sys.stderr.write("Error:" + str(er))
                traceback.print_exc()
        fig.suptitle( 'jane-fk %s' % ( start ))
        fig.autofmt_xdate()
        fig.subplots_adjust(left=0.15, top=0.95, right=0.95, bottom=0.2, hspace=0)
        plt.savefig("%s/FK-%s.png"%(figure_path,start.strftime('%Y-%m-%dT%H')))
        #plt.show()
        plt.close("all")
    except:
        continue

# FK_trig
written by Dr. Joachim Wassermann, adapted by Nicole Katrin Richels

In [None]:
from obspy import *
from obspy.clients.filesystem import sds
from obspy.clients.fdsn import Client
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.colorbar import ColorbarBase
from matplotlib.colors import Normalize
import numpy as np
import logging
import os
from datetime import timedelta
import matplotlib as mpl
from matplotlib.gridspec import GridSpec

# Set better default style for matplotlib
plt.style.use('seaborn-v0_8')
mpl.rcParams['axes.facecolor'] = 'white'
mpl.rcParams['figure.facecolor'] = 'white'
mpl.rcParams['font.family'] = 'sans-serif'

# Add logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

out_file = "./Grenzgletscher_fk/fk_trigger.csv"
root_dir = "./Grenzgletscher_fk/"
save_path = "./Grenzgletscher_fk/trig_figs/"

# Create directory if it doesn't exist
os.makedirs(save_path, exist_ok=True)

cl = sds.Client(sds_root=root_dir)

ostart = start = UTCDateTime(2024, 3, 19)
end = UTCDateTime(2024, 3, 23)
trig_times = []

# Debug info
logger.info(f"Starting trigger detection from {ostart} to {end}")
trigger_count = 0

# Thresholds for detection
relp_threshold = 0.6
slow_threshold = 0.5
min_trigger_separation = 2  # Minimum seconds between triggers

last_trigger_abs_time = None

while start + 1*3600 < end:
    endy = start + 1*3600
    try:
        logger.info(f"Processing window: {start} - {endy}")
        st = cl.get_waveforms(network="XG", station="UP1", location="", channel="ZG?", starttime=start, endtime=endy)
        st.merge()
        
        # Check for required channels
        channels = [tr.stats.channel for tr in st]
        if "ZGC" not in channels or "ZGA" not in channels:
            logger.warning(f"Missing required channels. Available: {channels}")
            start += 1*3600
            continue
        
        relp = st.select(channel="ZGC")[0]
        slow = st.select(channel="ZGA")[0]
        
        window_triggers = 0
        
        for i in range(1, relp.stats.npts):
            # Check if conditions are met
            if relp.data[i] > relp_threshold and slow.data[i] < slow_threshold:
                # Check for state transition
                if i > 5 and (
                    (np.mean(relp.data[i-1:i]) < relp_threshold) or 
                    (np.mean(slow.data[i-1:i]) > slow_threshold)
                ):
                    trig_time = relp.times(reftime=ostart)[i]
                    abs_time = ostart + trig_time
                    
                    # Check minimum separation
                    if last_trigger_abs_time is None or (abs_time - last_trigger_abs_time) > min_trigger_separation:
                        trig_times.append(trig_time)
                        logger.info(f"Trigger detected at index {i}: relative time={trig_time}, absolute time={abs_time}")
                        window_triggers += 1
                        last_trigger_abs_time = abs_time
                    else:
                        logger.info(f"Skipping close trigger at {abs_time} (too close to previous)")
        
        logger.info(f"Found {window_triggers} triggers in this window")
        trigger_count += window_triggers
        start += 1*3600
    except Exception as e:
        logger.error(f"Error processing window {start}-{endy}: {e}")
        start += 1*3600
        continue

logger.info(f"Total triggers detected: {trigger_count}")
logger.info(f"Writing triggers to {out_file}")

# Add validation before writing to CSV
valid_triggers = []
for j in trig_times:
    trigger_time = ostart + j
    valid_triggers.append(trigger_time)

# Write triggers to file
with open(out_file, "w") as fo:
    for trigger_time in valid_triggers:
        fo.write("%s\n" % trigger_time)

logger.info(f"Written {len(valid_triggers)} triggers to CSV file")

# Process individual events
all_baz = []
all_slow = []
all_rp = []
clw = Client("http://tarzan.geophysik.uni-muenchen.de")

# Define the global colormap
global_cmap = plt.cm.viridis

# Define wave velocities for incidence angle calculation
p_velocity = 3.8  # km/s for P-waves
s_velocity = 1.8  # km/s for S-waves

logger.info(f"Processing {len(trig_times)} individual events")

#Track successful plot creation
successful_plots = 0
failed_plots = 0

for j_idx, j in enumerate(trig_times):
    try:
        event_time = ostart + j
        logger.info(f"Processing event {j_idx+1}/{len(trig_times)} at {event_time}")
        
        # Get waveform data with a window around the event time
        try:
            st = clw.get_waveforms(network="XG", station="UP1", location="", channel="??Z", 
                                   starttime=(event_time-1), endtime=event_time+10)
            ar = cl.get_waveforms(network="XG", station="UP1", location="", channel="ZG?", 
                                  starttime=(event_time-1), endtime=event_time+10)
        except Exception as data_error:
            logger.error(f"Failed to get waveform data for event at {event_time}: {data_error}")
            failed_plots += 1
            continue
        
        # Check if valid data
        if len(st) == 0 or len(ar) == 0:
            logger.warning(f"No data found for event at {event_time}, skipping")
            failed_plots += 1
            continue
            
        # Check for required channels
        ar_channels = [tr.stats.channel for tr in ar]
        if "ZGC" not in ar_channels or "ZGA" not in ar_channels or "ZGS" not in ar_channels:
            logger.warning(f"Missing required array channels for event at {event_time}, skipping. Available: {ar_channels}")
            failed_plots += 1
            continue
            
        # Check if there's a valid vertical component
        if not st.select(component="Z"):
            logger.warning(f"No vertical component found for event at {event_time}, skipping")
            failed_plots += 1
            continue
            
        # Process waveforms
        st.detrend("linear")
        st.taper(type='cosine', max_percentage=0.05)
        st.filter("bandpass", freqmin=1, freqmax=20)

        # Extract data
        if ar.select(channel="ZGC"):
            rel_power = ar.select(channel="ZGC")[0].data
            all_rp.append(rel_power)
        else:
            logger.warning(f"Missing ZGC channel for event at {event_time}")
            failed_plots += 1
            continue
            
        if ar.select(channel="ZGS"):
            baz = ar.select(channel="ZGS")[0].data
            all_baz.append(baz)
        else:
            logger.warning(f"Missing ZGS channel for event at {event_time}")
            failed_plots += 1
            continue
            
        if ar.select(channel="ZGA"):
            slow = ar.select(channel="ZGA")[0].data
            all_slow.append(slow)
        else:
            logger.warning(f"Missing ZGA channel for event at {event_time}")
            failed_plots += 1
            continue
            
        # Calculate incidence angles based on slowness values
        incidence_angles = []
        wave_types = []
        
        for s in slow:
            if s < 0.3:  # P-wave region
                sin_i = min(p_velocity * s, 0.99)
                angle = np.degrees(np.arcsin(sin_i))
                wave_type = "P"
            elif s <= 0.6:  # S-wave region
                if s_velocity * s > 0.99:
                    # Scale between 60-85 degrees based on the slowness value
                    angle = 60 + 25 * (s - 0.2) / 0.3
                else:
                    sin_i = min(s_velocity * s, 0.99)
                    angle = np.degrees(np.arcsin(sin_i))
                wave_type = "S"
            else:
                # For values outside velocity model assumptions
                angle = np.nan
                wave_type = "Unknown"
                
            incidence_angles.append(angle)
            wave_types.append(wave_type)
        
        # Convert to numpy arrays
        incidence_angles = np.array(incidence_angles)
        wave_types = np.array(wave_types)
        
        # Close any existing figures
        plt.close('all')
        
        # Create figure
        fig = plt.figure(figsize=(12, 16), dpi=100)
        fig.suptitle(f"Seismic Event Analysis - {event_time.strftime('%Y-%m-%d %H:%M:%S')}", 
                    fontsize=16, fontweight='bold', y=0.98)
        
        # Create a grid layout with 3 rows and 2 columns
        gs = plt.GridSpec(3, 2, figure=fig, height_ratios=[1, 1, 1], width_ratios=[1, 1],
                          hspace=0.35, wspace=0.35)
        
        common_time_limits = None
        if st.select(component="Z"):
            vert_tr = st.select(component="Z")[0]
            time_data = vert_tr.times("matplotlib")
            
            # Get time limits for alignment
            end_time = max(time_data)
            start_time = min(time_data)
            
            # Common time limits for all time-based plots
            common_time_limits = [start_time, end_time]
            
            # Create the time locator for all plots
            seconds_locator = mdates.SecondLocator(interval=1)
            seconds_formatter = mdates.DateFormatter('%H:%M:%S')
        else:
            logger.warning("No vertical component data for establishing time limits")
            failed_plots += 1
            continue
        
        # PLOT 1: Seismogram - Row 1, Col 1 (Upper Left)
        axtrace = fig.add_subplot(gs[0, 0])
        
        # Plot the vertical component data
        if st.select(component="Z"):
            axtrace.plot(time_data, vert_tr.data, 'k', linewidth=1.2)
            axtrace.ticklabel_format(axis='y', style='sci', scilimits=(-2,2))
            axtrace.set_ylabel('Amplitude [mm/s]', color='k', fontsize=11, fontweight='bold')
            axtrace.set_title('Vertical Component Waveform', fontsize=12, fontweight='bold', pad=10)
            
            # Set x-axis ticks every second
            axtrace.xaxis.set_major_locator(seconds_locator)
            axtrace.xaxis.set_major_formatter(seconds_formatter)
            
            # Set time limits
            axtrace.set_xlim(common_time_limits)
            
            # Enhance grid for seismogram
            axtrace.grid(True, which='both', axis='x', color='gray', alpha=0.5, linestyle='-')
            axtrace.grid(True, which='major', axis='y', color='gray', alpha=0.5, linestyle='-')
            
            # Rotate time labels
            plt.setp(axtrace.xaxis.get_majorticklabels(), rotation=45, ha='right')
        else:
            logger.warning("No vertical component data for trace plot")
            failed_plots += 1
            continue
            
        # PLOT 2: Polar Plot - Row 1, Col 2 (Upper Right)
        polar_ax = fig.add_subplot(gs[0, 1], projection='polar')
        
        # Check for the necessary data
        if len(baz) > 0 and len(slow) > 0 and len(rel_power) > 0:
            # Convert backazimuth to radians
            baz_rad = np.radians(baz)
            baz_rad[baz_rad < 0] += 2*np.pi
            baz_rad[baz_rad > 2*np.pi] -= 2*np.pi
            
            # Create 2D histogram for polar plot
            N = int(360./5.)  # 5-degree bins
            abins = np.arange(N + 1) * 2*np.pi / N
            sbins = np.linspace(0, 0.4, 20) 
            
            hist, baz_edges, sl_edges = np.histogram2d(baz_rad, slow, bins=[abins, sbins], weights=rel_power)
            
            # Create meshgrid for pcolormesh
            A, S = np.meshgrid(abins, sbins)

            polar_ax.set_theta_zero_location("N")
            polar_ax.set_theta_direction(-1)
            
            # Use pcolormesh for polar plot with the global colormap
            pcm = polar_ax.pcolormesh(A, S, hist.T, cmap=global_cmap, alpha=0.7, shading='auto')
            
            # Improve polar plot settings
            polar_ax.grid(True, linewidth=1.5)
            
            # Add radial labels
            polar_ax.set_rticks([0.1, 0.2, 0.3, 0.4])
            polar_ax.set_rlabel_position(135)
            polar_ax.set_rmax(0.4)
            polar_ax.set_title('Polar Plot: Backazimuth vs. Slowness', fontsize=12, fontweight='bold', pad=15)
        else:
            logger.warning("Missing data for polar plot")
            
        # PLOT 3: Spectrogram - Row 2, Col 1 (Middle Left)
        axspec = fig.add_subplot(gs[1, 0])

        # Get the vertical component data for spectrogram
        if st.select(component="Z"):
            tr = st.select(component="Z")[0]
    
            try:
                # Calculate spectrogram
                specgram = tr.spectrogram(wlen=0.5, per_lap=0.9, show=False, axes=axspec)
        
                # Limit frequency range
                axspec.set_ylim(1, 25)  # Limit frequency to 1-25 Hz
        
                # Set labels and grid
                axspec.set_ylabel('Frequency [Hz]', fontsize=11, fontweight='bold')
                
                # Clear the current x-axis labels and ticks
                axspec.set_xticklabels([])
                axspec.set_xticks([])
                
                # Create secondary axis that matches seismogram time
                ax2 = axspec.twiny()
                ax2.set_xlim(common_time_limits)
                ax2.xaxis.set_major_locator(seconds_locator)
                ax2.xaxis.set_major_formatter(seconds_formatter)
                ax2.xaxis.tick_bottom()
                ax2.xaxis.set_label_position('bottom')
                ax2.tick_params(axis='x', pad=10)
                
                # Rotate time labels
                plt.setp(ax2.xaxis.get_majorticklabels(), rotation=45, ha='right')
                
                # Add horizontal grid lines
                axspec.grid(True, which='major', axis='y', color='gray', alpha=0.01, linestyle='-')
                
                # Add vertical grid lines
                for pos in ax2.get_xticks():
                    axspec.axvline(pos, color='gray', alpha=0.01, linestyle='-')
                
                # Set title above the plot
                axspec.set_title('Spectrogram', fontsize=12, fontweight='bold', pad=10)
                
            except Exception as e:
                logger.error(f"Error creating spectrogram: {e}")
        else:
            logger.warning("No vertical component data for spectrogram")
            
        # PLOT 4: Incidence Angle - Row 2, Col 2 (Middle Right)
        axangle = fig.add_subplot(gs[1, 1])
        
        # Only plot valid incidence angles
        valid_mask = ~np.isnan(incidence_angles)
        if any(valid_mask):
            # Get matplotlib times for the incidence angle data
            angle_times = ar.select(channel="ZGA")[0].times("matplotlib")
            
            # Calculate point sizes based on relative power if not already defined
            rel_power_norm = rel_power / np.max(rel_power) if np.max(rel_power) > 0 else np.zeros_like(rel_power)
            sizes = 20 + 100 * rel_power_norm
            
            # Use the same marker sizing and coloring scheme as the other plots
            scatter_angle = axangle.scatter(
                angle_times[valid_mask], 
                incidence_angles[valid_mask],
                c=rel_power[valid_mask], 
                cmap=global_cmap, 
                s=sizes[valid_mask], 
                alpha=0.7
            )
            
            # Add markers to indicate wave type
            p_mask = np.logical_and(valid_mask, np.array(wave_types) == "P")
            if any(p_mask):
                axangle.scatter(
                    angle_times[p_mask], 
                    incidence_angles[p_mask],
                    s=30, alpha=0.7, facecolors='none', edgecolors='blue',
                    linewidth=1.5, marker='o', label='P-wave'
                )
                
            s_mask = np.logical_and(valid_mask, np.array(wave_types) == "S")
            if any(s_mask):
                axangle.scatter(
                    angle_times[s_mask], 
                    incidence_angles[s_mask],
                    s=30, alpha=0.7, facecolors='none', edgecolors='red',
                    linewidth=1.5, marker='s', label='S-wave'
                )
            
            axangle.set_ylabel('Incidence Angle [deg]', fontsize=10, fontweight='bold')
            axangle.set_title('Incidence Angle vs. Time', fontsize=11, fontweight='bold', pad=10)
            
            # Set x-axis ticks every second
            axangle.xaxis.set_major_locator(seconds_locator)
            axangle.xaxis.set_major_formatter(seconds_formatter)
            
            # Set time limits to match seismogram
            axangle.set_xlim(common_time_limits)
            
            # Set reasonable y-limits for the plot
            axangle.set_ylim(0, 90)
            
            # Enhanced grid with lines every second
            axangle.grid(True, which='major', axis='both', color='gray', alpha=0.5, linestyle='-')
            axangle.legend(loc='upper right', fontsize=9)
            
            # Rotate time labels
            plt.setp(axangle.xaxis.get_majorticklabels(), rotation=45, ha='right')
        else:
            logger.warning("No valid incidence angle data for plot")
        
        # PLOT 5: Backazimuth - Row 3, Col 1 (Lower Left)
        axbaz = fig.add_subplot(gs[2, 0])
        
        # Get matplotlib times for the backazimuth data
        if ar.select(channel="ZGS") and len(ar.select(channel="ZGS")[0].data) > 0:
            baz_times = ar.select(channel="ZGS")[0].times("matplotlib")
            
            # Check for matching data lengths
            if len(baz_times) == len(baz) and len(baz) == len(rel_power):
                scatter_baz = axbaz.scatter(baz_times, baz, 
                           c=rel_power, cmap=global_cmap, s=sizes, alpha=0.7)
                axbaz.set_ylabel('Backazimuth [deg]', fontsize=11, fontweight='bold')
                axbaz.set_xlabel('Time (UTC)', fontsize=11, fontweight='bold')
                axbaz.set_ylim(0, 360)
                axbaz.set_yticks([0, 90, 180, 270, 360])
                axbaz.set_title('Backazimuth vs. Time', fontsize=12, fontweight='bold', pad=10)
                
                # Set x-axis ticks every second
                axbaz.xaxis.set_major_locator(seconds_locator)
                axbaz.xaxis.set_major_formatter(seconds_formatter)
                
                # Set time limits to match seismogram
                axbaz.set_xlim(common_time_limits)
                
                # Enhanced grid with lines
                axbaz.grid(True, which='major', axis='both', color='gray', alpha=0.5, linestyle='-')
                
                # Rotate time labels
                plt.setp(axbaz.xaxis.get_majorticklabels(), rotation=45, ha='right')
            else:
                logger.warning(f"Data length mismatch in backazimuth plot")
        else:
            logger.warning("No backazimuth data available for plot")
        
        # PLOT 6: Slowness - Row 3, Col 2 (Lower Right)
        axslow = fig.add_subplot(gs[2, 1])
        
        # Add horizontal line for the threshold
        axslow.axhline(y=slow_threshold, color='r', linestyle='--', alpha=0.7, 
                       label=f'Threshold ({slow_threshold})')
        
        # Get matplotlib times for the slowness data
        if ar.select(channel="ZGA") and len(ar.select(channel="ZGA")[0].data) > 0:
            slow_times = ar.select(channel="ZGA")[0].times("matplotlib")
            
            # Check for matching data lengths
            if len(slow_times) == len(slow) and len(slow) == len(rel_power):
                # Use scatter for slowness, sized by rel_power like backazimuth
                scatter_slow = axslow.scatter(slow_times, slow, 
                            c=rel_power, cmap=global_cmap, s=sizes, alpha=0.7)
                
                axslow.set_ylabel('Slowness [s/km]', fontsize=11, fontweight='bold')
                axslow.set_xlabel('Time (UTC)', fontsize=11, fontweight='bold')
                axslow.set_title('Slowness vs. Time', fontsize=12, fontweight='bold', pad=10)
                
                # Set x-axis ticks every second
                axslow.xaxis.set_major_locator(seconds_locator)
                axslow.xaxis.set_major_formatter(seconds_formatter)
                
                # Set time limits to match seismogram
                axslow.set_xlim(common_time_limits)
                
                # Set y-axis limits
                axslow.set_ylim(0, max(1.0, np.max(slow)*1.1))
                
                # Enhanced grid with lines
                axslow.grid(True, which='major', axis='both', color='gray', alpha=0.5, linestyle='-')
                
                axslow.legend(loc='upper right', fontsize=9)
                
                # Rotate time labels
                plt.setp(axslow.xaxis.get_majorticklabels(), rotation=45, ha='right')
            else:
                logger.warning(f"Data length mismatch in slowness plot")
        else:
            logger.warning("No slowness data available for plot")
        
        # Colorbar for all plots using the same colormap
        if 'pcm' in locals():
            cax = fig.add_axes([0.93, 0.3, 0.02, 0.4])  # Position for vertical colorbar
            cbar = fig.colorbar(pcm, cax=cax)
            cbar.set_label('Relative Power', fontsize=10, fontweight='bold')
        
        # Save figure
        fmt = "png"
        filename = f'{save_path}UP1_{event_time.strftime("%Y%m%d_%H%M%S")}_array.{fmt}'
        

        try:
            plt.savefig(filename, format=fmt, dpi=300)
            logger.info(f"Saved figure: {filename}")
            successful_plots += 1
        except Exception as save_error:
            logger.error(f"Failed to save figure {filename}: {save_error}")
            failed_plots += 1
        
        plt.close("all")
        
    except Exception as e:
        logger.error(f"Error processing event at {event_time}: {e}")
        failed_plots += 1
        plt.close("all")  # Make sure to close all figures even in case of error

logger.info(f"Processing complete! Successfully created {successful_plots} plots, {failed_plots} failed")
logger.info(f"Summary: {len(trig_times)} triggers detected, {len(valid_triggers)} written to CSV, {successful_plots} plots created")

# Feature Extraction

In [None]:
from obspy import UTCDateTime
from obspy.clients.filesystem import sds
import numpy as np
import logging
import os
import pandas as pd
from scipy import signal
from scipy.interpolate import interp1d
from scipy.stats import kurtosis, skew
from scipy.signal import hilbert
import warnings
from obspy.clients.fdsn import Client
warnings.filterwarnings("ignore")

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Define parameters
root_dir = "./Grenzgletscher_fk/"
input_file = "./Grenzgletscher_fk/2025_fk_trigger.csv"
output_file = "./Grenzgletscher_fk/2025_features.csv"
window_before = 1    # seconds before trigger
window_after = 5     # seconds after trigger

# Velocity model parameters
p_velocity = 3.8  # km/s for P-waves
s_velocity = 1.8  # km/s for S-waves

cl = sds.Client(sds_root=root_dir)
clw = Client("http://tarzan.geophysik.uni-muenchen.de")

def calculate_incidence_angles_from_slowness(slowness_data):
    """Calculate incidence angles from slowness values using velocity model"""
    if slowness_data is None or len(slowness_data) == 0:
        return None, None
    
    incidence_angles = []
    wave_types = []
    
    for s in slowness_data:
        if np.isnan(s):
            incidence_angles.append(np.nan)
            wave_types.append("Unknown")
            continue
            
        if s < 0.3:  # P-wave region
            sin_i = min(p_velocity * s, 0.99)
            angle = np.degrees(np.arcsin(sin_i))
            wave_type = "P"
        elif s <= 0.6:  # S-wave region
            if s_velocity * s > 0.99:
                # Scale between 60-85 degrees based on the slowness value
                angle = 60 + 25 * (s - 0.2) / 0.3
            else:
                sin_i = min(s_velocity * s, 0.99)
                angle = np.degrees(np.arcsin(sin_i))
            wave_type = "S"
        else:
            # For values outside velocity model assumptions
            angle = np.nan
            wave_type = "Unknown"
            
        incidence_angles.append(angle)
        wave_types.append(wave_type)
    
    return np.array(incidence_angles), np.array(wave_types)

def map_array_to_waveform_indices(array_data, array_sr, waveform_sr, waveform_length):
    # Map array processing data to waveform sample indices using interpolation
    if array_data is None or len(array_data) == 0:
        return None
    
    try:
        array_times = np.arange(len(array_data)) / array_sr
        waveform_times = np.arange(waveform_length) / waveform_sr
        
        if len(array_data) == 1:
            return np.full(waveform_length, array_data[0])
        
        interp_func = interp1d(array_times, array_data, kind='linear', 
                              fill_value='extrapolate', bounds_error=False)
        mapped_data = interp_func(waveform_times)
        
        # Handle NaN values
        if np.any(np.isnan(mapped_data)):
            mask = ~np.isnan(mapped_data)
            if np.any(mask):
                first_valid = np.where(mask)[0][0]
                mapped_data[:first_valid] = mapped_data[first_valid]
                for i in range(1, len(mapped_data)):
                    if np.isnan(mapped_data[i]):
                        mapped_data[i] = mapped_data[i-1]
        
        return mapped_data
        
    except Exception as e:
        logger.warning(f"Error in array data mapping: {e}")
        return np.full(waveform_length, np.mean(array_data))

def analyze_slowness_during_signal(ar_slowness, ar_back_azimuth, ar_incidence_angle, 
                                 envelope, trigger_sample, sample_rate):

   # Analyze slowness changes from trigger until amplitude drops below 2% of max.
    features = {}
    
    if ar_slowness is None or len(ar_slowness) == 0:
        # No slowness data available
        for key in ['slowness_min_during_signal', 'slowness_mean_during_signal', 
                   'slowness_std_during_signal', 'slowness_change_during_signal',
                   'back_azimuth_at_min_slowness', 'incidence_angle_at_min_slowness']:
            features[key] = np.nan
        return features
    
    # Find signal activity window
    envelope_max = np.max(envelope)
    threshold = envelope_max * 0.02  # 2% threshold
    
    # Find start and end of signal activity
    above_threshold = envelope > threshold
    if not np.any(above_threshold):
        # No significant signal
        for key in ['slowness_min_during_signal', 'slowness_mean_during_signal', 
                   'slowness_std_during_signal', 'slowness_change_during_signal',
                   'back_azimuth_at_min_slowness', 'incidence_angle_at_min_slowness']:
            features[key] = np.nan
        return features
    
    signal_indices = np.where(above_threshold)[0]
    signal_start = max(0, signal_indices[0])
    signal_end = min(len(ar_slowness), signal_indices[-1] + 1)
    
    # Extract slowness during signal activity
    slowness_during_signal = ar_slowness[signal_start:signal_end]
    
    # Remove NaN values
    valid_mask = ~np.isnan(slowness_during_signal)
    if not np.any(valid_mask):
        for key in ['slowness_min_during_signal', 'slowness_mean_during_signal', 
                   'slowness_std_during_signal', 'slowness_change_during_signal',
                   'back_azimuth_at_min_slowness', 'incidence_angle_at_min_slowness']:
            features[key] = np.nan
        return features
    
    valid_slowness = slowness_during_signal[valid_mask]
    valid_indices = np.where(valid_mask)[0] + signal_start
    
    # Slowness features during signal
    features['slowness_min_during_signal'] = np.min(valid_slowness)
    features['slowness_mean_during_signal'] = np.mean(valid_slowness)
    features['slowness_std_during_signal'] = np.std(valid_slowness)
    features['slowness_change_during_signal'] = np.max(valid_slowness) - np.min(valid_slowness)
    
    # Find index of minimum slowness
    min_slowness_idx_in_valid = np.argmin(valid_slowness)
    min_slowness_global_idx = valid_indices[min_slowness_idx_in_valid]
    
    # Back azimuth at minimum slowness
    if (ar_back_azimuth is not None and len(ar_back_azimuth) > min_slowness_global_idx):
        features['back_azimuth_at_min_slowness'] = ar_back_azimuth[min_slowness_global_idx]
    else:
        features['back_azimuth_at_min_slowness'] = np.nan
    
    # Incidence angle at minimum slowness
    if (ar_incidence_angle is not None and len(ar_incidence_angle) > min_slowness_global_idx):
        features['incidence_angle_at_min_slowness'] = ar_incidence_angle[min_slowness_global_idx]
    else:
        features['incidence_angle_at_min_slowness'] = np.nan
    
    return features

def extract_features(trace, ar_back_azimuth, ar_incidence_angle, ar_slowness, trigger_time):
    # Extract all required features from the data
    
    if trace is None or not hasattr(trace, 'data') or len(trace.data) == 0:
        return {}
    
    data = trace.data.astype(np.float64)
    data = data - np.mean(data)  # Remove DC
    sample_rate = trace.stats.sampling_rate
    features = {}
    
    # Basic waveform features
    features['waveform_rms'] = np.sqrt(np.mean(np.square(data)))
    features['waveform_peak_to_peak'] = np.max(data) - np.min(data)
    features['waveform_abs_energy'] = np.sum(np.abs(data))
    features['waveform_max_abs'] = np.max(np.abs(data))
    features['waveform_mean_abs'] = np.mean(np.abs(data))
    features['waveform_skewness'] = skew(data)
    features['waveform_kurtosis'] = kurtosis(data)
    features['waveform_std'] = np.std(data)
    
    # Zero crossing rate
    zero_crossings = np.where(np.diff(np.signbit(data)))[0]
    features['zero_crossing_rate'] = len(zero_crossings) / (len(data) / sample_rate)
    
    # Spectral features
    if len(data) >= 4:
        windowed_data = data * signal.windows.hann(len(data))
        fft = np.abs(np.fft.rfft(windowed_data))
        freqs = np.fft.rfftfreq(len(data), d=1.0/sample_rate)
        
        if len(fft) > 1:
            fft = fft[1:]  # Remove DC
            freqs = freqs[1:]
        
        if len(fft) > 0 and np.sum(fft) > 1e-10:
            total_power = np.sum(fft)
            features['spectral_centroid'] = np.sum(freqs * fft) / total_power
            features['dominant_freq'] = freqs[np.argmax(fft)]
            features['spectral_spread'] = np.sqrt(np.sum(((freqs - features['spectral_centroid']) ** 2) * fft) / total_power)
            
            # Spectral rolloff
            cumulative_power = np.cumsum(fft)
            rolloff_idx = np.where(cumulative_power >= 0.85 * total_power)[0]
            features['spectral_rolloff'] = freqs[rolloff_idx[0]] if len(rolloff_idx) > 0 else freqs[-1]
            
            # Spectral flatness
            geometric_mean = np.exp(np.mean(np.log(fft + 1e-10)))
            features['spectral_flatness'] = geometric_mean / np.mean(fft)
            
            # Frequency band ratios
            fft_power = fft**2
            total_power_squared = np.sum(fft_power)
            for low, high in [(1, 5), (5, 10), (10, 15), (15, 20)]:
                band_indices = np.logical_and(freqs >= low, freqs <= high)
                if np.any(band_indices):
                    band_energy = np.sum(fft_power[band_indices])
                    features[f'spec_band_{low}_{high}_ratio'] = band_energy / total_power_squared
                else:
                    features[f'spec_band_{low}_{high}_ratio'] = 0.0
        else:
            # Default spectral features
            for key in ['spectral_centroid', 'dominant_freq', 'spectral_spread', 'spectral_rolloff', 'spectral_flatness']:
                features[key] = 0.0
            for low, high in [(1, 5), (5, 10), (10, 15), (15, 20)]:
                features[f'spec_band_{low}_{high}_ratio'] = 0.0
    
    # Envelope and duration features
    try:
        analytic_signal = hilbert(data)
        envelope = np.abs(analytic_signal)
        
        # Smooth envelope
        from scipy.ndimage import uniform_filter1d
        smooth_window = max(1, int(sample_rate * 0.05))
        smoothed_envelope = uniform_filter1d(envelope, size=smooth_window)
        
        envelope_max = np.max(smoothed_envelope)
        
        if envelope_max > 0:
            # Duration features with 5% and 2% thresholds
            for threshold, name in [(0.05, '5_percent'), (0.02, '2_percent')]:
                envelope_threshold = envelope_max * threshold
                above_threshold = smoothed_envelope > envelope_threshold
                
                if np.any(above_threshold):
                    duration_indices = np.where(above_threshold)[0]
                    duration_sec = (duration_indices[-1] - duration_indices[0]) / sample_rate
                    features[f'duration_{name}'] = duration_sec
                else:
                    features[f'duration_{name}'] = 0.0
            
            # Rise time
            above_start = smoothed_envelope > (envelope_max * 0.05)
            above_end = smoothed_envelope > (envelope_max * 0.8)
            
            if np.any(above_start) and np.any(above_end):
                rise_start_idx = np.where(above_start)[0][0]
                rise_end_idx = np.where(above_end)[0][0]
                features['envelope_rise_time'] = max(0, (rise_end_idx - rise_start_idx) / sample_rate)
            else:
                features['envelope_rise_time'] = 0.0
            
            # Peak position and envelope shape
            peak_idx = np.argmax(smoothed_envelope)
            features['peak_position_normalized'] = peak_idx / len(smoothed_envelope)
            features['envelope_skewness'] = skew(smoothed_envelope)
            features['envelope_kurtosis'] = kurtosis(smoothed_envelope)
        else:
            # Zero envelope case
            for name in ['5_percent', '2_percent']:
                features[f'duration_{name}'] = 0.0
            features['envelope_rise_time'] = 0.0
            features['peak_position_normalized'] = 0.5
            features['envelope_skewness'] = 0.0
            features['envelope_kurtosis'] = 0.0
        
        # Improved slowness analysis during signal activity
        trigger_sample = int(window_before * sample_rate)
        slowness_features = analyze_slowness_during_signal(
            ar_slowness, ar_back_azimuth, ar_incidence_angle,
            smoothed_envelope, trigger_sample, sample_rate
        )
        features.update(slowness_features)
        
    except Exception as e:
        logger.warning(f"Error computing envelope/slowness features: {e}")
        # Set defaults
        for name in ['5_percent', '2_percent']:
            features[f'duration_{name}'] = 0.0
        features['envelope_rise_time'] = 0.0
        features['peak_position_normalized'] = 0.5
        features['envelope_skewness'] = 0.0
        features['envelope_kurtosis'] = 0.0
        for key in ['slowness_min_during_signal', 'slowness_mean_during_signal', 
                   'slowness_std_during_signal', 'slowness_change_during_signal',
                   'back_azimuth_at_min_slowness', 'incidence_angle_at_min_slowness']:
            features[key] = np.nan
    
    return features

def main():
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    
    # Load trigger times
    try:
        if input_file.endswith('.csv'):
            df = pd.read_csv(input_file)
            time_column = df.columns[0]
            trigger_times = [UTCDateTime(ts) for ts in df[time_column]]
        else:
            with open(input_file, 'r') as f:
                trigger_times = [UTCDateTime(line.strip()) for line in f if line.strip()]
        logger.info(f"Loaded {len(trigger_times)} trigger times")
    except Exception as e:
        logger.error(f"Error loading trigger file: {e}")
        return
    
    all_features = []
    
    # Process each event
    for i, event_time in enumerate(trigger_times):
        try:
            logger.info(f"Processing event {i+1}/{len(trigger_times)} at {event_time}")
            event_features = {"event_time": str(event_time)}

            # Get waveform data
            try:
                st = clw.get_waveforms(network="4D", station="A2P1", location="*", channel="??Z", 
                                       starttime=(event_time-window_before), 
                                       endtime=event_time+window_after)
            except Exception as e:
                logger.error(f"Error getting waveform data: {e}")
                continue
            
            # Get array data
            try:
                ar = cl.get_waveforms(network="4D", station="A2P1", location="*", channel="ZG?", 
                                     starttime=(event_time-window_before), 
                                     endtime=event_time+window_after)
            except Exception as e:
                logger.warning(f"Error getting array data: {e}")
                ar = []
            
            if len(st) == 0:
                continue
                
            # Preprocess seismic waveforms only
            st.detrend("linear")
            st.taper(type='cosine', max_percentage=0.05)
            st.filter("bandpass", freqmin=1, freqmax=20)
            
            vert_traces = st.select(component="Z")
            if len(vert_traces) == 0:
                continue
                
            vert_tr = vert_traces[0]
            
            # Process array data
            back_azimuth = incidence_angle = slowness = None
            if len(ar) > 0:
                for channel_code, var_name in [("ZGS", "back_azimuth"), ("ZGA", "slowness")]:
                    traces = [tr for tr in ar if tr.stats.channel == channel_code]
                    if traces:
                        raw_data = traces[0].data
                        raw_sr = traces[0].stats.sampling_rate
                        mapped_data = map_array_to_waveform_indices(
                            raw_data, raw_sr, vert_tr.stats.sampling_rate, len(vert_tr.data)
                        )
                        if var_name == "back_azimuth":
                            back_azimuth = mapped_data
                        elif var_name == "slowness":
                            slowness = mapped_data
                
                # Calculate incidence angles from slowness instead of reading from ZGI
                if slowness is not None:
                    incidence_angle, wave_types = calculate_incidence_angles_from_slowness(slowness)
            
            # Extract features
            extracted_features = extract_features(vert_tr, back_azimuth, incidence_angle, slowness, event_time)
            
            if extracted_features:
                event_features.update(extracted_features)
                all_features.append(event_features)
                
        except Exception as e:
            logger.error(f"Error processing event at {event_time}: {e}")
            continue
    
    # Save results
    if all_features:
        df = pd.DataFrame(all_features)
        df.to_csv(output_file, index=False)
        logger.info(f"Saved {len(all_features)} feature sets to {output_file}")
        logger.info(f"Total features: {len(df.columns)-1}")
    else:
        logger.error("No features extracted")

if __name__ == "__main__":
    main()

# Feature Plots
## Assign class names manually in a column named class_names before running this
### The file is called features_labeled.csv.

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import os
from matplotlib.colors import to_rgba
import warnings
warnings.filterwarnings('ignore')

plt.style.use('classic')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12
plt.rcParams['axes.grid'] = True
output_dir = "./Grenzgletscher_fk/plots"
os.makedirs(output_dir, exist_ok=True)

def clean_numeric_value(value):
    """Clean numeric values that might have periods as thousand separators"""
    if pd.isna(value) or value == '':
        return np.nan
    
    if isinstance(value, str):
        # Remove periods used as thousand separators, but keep decimal points
        # First, check if it's scientific notation
        if 'E' in value.upper():
            try:
                return float(value.replace(',', '.'))
            except:
                return np.nan
        
        # Count periods to determine if they're thousand separators
        period_count = value.count('.')
        if period_count > 1:
            # Multiple periods likely means thousand separators
            # Keep only the last period as decimal point
            parts = value.split('.')
            if len(parts) > 1:
                integer_part = ''.join(parts[:-1])
                decimal_part = parts[-1]
                cleaned_value = f"{integer_part}.{decimal_part}"
            else:
                cleaned_value = value
        else:
            cleaned_value = value
        
        try:
            return float(cleaned_value)
        except:
            return np.nan
    
    return float(value) if not pd.isna(value) else np.nan

def load_data(filepath):
    print(f"Loading data from {filepath}")
    
    try:
        # Try reading with different separators
        if filepath.endswith('.csv'):
            try:
                df = pd.read_csv(filepath, sep=';')
            except:
                try:
                    df = pd.read_csv(filepath, sep=',')
                except:
                    df = pd.read_csv(filepath, sep='\t')
        else:
            df = pd.read_csv(filepath)
    except Exception as e:
        print(f"Error reading file: {e}")
        return None
    
    print(f"Initial shape: {df.shape}")
    print(f"Columns: {list(df.columns)}")
    
    # Handle class column naming
    if 'class_name' in df.columns and 'class' not in df.columns:
        df = df.rename(columns={'class_name': 'class'})
    
    # Clean numeric columns
    numeric_columns = []
    for col in df.columns:
        if col not in ['event_time', 'class', 'class_name']:
            try:
                # Try to convert the first few non-null values
                sample_values = df[col].dropna().head(10)
                cleaned_sample = [clean_numeric_value(val) for val in sample_values]
                if any(not pd.isna(val) for val in cleaned_sample):
                    print(f"Cleaning numeric column: {col}")
                    df[col] = df[col].apply(clean_numeric_value)
                    numeric_columns.append(col)
            except Exception as e:
                print(f"Could not convert column {col} to numeric: {e}")
    
    print(f"Successfully converted {len(numeric_columns)} numeric columns")
    
    # Handle time column
    if 'event_time' in df.columns:
        try:
            df['event_time'] = pd.to_datetime(df['event_time'])
            df = df.sort_values('event_time')
            
            first_time = df['event_time'].min()
            df['time_delta_min'] = (df['event_time'] - first_time).dt.total_seconds() / 60
            df['time_delta_sec'] = (df['event_time'] - first_time).dt.total_seconds()
            df['time_delta_days'] = (df['event_time'] - first_time).dt.total_seconds() / (60*60*24)
            
            df['day_of_year'] = df['event_time'].dt.dayofyear
            df['hour_of_day'] = df['event_time'].dt.hour
        except Exception as e:
            print(f"Could not process event_time: {e}")
            df['event_time'] = pd.date_range(start='2024-03-19', periods=len(df), freq='H')
            df['time_delta_min'] = np.arange(len(df))
            df['time_delta_sec'] = np.arange(len(df)) * 60
            df['time_delta_days'] = np.arange(len(df)) / 24
    else:
        df['event_time'] = pd.date_range(start='2024-03-19', periods=len(df), freq='H')
        df['time_delta_min'] = np.arange(len(df))
        df['time_delta_sec'] = np.arange(len(df)) * 60
        df['time_delta_days'] = np.arange(len(df)) / 24
    
    print(f"Final shape: {df.shape}")
    print(f"Classes found: {sorted(df['class'].unique()) if 'class' in df.columns else 'No class column'}")
    print(f"Loaded {len(df)} events with {len(df.columns)} features")
    
    return df

def get_class_colors(classes):
    color_palette = {
        0: 'tab:blue',
        1: 'tab:orange', 
        2: 'tab:green',
        3: 'tab:red',
        4: 'tab:purple',
        5: 'tab:brown',
        6: 'tab:pink',
        7: 'tab:gray',
        8: 'tab:olive',
        9: 'tab:cyan'
    }
    
    # Handle string class names
    unique_classes = sorted(classes)
    if isinstance(unique_classes[0], str):
        color_map = {}
        colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 
                 'tab:brown', 'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']
        for i, cls in enumerate(unique_classes):
            color_map[cls] = colors[i % len(colors)]
        return color_map
    else:
        return {cls: color_palette[i % 10] for i, cls in enumerate(unique_classes)}

def save_plot(fig, filename):
    filepath = os.path.join(output_dir, filename)
    fig.savefig(filepath, dpi=300, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved: {filepath}")

def plot_boxplots_for_all_features(df):
    exclude_cols = ['event_time', 'time_delta_min', 'time_delta_sec', 'time_delta_days', 
                    'day_of_year', 'hour_of_day', 'class', 'class_name']
    feature_cols = [col for col in df.columns 
                   if col not in exclude_cols and pd.api.types.is_numeric_dtype(df[col])]
    
    if not feature_cols:
        print("No numeric feature columns found for box plots!")
        return
    
    print(f"Creating box plots for {len(feature_cols)} features")
    
    classes = sorted(df['class'].unique())
    class_colors = get_class_colors(classes)
    
    # Create box plots for each feature across all classes
    for i, col in enumerate(feature_cols):
        print(f"Processing feature {i+1}/{len(feature_cols)}: {col}")
        
        fig, ax = plt.subplots(figsize=(14, 8))
        
        # Prepare data for each class (using absolute values)
        class_data = []
        class_labels = []
        for cls in classes:
            data = df[df['class'] == cls][col].dropna()
            if len(data) > 0:
                class_data.append(np.abs(data.values))  # Take absolute values
                # Use class names directly, just capitalize them
                class_labels.append(str(cls).capitalize())
        
        if not class_data:
            print(f"No data found for feature {col}")
            plt.close(fig)
            continue
        
        # Create box plot
        try:
            bplot = ax.boxplot(class_data, 
                      patch_artist=True,
                      notch=False,
                      showfliers=True,
                      widths=0.6,
                      medianprops={'color': 'red', 'linewidth': 1.5},
                      whiskerprops={'linewidth': 1.2},
                      capprops={'linewidth': 1.2},
                      boxprops={'linewidth': 1.2})
            
            # Color the boxes according to class colors
            for j, box in enumerate(bplot['boxes']):
                if j < len(classes):
                    box_color = class_colors[classes[j]]
                    box.set(facecolor=box_color, alpha=0.6)
            
            # Add scatter points with slight jitter for better distribution visibility
            for j, cls in enumerate(classes):
                if j >= len(class_data):
                    continue
                y = class_data[j]  # Already absolute values from above
                if len(y) > 0:
                    # Limit number of points for readability
                    if len(y) > 1000:
                        idx = np.random.choice(len(y), 1000, replace=False)
                        y = y[idx]
                    
                    x = np.random.normal(j+1, 0.08, size=len(y))
                    ax.scatter(x, y, alpha=0.4, s=8, c=class_colors[cls], edgecolor='none')
            
            ax.set_title(f'Distribution of |{col}| by Class', fontsize=14, fontweight='bold')
            ax.set_ylabel(f'|{col}|', fontsize=12)
            ax.set_xlabel('Class', fontsize=12)
            ax.set_xticklabels(class_labels)
            ax.grid(True, linestyle='--', alpha=0.7)
            
            # Add statistics text
            stats_text = f"Total samples: {len(df)}\n"
            for j, cls in enumerate(classes):
                count = len(df[df['class'] == cls])
                stats_text += f"{str(cls).capitalize()}: {count} samples\n"
            
            ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, 
                   verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
            
            plt.tight_layout()
            
            # Clean filename
            clean_col_name = col.replace('/', '_').replace(' ', '_').replace('(', '').replace(')', '')
            save_plot(fig, f'boxplot_{clean_col_name}.png')
            
        except Exception as e:
            print(f"Error creating box plot for {col}: {e}")
            plt.close(fig)
    
    # Create summary statistics (using absolute values)
    try:
        print("\nCreating summary statistics...")
        # Apply absolute values to feature columns for summary statistics
        abs_df = df.copy()
        for col in feature_cols:
            abs_df[col] = np.abs(df[col])
        
        summary_stats = abs_df[feature_cols + ['class']].groupby('class').agg(['mean', 'std', 'median', 'min', 'max'])
        summary_stats.to_csv(os.path.join(output_dir, 'summary_statistics_by_class_absolute.csv'))
        print("Summary statistics (absolute values) saved to summary_statistics_by_class_absolute.csv")
    except Exception as e:
        print(f"Error creating summary statistics: {e}")

def main():
    # Path to the labeled features CSV file
    input_file = "./Grenzgletscher_fk/known_features_labeled.csv"
    
    # Check if file exists
    if not os.path.exists(input_file):
        print(f"Error: File {input_file} not found!")
        return
    
    # Load data
    df = load_data(input_file)
    
    if df is None:
        print("Failed to load data!")
        return
    
    # Check if we have a class column
    if 'class' not in df.columns:
        print("Error: No 'class' column found in the data!")
        return
    
    # Run box plot visualization
    print("Starting box plot creation...")
    plot_boxplots_for_all_features(df)
    
    print("Box plot visualizations completed!")
    print(f"All plots saved to: {output_dir}")

if __name__ == "__main__":
    main()

# Random Forest
The Random Forest tutorial with the original code was written by Michaela Wenner and presented at the 9th Munich Earth Skience School (MESS 2019). The material is available online at https://github.com/krischer/mess_2019/blob/master/3_wednesday/random_forest.ipynb.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.utils.class_weight import compute_class_weight
import joblib
import warnings

# Load data
df = pd.read_csv("./Grenzgletscher_fk/features_labeled.csv", sep=";")

# Split features and labels
label_column = "class_name"

# Drop the event_time column - we don't want time as a feature
X = df.drop(columns=[label_column, "event_time"])
y = df[label_column]

# Analyze class distribution
print(f"\nClass distribution:")
class_counts = pd.Series(y).value_counts()
print(class_counts)
print(f"\nClass percentages:")
class_percentages = (class_counts / len(y) * 100).round(2)
print(class_percentages)

# Identify classes with very few samples (less than 2% of data or less than 10 samples)
min_samples_threshold = max(10, len(y) * 0.02)
rare_classes = class_counts[class_counts < min_samples_threshold].index.tolist()
if rare_classes:
    print(f"\nWarning: Classes with very few samples (< {min_samples_threshold:.0f}): {rare_classes}")
    print("Consider combining these with similar classes or collecting more data.")

# Encode labels if they're words
le = LabelEncoder()
y_encoded = le.fit_transform(y)
class_names = le.classes_

print(f"\nClasses found: {class_names}")

# Calculate class weights to handle imbalance
class_weights = compute_class_weight('balanced', classes=np.unique(y_encoded), y=y_encoded)
class_weight_dict = dict(zip(np.unique(y_encoded), class_weights))
print(f"\nClass weights for balancing:")
for i, weight in class_weight_dict.items():
    print(f"  {class_names[i]}: {weight:.3f}")

# Split into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(
    X, y_encoded, test_size=0.3, random_state=42, stratify=y_encoded
)

print(f"\nTraining set size: {X_train.shape}")
print(f"Test set size: {X_test.shape}")

# Check class distribution in train/test sets
print(f"\nTraining set class distribution:")
train_class_counts = pd.Series(y_train).value_counts()
for i, count in train_class_counts.items():
    print(f"  {class_names[i]}: {count}")

print(f"\nTest set class distribution:")
test_class_counts = pd.Series(y_test).value_counts()
for i, count in test_class_counts.items():
    print(f"  {class_names[i]}: {count}")

# Train the Random Forest classifier with class balancing
clf = RandomForestClassifier(
    n_estimators=300,  
    max_depth=25,      
    min_samples_split=5,  
    min_samples_leaf=2,   
    class_weight='balanced',  # Handle class imbalance
    oob_score=True,
    random_state=42,
    bootstrap=True,
    n_jobs=-1  # Use all available cores
)

print("\nTraining Random Forest model with balanced class weights...")
clf.fit(X_train, y_train)

# Cross-validation to get more robust performance estimate
cv_scores = cross_val_score(clf, X_train, y_train, cv=StratifiedKFold(n_splits=5, shuffle=True, random_state=42), scoring='accuracy')
print(f"\nCross-validation scores: {cv_scores}")
print(f"Mean CV accuracy: {cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})")

# Evaluate the model
y_pred = clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)

print(f"\nModel Performance:")
print(f"Test Accuracy: {accuracy:.4f}")
print(f"OOB Score: {clf.oob_score_:.4f}")

# Classification report with zero_division parameter to suppress warnings
print("\nClassification Report:")
print(classification_report(
    y_test, y_pred, 
    target_names=class_names, 
    zero_division=0  
))

# Detailed per-class analysis
print("\nDetailed per-class analysis:")
for i, class_name in enumerate(class_names):
    true_count = np.sum(y_test == i)
    pred_count = np.sum(y_pred == i)
    correct_count = np.sum((y_test == i) & (y_pred == i))
    
    if true_count > 0:
        recall = correct_count / true_count
    else:
        recall = 0
        
    if pred_count > 0:
        precision = correct_count / pred_count
    else:
        precision = 0
        
    print(f"  {class_name}:")
    print(f"    True samples: {true_count}, Predicted: {pred_count}, Correct: {correct_count}")
    print(f"    Precision: {precision:.3f}, Recall: {recall:.3f}")

# Confusion Matrix in percentages
plt.figure(figsize=(12, 10))
cm = confusion_matrix(y_test, y_pred)
cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
# Replace NaN values (from division by zero) with 0
cm_percent = np.nan_to_num(cm_percent, nan=0.0)
sns.heatmap(cm_percent, annot=True, fmt='.1f', xticklabels=class_names, yticklabels=class_names, 
            cmap="Blues", cbar_kws={'label': 'Percentage (%)'})
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix (in percent)")
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig("confusion_matrix_percent_0.2.png", dpi=300, bbox_inches='tight')
plt.show()

# Feature Importance
importances = pd.Series(clf.feature_importances_, index=X.columns)

# Most Important Features
plt.figure(figsize=(12, 8))
importances.nlargest(15).plot(kind='barh', color='teal')
plt.title("Top 15 Most Important Features")
plt.xlabel("Relative Importance")
plt.tight_layout()
plt.show()

# Least Important Features
plt.figure(figsize=(12, 8))
importances.nsmallest(15).plot(kind='barh', color='coral')
plt.title("Top 15 Least Important Features")
plt.xlabel("Relative Importance")
plt.tight_layout()
plt.show()

# Print top features
print(f"\nTop 10 Most Important Features:")
for i, (feature, importance) in enumerate(importances.nlargest(10).items(), 1):
    print(f"{i:2d}. {feature}: {importance:.4f}")

# Print least important features
print(f"\nTop 10 Least Important Features:")
for i, (feature, importance) in enumerate(importances.nsmallest(10).items(), 1):
    print(f"{i:2d}. {feature}: {importance:.4f}")

# Analyze prediction confidence
y_pred_proba = clf.predict_proba(X_test)
max_probabilities = np.max(y_pred_proba, axis=1)

print(f"\nPrediction confidence analysis:")
print(f"Mean prediction confidence: {max_probabilities.mean():.3f}")
print(f"Min prediction confidence: {max_probabilities.min():.3f}")
print(f"Max prediction confidence: {max_probabilities.max():.3f}")

# Find low-confidence predictions
low_confidence_threshold = 0.5
low_confidence_mask = max_probabilities < low_confidence_threshold
if np.any(low_confidence_mask):
    print(f"\nFound {np.sum(low_confidence_mask)} predictions with confidence < {low_confidence_threshold}")
    print("These might be misclassified or difficult cases.")

# Save Model and encoders
print("\nSaving model and encoders...")
joblib.dump(clf, "random_forest_model.pkl")
joblib.dump(le, "label_encoder.pkl")

# Save feature names for later use
feature_names = X.columns.tolist()
joblib.dump(feature_names, "feature_names.pkl")

print("Model training and evaluation complete!")
print(f"Files saved:")
print(f"  - Model: random_forest_model.pkl")
print(f"  - Label encoder: label_encoder.pkl")
print(f"  - Feature names: feature_names.pkl")
print(f"  - Confusion matrix: confusion_matrix_percent.png")

# Additional recommendations
if rare_classes:
    print(f"\nRecommendations:")
    print(f"Consider collecting more data for classes: {rare_classes}")

# Trained Random Forest on a new dataset
## Dataset mustn't have class names
### File is called 2025_features.csv

In [None]:
# Load Model again

import joblib
import pandas as pd

# Load model and encoder
clf = joblib.load("random_forest_model.pkl")
le = joblib.load("label_encoder.pkl")

# Load new data to classify (must have the same feature columns)
new_data = pd.read_csv("2025_features.csv")  # No 'class_name' column

# Predict
predictions = clf.predict(new_data)

# Decode predicted labels
predicted_labels = le.inverse_transform(predictions)
print(predicted_labels)

# Save predictions to csv
output = pd.DataFrame({
    "event_time": new_data["event_time"],
    "predicted_event": predicted_labels
})
output.to_csv("20205_predictions.csv", index=False)