In [1]:
import pandas as pd
import numpy as np
from typing import List, Dict
import matplotlib.pyplot as plt
import random
import plotly.express as px
import warnings
from collections import Counter
from sklearn.feature_extraction.text import TfidfVectorizer
warnings.filterwarnings("ignore")


random.seed(42)

The following cell contains code for creating the author vector csv

In [2]:
doc_df = pd.read_csv("data/features/document_vectors.csv")
doc_values = doc_df.loc[:, ~doc_df.columns.isin(['doc_id', 'author_id'])]
author_ids = set(doc_df["author_id"])

def make_author_vector(doc_vectors:np.ndarray) -> np.ndarray:
    return np.mean(doc_vectors, axis=0)

def make_author_vector_df(doc_df:pd.DataFrame, author_ids) -> pd.DataFrame:
    """Creates author vectors by averaging each author's documents into one"""
    df_copy = doc_df.copy(deep=True).drop(columns="author_id").drop(columns="doc_id")
    
    author_ids_to_avs = {}
    for author_id in author_ids:
        doc_vectors = df_copy.loc[doc_df['author_id'] == author_id].values
        author_ids_to_avs[author_id] = make_author_vector(doc_vectors)
        
    av_df = pd.DataFrame(author_ids_to_avs).T
    av_df.columns = df_copy.columns


    
    return av_df


av = make_author_vector_df(doc_df, author_ids)




In [5]:
from scipy.stats import zscore
from components.processing import author_vectors, authors_df, docs_df


def get_threshold_zscores_idxs(zscores, threshold:float):
    """Gets indices for |zscores| that meet a threshold"""
    selected = []
    for i, zscore in enumerate(zscores):
        if abs(zscore) >= threshold:
            selected.append(i)
    return selected


def get_identifying_features(author_id:str, threshold=2.0):
    """
    Given an author, calculates their zscores for all features and selects the ones that deviate the most from the 
    mean. These features are what separate this author from the average author
    """
    zscores = zscore(author_vectors)
    author_idx = authors_df.loc[authors_df["author_id"] == author_id].index[0]
    author_zscores = zscores.iloc[author_idx]
    
    selected_zscores = get_threshold_zscores_idxs(author_zscores, threshold)
    return author_zscores.iloc[selected_zscores]


def get_author_entries(author_id:str) -> pd.DataFrame:
    return docs_df.loc[docs_df["author_id"] == author_id]

def features_to_show(author_id:str) -> List[str]:
    """Given an author id, returns n amount of this author's most identifying features"""
    features = get_identifying_features(author_id).index.to_list()
    if len(features) > 10:
        return features[:12]
    return features



        
            

author = "en_110"
get_identifying_features(author)

POS Bigram: ('ADJ', 'PUNCT')    2.338344
POS Bigram: ('VERB', 'ADJ')     2.548491
POS Bigram: ('ADJ', 'ADP')      2.399235
Letter: a                      -2.417109
Letter: n                      -2.099656
Letter: u                       2.285633
Letter: y                       2.093896
Letter: ó                       7.416198
Emoji: 😊                        2.269565
Emoji: ❤️                       2.529055
Dependency label: acomp         2.368167
Name: 33, dtype: float64

In [5]:
get_identifying_features("en_35")

POS Unigram: ADV               2.391919
POS Bigram: ('AUX', 'ADV')     2.259470
POS Bigram: ('ADV', 'ADJ')     2.515201
Function word: them            2.747071
Function word: did             2.361154
Function word: down            2.296019
Function word: under           2.342981
Function word: again           2.244927
Function word: only            2.102501
Function word: than            2.021664
Punctuation: *                 2.241301
Letter: b                      2.083508
Dependency label: advmod       2.769957
Dependency label: auxpass      2.170734
Dependency label: nsubjpass    2.328799
Morphology tag: Past           2.508990
Name: 0, dtype: float64

In [5]:
from dash import dash_table, html


df = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/solar.csv')




DataTable(data=[{'State': 'California', 'Number of Solar Plants': 289, 'Installed Capacity (MW)': 4395, 'Average MW Per Plant': 15.3, 'Generation (GWh)': 10826}, {'State': 'Arizona', 'Number of Solar Plants': 48, 'Installed Capacity (MW)': 1078, 'Average MW Per Plant': 22.5, 'Generation (GWh)': 2550}, {'State': 'Nevada', 'Number of Solar Plants': 11, 'Installed Capacity (MW)': 238, 'Average MW Per Plant': 21.6, 'Generation (GWh)': 557}, {'State': 'New Mexico', 'Number of Solar Plants': 33, 'Installed Capacity (MW)': 261, 'Average MW Per Plant': 7.9, 'Generation (GWh)': 590}, {'State': 'Colorado', 'Number of Solar Plants': 20, 'Installed Capacity (MW)': 118, 'Average MW Per Plant': 5.9, 'Generation (GWh)': 235}, {'State': 'Texas', 'Number of Solar Plants': 12, 'Installed Capacity (MW)': 187, 'Average MW Per Plant': 15.6, 'Generation (GWh)': 354}, {'State': 'North Carolina', 'Number of Solar Plants': 148, 'Installed Capacity (MW)': 669, 'Average MW Per Plant': 4.5, 'Generation (GWh)': 11

DataTable(data=[{'hello': [1, 2, 3], 'world': [4, 5, 6]}, {'hello': [1, 6, 3], 'world': [4, 5, 6]}])