# Testing real world scenario of user registration and verification

In [2]:
from classifier import Deep_Classifier
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
import mne

In [3]:
path_to_registry = "./registry"
if not os.path.exists(path_to_registry):
    os.makedirs(path_to_registry,exist_ok=True)

In [4]:
def load_chkpoint(model, chk_path):
    """loads the model from give path
    Args:
    chk_path (string) : path to the pth file containing state dict  of the model 
    model : the instance of the model
    """

    chk_point = torch.load(chk_path)
    model.load_state_dict(chk_point)

    return model

In [16]:
def sample_windows(eeg_signal,extract_multiples,window_size,slide_interval=16):
    """
    takes input the eeg_signal and extract windows from it  
    slide_interval : decides how much the window has to slide
    """

    if extract_multiples == True:
        idx = 0 
        end = eeg_signal.shape[1] - window_size
        
        samples = []
        while idx <= end:
            samples.append(eeg_signal[:,idx:idx+window_size])
            idx+= slide_interval

        return np.array(samples)
    
    return eeg_signal[:,:window_size]

In [6]:
def find_centroids(input_arr,k=2):
    """
    fits the k keans clustering, and return the centroids
    
    Args:
        input_arr (ndnumpy) : numpy array of shape num_samples * dimension_of_each_sample
        k (int) : number of clusters to be made     
    Returns:
        centroids (ndnumpy) : centroids of the clusters of shape k * dimension_of_each_sample
    """

    
    kmeans = KMeans(n_clusters=k, random_state=42)
    kmeans.fit(input_arr)

    # Cluster labels and centroids
    # labels = kmeans.labels_
    centroids = kmeans.cluster_centers_

    return centroids

In [7]:
def save_user(centroids,user_name):
    """ 
    saves the information of the user
    """

    info = {"signatures":centroids}
    path = os.path.join(path_to_registry,f"{user_name}.pth")
    torch.save(info,path)
    print(f"---- saved signature successfully at {path} ----")



    

In [24]:
def register_user(model, eeg_signal,user_name,extract_multiples=False):
    """ registers the users in the authentication system

    Args:
        
        model: the instance of the model
        eeg_signal (ndnumpy) : unepoched eeg_signal of the user performing a predefined task of shape nchan * sample_points
        user_name: the name of the user
        extract_multiples (boolean) : determines weather to extract multiple signatures from one user

    """

    window_size = 80 # no of points to be used in one window, fixed ! 

    input_arr = sample_windows(eeg_signal,extract_multiples,window_size)
    input_tensor = torch.Tensor(input_arr)
    if len(input_tensor.shape) < 3:
        # adding batch dimension
        input_tensor = torch.unsqueeze(input_tensor,dim=0) 

    model.eval()
    with torch.no_grad():
        signatures  = model(input_tensor) # extacting the signature from the model

        # converting them to the unit vectors
        norms = torch.norm(signatures,dim=1)
        signatures = signatures/norms

    # converting to numpy arrs
    centroids = signatures.detach().numpy()
            

    if extract_multiples == True:
        # applying kmeans clustering 
        centroids = find_centroids(centroids)

    save_user(centroids,user_name)
    print(f"---- user {user_name} registered successfully ----")

In [9]:
def load_model(chk_points=None):
    model =  Deep_Classifier(10)
    # model = load_chkpoint(model)
    return model


In [10]:
def load_user(path_edf):
    """  
    reads the edf file and returns the data in numpy arr
    """

    edf_data = mne.io.read_raw_edf(path_edf, verbose=False,preload=True)
    return edf_data.get_data()

    

In [21]:
iota_data = np.random.rand(64,179)

In [13]:
model = load_model()

In [14]:
eeg_signal = load_user(path_edf="../../../data/eeg-motor-movementimagery-dataset-1.0.0/files/S109/S109R14.edf")

In [25]:
register_user(model,eeg_signal,user_name="test_user_v2_109",extract_multiples=True)

: 