# Automatic artifact detection

In [None]:
import mne
import numpy as np
import yasa
import os
import logging

# Setup logging
logging.basicConfig(filename='artifact_detection.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Define your directory paths
data_dir = r'D:\EDF'
hypno_dir = r'D:\hypno'
output_dir = r'D:\automatic_artefacts'
window = 2

# Get a list of all EDF files in the data directory
edf_files = [f for f in os.listdir(data_dir) if f.endswith('.edf')]

# Get a list of already processed files to skip them
processed_files = {f.replace('.art.txt', '') for f in os.listdir(output_dir) if f.endswith('.art.txt')}

# Process each file
for edf_file in edf_files:
    participant = edf_file.split('.')[0]
    if participant in processed_files:
        continue  # Skip already processed files
    
    try:
        raw_path = os.path.join(data_dir, edf_file)
        hypnogram_path = os.path.join(hypno_dir, f'{participant}.hyp.txt')
        output_path = os.path.join(output_dir, f'{participant}.art.txt')

        # Load the EDF file without preloading data to get metadata
        raw = mne.io.read_raw_edf(raw_path, preload=False)

        # Calculate total duration in seconds and the corresponding length of the artifact array
        total_duration_secs = int(raw.times[-1])
        total_duration_4s_intervals = (total_duration_secs + 3) // 4

        # Initialize the results array with 3 for non-REM phases
        artifacts_array = np.full(total_duration_4s_intervals, 3, dtype=int)

        # Load the hypnogram file and identify REM sleep phases
        hypno = np.loadtxt(hypnogram_path, dtype=int)
        indices = np.where(hypno == 2)[0]
        start_times, end_times = [], []

        for i in range(len(indices)):
            if i == 0 or indices[i-1] != indices[i] - 1:
                start_times.append(indices[i] * 20)
            if i == len(indices) - 1 or indices[i+1] != indices[i] + 1:
                end_times.append((indices[i] + 1) * 20)

        intervals = list(zip(start_times, end_times))

        for start_time, end_time in intervals:
            # Adjust end_time if it's beyond the recording's limit
            if end_time > total_duration_secs:
                end_time = total_duration_secs - 1  # Reduce by one second to avoid cropping error

            raw_cropped = raw.copy().crop(tmin=start_time, tmax=end_time).load_data()
            sf = raw_cropped.info['sfreq']
            data_cropped = raw_cropped.get_data()
            art_det = yasa.art_detect(data_cropped, sf, window=window)
            art_det_array = np.array(art_det[0])
            artifact_values = np.where(art_det_array, 2, 1)

            interval_start_index = start_time // 4
            interval_end_index = min(total_duration_4s_intervals, interval_start_index + len(artifact_values))
            artifacts_array[interval_start_index:interval_end_index] = artifact_values[:interval_end_index - interval_start_index]

        np.savetxt(output_path, artifacts_array, fmt='%d', newline='\r')
        logging.info(f'Successfully processed and saved {participant}')
    except Exception as e:
        logging.error(f'Failed to process {participant}: {str(e)}')

logging.info('All files processed.')
