# Clustering GRB's

### Imports

In [None]:
import numpy as np
import re
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from glob import glob
from sklearn.cluster import DBSCAN
from sklearn.decomposition import PCA
from tqdm import tqdm
import pandas as pd
from sklearn.manifold import TSNE
import seaborn as sns


### Set parameters for Clustering

In [None]:
num_pca_components = 20
dbscan_eps = 2.5
dbscan_min_samples = 2
num_kmeans_clusters = 20

### Load and preprocess data

In [None]:

def extract_number(filename):
    # Extract the number between "bn" and "_v00" using regular expression
    match = re.search(r'(\d+).npy', filename)
    if match:
        return int(match.group(1))
    else:
        return None  # Return None for filenames that don't match the pattern
    
files = glob('../clean_bursts/*')
data_list = []
data_list_names = []
for file in files:
    data_list.append(np.load(file))
    data_list_names.append(extract_number(file))

max_length = max(len(arr) for arr in data_list)
data_list = [(arr-min(arr))/max(arr-min(arr)) for arr in data_list]
data_list = [np.pad(arr, (0, max_length - len(arr)), mode='minimum') for arr in data_list]

light_curves = np.stack(data_list)

# # Standardize the data
# scaler = StandardScaler()
# light_curves_standardized = scaler.fit_transform(light_curves)

### Apply PCA for dimension reduction

In [None]:
# Apply PCA
pca = PCA(n_components=num_pca_components)
X_pca = pca.fit_transform(light_curves)

### Run DBSCAN

In [None]:

# Initialize DBSCAN model
dbscan = DBSCAN(eps=dbscan_eps, min_samples=dbscan_min_samples)

# Fit the model to the standardized light curves
cluster_assignments_dbscan = dbscan.fit_predict(X_pca)

total_curves = 0
count_curves = {}
for i in set(cluster_assignments_dbscan):
    num = sum(cluster_assignments_dbscan==i)
    print(f'{num} in cluster {i}')
    count_curves[num] = count_curves.get(num, 0) + 1
    total_curves += num

print(f'total number of curves is {total_curves}')
for num in count_curves.keys():
    print(f'there are {count_curves[num]} clusters with {num} elements')



### Plot DBSCAN

In [None]:

indices_dict = {}
num_show = 10

for num in np.unique(cluster_assignments_dbscan):
    indices = np.where(cluster_assignments_dbscan == num)[0]
    stop = max(num_show, len(indices))
    indices_dict[num] = indices[:num_show]

indices_dict.pop(-1)


for clusters in indices_dict.keys():
    plt.figure(figsize=(9, 3))
    plt.suptitle(f"cluster {clusters}", fontsize=18, y=0.95)

    num_cols = int(np.ceil(len(indices_dict[clusters]) / 2 ))
    for i, indices in enumerate(indices_dict[clusters], 1):
        ax = plt.subplot(2,num_cols, i)
        ax.get_yaxis().set_visible(False)
        ax.plot(np. trim_zeros(light_curves[indices,:]), 'b')
    plt.show()


### Save DBSCAN results

In [None]:
clust_dict = {}
for name, cluster in zip(data_list_names, cluster_assignments_dbscan):
    if cluster in clust_dict:
        clust_dict[cluster].append(name)
    else: 
        clust_dict.update({cluster: [name]})



df = pd.DataFrame({k: pd.Series(v) for k, v in clust_dict.items()})
df.to_excel('dbscan_clustering.xlsx')

### Run Kmeans

In [None]:
num_clusters = num_kmeans_clusters
kmeans = KMeans(n_clusters=num_clusters, random_state=42)

# Fit the model to the standardized light curves
cluster_assignments_kmeans = kmeans.fit_predict(X_pca)
  

total_curves = 0
count_curves = {}
for i in set(cluster_assignments_kmeans):
    num = sum(cluster_assignments_kmeans==i)
    print(f'{num} in cluster {i}')
    count_curves[num] = count_curves.get(num, 0) + 1
    total_curves += num

print(f'total number of curves is {total_curves}')
for num in count_curves.keys():
    print(f'there are {count_curves[num]} clusters with {num} elements')




### Plot Kmeans results

In [None]:

indices_dict = {}
num_show = 10

for num in np.unique(cluster_assignments_kmeans):
    indices = np.where(cluster_assignments_kmeans == num)[0]
    stop = max(num_show, len(indices))
    indices_dict[num] = indices[:num_show]


for clusters in indices_dict.keys():
    plt.figure(figsize=(9, 3))
    plt.suptitle(f"cluster {clusters}", fontsize=18, y=0.95)

    num_cols = int(np.ceil(len(indices_dict[clusters]) / 2 ))
    for i, indices in enumerate(indices_dict[clusters], 1):
        ax = plt.subplot(2,num_cols, i)
        ax.get_yaxis().set_visible(False)
        ax.plot(np. trim_zeros(light_curves[indices,:]), 'b')
    plt.show()


# indices_dict = {}
# num_show = 3

# for num in np.unique(cluster_assignments_kmeans):
#     indices = np.where(cluster_assignments_kmeans == num)[0]
#     stop = max(num_show, len(indices))
#     indices_dict[num] = indices[:num_show]

# for clusters in indices_dict.keys():
#     for indices in indices_dict[clusters]:
#         plt.plot(light_curves[indices,:])
#         plt.title(f'cluster {clusters}')
#         plt.show()

### Save Kmeans results

In [None]:
clust_dict = {}
for name, cluster in zip(data_list_names, cluster_assignments_kmeans):
    if cluster in clust_dict:
        clust_dict[cluster].append(name)
    else: 
        clust_dict.update({cluster: [name]})



df = pd.DataFrame({k: pd.Series(v) for k, v in clust_dict.items()})
df.to_excel('kmeans_clustering.xlsx')

### Find ideal k 

In [None]:
inertias = []
for k in tqdm(range(1,150)):
    kmeans = KMeans(n_clusters=k, random_state=42)
    cluster_assignments_kmeans = kmeans.fit_predict(X_pca)
    inertias.append(kmeans.inertia_)
    
plt.plot(np.array(inertias))



### t-SNE for visualization

In [None]:
tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
tsne_results = tsne.fit_transform(light_curves)

color_neg_one= {-1: 'black'}
color_palete = sns.husl_palette(n_colors=40, s=0.7, l=0.6)
color_palete = [color_neg_one.get(group, color) for group, color in zip(range(-1, 39), color_palete)]
plt.figure(figsize=(16,10))
sns.scatterplot(
    x=tsne_results[:,0], y=tsne_results[:,1],
    palette=color_palete,
    hue = cluster_assignments_dbscan,
    data=tsne_results,
    legend="full",
    alpha=1
)

plt.figure(figsize=(16,10))

sns.scatterplot(
    x=tsne_results[:,0], y=tsne_results[:,1],
    palette=sns.husl_palette(n_colors=40, s=0.7, l=0.6),
    hue = cluster_assignments_kmeans,
    data=tsne_results,
    legend="full",
    alpha=1
)