In [1]:
import sys
!{sys.executable} -m pip install scikit-learn



In [13]:
import os
import numpy as np

plot_folder = os.path.expanduser("~/plots")
os.makedirs(plot_folder, exist_ok=True)

data_folder = "/dda2/enia"  # Set full path to your data folder
rat = "jc320"
day = "240924"
sessions = ["training1"]

# ===> Parameters (edit as needed)
def load_data(data_folder, rat, day, sessions):

    import pandas as pd
    base_path = os.path.join(data_folder, f"{rat}-{day}")
    
    # ----------- Helper: Load and concat files -----------
    def load_and_concat(filenames, axis=0):
        arrays = [np.loadtxt(f) for f in filenames]
        return np.concatenate(arrays, axis=axis)
    
    def load_des_with_noise(des_path):
        des = pd.read_csv(des_path, header=None, names=["type"])
        noise_df = pd.DataFrame({"type": ["noise", "multiunit"]})
        des_full = pd.concat([noise_df, des], ignore_index=True)
        return des_full
    
    def load_txt_list(path_list):
        return [list(map(int, open(p).read().split())) for p in path_list]
    
    # ----------- Initialize containers -----------
    res_all = []
    clu_all = []
    whl_all = []
    whl_speed_all = []
    reward_arms_all = []
    all_arms_all = []
    trials_all = []
    
    lwhl_raw_all = []
    lwhl_all = []
    lwhl_speed_all = []

    # ----------- Gather all paths-----------
    for session in sessions:
        print(f"🔄 Loading session: {session}")
    
        # .res and .clu
        res_path = os.path.join(base_path, f"{rat}-{day}_{session}.res")
        clu_path = os.path.join(base_path, f"{rat}-{day}_{session}.clu")
        res = np.loadtxt(res_path, dtype=int)
        clu = np.loadtxt(clu_path, dtype=int)[1:]  # skip first element
        res_all.append(res)
        clu_all.append(clu)
    
        # File paths only (headers will be preserved later if needed)
        whl_all.append(os.path.join(base_path, f"{rat}-{day}_{session}.whl"))
        whl_speed_all.append(os.path.join(base_path, f"{rat}_{day}_{session}.speed"))
        trials_all.append(os.path.join(base_path, f"{rat}_{day}_{session}.trials"))
        lwhl_raw_all.append(os.path.join(base_path, f"{rat}_{day}_{session}.lwhl_raw"))
        lwhl_all.append(os.path.join(base_path, f"{rat}_{day}_{session}.lwhl"))
        lwhl_speed_all.append(os.path.join(base_path, f"{rat}_{day}_{session}.lwhl_speed"))
        reward_arms_all.append(os.path.join(base_path, f"{rat}-{day}_{session}.reward_arms"))
        all_arms_all.append(os.path.join(base_path, f"{rat}-{day}_{session}.all_arms"))
    
    # ----------- Concatenate .res and .clu -----------
    res = np.concatenate(res_all)
    clu = np.concatenate(clu_all)
    print(f"res length: {len(res)}")
    print(f"clu length: {len(clu)} (should match res)")
    
    # ----------- Load .des with noise/multiunit types prepended -----------
    des_path = os.path.join(base_path, f"{rat}-{day}.des")
    putative_type = load_des_with_noise(des_path)
    print(f"des loaded: {len(putative_type)} total cluster types")

    # ----------- Load and concat .whl (no header) -----------
    whl_list = [pd.read_csv(f, sep=r"\s+", header=None) for f in whl_all]
    whl = pd.concat(whl_list, ignore_index=True)
    print(f"whl shape: {whl.shape}")
    
    # ----------- Load and concat .speed robustly (no header) -----------
    speed_all = []
    for speed_path in whl_speed_all:
        try:
            speed_df = pd.read_csv(speed_path, sep=None, engine="python", header=None)
            speed_col = speed_df.select_dtypes(include=[np.number]).iloc[:, 0]
            speed_all.append(speed_col.values)
        except Exception as e:
            print(f"Error loading {speed_path}: {e}")
    
    speed = np.concatenate(speed_all)
    print(f"speed length: {len(speed)}")
    print(f"NaNs in speed: {np.isnan(speed).sum()} / {len(speed)}")

    # ----------- Load reward_arms and all_arms -----------
    reward_arms = load_txt_list(reward_arms_all)
    print(f"Loaded {len(reward_arms)} reward_arms files")
    
    all_arms = []
    for file_path in all_arms_all:
        with open(file_path, "r") as f:
            for line in f:
                if line.strip():
                    arms = [int(x) for x in line.strip().split()]
                    all_arms.append(arms)
    
    print(f"Loaded {len(all_arms)} trials from all_arms files")
    print("Example trial arms:", all_arms[0])
    
    # ----------- Load and parse .trials (no header) -----------
    trials_segments = []
    for tpath in trials_all:
        with open(tpath, "r") as f:
            for line in f:
                line = line.strip()
                if line:
                    times = [int(x) for x in line.split()]
                    trials_segments.append(times)
    
    print(f"Loaded {len(trials_segments)} trial segments from .trials")
    #print("⏱️  Example trial timestamps:", trials_segments[0])

    # ----------- Load .lwhl_raw, .lwhl and .lwhl_speed (WITH header) -----------
    lwhl_raw_list = [pd.read_csv(f, sep=r"\s+", header=0) for f in lwhl_raw_all]
    lwhl_list = [pd.read_csv(f, sep=r"\s+", header=0) for f in lwhl_all]
    lwhl_speed_list = [pd.read_csv(f, sep=r"\s+", header=0) for f in lwhl_speed_all]
    
    lwhl_raw   = pd.concat(lwhl_raw_list, ignore_index=True)
    lwhl       = pd.concat(lwhl_list, ignore_index=True)
    lwhl_speed = pd.concat(lwhl_speed_list, ignore_index=True)
    
    #print(f"lwhl_raw shape: {lwhl_raw.shape}")
    print(f"lwhl shape: {lwhl.shape}")
    print(f"lwhl_speed shape: {lwhl_speed.shape}")

    return res, clu, putative_type, whl, speed, reward_arms, all_arms, trials_segments, lwhl_raw, lwhl, lwhl_speed

In [14]:
def clean_speed(speed):
    speed = np.nan_to_num(speed, nan=0.0)
    speed = np.abs(speed)
    print(f"Cleaned speed: {len(speed)} values")
    return speed

In [15]:
def extract_trials_events_armids(trials_segments):

    trial_intervals = []
    for segment in trials_segments:
        if len(segment) >= 2:
            trial_intervals.append([segment[0], segment[-1]])
    
    print(f"Total trials: {len(trial_intervals)}")
    print("First trial:", trial_intervals[0])

    event_intervals = []
    event_arm_ids = []
    
    for trial in trials_segments:
        # Skip empty or incomplete
        if len(trial) < 3:
            continue
    
        # Remove start/end of trial to isolate events
        events = trial[1:-1]
    
        # Each event has 4 timestamps: in, reward1, reward2, out
        for i in range(0, len(events), 4):
            try:
                t_in, r1, r2, t_out = events[i:i+4]
                event_intervals.append([t_in, t_out])
            except ValueError:
                continue  # skip incomplete events
    
    # Flatten all_arms and sync with event count
    event_arm_ids = [arm for trial_arms in all_arms for arm in trial_arms]
    
    assert len(event_intervals) == len(event_arm_ids), \
        f"Mismatch: {len(event_intervals)} events vs {len(event_arm_ids)} arm_ids"
    print(f"Extracted {len(event_intervals)} events with arm IDs.")

    return trial_intervals, event_intervals, event_arm_ids

In [16]:
def convert_nested_list_to_ms(nested_list, sampling_rate_hz):
    """
    Converts a list of lists of timestamps into samples on ms 
    """
    return [(np.array(vec) / sampling_rate_hz) * 1000 for vec in nested_list]

In [17]:
res, clu, putative_type, whl, speed, reward_arms, all_arms, trials_segments, lwhl_raw, lwhl, lwhl_speed = load_data(data_folder, rat, day, sessions)

speed = clean_speed(speed)

🔄 Loading session: training1
res length: 16382116
clu length: 16382116 (should match res)
des loaded: 508 total cluster types
whl shape: (295727, 2)
speed length: 295727
NaNs in speed: 39 / 295727
Loaded 1 reward_arms files
Loaded 39 trials from all_arms files
Example trial arms: [4, 7, 1, 5, 1, 2, 8]
Loaded 39 trial segments from .trials
lwhl shape: (295727, 7)
lwhl_speed shape: (295727, 4)
Cleaned speed: 295727 values


In [18]:
trial_intervals, event_intervals, event_arm_ids = extract_trials_events_armids(trials_segments)

trials_segments_ms = convert_nested_list_to_ms(trials_segments, sampling_rate_hz=39.0625)
event_intervals_ms = convert_nested_list_to_ms(event_intervals, sampling_rate_hz=39.0625)

Total trials: 39
First trial: [0, 5793]
Extracted 180 events with arm IDs.


In [19]:
import numpy as np
import pandas as pd
from typing import List, Tuple, Optional

def extract_event_trial_ids(lwhl: pd.DataFrame,
    events_ms: List[Tuple[float, float]],
    *,
    sampling_rate_wheel: float = 39.0625,
    trial_col: str = "trial_id", min_fraction: float = 0.5,
    fill_on_empty: bool = True,) -> np.ndarray:
    """
    Devuelve un vector (n_events,) con el trial_id 'modo' (más frecuente) dentro de cada intervalo [start_ms, end_ms).
    - lwhl.index son índices de muestreo (0..N-1) a 39.0625 Hz (25.6 ms/ muestra).
    - Si un evento cae entre dos trials, toma la moda. Si la moda < min_fraction del intervalo, marca -1 (ambiguo).
    - Si no hay muestras dentro del intervalo y fill_on_empty=True, usa el 'trial_ID' del sample más cercano al inicio; si no, -1.

    Consejos:
    - Si tu columna tiene NaNs (p. ej. en el centro), la moda se calcula ignorando NaNs.
    """
    assert trial_col in lwhl.columns, f"Columna '{trial_col}' no encontrada en lwhl"
    ms_per_sample = 1000.0 / float(sampling_rate_wheel)
    n = len(lwhl)

    trial_ids = np.full(len(events_ms), -1, dtype=int)

    # vector numpy del trial_ID (puede contener NaNs)
    col = lwhl[trial_col].to_numpy()

    for i, (start_ms, end_ms) in enumerate(events_ms):
        # índices [start_idx, end_idx) en muestras
        start_idx = int(np.floor(start_ms / ms_per_sample))
        end_idx   = int(np.ceil (end_ms   / ms_per_sample))

        # clamp
        start_idx = max(0, min(start_idx, n-1))
        end_idx   = max(start_idx+1, min(end_idx, n))  # end exclusivo

        seg = col[start_idx:end_idx]

        # ignora NaNs
        if seg.dtype.kind in "fF" or np.issubdtype(seg.dtype, np.floating):
            seg_valid = seg[~np.isnan(seg)]
        else:
            seg_valid = seg

        if seg_valid.size == 0:
            if fill_on_empty:
                # usa el sample más cercano al inicio
                near_idx = int(np.clip(round(start_ms / ms_per_sample), 0, n-1))
                val = col[near_idx]
                if (isinstance(val, (float, np.floating)) and np.isnan(val)):
                    trial_ids[i] = -1
                else:
                    trial_ids[i] = int(val)
            else:
                trial_ids[i] = -1
            continue

        # moda + fracción
        vals, counts = np.unique(seg_valid.astype(int), return_counts=True)
        j = int(np.argmax(counts))
        mode_val, mode_cnt = int(vals[j]), int(counts[j])
        frac = mode_cnt / float(seg_valid.size)

        trial_ids[i] = mode_val if frac >= float(min_fraction) else -1

    return trial_ids

In [24]:
trial_ids_all = extract_event_trial_ids(lwhl, event_intervals_ms, sampling_rate_wheel=39.0625)

import numpy as np

rewarded_arms_session = np.array(reward_arms, dtype=int)      
rewarded_mask = np.isin(event_arm_ids, rewarded_arms_session)

idx_rewards = np.where(rewarded_mask)[0]
trial_ids_rewarded = trial_ids_all[idx_rewards]

pre_trial_ids_all = []
pre_trial_ids_rewarded = []

for i in range(1, len(trial_ids_all)):
    current_trial = trial_ids_all[i]
    previous_trial = trial_ids_all[i-1]
    pre_trial_ids_all.append(current_trial) # label of the trial of the NEXT event

for i in range(1, len(trial_ids_rewarded)):
    current_trial = trial_ids_rewarded[i]
    previous_trial = trial_ids_rewarded[i-1]
    pre_trial_ids_rewarded.append(current_trial) # label of the trial of the NEXT event

In [26]:
import pickle
with open("raw_data.pkl", "wb") as f:
        pickle.dump({
            "res": res,
            "clu": clu,
            "putative_type": putative_type,
            "whl": whl,
            "speed": speed,
            "reward_arms": reward_arms,
            "all_arms": all_arms,
            "lwhl": lwhl,
            "trials_segments_ms": trials_segments_ms,
            "event_intervals_ms": event_intervals_ms,
            "event_arm_ids": event_arm_ids,
            "trial_ids_all": trial_ids_all,
            "trial_ids_rewarded":trial_ids_rewarded}, f)
    
print("Saved as raw_data.pkl")

Saved as raw_data.pkl
