# Construct-Text Similarity: zero-shot classification with embeddings and using your own lexicons

Author: Daniel Low (Harvard University)

Please ask for updated reference if you use this code (preprint will be out soon).

In [None]:
# config
location = 'local'
preprocessing = True # If False, load preprocessed data
balance_training_set = True
classes = ['suicidal_desire', 'active_rescue']
random_seed = 123

In [None]:
!python --version # tested with Python 3.10.12

In [None]:
# !pip install -q deplacy==2.0.5
# !pip install -q flair==0.13.0
# !pip install -q --upgrade urllib3==2.0.7
# !pip install -q sentence-transformers==2.2.2


In [None]:
'''
Authors: Daniel M. Low
License: See license in github repository
'''

import os
import dill
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import datetime
import spacy
import deplacy
import importlib
pd.set_option("display.max_columns", None)
# pd.options.display.width = 0

# local scripts
import sys
sys.path.append('./../concept-tracker') # wherever cts.py is
from concept_tracker import cts
from concept_tracker.utils.tokenizer import spacy_tokenizer
import srl_constructs



In [None]:
# config

ts = datetime.datetime.utcnow().strftime('%y-%m-%dT%H-%M-%S')



if location == 'openmind':
  input_dir = '/nese/mit/group/sig/projects/dlow/ctl/datasets/'
  output_dir = 'home/dlow/zero_shot/data/output/'
elif location =='local':
  input_dir = input_dir = './../../../data/ctl/input/datasets/'
  output_dir = './data/output/'
os.makedirs(output_dir, exist_ok=True)



# Preprocessing



## Documents to measure

In [None]:
train = pd.read_parquet(input_dir + f'train10_train_metadata_messages_clean.gzip', engine='pyarrow')
# X_train = pd.read_csv('./data/input/ctl/X_train_all_with_interaction_preprocessed_24-03-07T04-25-04.csv', index_col = 0) # preprocessing
test = pd.read_parquet(input_dir + f'train10_test_metadata_messages_clean.gzip', engine='pyarrow')
# X_test = pd.read_csv('./data/input/ctl/X_test_all_with_interaction_preprocessed_24-03-07T04-25-04.csv', index_col = 0) # preprocessing


In [None]:
"""
# this is a variable I created with CTL's suicide risk assessment ladder which helps make sure certain constructs are not present at lower levels. See train_test_split_ids.ipynb for function
def get_true_risk_8(row):
	if (row['3rd_party'] ==1 or row['testing'] == 1 or row['prank'] == 1):
		return -1
	elif (row['active_rescue'] > 0):
		return 8 # active rescue
	
	elif (row['ir_flag'] > 0):
		return 7 # high risk
	
	elif (row['timeframe'] > 0):
		return 6 # high risk
	
	elif (row['suicidal_capability'] > 0):
		return 5 # high risk
	
	elif (row['suicidal_intent']>0):
		return 4
	elif row['self_harm']>0:
		return 3

	elif (row['suicidal_desire']>0 or row['suicide']>0):
		return 2
	else: 
		return 1

		"""

In [None]:
train[train['suicide_ladder_8'] == 8]['word_count_with_interaction'].describe()


In [None]:
# active rescues tend to have shorter conversations than imminent risk without active rescue
train[train['suicide_ladder_8'].isin([6,7])]['word_count_with_interaction'].describe()

In [None]:
train[train['suicide_ladder_8'].isin([2])]['word_count_with_interaction'].describe()

In [None]:
split = 'train'
df_i = train.copy()	
active_rescue = df_i[df_i['suicide_ladder_8']==8] # immiment risk which could not be de-escaleted
suicidal = df_i[df_i['suicide_ladder_8'].isin([2])] # no imminent risk, no intent or capability; desire or other forms of suicide without any of those tags
suicidal_desire = suicidal[suicidal['suicidal_desire']==1] # no imminent risk, no intent or capability; desire confirmed

# subsample to match minority class:
smallest_group = active_rescue.shape[0]
suicidal_desire = suicidal_desire.sample(n= smallest_group, random_state=random_seed)

# add label 
active_rescue['dv'] = ['active_rescue']*active_rescue.shape[0]
suicidal_desire['dv'] = ['suicidal_desire']*suicidal_desire.shape[0]
train_subset = pd.concat([active_rescue, suicidal_desire]).sample(frac=1, random_state=random_seed).reset_index(drop=True)
display(train_subset['dv'].value_counts())


# Test 
split = 'test'
df_i = test.copy()	
active_rescue = df_i[df_i['suicide_ladder_8']==8] # immiment risk which could not be de-escaleted
suicidal = df_i[df_i['suicide_ladder_8'].isin([2])] # no imminent risk, no intent or capability; desire or other forms of suicide without any of those tags
suicidal_desire = suicidal[suicidal['suicidal_desire']==1] # no imminent risk, no intent or capability; desire confirmed

# add label 
active_rescue['dv'] = ['active_rescue']*active_rescue.shape[0]
suicidal_desire['dv'] = ['suicidal_desire']*suicidal_desire.shape[0]
test_subset = pd.concat([active_rescue, suicidal_desire]).sample(frac=1, random_state=random_seed).reset_index(drop=True)
display(test_subset['dv'].value_counts())


In [None]:
# You can try to automate for all types 

# dfs = {}
# for df_i, name in zip([train, test], ['train', 'test']):	
# 	display(df_i['suicide_ladder_8'].value_counts() ) # this is a variable I created with CTL's suicide risk assessment ladder which helps make sure certain constructs are not present at lower levels. See train_test_split_ids.ipynb for function
# 	# see get_true_risk_8 function in train_test_split.ipynb for interpretation
# 	active_rescue = df_i[df_i['suicide_ladder_8']==8] # immiment risk which could not be de-escaleted
# 	imminent_risk = df_i[df_i['suicide_ladder_8'].isin([6,7])] # has intent and capability and timeframe
# 	timeframe = imminent_risk[imminent_risk['timeframe']==1] # has intent and capability and timeframe
# 	suicidal = df_i[df_i['suicide_ladder_8'].isin([2])] # no imminent risk, no intent or capability; desire or other forms of suicide without any of those tags
# 	suicidal_desire = suicidal[suicidal['suicidal_desire']==1] # no imminent risk, no intent or capability; desire confirmed
	
# 	if name =='train' and balance_training_set:
# 		print(1)
# 		smallest_group = active_rescue.shape[0]
# 		timeframe = timeframe.sample(n= smallest_group, random_state=random_seed)
# 		suicidal_desire = suicidal_desire.sample(n= smallest_group, random_state=random_seed)
# 		df_i_subset = pd.concat([active_rescue, timeframe, suicidal_desire]).sample(frac=1, random_state=random_seed).reset_index(drop=True)
# 	else:
# 		df_i_subset = pd.concat([active_rescue, timeframe, suicidal_desire]).sample(frac=1).reset_index(drop=True)


# 	dfs[name] = df_i_subset.copy()


## Construct Text Similarity feature extraction

Constructs:

1. Build constructs
2. Encode constructs

Documents:

3. Tokenize documents
4. Encode documents

Similarity:

5. Compute cosine similarity between constructs and docs and take maximum similarity per document

### 1. Build constructs


In [None]:
import dill

In [None]:
# 1. Build constructs
'''
construct_tokens_d = {
	'annoyance': ['annoyed', 'bothering me', 'annoying'],
	'anger': ['angry', 'rage'],
	'gratitude': ['grateful', 'thank you']}
'''

# I'm going to use prototypical tokens from the Suicide Risk Lexicon (the ones clinicians labelled as 3/3 on average)

srl = dill.load(open("./../lexicon/data/input/lexicons/suicide_risk_lexicon_validated_prototypical_tokens_24-03-06T00-47-30.pickle", "rb"))
constructs_to_measure = srl_constructs.constructs_in_order

# remove_constructs = [
# 				'Bullying', # 'tell me to kill myself' matches with 'kill myself'
# 					 'Social withdrawal',  #'want to be alone' matches with many loneliness comments
# 					 'Suicide exposure' #'suicide aftermath' matches with 'suicide' etc
# 					 'Agitated',# matches 'frustrated', 'stress'
# 					 'Emotional pain & psychache', 
# 					 'Grief & bereavement', #'commited suicide' with 'suicide'
# 					 'Perfectionism',
# 					 'Discrimination',#"treated with less respect" matches with common phrase by therapists "No one deserves to be treated that way"
# 					 ]
# constructs_to_measure = [x for x in constructs_to_measure if x not in remove_constructs]


construct_tokens_d = {}
for construct in constructs_to_measure:
	tokens = srl.constructs[construct]['tokens']                      
	construct_tokens_d[construct] = tokens

### 2. Encode constructs

For faster processing, you can can randomly select up to 10-20 tokens from each construct. If a single construct has 50 tokens, it add a lot of time without much increase in performance

In [None]:
if preprocessing:

	embeddings_dir = './../lexicon/data/input/lexicons/'
	prior_embeddings = dill.load(open(embeddings_dir+'embeddings_lexicon-tokens_all-MiniLM-L6-v2.pickle', "rb"))



	tokens_to_encode = []
	# single dictionary for all tokens, not split by construct
	construct_embeddings_d = {}
	tokens_to_encode = []
	for construct in srl.constructs.keys():
		tokens = srl.constructs[construct]['tokens']                      
		for token_i in tokens:
			if token_i in prior_embeddings.keys():
				embedding = prior_embeddings[token_i]
				construct_embeddings_d[token_i] = embedding
			else:
				tokens_to_encode.append(token_i)

	print('tokens_to_encode',len(tokens_to_encode))
	if len(tokens_to_encode)>0:
		from sentence_transformers import SentenceTransformer
		embeddings_name = 'all-MiniLM-L6-v2'
		sentence_embedding_model = SentenceTransformer(embeddings_name)       # load embedding
		embeddings = sentence_embedding_model.encode(tokens_to_encode, convert_to_tensor=True,show_progress_bar=True)	
		embeddings_d = dict(zip(tokens_to_encode, embeddings))
		construct_embeddings_d.update(embeddings_d)

### 3. Tokenize and clean docs into complete clauses (subject+predict, split if you find conjunction between complete clauses)


In [None]:
# This was already done in another script
def clean_ctl_text(message_with_interaction):
	from concept_tracker.utils import clean
	# Fast: 1 sec every 10 000 messages
	docs_clean = [str(n) if str(n)!='nan' else '' for n in message_with_interaction]
	docs_clean = [n.replace('!.', '!').replace('?.', '?').replace('....', '...').replace('...', '... ') for n in docs_clean]
	message_with_interaction_clean = [clean.remove_multiple_spaces(doc) for doc in docs_clean]
	return message_with_interaction_clean


# This needs to be done after tokenizing
def clean_ctl_conversation(docs):

		docs_clean_clauses_clean = []
		for doc in docs:
			clauses_doc_i = [] 
			for clause in doc:
				clauses_doc_i.extend(clause.split('\n'))
			doc_clean = [n.replace('texter : ','').replace('counselor : ','').replace('\n','. ').strip('.,:\n').replace(" '", "'").replace(' ’', "'").replace(' ,', ',').strip(' ').replace('observer : ', '').replace(" n't", "n't").replace(" ( 1/2 )", "").replace('{ { URL } }', '').replace('[ scrubbed ]','').replace('  ', ' ') for n in clauses_doc_i]
			doc_clean = [n for n in doc_clean if (n not in ['texter', 'counselor', 'observer', '', '-- UNREADABLE MESSAGE --', '( 2/2 )', '( 1/2 )']) and (len(n)>5)]
			
			docs_clean_clauses_clean.append(doc_clean)
		return docs_clean_clauses_clean

In [None]:
if preprocessing:
	# 10 minutes for both train and test subsets
	docs_clean = train_subset['message_with_interaction_clean'].values
	docs_clean = [n.replace('\n', '. ') for n in docs_clean] # help tokenize by clause, specially for CTL data
	docs_clauses = spacy_tokenizer(docs_clean,language = 'en', model='en_core_web_sm',method = 'clause', lowercase=False,display_tree = False,remove_punct=False,clause_remove_conj = True)
	docs_clauses = clean_ctl_conversation(docs_clauses)
	train_subset['message_with_interaction_clean_clauses'] = docs_clauses
	train_subset.to_csv('./data/input/ctl/X_train_all_with_interaction_desire_active_rescue_subset_tokenized_clauses.csv', index=False)


	docs_clean = test_subset['message_with_interaction_clean'].values
	docs_clean = [n.replace('\n', '. ') for n in docs_clean] # help tokenize by clause, specially for CTL data
	docs_clauses = spacy_tokenizer(docs_clean,language = 'en', model='en_core_web_sm',method = 'clause', lowercase=False,display_tree = False,remove_punct=False,clause_remove_conj = True)
	docs_clauses = clean_ctl_conversation(docs_clauses)
	test_subset['message_with_interaction_clean_clauses'] = docs_clauses


	
	test_subset.to_csv('./data/input/ctl/X_test_all_with_interaction_desire_active_rescue_subset_tokenized_clauses.csv', index=False )



In [None]:
train_subset= pd.read_csv('./data/input/ctl/X_train_all_with_interaction_desire_active_rescue_subset_tokenized_clauses.csv')

test_subset = pd.read_csv('./data/input/ctl/X_test_all_with_interaction_desire_active_rescue_subset_tokenized_clauses.csv')


### 4. Encode documents

1 second per conversation if tokenized into clauses (one embedding per clause)

1000 conversations in 15 minutes


In [None]:
# Old code, now I've sped it up with batch processing below

# if preprocessing:
# 	from sentence_transformers import SentenceTransformer
# 	embeddings_name = 'all-MiniLM-L6-v2'
# 	sentence_embedding_model = SentenceTransformer(embeddings_name)       # load embedding
# 	import pickle

# 	# already encoded

# 	with open('./data/input/ctl/embeddings/'+f'embeddings_all-MiniLM-L6-v2_docs_clauses_with-interaction_24-03-07T04-25-04.pickle', 'rb') as handle:
# 			docs_embeddings_d = pickle.load(handle)
# 	import tqdm 
# 	for df_i_subset in [train_subset, test_subset]: # TODO add train_subset
# 		docs_to_encode = []
# 		docs_embeddings_d_subset = {}
# 		for conversation_id in df_i_subset['conversation_id'].tolist():
# 			if conversation_id in docs_embeddings_d.keys():
# 				embedding = docs_embeddings_d[conversation_id]
# 				docs_embeddings_d_subset[conversation_id] = embedding
# 			else:
# 				docs_to_encode.append(conversation_id)
# 		print(len(docs_to_encode))

	
# 		if len(docs_to_encode)>0:
# 			print('encoding...')
# 			from sentence_transformers import SentenceTransformer
# 			embeddings_name = 'all-MiniLM-L6-v2'
# 			sentence_embedding_model = SentenceTransformer(embeddings_name)       # load embedding
# 			for doc in tqdm.tqdm(docs_to_encode):
# 				doc_clauses = df_i_subset[df_i_subset['conversation_id'] == doc]['message_with_interaction_clean_clauses'].values[0]
# 				doc_clauses = eval(doc_clauses) # when reloading a DF where each cell is a list of lists
# 				embeddings = sentence_embedding_model.encode(doc_clauses, convert_to_tensor=True,show_progress_bar=False)	
# 				docs_embeddings_d[doc] = embeddings


# 	# re save pickle
# 	import dill
# 	with open('./data/input/ctl/embeddings/'+'embeddings_all-MiniLM-L6-v2_docs_clauses_with-interaction_24-03-07T04-25-04.pickle', 'wb') as handle:
# 		dill.dump(docs_embeddings_d, handle, protocol=dill.HIGHEST_PROTOCOL)

In [None]:
test_subset_toy = test_subset.sample(n=10)

In [None]:
%time 

# Encode clauses

if preprocessing:
	from sentence_transformers import SentenceTransformer
	import pickle
	import tqdm

	embeddings_name = 'all-MiniLM-L6-v2'
	sentence_embedding_model = SentenceTransformer(embeddings_name)

	# Load existing embeddings
	with open('./data/input/ctl/embeddings/'+f'embeddings_all-MiniLM-L6-v2_docs_clauses_with-interaction_24-03-07T04-25-04.pickle', 'rb') as handle:
	    docs_embeddings_d = pickle.load(handle)
	
	# docs_embeddings_d = {} #Warning, here it is empty instead of loading
	
	for df_i_subset in [train_subset, test_subset]:  # Add train_subset if needed
		docs_to_encode = []
		docs_embeddings_d_subset = {}
		
		for conversation_id in df_i_subset['conversation_id'].tolist():
			if conversation_id in docs_embeddings_d:
				docs_embeddings_d_subset[conversation_id] = docs_embeddings_d[conversation_id]
			else:
				docs_to_encode.append(conversation_id)
		print(docs_to_encode)
		if docs_to_encode:
			print('Encoding...')
			
			doc_tokens_example = df_i_subset[df_i_subset['conversation_id'] == docs_to_encode[0]]['message_with_interaction_clean_clauses'].values[0]
			if isinstance(doc_tokens_example, str):
				# eval string to turn into list of lists (happens when this columns is saved and reloaded)
				clauses_to_encode = [eval(df_i_subset[df_i_subset['conversation_id'] == doc]['message_with_interaction_clean_clauses'].values[0]) 
								for doc in docs_to_encode]
			elif isinstance(doc_tokens_example, list):
				clauses_to_encode = [df_i_subset[df_i_subset['conversation_id'] == doc]['message_with_interaction_clean_clauses'].values[0] 
								for doc in docs_to_encode]

			
			
			# Flatten the list of lists and remember the split points
			flat_clauses, split_indices = [], [0]
			for clauses in clauses_to_encode:
				split_indices.append(split_indices[-1] + len(clauses))
				flat_clauses.extend(clauses)
			
			# Process in batches
			batch_size = 512  # Adjust based on your available memory and requirements
			all_embeddings = []
			for i in tqdm.tqdm(range(0, len(flat_clauses), batch_size)):
				batch_clauses = flat_clauses[i:i+batch_size]
				batch_embeddings = sentence_embedding_model.encode(batch_clauses, convert_to_tensor=False, show_progress_bar=False)
				all_embeddings.extend(batch_embeddings)
			
			# Assign embeddings to respective documents
			for i, doc_id in enumerate(docs_to_encode):
				start, end = split_indices[i], split_indices[i+1]
				docs_embeddings_d[doc_id] = np.array(all_embeddings[start:end])

	# Save updated embeddings
	with open('./data/input/ctl/embeddings/'+'embeddings_all-MiniLM-L6-v2_docs_clauses_with-interaction_24-03-07T04-25-04.pickle', 'wb') as handle:
		pickle.dump(docs_embeddings_d, handle, protocol=pickle.HIGHEST_PROTOCOL)


### 4. Compute cosine similarity between constructs and docs
10 sec per 500 docs

In [None]:
from importlib import reload
reload(cts)

In [None]:
embeddings_dir

In [None]:
if preprocessing:

	# subselect keys if in train_subset['conversation_id'].values
	docs_embeddings_d_train = {}
	docs_embeddings_d_test = {}
	for k,v in list(docs_embeddings_d.items()):
		if k in train_subset['conversation_id'].values:
			docs_embeddings_d_train[k] = v
		if k in test_subset['conversation_id'].values:
			docs_embeddings_d_test[k] = v

	
	# X_train
	# ========================================================
	# we'll measure all the constructs in construct_tokens_d for the docs in docs
	print('Extract features for training set')
	X_train, X_train_cosine_scores_per_doc = cts.measure(
				construct_tokens_d = construct_tokens_d,
				construct_embeddings_d = construct_embeddings_d,
				docs_embeddings_d = docs_embeddings_d_train,
				method = 'lexicon_clause', 
				summary_stat = ['max'],
				return_cosine_similarity=True,
				minmaxscaler = None,
				doc_id_col_name = 'conversation_id',
				remove_stat_name_from_col_name = True
			)

	train_subset = train_subset.merge(X_train, on = 'conversation_id')
	train_subset.to_csv('./data/input/ctl/X_train_all_with_interaction_desire_active_rescue_subset_tokenized_clauses_cts-prototypes.csv', index=False)
	with open('./data/input/ctl/embeddings/'+'X_train_all_with_interaction_desire_active_rescue_subset_tokenized_clauses_cts-prototypes_cosine_similarities.pickle', 'wb') as handle:
		dill.dump(X_train_cosine_scores_per_doc, handle, protocol=dill.HIGHEST_PROTOCOL)

	# X_test
	# ========================================================
	print('Extract features for test set')
	X_test, X_test_cosine_scores_per_doc = cts.measure(
				construct_tokens_d = construct_tokens_d,
				construct_embeddings_d = construct_embeddings_d,
				docs_embeddings_d = docs_embeddings_d_test,
				method = 'lexicon_clause', 
				summary_stat = ['max'],
				return_cosine_similarity=True,
				minmaxscaler = None,
				doc_id_col_name = 'conversation_id',
				remove_stat_name_from_col_name = True
			)




	
	test_subset = test_subset.merge(X_test, on = 'conversation_id')
	
	test_subset.to_csv('./data/input/ctl/X_test_all_with_interaction_desire_active_rescue_subset_tokenized_clauses_cts-prototypes.csv', index=False )
	with open('./data/input/ctl/embeddings/'+'X_test_all_with_interaction_desire_active_rescue_subset_tokenized_clauses_cts-prototypes_cosine_similarities.pickle', 'wb') as handle:
		dill.dump(X_test_cosine_scores_per_doc, handle, protocol=dill.HIGHEST_PROTOCOL)



# Load preprocessed data and run some descriptive statistics

In [None]:
preprocessing = False

In [None]:

if not preprocessing:
	train_subset = pd.read_csv('./data/input/ctl/X_train_all_with_interaction_desire_active_rescue_subset_tokenized_clauses_cts-prototypes.csv')
	test_subset = pd.read_csv('./data/input/ctl/X_test_all_with_interaction_desire_active_rescue_subset_tokenized_clauses_cts-prototypes.csv')


In [None]:
from concept_tracker.feature_extraction import pronouns, verbs
from importlib import reload
reload(pronouns)


In [None]:
doc_clauses = train_subset['message_with_interaction_clean_clauses'] .values
doc_clauses = [eval(n) for n in doc_clauses]
doc_proportions_df = verbs.extract_tenses_and_aspects(doc_clauses, list_of_clauses = True)

texter_messages = [n.replace('\n', '. ') for n in  train_subset['message_clean']]
pronouns_df = pronouns.count_pronouns(texter_messages, normalize=True)
assert train_subset.shape[0]== doc_proportions_df.shape[0] == pronouns_df.shape[0]

train_subset = pd.concat([train_subset,doc_proportions_df, pronouns_df], axis=1)
train_subset.to_csv('./data/input/ctl/X_train_all_with_interaction_desire_active_rescue_subset_tokenized_clauses_cts-prototypes.csv', index=False )


In [None]:
doc_clauses = test_subset['message_with_interaction_clean_clauses'] 
doc_clauses = [eval(n) for n in doc_clauses]
doc_proportions_df = verbs.extract_tenses_and_aspects(doc_clauses, list_of_clauses = True)

texter_messages = [n.replace('\n', '. ') for n in  test_subset['message_clean']]
pronouns_df = pronouns.count_pronouns(texter_messages, normalize=True)

assert test_subset.shape[0]== doc_proportions_df.shape[0] == pronouns_df.shape[0]
test_subset = pd.concat([test_subset,doc_proportions_df, pronouns_df], axis=1)


In [None]:
test_subset.to_csv('./data/input/ctl/X_test_all_with_interaction_desire_active_rescue_subset_tokenized_clauses_cts-prototypes.csv', index=False )

In [None]:
test_subset

In [None]:
# The zero-shot feature is able to separate docs about the construct from docs
# not about the construct often quite well, depending on how the construct was built
construct = 'Active suicidal ideation & suicidal planning'
import plotly.express as px
df = px.data.tips()

train_subset['hover_text'] = train_subset['message_with_interaction_clean'].apply(lambda x: x.replace('\n', '<br>'))



fig = px.box(train_subset, x='dv', y=construct, points="all", hover_data = ['conversation_id', 
																			'hover_text',
																			# 'message_with_interaction_clean'
																			])
fig.update_layout(yaxis_range=[-0.1,1])

fig.show()



# Evaluate

In [None]:

from sklearn import metrics

construct = 'Active suicidal ideation & suicidal planning'

df = test_subset.copy()

y_similarity = df[construct].values

y = df['dv'].replace({'active_rescue':1, 'suicidal_desire':0}).values
roc_auc = metrics.roc_auc_score(y, y_similarity)
print(f'roc auc {roc_auc:.2}')
f1 = metrics.f1_score(y, y_similarity>0.5)
print(f'f1 score {f1:.2}') # you can't use 0.5 as the threshold, the cosine similarity isn't calibrated, it depends on each model_name type

# Get optimal threshold
false_positive_rate_list, true_positive_rate_list, thresholds = metrics.roc_curve(y,y_similarity) # Important: other metrics take binary predictions y_pred. Here we test different thresholds, so we need probabilities (this will change the outputs)
i_opt = np.argmax(np.array(true_positive_rate_list)-np.array(false_positive_rate_list))          # at which cutoff index are TPR and FPR maximal?
print(i_opt)
cut_opt = thresholds[i_opt]
print(f'Optimal cutoff value: {cut_opt:.3f}')
f1_opt = metrics.f1_score(y, y_similarity>cut_opt)
print(f'f1 score with optimal cutoff {f1:.2}') # altough this is if we have access to a test set.

ax = sns.kdeplot(data=df, x=construct, hue=y)
# plt.title(f'ROC AUC = {roc_auc:.2f}\nF1 = {f1:.2f}')
plt.title(f'ROC AUC = {roc_auc:.2f}')

# Add a vertical line for the optimal cut-off
line = plt.axvline(x=cut_opt, color='r', linestyle='dashed', linewidth=1)
label_line = f'Opt thresh {cut_opt:.2f}: F1 = {f1_opt:.2f}'
labels = []
# Add the custom legend entry for the axvline

labels.append(f'active_rescue'.replace('_', ' ').capitalize())
labels.append(f'suicidal_desire'.replace('_', ' ').capitalize())
labels.append(label_line)

# Redraw the legend with the combined handles and labels
ax.legend(labels=labels)

plt.show()



# Explainability: construct tokens vs. doc tokens


In [None]:
def return_cosine_df(doc_id, construct, docs_clauses, construct_tokens_d, cosine_scores_per_doc):
  doc_clauses_i = docs_clauses[doc_id]
  construct_tokens_i = construct_tokens_d[construct]
  df = pd.DataFrame(cosine_scores_per_doc[f'{doc_id}_{construct}'], index = construct_tokens_i, columns = doc_clauses_i)
  return df


In [None]:

'''
construct = 'gratitude'
doc_id = 65

construct = 'anger'
doc_id = 22

construct = 'annoyance'
doc_id = 64
'''

construct = 'Direct self-injury'
doc_id = 0



df = return_cosine_df(doc_id, construct, , construct_tokens_d,X_test_cosine_scores_per_doc)
display(df)