To do:
- include other sessions: 757216464, 750332458
- do dataset: neuron filtering (discard the faulty ones), filter for stimuli & brain regions, see how to join them 
- sanity checks - EDA
- train-test split
- run models (XGBoost, Logistic)
- accuracy checks

In [1]:
# libraries
# standard libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib

# allen data
from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# supress all user warnings
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

### 1. Include Other Sessions

In [28]:
# set up cache
data_dir = "/Users/emmamora/Documents/programming/neuroscience/allendata"
manifest_path = os.path.join(data_dir, "manifest.json")
cache = EcephysProjectCache.from_warehouse(manifest=manifest_path)

# choose sessions and load data
## IDs we're working with
session_id_1 = 757216464
session_id_2 = 750332458
## load each session's data
session_1 = cache.get_session_data(session_id_1)
session_2 = cache.get_session_data(session_id_2)

### 2. Dataset Creation

#### 2.1. Data Filtering (Unit Quality, Brain Regions & Stimuli)

**General remarks:**
- one trial = one presentation of a visual stimulus 
- we have significantly less stimuli (actual images) than trials => mouse has been shown the same image multriple times

**Quality units remarks**
- we want to filter out the low quality neurons, because:
    - some neurons have **low firing rate** (very little signal)
    - **high noise** (unreliable spikes = the spike that we're seeing might not be the actual spike of the neuron)
    - **unstable waveform** (may drift (which means that the recorded signal from a neuron changes position over time, which happens when the brain or the probe might move slightly during a long experiment) or not be isolated well (many neurons are firing from the same brain region, so their electrical activity overlaps))
- so we apply the following threshold:
    - for **firing_rate** >1Hz which excludes silent or barely active neurons (likely noise or poor data)
    - for **isi_violations** <0.5 which ensures spikes aren’t overlapping unnaturally — cleaner spike sorting
    - for **d_prime** >2 which measures how distinguishable the neuron is from noise — 2+ is strong
    - for **amplitude_cutoff** <0.1 which helps remove unstable waveforms — good for reliability
    - for **isolation_distance** >20 which indicates well-isolated neurons (less noise from other units nearby)
    
**Brain region remarks**
- we'll focus on the **visual cortex** => we need to filter the trials that have the following regions:
    - VISp
    - VISam
    - VISal
    - VISl
    - VISrl

**Stimuli remarks:**
- we'll focus on **static_gratings** and **drifting_gratings**
- static gratings
    - still images of oriented stripes with varying orientation/spatial frequency
- drifting gratings 
    - moving gratings that test direction selectivity and temporal frequency

In [40]:
# define visual areas and stimuli of interest
visual_areas = ["VISp", "VISam", "VISal", "VISl", "VISrl"]
stimuli_of_interest = ["static_gratings", "drifting_gratings"]

##### **Quality and Brain Region Selection**

Filtering for good neurons in the visual cortex:

In [41]:
# session 1
units_1_filtered = session_1.units[
    (session_1.units["firing_rate"] > 1) &              
    (session_1.units["isi_violations"] < 0.5) &         
    (session_1.units["d_prime"] > 2) &                  
    (session_1.units["amplitude_cutoff"] < 0.1) &       
    (session_1.units["isolation_distance"] > 20) &      
    (session_1.units["ecephys_structure_acronym"].isin(visual_areas))
].copy()

# session 2
units_2_filtered = session_2.units[
    (session_2.units["firing_rate"] > 1) &
    (session_2.units["isi_violations"] < 0.5) &
    (session_2.units["d_prime"] > 2) &
    (session_2.units["amplitude_cutoff"] < 0.1) &
    (session_2.units["isolation_distance"] > 20) &
    (session_2.units["ecephys_structure_acronym"].isin(visual_areas))
].copy()

Checks to see if quality and brain region filtering worked fine:

In [42]:
# neuron filtering metrics

# total number of neurons before filtering (from the original session)
print("Session 1 - Total neurons before filtering:", len(session_1.units))
print("Session 2 - Total neurons before filtering:", len(session_2.units))

# number of neurons after filtering for "good" units and visual cortex regions
print("Session 1 - Filtered neurons (good units in visual cortex):", len(units_1_filtered))
print("Session 2 - Filtered neurons (good units in visual cortex):", len(units_2_filtered))

# print counts per visual region
print("\nSession 1 - Filtered Neurons per Visual Region:")
print(units_1_filtered["ecephys_structure_acronym"].value_counts())

print("\nSession 2 - Filtered Neurons per Visual Region:")
print(units_2_filtered["ecephys_structure_acronym"].value_counts())

Session 1 - Total neurons before filtering: 959
Session 2 - Total neurons before filtering: 902
Session 1 - Filtered neurons (good units in visual cortex): 249
Session 2 - Filtered neurons (good units in visual cortex): 233

Session 1 - Filtered Neurons per Visual Region:
VISp     70
VISam    59
VISal    45
VISl     44
VISrl    31
Name: ecephys_structure_acronym, dtype: int64

Session 2 - Filtered Neurons per Visual Region:
VISam    55
VISal    55
VISp     52
VISrl    37
VISl     34
Name: ecephys_structure_acronym, dtype: int64


##### **Stimuli Selection**

Filtering for **static_gratings** and **drifting_gratings** in each session:

In [47]:
# session 1
stimuli_1_full = session_1.stimulus_presentations
stimuli_1_filtered = stimuli_1_full[stimuli_1_full["stimulus_name"].isin(stimuli_of_interest)].copy()

# session 2
stimuli_2_full = session_2.stimulus_presentations
stimuli_2_filtered = stimuli_2_full[stimuli_2_full["stimulus_name"].isin(stimuli_of_interest)].copy()

Checks to see if it worked:

In [48]:
# stimuli filtering metrics
# total number of trials before stimuli filtering and after
print("\nSession 1 - Total trials before stimuli filtering:", len(stimuli_1_full))
print("Session 1 - Trials after filtering (static & drifting):", len(stimuli_1_filtered))

print("\nSession 2 - Total trials before stimuli filtering:", len(stimuli_2_full))
print("Session 2 - Trials after filtering (static & drifting):", len(stimuli_2_filtered))


Session 1 - Total trials before stimuli filtering: 70390
Session 1 - Trials after filtering (static & drifting): 6630

Session 2 - Total trials before stimuli filtering: 70390
Session 2 - Trials after filtering (static & drifting): 6630


Aligning static and drifting gratings: 

In [49]:
# align stimuli in session 1
stimuli_1_aligned = stimuli_1_filtered.copy()
stimuli_1_aligned["speed"] = stimuli_1_aligned.apply(
    lambda row: 0 if row["stimulus_name"] == "static_gratings" else row["temporal_frequency"],
    axis=1
)

# align stimuli in session 2
stimuli_2_aligned = stimuli_2_filtered.copy()
stimuli_2_aligned["speed"] = stimuli_2_aligned.apply(
    lambda row: 0 if row["stimulus_name"] == "static_gratings" else row["temporal_frequency"],
    axis=1
)

# display aligned results
print("Session 1 - Aligned Stimuli (head):")
display(stimuli_1_aligned.head())

print("Session 2 - Aligned Stimuli (head):")
display(stimuli_2_aligned.head())

Session 1 - Aligned Stimuli (head):


Unnamed: 0_level_0,stimulus_block,start_time,stop_time,y_position,frame,x_position,size,phase,color,temporal_frequency,contrast,orientation,stimulus_name,spatial_frequency,duration,stimulus_condition_id,speed
stimulus_presentation_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
3798,2.0,1586.113585,1588.115245,,,,"[250.0, 250.0]","[42471.86666667, 42471.86666667]",,1.0,0.8,90.0,drifting_gratings,0.04,2.00166,246,1.0
3799,2.0,1589.116095,1591.117775,,,,"[250.0, 250.0]","[42471.86666667, 42471.86666667]",,8.0,0.8,0.0,drifting_gratings,0.04,2.00168,247,8.0
3800,2.0,1592.118605,1594.120275,,,,"[250.0, 250.0]","[42471.86666667, 42471.86666667]",,8.0,0.8,90.0,drifting_gratings,0.04,2.00167,248,8.0
3801,2.0,1595.121125,1597.122785,,,,"[250.0, 250.0]","[42471.86666667, 42471.86666667]",,15.0,0.8,135.0,drifting_gratings,0.04,2.00166,249,15.0
3802,2.0,1598.123625,1600.125295,,,,"[250.0, 250.0]","[42471.86666667, 42471.86666667]",,1.0,0.8,315.0,drifting_gratings,0.04,2.00167,250,1.0


Session 2 - Aligned Stimuli (head):


Unnamed: 0_level_0,stimulus_block,start_time,stop_time,y_position,frame,x_position,size,phase,color,temporal_frequency,contrast,orientation,stimulus_name,spatial_frequency,duration,stimulus_condition_id,speed
stimulus_presentation_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
3798,2.0,1585.647748,1587.649398,,,,"[250.0, 250.0]","[21235.93333333, 21235.93333333]",,4.0,0.8,180.0,drifting_gratings,0.04,2.00165,246,4.0
3799,2.0,1588.650242,1590.651902,,,,"[250.0, 250.0]","[21235.93333333, 21235.93333333]",,4.0,0.8,225.0,drifting_gratings,0.04,2.00166,247,4.0
3800,2.0,1591.652728,1593.654418,,,,"[250.0, 250.0]","[21235.93333333, 21235.93333333]",,2.0,0.8,135.0,drifting_gratings,0.04,2.00169,248,2.0
3801,2.0,1594.655252,1596.656912,,,,"[250.0, 250.0]","[21235.93333333, 21235.93333333]",,2.0,0.8,0.0,drifting_gratings,0.04,2.00166,249,2.0
3802,2.0,1597.657758,1599.659418,,,,"[250.0, 250.0]","[21235.93333333, 21235.93333333]",,15.0,0.8,315.0,drifting_gratings,0.04,2.00166,250,15.0


##### **Build Filtered Session Objects**

In [50]:
# python object that will allow us to access data without overwriting the initial structure with a dataframe
from types import SimpleNamespace

# new object for session 1 that preserves all needed properties but only for the filtered data
filtered_session_1 = SimpleNamespace(
    stimulus_presentations = stimuli_1_filtered.copy(),  # already filtered for static and drifting gratings
    units = units_1_filtered.copy(),                     # filtered for good neurons & visual cortex regions (with thresholds)
    spike_times = {unit_id: session_1.spike_times[unit_id] 
                   for unit_id in units_1_filtered.index}  # keep spike times only for filtered neurons
)

# similar object for session 2
filtered_session_2 = SimpleNamespace(
    stimulus_presentations = stimuli_2_filtered.copy(),
    units = units_2_filtered.copy(),
    spike_times = {unit_id: session_2.spike_times[unit_id] 
                   for unit_id in units_2_filtered.index}
)

Checks:

In [51]:
print("Filtered Session 1 - Stimulus Presentations (head):")
display(filtered_session_1.stimulus_presentations.head())

print("\nFiltered Session 1 - Units (head):")
display(filtered_session_1.units.head())

print("\nFiltered Session 2 - Stimulus Presentations (head):")
display(filtered_session_2.stimulus_presentations.head())

print("\nFiltered Session 2 - Units (head):")
display(filtered_session_2.units.head())

Filtered Session 1 - Stimulus Presentations (head):


Unnamed: 0_level_0,stimulus_block,start_time,stop_time,y_position,frame,x_position,size,phase,color,temporal_frequency,contrast,orientation,stimulus_name,spatial_frequency,duration,stimulus_condition_id
stimulus_presentation_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
3798,2.0,1586.113585,1588.115245,,,,"[250.0, 250.0]","[42471.86666667, 42471.86666667]",,1.0,0.8,90.0,drifting_gratings,0.04,2.00166,246
3799,2.0,1589.116095,1591.117775,,,,"[250.0, 250.0]","[42471.86666667, 42471.86666667]",,8.0,0.8,0.0,drifting_gratings,0.04,2.00168,247
3800,2.0,1592.118605,1594.120275,,,,"[250.0, 250.0]","[42471.86666667, 42471.86666667]",,8.0,0.8,90.0,drifting_gratings,0.04,2.00167,248
3801,2.0,1595.121125,1597.122785,,,,"[250.0, 250.0]","[42471.86666667, 42471.86666667]",,15.0,0.8,135.0,drifting_gratings,0.04,2.00166,249
3802,2.0,1598.123625,1600.125295,,,,"[250.0, 250.0]","[42471.86666667, 42471.86666667]",,1.0,0.8,315.0,drifting_gratings,0.04,2.00167,250



Filtered Session 1 - Units (head):


Unnamed: 0_level_0,waveform_PT_ratio,waveform_amplitude,amplitude_cutoff,cluster_id,cumulative_drift,d_prime,firing_rate,isi_violations,isolation_distance,L_ratio,...,ecephys_structure_id,ecephys_structure_acronym,anterior_posterior_ccf_coordinate,dorsal_ventral_ccf_coordinate,left_right_ccf_coordinate,probe_description,location,probe_sampling_rate,probe_lfp_sampling_rate,probe_has_lfp_data
unit_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
951814834,0.335786,175.59282,0.000277,242,87.04,4.768015,3.403535,0.00336,65.048888,0.008656,...,394.0,VISam,7608.0,1233.0,7607.0,probeA,See electrode locations,29999.95775,1249.99824,True
951814827,0.400624,75.98175,0.055686,241,158.11,2.849739,6.547632,0.101696,76.272023,0.012051,...,394.0,VISam,7608.0,1233.0,7607.0,probeA,See electrode locations,29999.95775,1249.99824,True
951814874,0.765148,57.15762,0.069544,247,661.3,5.365121,3.97974,0.009831,113.397892,0.000187,...,394.0,VISam,7599.0,1202.0,7611.0,probeA,See electrode locations,29999.95775,1249.99824,True
951814839,0.241974,109.24524,0.000933,243,129.39,4.592035,6.95704,0.009651,88.91945,0.004392,...,394.0,VISam,7599.0,1202.0,7611.0,probeA,See electrode locations,29999.95775,1249.99824,True
951814898,0.367812,65.25753,0.050332,250,126.56,3.972706,12.579504,0.00615,125.368317,0.000282,...,394.0,VISam,7584.0,1145.0,7613.0,probeA,See electrode locations,29999.95775,1249.99824,True



Filtered Session 2 - Stimulus Presentations (head):


Unnamed: 0_level_0,stimulus_block,start_time,stop_time,y_position,frame,x_position,size,phase,color,temporal_frequency,contrast,orientation,stimulus_name,spatial_frequency,duration,stimulus_condition_id
stimulus_presentation_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
3798,2.0,1585.647748,1587.649398,,,,"[250.0, 250.0]","[21235.93333333, 21235.93333333]",,4.0,0.8,180.0,drifting_gratings,0.04,2.00165,246
3799,2.0,1588.650242,1590.651902,,,,"[250.0, 250.0]","[21235.93333333, 21235.93333333]",,4.0,0.8,225.0,drifting_gratings,0.04,2.00166,247
3800,2.0,1591.652728,1593.654418,,,,"[250.0, 250.0]","[21235.93333333, 21235.93333333]",,2.0,0.8,135.0,drifting_gratings,0.04,2.00169,248
3801,2.0,1594.655252,1596.656912,,,,"[250.0, 250.0]","[21235.93333333, 21235.93333333]",,2.0,0.8,0.0,drifting_gratings,0.04,2.00166,249
3802,2.0,1597.657758,1599.659418,,,,"[250.0, 250.0]","[21235.93333333, 21235.93333333]",,15.0,0.8,315.0,drifting_gratings,0.04,2.00166,250



Filtered Session 2 - Units (head):


Unnamed: 0_level_0,waveform_PT_ratio,waveform_amplitude,amplitude_cutoff,cluster_id,cumulative_drift,d_prime,firing_rate,isi_violations,isolation_distance,L_ratio,...,ecephys_structure_id,ecephys_structure_acronym,anterior_posterior_ccf_coordinate,dorsal_ventral_ccf_coordinate,left_right_ccf_coordinate,probe_description,location,probe_sampling_rate,probe_lfp_sampling_rate,probe_has_lfp_data
unit_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
951819541,0.533997,127.450245,0.021545,295,171.52,3.60235,4.160879,0.248779,44.124039,0.043728,...,394.0,VISam,,,,probeA,See electrode locations,29999.968724,1249.998697,True
951819532,0.755187,114.80742,0.004377,294,88.95,4.820975,6.775049,0.002513,81.055453,0.002081,...,394.0,VISam,,,,probeA,See electrode locations,29999.968724,1249.998697,True
951819523,0.323419,170.707485,0.000929,293,69.65,6.393375,1.705023,0.0,95.385822,8.9e-05,...,394.0,VISam,,,,probeA,See electrode locations,29999.968724,1249.998697,True
951819550,0.403938,84.302595,0.042555,296,213.93,3.192182,3.998565,0.158745,50.227661,0.023063,...,394.0,VISam,,,,probeA,See electrode locations,29999.968724,1249.998697,True
951819744,0.426593,106.70049,0.034348,318,293.19,3.834414,5.788746,0.218047,66.899268,0.006325,...,394.0,VISam,,,,probeA,See electrode locations,29999.968724,1249.998697,True


##### **Joining the Datasets**

Merge stimulus presentations:

In [None]:
# 'session_id' column to each filtered stimuli to keep track of their origin
stimuli_1_merged = filtered_session_1.stimulus_presentations.assign(session_id=session_id_1)
stimuli_2_merged = filtered_session_2.stimulus_presentations.assign(session_id=session_id_2)

# concatenate dfs
merged_stimuli = pd.concat([stimuli_1_merged, stimuli_2_merged], ignore_index=True)
print("\nMerged Stimuli DataFrame shape:", merged_stimuli.shape)
display(merged_stimuli.head())

# (B) Merge Units Data:
# Similarly, add a 'session_id' column to each filtered units DataFrame.
units_1_merged = filtered_session_1.units.assign(session_id=session_id_1)
units_2_merged = filtered_session_2.units.assign(session_id=session_id_2)

# Concatenate them into one DataFrame
merged_units = pd.concat([units_1_merged, units_2_merged], ignore_index=True)
print("\nMerged Units DataFrame shape:", merged_units.shape)
display(merged_units.head())

# (C) Merge Spike Times:
# Since spike_times are stored as dictionaries (keyed by unit ID), we want to combine them.
# To ensure uniqueness, we prefix each unit's key with its session id.
merged_spike_times = {
    f"{session_id_1}_{unit_id}": spikes 
        for unit_id, spikes in filtered_session_1.spike_times.items()
}
merged_spike_times.update({
    f"{session_id_2}_{unit_id}": spikes 
        for unit_id, spikes in filtered_session_2.spike_times.items()
})
print("\nTotal units in merged spike times dictionary:", len(merged_spike_times))

# ============================================
# Final Check: Summary of the Merged Dataset
# ============================================
print("\nSummary of Merged Dataset:")
print("Merged Stimuli: {} trials".format(merged_stimuli.shape[0]))
print("Merged Units: {} neurons".format(merged_units.shape[0]))
print("Merged Spike Times: {} units".format(len(merged_spike_times)))

Merge units data:

In [None]:
# add a 'session_id' column to each filtered units df
units_1_merged = filtered_session_1.units.assign(session_id=session_id_1)
units_2_merged = filtered_session_2.units.assign(session_id=session_id_2)

# concatenate them into one df
merged_units = pd.concat([units_1_merged, units_2_merged], ignore_index=True)
print("\nMerged Units DataFrame shape:", merged_units.shape)
display(merged_units.head())

Merge spike times:

In [None]:
# spike_times are stored as dictionaries (keyed by unit ID) => we want to combine
# To ensure uniqueness, we prefix each unit's key with its session id.
merged_spike_times = {
    f"{session_id_1}_{unit_id}": spikes 
        for unit_id, spikes in filtered_session_1.spike_times.items()
}
merged_spike_times.update({
    f"{session_id_2}_{unit_id}": spikes 
        for unit_id, spikes in filtered_session_2.spike_times.items()
})
print("\nTotal units in merged spike times dictionary:", len(merged_spike_times))


### 3. EDA (sanity checks)

### 4. Decoding Supervised Learning

#### 4.1. Train-Test Split

- we split the data so the model learns from a training set and we then evaluate it on a seperate test set to see how well it generalises 

In [None]:
from sklearn.model_selection import train_test_split

# Split into 80% training and 20% testing
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

Then we must: 

Train the model on X_train, y_train

Evaluate the model on X_test, y_test

#### 4.2. Baseline Models

Before using XGBoost, we start simple (Logistic Regression) to have something to compare to

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

# Train a logistic regression
model = LogisticRegression(max_iter=1000)
model.fit(X_train, y_train)

# Test the model
y_pred = model.predict(X_test)

# Evaluate
print("Baseline Accuracy:", accuracy_score(y_test, y_pred))

# Confusion Matrix
conf_mat = confusion_matrix(y_test, y_pred)
print("Confusion Matrix:\n", conf_mat)

# Plot Confusion Matrix nicely
plt.figure(figsize=(6,5))
sns.heatmap(conf_mat, annot=True, fmt="d", cmap="Blues")
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix (Logistic Regression)')
plt.show()

# Classification Report
report = classification_report(y_test, y_pred)
print("Classification Report:\n", report)

#### 4.3. Improvements (nested cv, gridsearch etc)

- Cross-validation: Instead of training/testing once, you split into many folds (like 5 times) to get a more reliable accuracy.
- GridSearch: Try different model settings (hyperparameters) to find the best ones automatically.
- XGBoost with binary predictions

In [None]:
from sklearn.model_selection import cross_val_score

scores = cross_val_score(model, X, y, cv=5)  # 5-fold cross-validation
print("Cross-validated accuracy:", scores.mean())

In [None]:
from sklearn.model_selection import GridSearchCV
from xgboost import XGBClassifier

# Define model
xgb_model = XGBClassifier()

# Define parameters to try
param_grid = {
    'max_depth': [3, 5, 7],
    'learning_rate': [0.01, 0.1, 0.2],
    'n_estimators': [50, 100, 200]
}

# Grid Search
grid = GridSearchCV(xgb_model, param_grid, cv=5)
grid.fit(X_train, y_train)

print("Best parameters:", grid.best_params_)
print("Best cross-validated accuracy:", grid.best_score_)
