# Learning to communicate about shared procedural abstractions
# Analysis

The majority of our plots and all of our statistical analyses can be found in `./stats.Rmd`



In [1]:
import os
import sys
os.getcwd()
sys.path.append("..")

import numpy as np
import scipy.stats as stats
import pandas as pd
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

from collections import Counter
import json
import re
import ast

import  matplotlib
from matplotlib import pylab, mlab, pyplot
%matplotlib inline
from IPython.core.pylabtools import figsize, getfigs
plt = pyplot
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
import seaborn as sns
sns.set_context('talk')
sns.set_style('darkgrid')

from IPython.display import clear_output
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", message="numpy.dtype size changed")
warnings.filterwarnings("ignore", message="numpy.ufunc size changed")

In [2]:
## directory & file hierarchy
proj_dir = os.path.abspath('../..')
datavol_dir = os.path.join(proj_dir,'data')
analysis_dir =  os.path.abspath('.')
results_dir = os.path.join(proj_dir,'results')
plot_dir = os.path.join(results_dir,'plots')
csv_dir = os.path.join(results_dir,'csv')
json_dir = os.path.join(results_dir,'json')
exp_dir = os.path.abspath(os.path.join(proj_dir,'behavioral_experiments'))
png_dir = os.path.abspath(os.path.join(datavol_dir,'png'))

## add helpers to python path
if os.path.join(proj_dir,'stimuli') not in sys.path:
    sys.path.append(os.path.join(proj_dir,'stimuli'))
    
if not os.path.exists(results_dir):
    os.makedirs(results_dir)
    
if not os.path.exists(plot_dir):
    os.makedirs(plot_dir)   
    
if not os.path.exists(csv_dir):
    os.makedirs(csv_dir)       

In [3]:
results_dir

'/Users/will/compositional-abstractions/results'

## Read dataframes

In [4]:
# read in dataframes from each eventType)
df_block = pd.read_csv(os.path.join(csv_dir,'df_block.csv'))
df_chat = pd.read_csv(os.path.join(csv_dir,'df_chat.csv'))
df_exit = pd.read_csv(os.path.join(csv_dir,'df_exit.csv'))
df_trial = pd.read_csv(os.path.join(csv_dir,'df_trial.csv'))

In [5]:
print('n:', df_block.gameid.nunique())

n: 73


In [6]:
# iterationNames
list(df_trial.iterationName.unique())

['pilot0', 'pilot1', 'pilot2', 'pilot3', 'pilot4', 'pilot4b']

## Exclusion criteria

In [7]:
# 75% Accuracy on 75% of trials
df75 = pd.DataFrame(df_trial.groupby(['gameid', 'trialNum'])['trialScore'].sum()>75).groupby(['gameid']).sum()
df75['trials'] = df75['trialScore']

df75 = df75[df75['trials']>=9]
includedGames = list(df75.reset_index().gameid)

print("Total dyads achieving 75% Accuracy on 75% of trials:",len(df75))

Total dyads achieving 75% Accuracy on 75% of trials: 49


In [8]:
# Exclude from analysis
df_block = df_block[df_block.gameid.isin(includedGames)]
df_chat = df_chat[df_chat.gameid.isin(includedGames)]
df_exit = df_exit[df_exit.gameid.isin(includedGames)]
df_trial = df_trial[df_trial.gameid.isin(includedGames)]

# Referring Expressions Annotations

## Create annotations dataframe

Here we load and merge data from two rounds of annotations.
1. We originally asked two annotators to identify the referring expressions in the chat data.
2. To make these annotations more robust, we later asked an additional two annotators, this time using a custom-built interface. The task can be found in `tasks/annotation`, and basic wrangling of that data can be found in `analysis/annotation`.

Here we wrangle data from both rouhns into a common format, and export for other analyses.

In [None]:
# first set of annotations (cogsci 2021)
df_jj = pd.read_csv('{}/csv/JJ_content.csv'.format(results_dir))

# second set of annotations (2023)
df_zc = pd.read_csv('{}/csv/ref_exp_annotations_2023.csv'.format(results_dir))

df_chat = pd.read_csv('{}/csv/df_chat_ids_cogsci21.csv'.format(results_dir))

In [None]:
df_chat.loc[:,'dyad_gameid'] = df_chat.gameid
df_chat.loc[:,'turn_num'] = df_chat.turnNum
df_chat.loc[:,'trial_num'] = df_chat.trialNum

In [None]:
df_jj.loc[:,'message_num'] = (df_jj.turnNum/2).astype(int)

df_jj_small = df_jj[['gameid','trialNum', 'message_num','turnNum','message','block_justin', 'toer_justin',
       'scene_justin', 'Flagged', 'phrases_justin', 'block_julia',
       'tower_juli', 'scene_juli', 'phrases_julia']].copy()

df_jj_small.rename(
            columns={
            'gameid': 'dyad_gameid',
            'trialNum': 'trial_num',
            'turnNum': 'turn_num',
            'toer_justin': 'tower_justin',
            'tower_juli': 'tower_julia',
            'scene_juli': 'scene_julia'
            # add more column names as needed
        }, inplace=True)

df_jj_small = df_jj_small.merge(df_chat[['dyad_gameid','trial_num','turn_num','message_id']], on = ['dyad_gameid','trial_num','turn_num'], how = 'left')

df_jj_small = df_jj_small.sort_values(['dyad_gameid','trial_num','message_num']).reset_index(drop=True)


In [None]:
# melt and pivot

suffix_columns = [col for col in df_jj_small.columns if col.endswith('_justin') or col.endswith('_julia')]

suffix_df = df_jj_small[['dyad_gameid','trial_num','turn_num','message_id'] + suffix_columns]

# Then, melt the DataFrame with the new index as the identifier variable
melted_df = pd.melt(suffix_df, id_vars=['dyad_gameid','trial_num','turn_num','message_id'], var_name='Type', value_name='Value')

# melted_df

# # Now, split the 'Type' column to separate the suffix and create a new column
melted_df[['Category', 'Suffix']] = melted_df['Type'].str.split('_', expand=True)

# # Drop the 'Type' column as it's no longer needed
melted_df.drop(columns=['Type'], inplace=True)
melted_df
# # Finally, pivot the table to the desired format
pivoted_df = melted_df.pivot(index=['dyad_gameid','trial_num','turn_num','message_id','Suffix'], 
                             columns='Category', values='Value').reset_index()

pivoted_df = pivoted_df.rename(columns={'Suffix':'workerID'})


In [None]:
# adjust a bad annotation
pivoted_df.loc[pivoted_df['tower'] == 'L','tower'] = 1

In [None]:
# convert to ints
pivoted_df.loc[:,'block'] = pivoted_df['block'].fillna(0).astype(int)
pivoted_df.loc[:,'tower'] = pivoted_df['tower'].fillna(0).astype(int)
pivoted_df.loc[:,'scene'] = pivoted_df['scene'].fillna(0).astype(int)

In [None]:
# merge in metadata
pivoted_df_merged = pivoted_df.merge(df_jj_small[['dyad_gameid','message_id','message_num','message']], 
                 on=['message_id','dyad_gameid'], how='left')

In [None]:
pivoted_df_merged['content'] = pivoted_df_merged['phrases'].str.lower()
pivoted_df_merged['content'] = pivoted_df_merged['content'].str.replace(r'~', '')
pivoted_df_merged['content'] = pivoted_df_merged['content'].str.replace(r'\(', '')
pivoted_df_merged['content'] = pivoted_df_merged['content'].str.replace(r'\)', '')
pivoted_df_merged['content'] = pivoted_df_merged['content'].str.replace(r'\,', '')
pivoted_df_merged['content'] = pivoted_df_merged['content'].str.replace(r"\'", '')
pivoted_df_merged['content'] = pivoted_df_merged['content'].str.replace(r"\:", '')
pivoted_df_merged['content'] = pivoted_df_merged['content'].str.replace(r"\;", '')

In [None]:
df_zc_small = df_zc[['workerID','message_id','dyad_gameid','msgNum','message','block','tower','refExps']].copy()

df_zc_small = df_zc_small.merge(df_chat[['message_id','trialNum']], 
                               how ='left',
                               on = 'message_id')

df_zc_small.rename(
            columns={
            'trialNum': 'trial_num',
            'msgNum': 'message_num'
        }, inplace=True)


df_zc_small.loc[:,'turn_num'] = (df_zc_small.message_num*2).astype(int)

df_zc_small = df_zc_small.sort_values(['workerID','dyad_gameid','trial_num','message_num']).reset_index(drop=True)


In [None]:
df_zc_small['content'] = df_zc_small['refExps'].str.lower()
df_zc_small['content'] = df_zc_small['content'].str.replace(r'~', '')
df_zc_small['content'] = df_zc_small['content'].str.replace(r'\(', '')
df_zc_small['content'] = df_zc_small['content'].str.replace(r'\)', '')
df_zc_small['content'] = df_zc_small['content'].str.replace(r'\,', '')
df_zc_small['content'] = df_zc_small['content'].str.replace(r"\'", '')
df_zc_small['content'] = df_zc_small['content'].str.replace(r"\:", '')
df_zc_small['content'] = df_zc_small['content'].str.replace(r"\;", '')

In [None]:
df_ref_exps = pd.concat([pivoted_df_merged, df_zc_small], ignore_index=True)

df_ref_exps = df_ref_exps.merge(df_chat[['message_id','leftTarget','rightTarget']], how ='left', on='message_id')
df_ref_exps.loc[:,'tower_pair'] = df_ref_exps.leftTarget + '_' + df_ref_exps.rightTarget
df_ref_exps.loc[:,'rep'] = ((df_ref_exps.trial_num)/ 3).astype(int) + 1

In [None]:
df_ref_exps.loc[:,'content'] = df_ref_exps.loc[:,'content'].astype(str)

In [None]:
#df_ref_exps.to_csv('{}/results/csv/df_ref_exps.csv'.format(analysis_dir))

## Inter-rater reliability

In [None]:
df_ref_exps_melt = df_ref_exps.melt(id_vars=['workerID','dyad_gameid','message_id','message_num','trial_num','tower_pair','rep'], value_vars=['block','tower'], value_name='n_refs')
df_ref_exps_melt = df_ref_exps_melt.rename(columns={'variable': 'exp_type'})
df_ref_exps_melt

In [None]:
# df_ref_exps_melt.to_csv('{}/results/csv/df_ref_exps_melt.csv'.format(results_dir))

In [None]:
df_ref_exps_table = df_ref_exps.pivot(index='message_id', columns='workerID', values=['block','tower'])
df_ref_exps_table

In [None]:
prop_all_agree_block = np.mean(
    (df_ref_exps_table['block','charles'] == df_ref_exps_table['block','julia']) &\
    (df_ref_exps_table['block','julia'] == df_ref_exps_table['block','justin']) &\
    (df_ref_exps_table['block','justin'] == df_ref_exps_table['block','zoe']))

print('%.1f' % (prop_all_agree_block*100) + '% total agreement on blocks') 

In [None]:
prop_all_agree_tower = np.mean(
    (df_ref_exps_table['tower','charles'] == df_ref_exps_table['tower','julia']) &\
    (df_ref_exps_table['tower','julia'] == df_ref_exps_table['tower','justin']) &\
    (df_ref_exps_table['tower','justin'] == df_ref_exps_table['tower','zoe']))

print('%.1f' % (prop_all_agree_tower*100) + '% total agreement on towers') 

## calculate inter rater reliability with ICC
https://en.wikipedia.org/wiki/Intraclass_correlation

In [None]:
import pingouin as pg
# https://www.statology.org/intraclass-correlation-coefficient-python/

In [None]:
df_ref_exps_melt.n_refs = pd.to_numeric(df_ref_exps_melt.n_refs)

In [None]:
pg.intraclass_corr(data = df_ref_exps_melt, targets="message_id", raters="workerID", ratings="n_refs")

In [None]:
pg.intraclass_corr(data = df_ref_exps_melt.query('exp_type=="block"'), 
                   targets="message_id", raters="workerID", ratings="n_refs")                                       

In [None]:
pg.intraclass_corr(data = df_ref_exps_melt.query('exp_type=="tower"'), 
                   targets="message_id", raters="workerID", ratings="n_refs")                                       

## Comparing to baseline distributions

In [None]:
import random

In [None]:
random.seed(0)

def shuffle_counts(df, within_exp_type=True, coupled=False):
    '''
    Shuffles counts of block and tower referring expressions.
    This decouples block and tower counts from each trial.
    '''

    df_shuffled = df.copy()

    for workerID in df.workerID.unique():
        
        if within_exp_type:
            
            if coupled:
                
                indicies = list(range(0, len(df.loc[(df.workerID == workerID) &
                                            (df.exp_type == df.exp_type.nunique())])))
                random.shuffle(indicies)
                
                for exp_type in df.exp_type.unique():
                
                    counts = df.loc[(df.workerID == workerID) &
                                                (df.exp_type == exp_type), 'n_refs'].reset_index()
                    
                    df_shuffled.loc[(df.workerID == workerID) &
                                              (df.exp_type == exp_type), 'n_refs'] = counts[indicies]
                
        
            for exp_type in df.exp_type.unique():
                
                counts = list(df.loc[(df.workerID == workerID) &
                                            (df.exp_type == exp_type), 'n_refs'])

                random.shuffle(counts)

                df_shuffled.loc[(df.workerID == workerID) &
                                              (df.exp_type == exp_type), 'n_refs'] = counts

                    
                
        else:
            if not(coupled):
                counts = list(df.loc[(df.workerID == workerID), 'n_refs'])

                random.shuffle(counts)

                df_shuffled.loc[(df.workerID == workerID), 'n_refs'] = counts
            else:
                print('does not make sense to ask for coupled block and tower responses across expression type')

    
    df_shuffled['n_refs'] = df_shuffled['n_refs'].astype(int)
    
    return df_shuffled

In [None]:
df_ref_exps_melt_shuffled = shuffle_counts(df_ref_exps_melt, within_exp_type=True, coupled=True)

In [None]:
df_ref_exps_melt_shuffled

In [None]:
df_ref_exps_shuffled_table = df_ref_exps_melt_shuffled.pivot(index='message_id', columns=['exp_type', 'workerID'], values=['n_refs'])['n_refs']

In [None]:
def prop_agreement(df_table, level = 'block'):
    prop = np.mean(
    (df_table[level,'charles'] == df_table[level,'julia']) &\
    (df_table[level,'julia'] == df_table[level,'justin']) &\
    (df_table[level,'justin'] == df_table[level,'zoe']))
    
    return prop

In [None]:
prop_agreement(df_ref_exps_shuffled_table, 'block')

In [None]:
prop_agreement(df_ref_exps_shuffled_table, 'tower')

In [None]:
random.seed(0)

agreement_baseline = {}
agreement_baseline['block'] = []
agreement_baseline['tower'] = []
icc_baseline = []
# icc_baseline['block'] = []
# icc_baseline['tower'] = []

for i in range(0,50):
    
    df_ref_exps_melt_shuffled = shuffle_counts(df_ref_exps_melt, within_exp_type=True, coupled=True)
    
    df_ref_exps_shuffled_table = df_ref_exps_melt_shuffled.pivot(index='message_id', columns=['exp_type', 'workerID'], values=['n_refs'])['n_refs']
    
    agreement_baseline['block'].append(prop_agreement(df_ref_exps_shuffled_table, 'block'))
    agreement_baseline['tower'].append(prop_agreement(df_ref_exps_shuffled_table, 'tower'))
    
    icc_baseline.append(\
            pg.intraclass_corr(data = df_ref_exps_melt_shuffled, 
                               targets="message_id", 
                               raters="workerID", 
                               ratings="n_refs").loc[0,"ICC"])
    

In [None]:
overall_icc = pg.intraclass_corr(data = df_ref_exps_melt, targets="message_id", raters="workerID", ratings="n_refs").loc[0,"ICC"]

In [None]:
# fig, ax = plt.subplots(figsize=(10,4))
sns.displot(icc_baseline, height=5, aspect=2)
plt.axvline(overall_icc, color='r', linestyle='--')
plt.show()

In [None]:
# fig, ax = plt.subplots(figsize=(10,4))
sns.displot(agreement_baseline['block'], height=5, aspect=2)
plt.axvline(prop_all_agree_block, color='r', linestyle='--')
plt.show()

In [None]:
sns.displot(agreement_baseline['tower'], height=5, aspect=2)
plt.axvline(prop_all_agree_tower, color='r', linestyle='--')

# Change in referring expression across repetitions

In [None]:
df_ref_exps = pd.read_csv(os.path.join(csv_dir,'df_ref_exps.csv'))
df_ref_exps.head()

In [None]:
from nltk.corpus import stopwords
stop = stopwords.words('english')

df_ref_exps.loc[:,'content'] = df_ref_exps.loc[:,'content'].astype(str)
df_ref_exps['content'] = df_ref_exps['content'].apply(lambda x: ' '.join([word for word in x.split() if word not in (stop)]))
df_ref_exps['content'].head()

In [None]:
# convert number words

def num_2_words(sentence):
    out = ""
    for word in sentence.split():
        try:
            o = num2words(word)
        except:
            o = word
        out = out+" "+ o
    return out

df_ref_exps['content'] = df_ref_exps['content'].apply(lambda x: num_2_words(x))

In [None]:
# lemmatize
import nltk
from nltk.tokenize import RegexpTokenizer


tokenizer = RegexpTokenizer(r'\w+')
w_tokenizer = nltk.tokenize.WhitespaceTokenizer()
lemmatizer = nltk.stem.WordNetLemmatizer()

def lemmatize_text(text):
    return [lemmatizer.lemmatize(w) for w in tokenizer.tokenize(text)]

df_ref_exps['BOW_lemmatized'] = df_ref_exps['content'].apply(lemmatize_text)
df_ref_exps['BOW_lemmatized'] = df_ref_exps['BOW_lemmatized'].apply(lambda x: [i.upper() for i in x])

df_ref_exps[['message','content','BOW_lemmatized']].head()

In [None]:
## get work frequencies
df_ref_exps['word_freq'] = df_ref_exps['BOW_lemmatized'].apply(lambda x: Counter(x))
df_ref_exps.head()

In [None]:
## concatenate lemmatized tokens, separated by spaces
df_ref_exps['BOW_concat'] = df_ref_exps['BOW_lemmatized'].apply(lambda x: ' '.join(x))

In [None]:
# Currently, the word counts represent the counts from all 4 of our naive raters. 
# So that we can examine how frequently different words were used, we need to convert these values into proportions.
split_words = df_ref_exps['BOW_concat'].apply(lambda x: x.split())
all_words = list(pd.Series([st for row in split_words for st in row]).unique())
support = {}
for word in all_words:
    support[word] = 0.000000001
    
def get_pdist(row):
    num_words = np.sum(list(row['word_freq'].values()))
    pdist = support.copy()
    for i, (word, count) in enumerate(row['word_freq'].items()):
        pdist[word] = count/num_words
    return pdist

In [None]:
df_ref_exps['word_pdist'] = df_ref_exps.apply(get_pdist, axis = 1)
df_ref_exps['word_pdist_numeric'] = df_ref_exps['word_pdist'].apply(lambda dist: list(dist.values()))

In [None]:
df_all_words = df_ref_exps[['dyad_gameid', 'rep', 'BOW_concat']]

In [None]:
for w in all_words:
    df_all_words.loc[:,w] = df_all_words['BOW_concat'].apply(lambda row: int(w in row.split()))

In [None]:
df_all_words_reps = df_all_words.groupby('rep').agg(sum)
df_all_words_reps

### figure 3A

In [None]:
# examine the change in word frequencies between trials.
# prep data
df_ref_exps_rep = df_ref_exps.groupby('rep')['BOW_concat'].apply(lambda group:' '.join(group)).reset_index()
df_ref_exps_rep['word_freq'] = df_ref_exps_rep['BOW_concat'].apply(lambda x: Counter(x.split()))
df_ref_exps_rep['word_pdist'] = df_ref_exps_rep.apply(get_pdist, axis=1)
df_ref_exps_rep['word_pdist_numeric'] = df_ref_exps_rep['word_pdist'].apply(lambda dist: list(dist.values()))
df_ref_exps_rep.index=df_ref_exps_rep['rep']

In [None]:
# calculate difference in proportion between reps (currently hardcoded to be 1 and 4)
rep_a = 1 
rep_b = 4

rep_diff = {}

for _, (k, rep_a_v) in enumerate(df_ref_exps_rep.loc[rep_a,'word_pdist'].items()):
    rep_diff[k] = df_ref_exps_rep.loc[rep_b,'word_pdist'][k] - rep_a_v

In [None]:
# find largest n increase/ decrease in proportion across reps
n = 6

# find the largest increase in proportion between reps
top_n = dict(sorted(rep_diff.items(), key=lambda item: item[1], reverse=True)[:n])

# find the largest decrease in proportion between reps
bottom_n = dict(sorted(rep_diff.items(), key=lambda item: item[1], reverse=False)[:n])

df_grouped = df_ref_exps.groupby('rep').agg({'BOW_lemmatized': 'sum'})


In [None]:
from matplotlib.ticker import FormatStrFormatter

font = {'fontname':'Helvetica'}
sns.set_theme(style='white')

x_limit = 6

labels, values = zip(*rep_diff.items())

# sort your values in descending order
indSort_high = np.argsort(values)[::-1]
indSort_low = np.argsort(values)

# rearrange your data
#labels = np.array(labels)[indSort_high][:x_limit][::-1]
labels = np.concatenate([np.array(labels)[indSort_low][:x_limit],np.array(labels)[indSort_high][:x_limit][::-1]])
#values = np.array(values)[indSort_high][:x_limit][::-1]
values = np.concatenate([np.array(values)[indSort_low][:x_limit], np.array(values)[indSort_high][:x_limit][::-1]])

indexes = np.arange(len(labels))

bar_width = 0.35

fig = plt.figure(num=None, figsize=(7, 11), dpi=80, facecolor='w', edgecolor='k')
ax = fig.add_subplot(111)
ax.bar(indexes, values, color = "#7D7D7D")
ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))

# add labels
plt.yticks(fontsize=16, **font)
plt.xticks(indexes + bar_width, labels,  rotation='vertical', fontsize=16, **font)
plt.ylabel("change in proportion", size = 24, **font)
plt.yticks(np.arange(-.13,.06, .02))
ax.axes.get_xaxis().set_visible(True)
#plt.title("highest delta words", size = 24, **font)
plt.show()

## Change in referring expression visualizations

In [None]:
# Exclude a couple of participants that have very different data (which make all other participants appear closer together)

exclude_ids = ['4338-9f5cbb3a-a351-45e4-9b99-dc2737fd4658', #"xxoooxxxo"-type responses
               '9387-db1af5ad-b089-48ad-a730-baee40f08177'  # lots of empty messages
              ]

df_ref_exps = df_ref_exps[~df_ref_exps.dyad_gameid.isin(exclude_ids)].reset_index()

In [None]:
# word count (across all four raters)
df_ref_exps_trial = df_ref_exps.groupby(['dyad_gameid','rep','trial_num'])['BOW_concat'].apply(lambda x: ' '.join(x)).reset_index()
df_ref_exps_trial['word_freq'] = df_ref_exps_trial['BOW_concat'].apply(lambda x: Counter(x.split()))
df_ref_exps_trial

In [None]:
# find which words were used by each participant in every trial

df_all_words_trial = df_ref_exps_trial[['dyad_gameid', 'rep', 'trial_num' ,'BOW_concat']]

for w in all_words:
    df_all_words_trial[w] = df_all_words_trial['BOW_concat'].apply(lambda row: int(w in row.split()))

## visualize change across repetition, with TSNE (figure 2B)

In [None]:
# identify some demonstrative examples
ps = [29, 38, 71]

# 29
# ['BLOCK', 'BLUE', 'RED']
# ['BLUE', 'RED']

# 38
# ['BLOCK', 'BLUE', 'L', 'RED', 'TWO']
# ['L', 'U', 'UPSIDE']

# 71
# ['1', 'HORIZONTAL', 'ONE', 'TWO', 'VERTICAL']
# ['L', 'LOWERCASE', 'N', 'SHAPE']

In [None]:
# visualizations using tsne, colored by rep

np.random.seed(0)

both_reps = pd.concat([df_all_words_trial[(df_all_words_trial.rep == 1)].loc[:,'TWO':'TWR'], df_all_words_trial[(df_all_words_trial.rep == 4)].loc[:,'TWO':'TWR']], axis=0)
rep_labels = list(pd.concat([df_all_words_trial[(df_all_words_trial.rep == 1)].rep, df_all_words_trial[(df_all_words_trial.rep == 4)].rep]))

pca = PCA(n_components=21)
pca_result = pca.fit_transform(both_reps)

tsne = TSNE(perplexity=15)
X_embedded = tsne.fit_transform(pca_result)
n = int(len(X_embedded[:,0])/2)

sns.set_style('white')

plt.figure(figsize=(8,8))
# palette = list(np.array(sns.color_palette("bright", 2)))
palette = ['#99CCCC']

for i, x in enumerate(X_embedded[:n,0]):
    if not i in ps:
        plt.arrow(x, 
                  X_embedded[i, 1],
                  X_embedded[i+n, 0]-x, 
                  X_embedded[i+n, 1]-X_embedded[i, 1], 
                  shape='full', 
                  lw=0.2, 
                  length_includes_head=True, 
                  overhang=0,
                  color='#21606B', 
                  head_width=0.7,
                  head_length=0.8,
                  alpha=0.8)

for i, x in enumerate(X_embedded[:n,0]):
    if i in ps:
        plt.arrow(x, 
                  X_embedded[i, 1],
                  X_embedded[i+n, 0]-x, 
                  X_embedded[i+n, 1]-X_embedded[i, 1], 
                  shape='full', 
                  lw=2, 
                  length_includes_head=True, 
                  overhang=0,
                  color='#FF0000', 
                  head_width=0.6,
                  head_length=0.7,
                  alpha=0.8)



sns.scatterplot(x = X_embedded[:n,0], 
                y = X_embedded[:n,1], 
                hue=rep_labels[:n], 
                legend=False, 
                palette=palette, 
                alpha=0.7, 
                s=50, 
                linewidth=0.5)

# plt.legend()
# plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0)
plt.tick_params(
    axis='both',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    left=False,         # ticks along the top edge are off
    labelbottom=False,
    labelleft=False) # labels along the bottom edge are off

# plt.savefig('../results/plots/rep1_clusters.pdf')

plt.savefig('../../results/plots/rep4_clusters.pdf')