In [1]:
import csv
from argparse import ArgumentParser

import numpy as np
import pandas as pd
from pathlib import Path
import mysql.connector

In [2]:
root = Path(r"D:\recording\2024-05-02-13-37-29")
results_dir = root / 'kilosort4'

# One value per cluster
camps = pd.read_csv(results_dir / 'cluster_Amplitude.tsv', sep='\t')['Amplitude'].values
clu_labels = pd.read_csv(results_dir / 'cluster_KSLabel.tsv', sep='\t')['KSLabel'].values
contam_pct = pd.read_csv(results_dir / 'cluster_ContamPct.tsv', sep='\t')['ContamPct'].values
chan_map =  np.load(results_dir / 'channel_map.npy')
templates =  np.load(results_dir / 'templates.npy')
chan_best = (templates**2).sum(axis=1).argmax(axis=-1)
chan_best = chan_map[chan_best]

# One value per spike
amplitudes = np.load(results_dir / 'amplitudes.npy')
st = np.load(results_dir / 'spike_times.npy')
clu = np.load(results_dir / 'spike_clusters.npy')

sr = 20_000

subfolders, n_samples = [], []
with open(root / 'concatenated.csv', newline='') as f:
    csvfile = csv.reader(f)
    _ = next(csvfile)  # skip header
    for (subfolder, n) in csvfile:
        subfolders.append(subfolder)
        n_samples.append(int(n))

In [3]:
n_samples = np.array(n_samples)
cum_samples = np.cumsum(n_samples)
cum_samples0 = np.append(0, cum_samples)

In [4]:
db = mysql.connector.connect(
    user='xper_rw',
    password='up2nite',
    host='172.30.6.54',
    database='SpikeData',
    client_flags=[mysql.connector.constants.ClientFlag.LOCAL_FILES],
    allow_local_infile=True,
)

In [5]:
notefiles = [f.replace('experiment', 'notes') + '.txt' for f in subfolders]
nf_str = ', '.join(f'"{f}"' for f in notefiles)

In [6]:
curs = db.cursor()
curs.execute('INSERT INTO ClusteredSessions () VALUES ()')
curs.execute('SELECT MAX(session_id) FROM ClusteredSessions')
session_id, = curs.fetchone()

In [7]:
curs.execute(f"SELECT recording_id FROM Recordings WHERE notefile IN ({nf_str})")
recording_ids = [i for i, in curs.fetchall()]
rstr = ', '.join(str(i) for i in recording_ids)
curs.execute(f"DELETE FROM ClusteredRuns WHERE recording_id IN ({rstr})")

In [8]:
q = "INSERT INTO ClusteredRuns (session_id, recording_id) VALUES (%s, %s)"
curs.executemany(q, [(session_id, rid) for rid in recording_ids])

In [9]:
q = "INSERT INTO Clusters (session_id, cluster, amplitude, contam_pct, best_channel, label) VALUES (%s, %s, %s, %s, %s, %s)"
data = [(session_id, i, amp, cptc, int(best_ch), lbl) for i, (amp, cptc, best_ch, lbl) in enumerate(zip(camps, contam_pct, chan_best, clu_labels))]
curs.executemany(q, data)

In [10]:
curs.execute(f"""
    SELECT    notefile, trial_id, start_sample, stop_sample 
    FROM      Trials INNER JOIN Recordings USING (recording_id)
    WHERE     notefile IN ({nf_str})
    ORDER BY  recording_id, start_sample
""")
trials_df = pd.DataFrame.from_records(curs.fetchall(), columns=['notefile', 'trial_id', 'start_sample', 'stop_sample'])
trials_df['folder_idx'] = trials_df.notefile.apply(lambda f: notefiles.index(f))
trials_df.drop('notefile', inplace=True, axis=1)

In [11]:
folder_idx = np.searchsorted(cum_samples, st)
spikes_df = pd.DataFrame(dict(
    spike_time=st - cum_samples0[folder_idx], 
    folder_idx=folder_idx,
    cluster_idx=clu,
    amplitude=amplitudes,
))

In [14]:
def write_recording(spks, trials, temp_csv_file = r"data.csv"):
    trials['trial_idx'] = np.arange(len(trials))
    spks['trial_idx'] = np.searchsorted(trials.start_sample, spks.spike_time) - 1
    trial_spks = spks.join(trials, how='inner', on='trial_idx', lsuffix='l_', rsuffix='r_')
    trial_spks.query('start_sample < spike_time <= stop_sample')
    trial_spks['spike_time'] = (trial_spks.spike_time - trial_spks.start_sample) / sr
    trial_spks = trial_spks.loc[:, ['trial_id', 'cluster_idx', 'spike_time', 'amplitude']]
    trial_spks.rename(columns=dict(cluster_idx='channel', spike_time='time'), inplace=True)
    trial_spks.to_csv(temp_csv_file, index=False)
    
    trial_id_str = ','.join(str(i) for i in trials.trial_id.unique())
    q = f'DELETE FROM SortedSpikeTimes WHERE trial_id IN ({trial_id_str})'
    curs.execute(q)
    
    q = rf"""
        LOAD DATA LOCAL INFILE '{temp_csv_file}' 
        INTO TABLE SortedSpikeTimes
        FIELDS TERMINATED BY ','
        LINES TERMINATED by '\r\n'
        IGNORE 1 LINES
    """
    curs.execute(q)
    db.commit()

In [15]:
for i in range(len(subfolders)):
    write_recording(
        spks=spikes_df[spikes_df.folder_idx == i].drop('folder_idx', axis=1),
        trials=trials_df[trials_df.folder_idx == i].drop('folder_idx', axis=1),
    )