Import Libraries

In [76]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_curve
from datetime import datetime
import numpy as np
import pandas as pd
import wandb
import sklearn
import json
import os
import random
import gc 


Data Preprocessing

In [1]:
## Global Variables
'''
M : Length of each sequence
'''

M = 100

In [2]:
## Root Dataset Directory

ROOT = "Dataset/Raw_Temp"

In [79]:

with open(os.path.join(ROOT,'fixed_data.txt'),'r') as f :
  fixed_data = json.load(f)

with open(os.path.join(ROOT,'free_data.txt'),'r') as f :
  free_data = json.load(f)

with open(os.path.join(ROOT,'Raw_Temp_Gay_Marriage_Fixed.json')) as f :
  gay_marriage_fixed = json.load(f)

with open(os.path.join(ROOT,'Raw_Temp_Gay_Marriage_Free.json')) as f :
  gay_marriage_free = json.load(f)

with open(os.path.join(ROOT,'Raw_Temp_Gun_Control_Fixed.json')) as f :
  gun_control_fixed = json.load(f)

with open(os.path.join(ROOT,'Raw_Temp_Gun_Control_Free.json')) as f :
  gun_control_free = json.load(f)

with open(os.path.join(ROOT,'Raw_Temp_Restaurant_Review_Fixed.json')) as f :
  rest_fixed = json.load(f)

with open(os.path.join(ROOT,'Raw_Temp_Restaurant_Review_Free.json')) as f :
  rest_free = json.load(f)


In [80]:
with open(os.path.join(ROOT,'Buffalo_Fixed.json')) as f :
  buffalo_fixed = json.load(f)
  
with open(os.path.join(ROOT,'Buffalo_Free.json')) as f :
  buffalo_free = json.load(f)

Dividing into batches

In [81]:
## We divide each user's keystroke into batches of size 15.

def divide_into_batches(x) :
  num = len(x) //M
  list_text = []

  for i in range(num):
    temp = x[int(i*len(x)/num):min(int((i+1)*len(x)/num),len(x))]
    
    if len(temp) < 0.8*M :
      continue
      print("Too Small",len(temp))
    
    list_text.append(temp)

  return list_text

In [82]:
def process_data_buffalo(data,verbose=0) :
  for key in data.keys() :
    if verbose == 1 :
      data[key] = json.loads(data[key])

    timestamp_kd = []
    timestamp_ku = []
    list_ = data[key]["keyboard_data"]
    
    for i in range(len(data[key]["keyboard_data"])) :
      if list_[i][1].lower() == "kd" :
        timestamp_kd.append(int(list_[i][2]) - int(list_[i-1][2]) if i > 0 else int(list_[i][2]))
      else :
        timestamp_ku.append(int(list_[i][2]) - int(list_[i-1][2]) if i > 0 else int(list_[i][2]))
      
    timestamp_kd = np.array(timestamp_kd, dtype=np.float32)
    timestamp_ku = np.array(timestamp_ku, dtype=np.float32)
    
    if len(timestamp_kd) > 0 :
      ## Min Max Scaling
      timestamp_kd = (timestamp_kd - min(timestamp_kd))/(max(timestamp_kd) - min(timestamp_kd))
    
    if len(timestamp_ku) > 0 :
      timestamp_ku = (timestamp_ku - min(timestamp_ku))/(max(timestamp_ku) - min(timestamp_ku))
    
    ku_count = 0
    kd_count = 0
    
    for i in range(len(data[key]["keyboard_data"])) :
      ## Swap 0 element with 1 element
      temp = list_[i][0]
      list_[i][0] = list_[i][1]
      list_[i][1] = temp
      
      if list_[i][0].lower() == "ku" :
        list_[i][2] = float(timestamp_ku[ku_count])
        ku_count += 1
      else :
        list_[i][2] = float(timestamp_kd[kd_count])
        kd_count += 1

    data[key]["keyboard_data"] = list_[1:]
    data[key] = divide_into_batches(data[key]["keyboard_data"])

  return data

buffalo_free = process_data_buffalo(buffalo_free)
buffalo_fixed = process_data_buffalo(buffalo_fixed)

In [83]:
def process_data(data,verbose=0) :
  for key in data.keys() :
    if verbose == 1 :
      data[key] = json.loads(data[key])

    timestamp_kd = []
    timestamp_ku = []
    list_ = data[key]["keyboard_data"]
    
    for i in range(len(data[key]["keyboard_data"])) :
      if list_[i][0] == "KD" :
        timestamp_kd.append(list_[i][2] - list_[i-1][2] if i > 0 else list_[i][2])
      else :
        timestamp_ku.append(list_[i][2] - list_[i-1][2] if i > 0 else list_[i][2])
      
    timestamp_kd = np.array(timestamp_kd, dtype=np.float32)
    timestamp_ku = np.array(timestamp_ku, dtype=np.float32)
    
    if len(timestamp_kd) > 0 :
      ## Min Max Scaling
      timestamp_kd = (timestamp_kd - min(timestamp_kd))/(max(timestamp_kd) - min(timestamp_kd))
    
    if len(timestamp_ku) > 0 :
      timestamp_ku = (timestamp_ku - min(timestamp_ku))/(max(timestamp_ku) - min(timestamp_ku))
    
    ku_count = 0
    kd_count = 0
    
    for i in range(len(data[key]["keyboard_data"])) :
      if list_[i][0] == "KU" :
        list_[i][2] = float(timestamp_ku[ku_count])
        ku_count += 1
      else :
        list_[i][2] = float(timestamp_kd[kd_count])
        kd_count += 1

    data[key]["keyboard_data"] = list_[1:]
    data[key] = divide_into_batches(data[key]["keyboard_data"])

  return data

free_data = process_data(buffalo_free)
fixed_data = process_data(buffalo_fixed)

In [None]:
# Separate the data into training and testing data based on keyboard id
key_0_free = {}
key_1_free = {}
key_2_free = {}
key_3_free = {}

key_0_fixed = {}
key_1_fixed = {}
key_2_fixed = {}
key_3_fixed = {}

for key in free_data.keys() :
    new_key = key[:4]
    
    if key[4] == '0' :
        key_0_free[new_key] = free_data[key]
    
    elif key[4] == '1' : 
        key_1_free[new_key] = free_data[key]

    elif key[4] == '2' :
        key_2_free[new_key] = free_data[key]
        
    elif key[4] == '3' :
        key_3_free[new_key] = free_data[key]
        
for key in fixed_data.keys() :
    new_key = key[:4]
    
    if key[4] == '0' :
        key_0_fixed[new_key] = fixed_data[key]
    
    elif key[4] == '1' : 
        key_1_fixed[new_key] = fixed_data[key]

    elif key[4] == '2' :
        key_2_fixed[new_key] = fixed_data[key]
        
    elif key[4] == '3' :
        key_3_fixed[new_key] = fixed_data[key]
        

Train-Test Dataset Formation & Splitting

In [84]:
def update_dict(adder, addee):
  """
  Updates a dictionary by adding the values from another dictionary.

  Parameters:
  - adder (dict): The dictionary to be updated.
  - addee (dict): The dictionary whose values will be added to the `adder` dictionary.

  Returns:
  - dict: The updated dictionary.
  """
  
  for key in addee:
    if key in adder:
      adder[key].extend(addee[key])
    else:
      adder[key] = addee[key]

  return adder

In [85]:
'''
    We can create any combination of datasets for training and 
    testing in this pipeline to create the training and testing sets.
    
    key_0_free : Keyboard - 0 Free Data
    key_1_free : Keyboard - 1 Free Data
    key_2_free : Keyboard - 2 Free Data
    key_3_free : Keyboard - 3 Free Data
    
    key_0_fixed : Keyboard - 0 Fixed Data
    key_1_fixed : Keyboard - 1 Fixed Data
    key_2_fixed : Keyboard - 2 Fixed Data
    key_3_fixed : Keyboard - 3 Fixed Data
'''

fixed_data_test = {}
free_data_test = {}

fixed_data_test.update(key_0_fixed)
free_data_test.update(key_0_free)

fixed_data_train = {}
free_data_train = {}

fixed_data_train.update(key_3_fixed)
fixed_data_train.update(key_2_fixed)
fixed_data_train.update(key_1_fixed)
free_data_train.update(key_3_free)
free_data_train.update(key_2_free)
free_data_train.update(key_1_free)

## For specific case combine the datasets into the training set only. Leave the test set empty.

In [86]:
## Separate data based on keyboard type

##Key board map to map each key to a number
keyboard_map = {
    'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10,
    'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20,
    'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26,
    '0': 27, 'd0' : 27, 'd1' : 28, '1': 28, '2': 29, 'd2': 29, '3': 30, 'd3': 30, '4': 31, 'd4' : 31,'5': 32, 'd5':32, '6': 33, 'd6':33, '7': 34, 'd7': 34, '8': 35, 'd8': 35, '9': 36, 'd9':36,
    'f1': 37, 'f2': 38, 'f3': 39, 'f4': 40, 'f5': 41, 'f6': 42, 'f7': 43, 'f8': 44, 'f9': 45, 'f10': 46,
    'f11': 47, 'f12': 48, 'esc': 49, '`': 50, '-': 51, 'subtract':51, '=': 52, 'backspace': 53,'back':53, 'tab': 54, '[': 55, ']': 56,
    '\\': 57, 'capslock': 58,'capital':58, ';': 59, '\'': 60, 'enter': 61, 'return':61, 'shift': 62, ',': 63, '.': 64, 'decimal':64, 'oemperiod': 64 , '/': 65 ,'divide':65, 'control': 66,'ctrl':66,
    'alt': 67, ' ': 68, 'printscreen': 69, 'scrolllock': 70,'scroll':70, 'pause': 71, 'insert': 72, 'home': 73, 'pageup': 74,
    'delete': 75, 'end': 76, 'pagedown': 77, 'arrowup': 78, 'arrowleft': 79, 'arrowdown': 80, 'arrowright': 81,
    'numlock': 82, 'numpad0': 83, 'numpad1': 84, 'numpad2': 85, 'numpad3': 86, 'numpad4': 87, 'numpad5': 88,
    'numpad6': 89, 'numpad7': 90, 'numpad8': 91, 'numpad9': 92, 'numpadmultiply': 93, 'numpadadd': 94,
    'numpadsubtract': 95, 'numpaddecimal': 96, 'numpaddivide': 97, 'numpadenter': 98, 'contextmenu': 99,
    'leftctrl': 100, 'leftshift': 101, 'leftshiftkey':101, 'leftalt': 102, 'lmenu':102, 'leftmeta': 103, 'rightctrl': 104,'rcontrolkey':104, 'rightshift': 105, 'rshiftkey': 105,
    'rightalt': 106, 'rmenu':106, 'rightmeta': 107, ':': 108, 'colon': 108, 'unidentified':0, ')': 109, '(':110, 'meta':111, '≠':112, '@':113, '>':114,'<':115, '*':116,'+':117,'add' : 117,'#':118,'$':119, '"':120, 'process':121,'_':122,
    '{':123,'}':124,'?':134,'f1':135,'f2':136,'f3':137,'f4':138,'f5':139,'f6':140,'f7':141,'f14':142,'´':143,'':0,'©':144,'escape':49,'clear':145,'lcontrolkey':100,'lshiftkey':101, 'space':68, 'left': 146, 'right': 147, 'up':148, 'down': 149, 'apps':150,
    'rwin' : 151, 'next': 152, 'lwin' : 153, 'browserback':154, 'browserforward':155, 'browserrefresh':156, 'browserstop':157, 'browsersearch':158, 'browserfavorites':159, 'browserhome':160, 'volumemute':161, 'volumedown':162, 'volumeup':163, 'medianexttrack':164, 'mediaprevioustrack':165
}

In [87]:
## Replace the character with the ascii value, KD : 0, KU : 1, timestamp : relative
def convert_list(list_) :
  """
  Converts a list of key events into a transformed list.

  Args:
    list_ (list): A list of key events, where each event is represented as a tuple of three elements: 
                  the key action ('KD' for Key Down or 'KU' for Key Up), the key value, and the timestamp.

  Returns:
    list: A transformed list where each event is represented as a list with the following elements:
          - 0 for Key Down or 1 for Key Up
          - The normalized value of the key (if applicable)
          - The timestamp difference between the current event and the previous event

  """
  start_time = 0
  trans_list = []
  timestamp_list_kd = []
  timestamp_list_ku = []
  shift_count = 0

  for i in range(len(list_)) :
    temp = []

    ## Assigning value to Key Up and Key Down
    if list_[i][0] == 'KD' :
      temp.append(0)
    else:
      temp.append(1)

    ## Convert to ascii value
    if (len(str(list_[i][1]).lower()) > 1 and isinstance(list_[i][1],int)) or list_[i][1].lower().find("oem") != -1 or list_[i][1].lower().find("lbutton,") != -1:
      continue
    else :
      temp.append(keyboard_map[str(list_[i][1]).lower()]/255)
      
    if str(list_[i][1]).lower() == 'shift' :
      shift_count += 1

    ## Store the diff in timestamp
    if i>=0 :
      temp.append(float(list_[i][2]))
  
    if shift_count < len(list_)*0.2 :    
      trans_list.append(temp)

  return trans_list

def convert_data(data):
  """
  Convert the given data dictionary into a new dictionary with converted lists.

  Args:
    data (dict): The input data dictionary.

  Returns:
    dict: The converted data dictionary.

  """
  p_data = {}

  for key in data.keys():
    list_compiled = []

    for list_ in data[key]:
      temp = convert_list(list_)

      if len(temp) > 0:
        list_compiled.append(temp)

    if len(list_compiled) > 0:
      p_data[key] = list_compiled

  return p_data


## Converting data
free_data_train = convert_data(free_data_train)
fixed_data_train = convert_data(fixed_data_train)
free_data_test = convert_data(free_data_test)
fixed_data_test = convert_data(fixed_data_test)

In [88]:
## Create validation set
free_data_val = {}
fixed_data_val = {}

for key in free_data_train.keys() :
    if len(free_data_train[key]) > 1 :
        val_sect = int(0.1*len(free_data_train[key]))
        free_data_val[key] = free_data_train[key][:val_sect]
        free_data_train[key] = free_data_train[key][3*val_sect:]
        
        
for key in fixed_data_train.keys() :
    if len(fixed_data_train[key]) > 1 :
        fixed_data_val[key] = fixed_data_train[key][:int(0.1*len(fixed_data_train[key]))]       
        fixed_data_train[key] = fixed_data_train[key][int(0.3*len(fixed_data_train[key])):]

In [89]:
## Padding and clipping sequences

mask_free_train = {}
mask_fixed_train = {}
mask_fixed_val = {}
mask_free_val = {}
mask_free_test = {}
mask_fixed_test = {}

def pad_clip_seq(x) :
  # print(x.shape)
  curr_mask = [1]*len(x)

  if(len(x) > M) :
    ## If length is greater than the sequence length M : Clip the sequence
    
    x = x[:M]
    curr_mask = curr_mask[:M]

  ## If length is less than the sequence length M : Pad the sequence 
  for i in range(max(0,M-len(x))) :
    x.append([-1,-1,-1])
    curr_mask.append(0)

  return x,curr_mask

def return_pad_seq(data):
  """
  Pad sequences in the given data dictionary and return the padded sequences along with the corresponding masks.

  Args:
    data (dict): A dictionary containing sequences to be padded.

  Returns:
    tuple: A tuple containing the padded sequences and their corresponding masks.

  """
  mask = {}

  for key in data.keys():
    new_list = []
    mask_l = []

    for list_ in data[key]:
      list_, curr_mask = pad_clip_seq(list_)

      new_list.append(list_)
      mask_l.append(curr_mask)

    data[key] = new_list
    mask[key] = mask_l

  return data, mask

free_data_train,mask_free_train = return_pad_seq(free_data_train)
fixed_data_train,mask_fixed_train = return_pad_seq(fixed_data_train)
fixed_data_val,mask_fixed_val = return_pad_seq(fixed_data_val)
free_data_val,mask_free_val = return_pad_seq(free_data_val)
free_data_test,mask_free_test = return_pad_seq(free_data_test)
fixed_data_test,mask_fixed_test = return_pad_seq(fixed_data_test)

Combining Pairs

In [90]:
## Combine Pairs of a given set of fixed and free data
def combine_pairs(fixed_data,free_data,mask_fixed,mask_free) :
  data = {}
  mask = {}
  y_data = {}

  for key in fixed_data.keys() :
    if key not in free_data.keys() :
      continue
    else: 
      data[key] = []
      mask[key] = []
      y_data[key] = []

      ## For each user, we create pairs of fixed and free data, Label : 1
      for fixed_index in range(len(fixed_data[key])) :
        for free_index in range(len(free_data[key])) :
          data[key].append([fixed_data[key][fixed_index],free_data[key][free_index]])
          mask[key].append([mask_fixed[key][fixed_index],mask_free[key][free_index]])
          y_data[key].append(1)

        for fixed_index_2 in range(len(fixed_data[key])) :
          if fixed_data[key][fixed_index_2] == fixed_data[key][fixed_index] or ([fixed_data[key][fixed_index_2],fixed_data[key][fixed_index]] in data[key]) or ([fixed_data[key][fixed_index],fixed_data[key][fixed_index_2]] in data[key]):
            continue
          else :
            data[key].append([fixed_data[key][fixed_index],fixed_data[key][fixed_index_2]])
            mask[key].append([mask_fixed[key][fixed_index],mask_fixed[key][fixed_index_2]])
            y_data[key].append(0)
      
      ## For each user, we create pairs of fixed and free data, Label : 0
      for free_index in range(len(free_data[key])) :
        for free_index_2 in range(len(free_data[key])) :
          if free_data[key][free_index_2] == free_data[key][free_index] or ([free_data[key][free_index_2],free_data[key][free_index]] in data[key]) or ([free_data[key][free_index],free_data[key][free_index_2]] in data[key]):
            continue
          else :
            data[key].append([free_data[key][free_index],free_data[key][free_index_2]])
            mask[key].append([mask_free[key][free_index],mask_free[key][free_index_2]])
            y_data[key].append(0)

  return data,y_data,mask

In [92]:
new_data_train,y_data_train,mask_train = combine_pairs(fixed_data_train,free_data_train,mask_fixed_train,mask_free_train)
new_data_val,y_data_val,mask_val = combine_pairs(fixed_data_val,free_data_val,mask_fixed_val,mask_free_val)
new_data_test,y_data_test,mask_test = combine_pairs(fixed_data_test,free_data_test,mask_fixed_test,mask_free_test)

In [93]:
## Convert the data into a single list

def combine(new_data, mask, y_data):
  """
  Combines the data, mask, and y_data into lists.

  Args:
    new_data (dict): A dictionary containing the new data.
    mask (dict): A dictionary containing the mask.
    y_data (dict): A dictionary containing the y_data.

  Returns:
    tuple: A tuple containing the combined data list, combined mask list, and y_data list.
  """
  combined_data_list = []
  combined_mask_list = []
  y_data_list = []

  for key in new_data.keys():
    if len(combined_data_list) == 0:
      combined_data_list = new_data[key]
      combined_mask_list = mask[key]
      y_data_list = y_data[key]
    else:
      combined_data_list.extend(new_data[key])
      combined_mask_list.extend(mask[key])
      y_data_list.extend(y_data[key])

  return combined_data_list, combined_mask_list, y_data_list

combined_data_list_train,combined_mask_list_train,y_data_list_train = combine(new_data_train,mask_train,y_data_train)
combined_data_list_val,combined_mask_list_val,y_data_list_val = combine(new_data_val,mask_val,y_data_val)
combined_data_list_test,combined_mask_list_test,y_data_list_test = combine(new_data_test,mask_test,y_data_test)

In [94]:
## Number of unique samples in training set and their distribution

len(combined_data_list_train)
print(np.unique(y_data_list_train,return_counts=True))

(array([0, 1]), array([180524, 179535]))


Class Balancing for Training Set

In [95]:
indexes = np.arange(len(combined_data_list_train))
y_data_list_train = np.array(y_data_list_train)
class_0_index = indexes[y_data_list_train == 0]
class_1_index = indexes[y_data_list_train == 1]

# print(len(class_1_index))
min_length = min(len(class_0_index),len(class_1_index))

indexes = np.concatenate((class_0_index[:min_length],class_1_index[:min_length]))
random.shuffle(indexes)

combined_data_list_train = np.array(combined_data_list_train)
combined_mask_list_train = np.array(combined_mask_list_train)
y_data_list_train = np.array(y_data_list_train)

combined_data_list_train = combined_data_list_train[indexes]
combined_mask_list_train = combined_mask_list_train[indexes]
y_data_list_train = y_data_list_train[indexes]

del indexes,class_0_index,class_1_index
gc.collect()

0

In [96]:
print(np.unique(y_data_list_train,return_counts=True))

(array([0, 1]), array([179535, 179535]))


In [97]:
## Creating Data Loader

class CustomDataset(Dataset) :
  def __init__(self,data,mask,label) :
    self.data = data
    self.mask = mask
    self.label = label

  def __getitem__(self,index) :
    x = torch.tensor(self.data[index],dtype=torch.float32)
    m = torch.tensor(self.mask[index],dtype=torch.float32)

    counts = np.unique(m[0,:],return_counts=True)
    length1 = counts[1][-1]
    counts = np.unique(m[1,:],return_counts=True)
    length2 = counts[1][-1]

    temp = torch.zeros(2)
    temp[self.label[index]] = 1

    return x[0,:],x[1,:],self.label[index],length1,length2

  def __len__(self) :
    return len(self.data)



In [98]:
## Config 
BATCH_SIZE = 512

In [99]:
train_data = CustomDataset(combined_data_list_train,combined_mask_list_train,y_data_list_train)
val_data = CustomDataset(combined_data_list_val,combined_mask_list_val,y_data_list_val)
test_data = CustomDataset(combined_data_list_test,combined_mask_list_test,y_data_list_test)

train_loader = DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)
val_loader = DataLoader(dataset=val_data,batch_size=BATCH_SIZE,shuffle=True)
test_loader = DataLoader(dataset=test_data,batch_size=BATCH_SIZE,shuffle=True)


Model & Loss function

In [100]:
def contrastive_loss(fv1, fv2, y, alpha):
    # Move all inputs to the specified device
    fv1 = fv1.to(device)
    fv2 = fv2.to(device)
    y = y.to(device)

    # Element-wise square of the difference
    squared_diff = (fv1 - fv2) ** 2

    # Summing the squared differences along the last dimension and taking the square root for Euclidean distance
    d = torch.sqrt(torch.sum(squared_diff, dim=-1))

    # Element-wise maximum
    max_part = torch.clamp_min(alpha - d, min=0) ** 2

    # Compute the loss
    loss = (((1 - y) * d ** 2) / 2) + ((y * max_part) / 2)
    return torch.mean(loss)

In [101]:
import torch
import torch.nn as nn

class TypeNet(nn.Module):
    def __init__(self, sequence_length, in_dim, hidden_dim_1, hidden_dim_2, output_dim, dropout):
        super(TypeNet, self).__init__()
        self.dropout = dropout
        self.lstm1 = nn.LSTM(input_size=in_dim, hidden_size=hidden_dim_1, batch_first=True, dropout=0.2)
        self.lstm2 = nn.LSTM(input_size=hidden_dim_1, hidden_size=hidden_dim_2, batch_first=True, dropout=0.2)
        self.bn1 = nn.BatchNorm1d(num_features=in_dim)
        self.bn2 = nn.BatchNorm1d(num_features=hidden_dim_1)
        self.bn3 = nn.BatchNorm1d(num_features=hidden_dim_2*sequence_length)
        self.bn4 = nn.BatchNorm1d(num_features=output_dim)
        self.act1 = nn.Tanh()
        self.act2 = nn.Tanh()
        self.act3 = nn.Tanh()
        self.dropout_1 = nn.Dropout(p=dropout)
        self.dropout_2 = nn.Dropout(p=dropout)
        self.fc1 = nn.Linear(sequence_length*hidden_dim_2, output_dim)
        self.softmax = nn.Softmax(dim=1)

        # Weight initialization
        for name, param in self.lstm1.named_parameters():
            if 'weight_ih' in name:
                nn.init.xavier_uniform_(param.data)
            elif 'weight_hh' in name:
                nn.init.kaiming_uniform_(param.data)
        for name, param in self.lstm2.named_parameters():
            if 'weight_ih' in name:
                nn.init.xavier_uniform_(param.data)
            elif 'weight_hh' in name:
                nn.init.kaiming_uniform_(param.data)
        nn.init.xavier_uniform_(self.fc1.weight)

    def forward(self, x, length):
        x = x.to(device)

        x = torch.movedim(x, 2, 1)
        out = self.bn1(x)
        out = torch.movedim(out, 2, 1)

        out = nn.utils.rnn.pack_padded_sequence(out, length, batch_first=True, enforce_sorted=False)
        out, _ = self.lstm1(out)
        out, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)

        out = self.act1(out)
        out = self.dropout_1(out)

        out = torch.movedim(out, 2, 1)
        out = self.bn2(out)
        out = torch.movedim(out, 2, 1)

        out_p = nn.utils.rnn.pack_padded_sequence(out, length, batch_first=True, enforce_sorted=False)
        out_p, _ = self.lstm2(out_p)
        out_p, _ = nn.utils.rnn.pad_packed_sequence(out_p, batch_first=True)

        out_p = self.act2(out_p)
        out_p = torch.reshape(out_p, (out.shape[0], out.shape[1]*out.shape[2]))
        out = self.bn3(out_p)
        out = self.fc1(out)

        return out


In [102]:
class CustomTypeNet(nn.Module) :
  def __init__(self,sequence_length,in_dim,hidden_dim_1,hidden_dim_2,output_dim,dropout) :
    super(CustomTypeNet,self).__init__()
    self.tn1 = TypeNet(sequence_length,in_dim,hidden_dim_1,hidden_dim_2,output_dim,dropout)
    self.tn2 = TypeNet(sequence_length,in_dim,hidden_dim_1,hidden_dim_2,output_dim,dropout)

  def forward(self,x1,x2,length1,length2) :
    x1 = self.tn1(x1,length1)
    x2 = self.tn2(x2,length2)

    return x1,x2


In [103]:


def calculate_eer(d, labels):
    # Element-wise square of the difference
    with torch.no_grad() :
      # squared_diff = (fv1 - fv2) ** 2

      # # Summing the squared differences along the last dimension and taking the square root for Euclidean distance
      # d = torch.sqrt(torch.sum(squared_diff, dim=1)).detach().cpu().numpy()

      # Calculate the False Positive Rates, True Positive Rates, and thresholds
      fpr, tpr, thresholds = roc_curve(labels, d, pos_label=1)      

      # Handle cases where tpr or fpr contains nan values
      tpr = np.nan_to_num(tpr)
      fpr = np.nan_to_num(fpr)

      # Find the EER
      eer_threshold = thresholds[np.argmin(np.absolute((1 - tpr) - fpr))]
      eer = fpr[np.argmin(np.absolute((1 - tpr) - fpr))]

      return eer, eer_threshold, np.average(fpr), np.average(1-tpr)


In [104]:
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")

Training Custom Model

In [105]:
## Model Parameters
EPOCHS = 200
LR = 0.001
DROPOUT = 0.5

In [106]:
now = datetime.now()
current_time = now.strftime("%H_%M_%S")
new_path = current_time

In [107]:
from torch.optim import Adam
from tqdm import tqdm
import gc
import torch.nn as nn


model = CustomTypeNet(M,3,128,128,128,DROPOUT)
model.to(device)
loss = nn.BCELoss()

# for param in model.parameters() :
#   print(param.device)

optimizer = Adam(model.parameters(),lr=LR,weight_decay=0.001)



In [108]:
import wandb

## Replace the Dataset with the combination of the dataset chosen for tracing the model performance

run = wandb.init(
    # Set the project where this run will be logged
    project="Content Specific TypeNet EER V2",
    # Track hyperparameters and run metadata
    config={
        "learning_rate": LR,
        "epochs": EPOCHS,
        "BATCH_SIZE": BATCH_SIZE,
        "M": M,
        "Dropout": DROPOUT,
        "Dataset": "RF",
        "Model Path": new_path
    },
)

VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▂▃▄▅▅▆▇█
f1_test,▃▁█▅▂▁▆▁▁
f1_train,▁▃▅▆▇▇▇██
f1_val,▂▁▄▄██▇▇▇
fnr_test,█▂▄▁▄▄▆▅▆
fnr_train,█▅▄▃▂▂▁▁▁
fnr_val,▇█▄▄▁▁▂▃▃
fpr_test,█▃▂▁▃▄▆▅▄
fpr_train,█▅▄▃▂▂▂▁▁
fpr_val,▇▇█▇▃▂▃▁▁

0,1
epoch,8
f1_test,0.01928
f1_train,0.79065
f1_val,0.82733
fnr_test,0.47433
fnr_train,0.27829
fnr_val,0.20574
fpr_test,0.46251
fpr_train,0.28002
fpr_val,0.33091


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113080288568097, max=1.0…

In [109]:
checkpoint_path = "18_52_42"

In [110]:
is_new_model = True

In [111]:
def load_latest_model() : 
  dir_list = os.listdir(os.path.join(os.getcwd(),checkpoint_path))
  num_max = 0
  
  for item in dir_list : 
    res = [int(i) for i in item if i.isdigit()]
    num_max = max(res[0],num_max)
    
  return str(num_max)
    
  

def load_model() : 
  print("Loading Model Checkpoint ....")
  curr_epoch = load_latest_model()
  model_path = "_model{}.pth".format(curr_epoch)
  print("Loading : {}".format(model_path))
  model = torch.load(os.path.join(checkpoint_path,model_path))
  return model


In [112]:
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_score, recall_score
from torch.nn.functional import cosine_similarity, binary_cross_entropy
from torch.nn import CosineSimilarity

loss_plot = []
val_loss_plot = []
train_eer_loss = []
val_eer_loss = []
iter = []
path = None
curr_epoch = 0
cos = CosineSimilarity(dim=1)

## Loading model from checkpoint
if is_new_model : 
  path = None
else: 
  path = checkpoint_path
  model = load_model()
  curr_epoch = int(load_latest_model()) + 1


for epoch in range(curr_epoch,EPOCHS+curr_epoch) :
  train_loss = 0
  val_loss = 0
  train_count = 0
  val_count = 0
  acc_train_avg = 0
  acc_val_avg = 0
  f1_train = 0
  f1_val = 0
  predicted_train = []
  label_train = []
  predicted_val = []
  label_val = []
  torch.cuda.empty_cache()
  
  eer_train = 0
  eer_val = 0
  eer_test = 0
  
  fpr_train = 0
  fpr_val = 0
  fpr_test = 0
  
  fnr_train = 0
  fnr_val = 0
  fnr_test = 0

  print("EPOCH - {}".format(epoch))
  
  ### Training Loop ###
  
  for item in tqdm(train_loader) :
    ## Train Model 
    try: 
      model.train()
      x1,x2,label,length1,length2 = item
      x1 = x1.to(device)
      x2 = x2.to(device)
      label = label.to(device)
      label = label.float()
      optimizer.zero_grad()
      output = model(x1,x2,length1,length2)
      d = (cos(output[0],output[1]) + 1)/2
      loss_ = loss(d,label)
      eer_result = calculate_eer(d.detach().cpu().numpy(),np.array(label.cpu()))
      eer_train += eer_result[0]/output[0].shape[0]    
      fpr_train += eer_result[2]
      fnr_train += eer_result[3]
      
      train_loss += loss_.item()
      loss_.backward()
      optimizer.step()

      curr_pred = (d.detach().cpu().numpy() >= 0.33)*1
      curr_label = label.detach().cpu().numpy()
      predicted_train.extend(curr_pred)
      label_train.extend(curr_label)
      
      ## Accuracy Calculation
      acc_curr = np.sum(curr_pred == curr_label)/output[0].shape[0]
      acc_train_avg += acc_curr
        
      ## Increase Count
      train_count += 1
      
      ## Delete variables
      del output,x1,x2,length1,length2,label,loss_,curr_pred,curr_label,d,eer_result
      
    except Exception as e:
      continue
    
  ### Validation Loop ###

  with torch.no_grad() :
    try: 
      for item in tqdm(val_loader) :
      ## Testing Model 
        x1,x2,label,length1,length2 = item
        x1 = x1.to(device)
        x2 = x2.to(device)
        label = label.to(device)
        label = label.float()
        optimizer.zero_grad()
        output = model(x1,x2,length1,length2)
        d = (cosine_similarity(output[0],output[1]) + 1)/2
        loss_ = loss(d,label)
        val_loss += torch.mean(loss_).item()
        
        eer_result = calculate_eer(d.detach().cpu().numpy(),np.array(label.cpu()))
        eer_val += eer_result[0]/output[0].shape[0]    
        fpr_val += eer_result[2]
        fnr_val += eer_result[3]    

        curr_pred = (d.detach().cpu().numpy() >= 0.33)*1
        curr_label = label.detach().cpu().numpy()
        predicted_val.extend(curr_pred)
        label_val.extend(curr_label)
        
        ## Accuracy Calculation
        acc_curr = np.sum(curr_pred == curr_label)/output[0].shape[0]
        acc_val_avg += acc_curr
        
        ## Increase Count
        val_count += 1

        ## Deleting 
        del output,x1,x2,length1,length2,label,loss_,curr_pred,curr_label,d,eer_result
      
    except Exception as e:
      continue
      

    
  acc_avg_test = 0
  av_f1 = 0
  av_prec = 0
  av_recall = 0
  test_count = 0
  
  ### Testing Loop ###

  with torch.no_grad() :
    try : 
      for item in tqdm(test_loader) :
        x1,x2,label,length1,length2 = item
        x1 = x1.to(device)
        x2 = x2.to(device)
        label = label.to(device)
        label = label.float()
        optimizer.zero_grad()
        model.eval()
        output = model(x1,x2,length1,length2)
        d = (cosine_similarity(output[0],output[1]) + 1)/2
        loss_ = loss(d,label)
        eer_result = calculate_eer(d.detach().cpu().numpy(),np.array(label.cpu()))
        eer_test += eer_result[0]/output[0].shape[0]    
        fpr_test += eer_result[2]
        fnr_test += eer_result[3]
      
        test_count += 1
        
        threshold = 0.33
        acc_avg_test += np.sum((d.detach().cpu().numpy() >= threshold)*1 == label.detach().cpu().numpy())/output[0].shape[0]
        av_f1 += f1_score((d.detach().cpu().numpy() >= threshold)*1,label.detach().cpu().numpy())
        av_prec += precision_score((d.detach().cpu().numpy() >= threshold)*1,label.detach().cpu().numpy())
        av_recall += recall_score((d.detach().cpu().numpy() >= threshold)*1,label.detach().cpu().numpy())

        del output,x1,x2,length1,length2,label,loss_,d,eer_result
        
    except Exception as e:
      continue
  
  ## Saving Model Checkpoint
  if path: 
    torch.save(model,os.path.join(os.getcwd(),path,'_model{}.pth').format(epoch))
  else: 
    path = new_path
    print(new_path)
    os.mkdir(os.path.join(os.getcwd(),path))
    
  ## Calculate f1 score
  f1_train = f1_score(label_train,predicted_train)
  prec_train = precision_score(label_train,predicted_train)
  prec_val = precision_score(label_val,predicted_val)
  recall_train = recall_score(label_train,predicted_train)
  recall_val = recall_score(label_val,predicted_val)
  f1_val = f1_score(label_val,predicted_val)
  
  print("Precision train/val/test: {}/{}/{}".format(prec_train,prec_val,av_prec/test_count))
  print("Recall train/val/test: {}/{}/{}".format(recall_train,recall_val,av_recall/test_count))
  print("F1 train/val/test: {}/{}/{}".format(f1_train,f1_val,av_f1/test_count))
  print("EER Train/Val/Test: {}/{}/{}".format(eer_train/train_count,eer_val/val_count,eer_test/test_count))
  print("FPR Train/Val/Test: {}/{}/{}".format(fpr_train/train_count,fpr_val/val_count,fpr_test/test_count))
  print("FNR Train/Val/Test: {}/{}/{}".format(fnr_train/train_count,fnr_val/val_count,fnr_test/test_count))
  
  ## Logging into wandb
  wandb.log({"train_accuracy": acc_train_avg/train_count,"val_accuracy": acc_val_avg/val_count,"train_loss": train_loss/train_count,"val_loss": val_loss/val_count,"train_f1": f1_train,"val_f1": f1_val,"test_accuracy":acc_avg_test/test_count,"test_f1":av_f1/test_count,"train_eer":eer_train/train_count,"val_eer":eer_val/val_count,"test_eer":eer_test/test_count, "train_fpr":fpr_train/train_count,"val_fpr":fpr_val/val_count,"test_fpr":fpr_test/test_count,"train_fnr":fnr_train/train_count,"val_fnr":fnr_val/val_count,"test_fnr":fnr_test/test_count,"train_precision":prec_train,"val_precision":prec_val,"test_precision":av_prec/test_count,"train_recall":recall_train,"val_recall":recall_val,"test_recall":av_recall/test_count,"threshold":threshold,"epoch":epoch,"path":path,"f1_train":f1_train,"f1_val":f1_val,"f1_test":av_f1/test_count,"prec_train":prec_train,"prec_val":prec_val,"prec_test":av_prec/test_count,"recall_train":recall_train,"recall_val":recall_val,"recall_test":av_recall/test_count,"fpr_train":fpr_train/train_count,"fpr_val":fpr_val/val_count,"fpr_test":fpr_test/test_count,"fnr_train":fnr_train/train_count,"fnr_val":fnr_val/val_count,"fnr_test":fnr_test/test_count})
  
  ## Appending to list
  val_loss_plot.append(val_loss/val_count)
  loss_plot.append(train_loss/train_count)
  iter.append(epoch)
  train_eer_loss.append(acc_train_avg/train_count)
  val_eer_loss.append(acc_val_avg/val_count)

  print("---------------------------")
  gc.collect()


EPOCH - 0


100%|██████████| 702/702 [01:39<00:00,  7.09it/s]
100%|██████████| 7/7 [00:00<00:00,  8.06it/s]
100%|██████████| 638/638 [01:34<00:00,  6.78it/s]


13_36_12
Precision train/val/test: 0.7170316275340491/0.8348717948717949/0.8484135389280921
Recall train/val/test: 0.7372044448157741/0.80039331366765/0.6738389516332158
F1 train/val/test: 0.7269781199106891/0.8172690763052208/0.7508094129990311
EER Train/Val/Test: 0.0005315213596573939/0.0013179846287291323/0.00048738479677669274
FPR Train/Val/Test: 0.3642737933465132/0.3402169317215728/0.34045630954005035
FNR Train/Val/Test: 0.2908530093661181/0.24216820107864254/0.2523365448053171
---------------------------
EPOCH - 1


100%|██████████| 702/702 [01:35<00:00,  7.34it/s]
100%|██████████| 7/7 [00:00<00:00,  7.43it/s]
100%|██████████| 638/638 [01:31<00:00,  6.94it/s]


Precision train/val/test: 0.7154187088374943/0.8631415241057543/0.806167788336597
Recall train/val/test: 0.8325451861753975/0.8185840707964602/0.6869043725810972
F1 train/val/test: 0.7695507680268133/0.8402725208175624/0.741372391332377
EER Train/Val/Test: 0.00044766509339633307/0.0012609743211180549/0.0005037591667368033
FPR Train/Val/Test: 0.31961223392967336/0.3631108830957022/0.33733601683022263
FNR Train/Val/Test: 0.26243474463458333/0.20611191792354672/0.27724488491340726
---------------------------
EPOCH - 2


100%|██████████| 702/702 [01:41<00:00,  6.93it/s]
100%|██████████| 7/7 [00:00<00:00,  7.03it/s]
100%|██████████| 638/638 [01:32<00:00,  6.90it/s]


Precision train/val/test: 0.7265296011890018/0.8644346871569704/0.8221680021520388
Recall train/val/test: 0.8495056674186092/0.7743362831858407/0.732399982571974
F1 train/val/test: 0.783219799773532/0.816908713692946/0.774378241872549
EER Train/Val/Test: 0.0004130109140802081/0.0016003640466151561/0.00042950285519294224
FPR Train/Val/Test: 0.3069538991533151/0.3571760257180279/0.33457303562475593
FNR Train/Val/Test: 0.2492343767435341/0.23499030907460391/0.2529509360710879
---------------------------
EPOCH - 3


100%|██████████| 702/702 [01:35<00:00,  7.33it/s]
100%|██████████| 7/7 [00:00<00:00,  7.78it/s]
100%|██████████| 638/638 [01:33<00:00,  6.83it/s]


Precision train/val/test: 0.7385926001092057/0.8374384236453202/0.8528170779917419
Recall train/val/test: 0.8513715988525914/0.8357915437561455/0.714629575369391
F1 train/val/test: 0.7909822916343238/0.8366141732283464/0.7772968530038153
EER Train/Val/Test: 0.00039656491231323626/0.0015873365339922746/0.0004268786592432416
FPR Train/Val/Test: 0.3003978490463039/0.3447751850437495/0.33648511435331135
FNR Train/Val/Test: 0.24218392822362816/0.23325659459335638/0.23214244114695531
---------------------------
EPOCH - 4


100%|██████████| 702/702 [01:38<00:00,  7.14it/s]
100%|██████████| 7/7 [00:00<00:00,  7.41it/s]
100%|██████████| 638/638 [01:34<00:00,  6.78it/s]


Precision train/val/test: 0.750679122124724/0.8683236382866208/0.8385070346904746
Recall train/val/test: 0.8557997047929373/0.8072763028515241/0.7361867729211613
F1 train/val/test: 0.799800109835481/0.836687898089172/0.7836805717782845
EER Train/Val/Test: 0.00037988325461033186/0.0012034791937401185/0.0004045110608039949
FPR Train/Val/Test: 0.291121595540677/0.32605615372017216/0.32648469081412174
FNR Train/Val/Test: 0.23642299517801738/0.23437156788308605/0.24038969370862912
---------------------------
EPOCH - 5


100%|██████████| 702/702 [01:34<00:00,  7.40it/s]
100%|██████████| 7/7 [00:00<00:00,  7.08it/s]
100%|██████████| 638/638 [01:34<00:00,  6.79it/s]


Precision train/val/test: 0.756432827403575/0.8643190056965303/0.8059722377314853
Recall train/val/test: 0.8591528114295263/0.8205506391347099/0.7226183553703556
F1 train/val/test: 0.8045273177728517/0.8418663303909205/0.7616872666449656
EER Train/Val/Test: 0.0003700065424028381/0.0021653986779277307/0.0004565640347697842
FPR Train/Val/Test: 0.2860663713469839/0.3394943684721094/0.33376934835434136
FNR Train/Val/Test: 0.23278918985718727/0.2185924872580716/0.26011221252007255
---------------------------
EPOCH - 6


100%|██████████| 702/702 [01:34<00:00,  7.40it/s]
100%|██████████| 7/7 [00:00<00:00,  7.87it/s]
100%|██████████| 638/638 [01:33<00:00,  6.81it/s]


Precision train/val/test: 0.7611263025644815/0.9051383399209486/0.8056147555367698
Recall train/val/test: 0.8604561784610243/0.788102261553589/0.7961306443317208
F1 train/val/test: 0.8077490196078432/0.8425755584756899/0.8005463968390361
EER Train/Val/Test: 0.00036171490942411985/0.001554264072696255/0.0003903677818894055
FPR Train/Val/Test: 0.2831958282673315/0.36270136329133357/0.32563456101614313
FNR Train/Val/Test: 0.22929969832275146/0.19265845208260207/0.2313000852898007
---------------------------
EPOCH - 7


100%|██████████| 702/702 [01:35<00:00,  7.34it/s]
100%|██████████| 7/7 [00:00<00:00,  7.02it/s]
100%|██████████| 638/638 [01:35<00:00,  6.72it/s]


Precision train/val/test: 0.7678628088684409/0.885792349726776/0.75104264235428
Recall train/val/test: 0.8626785863480658/0.7969518190757129/0.8107614948777332
F1 train/val/test: 0.8125139348282836/0.8390269151138716/0.7794052534932732
EER Train/Val/Test: 0.00035567850317759595/0.001554805106543766/0.00042864459975128325
FPR Train/Val/Test: 0.2789831653425427/0.3481107097044109/0.3201334274671568
FNR Train/Val/Test: 0.22724918384543447/0.21731830576945396/0.23429778245841165
---------------------------
EPOCH - 8


100%|██████████| 702/702 [01:38<00:00,  7.15it/s]
100%|██████████| 7/7 [00:01<00:00,  6.96it/s]
100%|██████████| 638/638 [01:36<00:00,  6.62it/s]


Precision train/val/test: 0.7731517762215423/0.8665644171779141/0.8101667979833486
Recall train/val/test: 0.8627509956275935/0.8333333333333334/0.7951928916804543
F1 train/val/test: 0.8154976887194769/0.849624060150376/0.8023108734598179
EER Train/Val/Test: 0.00034999952787164834/0.0021719758981473356/0.0003957968967819832
FPR Train/Val/Test: 0.27624314076220263/0.34957917955236784/0.3198601442574365
FNR Train/Val/Test: 0.22358433514533751/0.21879570537664864/0.21043320258830517
---------------------------
EPOCH - 9


100%|██████████| 702/702 [01:35<00:00,  7.32it/s]
100%|██████████| 7/7 [00:00<00:00,  7.32it/s]
100%|██████████| 638/638 [01:33<00:00,  6.80it/s]
