# Notebook for animal fluency feature extraction 

In [2]:
from gensim.models import KeyedVectors, Word2Vec
import scipy.stats as stats
import pandas as pd
import numpy as np
import ast

In [3]:
# read in the data 
# a file called "animal_fluency.xlsx"

# COLUMNS: 
# - FILE (a unique designator), 
# - ANIMALS (a list of human transcribed animals. repeats are designated with "(...)", non-animals were not transcribed)
# - GROUP (1 = control, 2 = aMCI, 3 = AD)

animal_fluency = pd.read_excel('animal_fluency.xlsx')


In [12]:
# compute the number of animals including repetitions, 
# number of animals without repetitions, 
# and the number of repetitions

animal_fluency['participant_num_animals_with_repetitions'] = animal_fluency.apply(lambda row: len(row['ANIMALS'].split()), axis=1)
animal_fluency['participant_num_animals_without_repetitions'] = animal_fluency.apply(lambda row: len(list(filter(lambda x: x[0] != '(', row['ANIMALS'].split()))), axis=1)
animal_fluency['participant_num_repetitions'] = animal_fluency.apply(lambda row: int(row['ANIMALS'].count('(')), axis=1)
                                                                           

In [None]:
# code for computing categories of animals based on cosine distance between consecutive animals

# for example, we can make a column for w2v cosines:
# first load the pretrained word2vec semantic model, from here: https://code.google.com/archive/p/word2vec/
w2vmodel = KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin', binary = True)
cosines = []
for animal_list in animal_fluency['ANIMALS']:
    this_cosine_list = []
    for i in range(len(animal_list) - 1):
        this_cosine_list.append(w2vmodel.similarity(animal_list[i], animal_list[i+1]))
    cosines.append(this_cosine_list)
animalfluency['w2v_cosines'] = cosines


# thresholded function splits snimals into appropriate categories given a threshold
# if the cosine falls below the threshold, a new category is started
# requires a column called cosines in the dataframe which is a list of cosine values between consecutive animals
# can be computed with any embeddings

def thresholded(threshold, c):
    threshold_groups_X = []
    for a, c in zip(animal_fluency.ANIMALS, animal_fluency[c]):
        
        cosines = ast.literal_eval(c)
        cosines.insert(0, 0)
        
        this_group = []
        this_participant = []
        
        for an, co in zip(a.split(' '), cosines):
            # can edit to not take into account repeated animals as well 
            # by keeping track of whether the "an" variable has been seen previously 
            
            if co == 0:
                this_group.append(an)
            elif co <= threshold:
                this_participant.append(this_group)
                this_group = [an]
            else:
                this_group.append(an)

        this_participant.append(this_group)
        threshold_groups_X.append(this_participant)
        
    return threshold_groups_X


# using the thresholded function to create cosine features
# with w2v embeddings 
participant_num_categories = []
participant_avg_items_per_category = []
participant_std_items_per_category = []
participant_min_items_per_category = []
participant_max_items_per_category = []

d_vals = {}

for k in np.arange(0.8, 0.975, .005):
    y = round(k, 3)
    v = thresholded(y, 'w2v_cosines')
    
    new_name_num = str(y) + '_participant_num_categories_w2v'
    new_name_avg = str(y) + '_avg_items_per_category_w2v'
    new_name_std = str(y) + '_std_items_per_category_w2v'
    new_name_min = str(y) + '_min_items_per_category_w2v'
    new_name_max = str(y) + '_max_items_per_category_w2v'
    new_name_catdivwords = str(y) + '_participant_categories_div_totalwords_w2v'
    
    d_vals[new_name_num] = []
    d_vals[new_name_avg] = []
    d_vals[new_name_std] = []
    d_vals[new_name_min] = []
    d_vals[new_name_max] = []
    d_vals[new_name_catdivwords] = []

    flat_v = [item for sublist in v for item in sublist]

    for this_participant in v:
        d_vals[new_name_num].append(len(this_participant))
        catstats = []
        for cat in this_participant:
            catstats.append(len(cat))
        d_vals[new_name_avg].append(np.average(catstats))
        d_vals[new_name_std].append(np.std(catstats))
        d_vals[new_name_min].append(np.amin(catstats))
        d_vals[new_name_max].append(np.amax(catstats))
        d_vals[new_name_catdivwords].append(float(len(this_participant))/float(len(flat_v)))

        
# adding thresholding features to the dataframe
for k, v in d_vals.items():
    animal_fluency[k]=v 
    

In [None]:
# for reference, also BERT cosine lists:

import torch
from transformers import BertTokenizer, BertModel

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

def get_anim_cosines(a_list):

  data_tensors = []

  for a in a_list:

    # tokenize the text input  
    # A [CLS] token is inserted at the beginning of the first sentence and a [SEP] token is inserted at the end of each sentence.
    #some words are split so the len of tokenized_text will be AT LEAST len(line) + 2
    tokenized_text = tokenizer.encode(a)
    
    # convert indexed tokens in a PyTorch tensor
    input_ids = torch.tensor(tokenized_text).unsqueeze(0)
    
    # run the input tensor through the BertModel
    # see text in above cell for what is contained in outputs variable
    outputs = model(input_ids)

    # get the last_hidden_state
    last_hidden_state = outputs[0]

    # last hidden state is dimension (batch_size, sequence_length, hidden_size)
    # we have one batch so grab this single batch - this_batch is a tensor for each token in tokenized_text
    this_batch = last_hidden_state[0]
    
    # now get the 768 dimension vector for the CLS token (the first in the list) 
    cls_vector = this_batch[0]

    data_tensors.append(cls_vector)
  
  cosine_sims = []
  i = 0
  while i < len(data_tensors)-1:
    
    # get two animal vectors
    anim1 = data_tensors[i]
    anim2 = data_tensors[i+1]

    cosine_sim = 1 - spatial.distance.cosine(anim1.detach().numpy(), anim2.detach().numpy())
    
    cosine_sims.append(cosine_sim)

    i += 1

  return cosine_sims

all_BERT_cosines = []
for a_list in animal_fluency['ANIMALS']):
  anims = a_list.split()
  cosines = get_anim_cosines(anims_edited)
  all_BERT_cosines.append(cosines)
  
animal_fluency['BERT_cosines'] = all_BERT_cosines


In [None]:
# computing the (updated) troyer categories

troyer_categories = {'Africa': ['flamingo','water-buffalo','wild-dog','aardvark', 'antelope', 'kudu','buffalo', 'camel', 'chameleon', 'cheetah','chimpanzee', 'cobra', 'eland', 'elephant', 'gazelle', 'giraffe', 'gnu', 'gorilla','gorillas','hippopotamus', 'hyena', 'impala', 'jackal', 'lemur', 'leopard', 'lion', 'manatee', 'mongoose', 'monkey', 'ostrich', 'panther', 'rhinoceros', 'rhino','tiger', 'tigers','wildebeest', 'warthog', 'zebra'],
                     'Australia': ['koala-bear', 'koala','emu', 'kangaroo', 'kiwi', 'opossum', 'possum','platypus', 'Tasmaniandevil', 'wallaby', 'wombat'],
                     'Arctic/Far North': ['kodiak-bear','auk', 'caribou', 'musk ox', 'penguin', 'polar bear','reindeer', 'seal', 'wolverine'],
                     'Farm': ['chicken','pony', 'ponies','lamb','chicken', 'cow', 'bull','donkey', 'burrow','ass','ferret', 'goat', 'goats','horse', 'horses','steer','mule','mules', 'pig','boar','hog','sheep', 'ram','turkey'],
                     'North America': ['kodiak-bear','panda','bears','badger', 'bear', 'grizzly','beaver', 'bobcat', 'caribou', 'chipmunk','cougar','catamount', 'deer', 'elk', 'fox', 'foxes','moose', 'mountain-lion', 'puma', 'rabbit', 'raccoon','skunk', 'squirrel', 'chipmunk','wolf', 'wolves', 'vole'],
                     'Water': ['catfish','zebra-fish', 'puffer-fish', 'seahorse','tadpole','stingray','goldfish','eel','crabs','crab','alligator', 'auk', 'beaver', 'crocodile', 'dolphin', 'porpoise', 'fish', 'frog','bullfrog','lobster', 'manatee', 'muskrat', 'newt', 'octopus', 'otter', 'sea-otter','oyster', 'penguin','platypus', 'salamander', 'sealion','sea-lion', 'seal', 'shark', 'toad', 'turtle', 'whale'],
                     'Beasts of burden': ['camel', 'donkey', 'burrow','ass','horses','horse','steer','pony','ponies', 'llama', 'ox', 'oxen','alpaca'],
                     'Fur': ['beaver', 'chinchilla', 'fox', 'foxes','mink', 'rabbit', 'alpaca'],
                     'Pets': ['fish','bird','ferret','snake','lizard','tarantula','puppies','puppy','kitten','kittens','bulldog','budgie', 'bunny_rabbit', 'canary', 'cat','cats', 'dog','dogs','poodle','gerbil', 'goldfish','golden retriever', 'guinea-pig', 'hamster', 'parrot', 'rabbit'],
                     'Bird': ['emu','turkeys','chickadee','albatross','crow','snow-bird','song-bird','cockatoo','black-bird','bat','blackbird','fowl','snowbird','flamingo','ostriches','yellow-finch','peacock','wood-pecker','vulture','pigeon','goose','geese','hawk','sparrow','rooster','duck','swan','mandarin-duck','cardinal','blue-bird','hen','pheasant','ibis','eagles','dove','doves','birds','falcon','owl','owls','bird','budgie', 'condor', 'eagle', 'finch', 'kiwi', 'macaw', 'parrot', 'parakeet','pelican', 'penguin', 'robin', 'toucan', 'woodpecker'],
                     'Bovine': ['bison', 'buffalo', 'cow', 'bull','musk ox', 'yak', 'water-buffalo'],
                     'Canine': ['coyote', 'dog','dogs', 'fox', 'foxes','hyena', 'jackal', 'wolf'],
                     'Deer': ['gazelles','antelope', 'kudu','caribou', 'eland', 'elk', 'gazelle', 'gnu', 'impala', 'moose','reindeer', 'wildebeest', 'deer'],
                     'Feline': ['bobcat', 'cat', 'cats','cheetah', 'cougar', 'catamount','jaguar', 'leopard', 'lion', 'lions','lynx','mountain-lion', 'ocelot', 'panther', 'puma', 'tiger', 'tigers'],
                     'Fish': ['starfish','catfish','zebra-fish', 'puffer-fish','tadpole','bass', 'guppy', 'salmon', 'trout', 'goldfish'],
                     'Insect': ['tarantula','fly','flies','butterflies','bumblebee','butterfly','spider','insect','cricket','wasp','bee','ant', 'beetle', 'cockroach', 'flea', 'fly', 'praying mantis'],
                     'Insectivores': ['aardvark', 'anteater', 'armadillo','hedgehog', 'mole', 'shrew'],
                     'Primate': ['monkies','ape', 'baboon', 'chimpanzee', 'gibbon', 'gorilla', 'gorillas','human','people','lemur','sloth', 'marmoset', 'monkey','monkeys','chimpanzees', 'orangutan', 'shrew'],
                     'Rabbit': ['coney', 'hare', 'pika', 'rabbit'],
                     'Reptile/Amphibian': ['snail','worm','reptile','komodo-dragon','alligator', 'chameleon', 'crocodile', 'frog', 'bullfrog','gecko','iguana', 'lizard', 'newt', 'salamander', 'copperhead','black-snake', 'snake', 'rattlesnake', 'rattle-snake','garden-snake','anaconda','coral-snake','cobra','boa-constrictor','toad', 'tortoise', 'turtle'],
                     'Rodent': ['possum','opossum','vole','beaver', 'chinchilla', 'chipmunk', 'gerbil', 'gopher', 'groundhog', 'guinea-pig', 'hamster', 'hedgehog', 'marmot', 'mole', 'mouse', 'mice','muskrat', 'porcupine', 'rat', 'squirrel', 'chipmunk','woodchuck'],
                     'Weasel': ['weasel','badger', 'ferret', 'marten', 'mink', 'mongoose', 'otter', 'sea-otter', 'polecat','skunk'],
                     'Dinosaur': ['dinosaur']}

category_lists = []

for al in animal_fluency['ANIMALS']:
    catlist = []
    for animal in al:
        # can remove this if statements if "forbidden" words should be counted
        if '(' not in animal:
            
            # remove the comma from the animal
            if ',' in animal:
                animal = animal[:-1]
            
            categories = []
            for k, v in troyer_categories.items():
                if animal.strip() in v:
                    categories.append(k)
                if animal+'s' in v and k not in categories:
                    categories.append(k)
                if animal[:-1] in v and k not in categories:
                    categories.append(k)
            
            catlist.append(categories)
    
    category_lists.append(catlist)

# NOTE!!! The final troyer categories given by hand here !!!
'''
THIS IS WHAT IS SHOWN ON EACH ITERATION

dog ['Pets', 'Canine']
cat ['Pets', 'Feline']
fish ['Water', 'Pets']
pig ['Farm']
bunny_rabbit ['Rabbit', 'Pets']
guinea_pig ['Pets']
hamster ['Pets', 'Rodent']
giraffe ['Africa']
hippo ['Africa']
monkey ['Africa', 'Primate']
gorilla ['Africa', 'Primate']
rhinoceros ['Africa']
zebra ['Africa']
horse ['Farm', 'Beasts of burden']
cow ['Farm', 'Bovine']
chicken ['Farm']
alligator ['Water', 'Reptile/Amphibian']
crocodile ['Water', 'Reptile/Amphibian']      
cheetah ['Africa', 'Feline']
birds ['Pets', 'Bird']
squirrels ['North America', 'Rodent']
snakes ['Pets', 'Reptile/Amphibian']
rodents ['Rodent']
skunks ['North America', 'Weasel']
boar ['Farm']

THIS IS WHAT IS TYPED INTO THE INPUT:
[dog cat fish] [pig] [bunny_rabbit guinea_pig hamster] [giraffe hippo monkey gorilla rhinoceros zebra] [horse cow chicken] [alligator crocodile] [cheetah] [birds] [squirrels] [snakes] [rodents skunks] [boar]
'''

for al, cl in zip(animal_fluency['ANIMALS'], category_lists):
    for a, c in zip(al.split(), cl.split()):
        print(a, c)
    categorized = input()
    cat_list.append(categorized)

participant_num_categories = []
participant_avg_items_per_category = []
participant_std_items_per_category = []
participant_min_items_per_category = []
participant_max_items_per_category = []

for cat in cat_list:
    
    try:
        participant_num_categories.append(cat.count('['))
    except:
        participant_num_categories.append(0)
        participant_avg_items_per_category.append(0)
        participant_std_items_per_category.append(0)
        participant_min_items_per_category.append(0)
        participant_max_items_per_category.append(0)
        continue
        
    categories = cat.split('] [')
    cleaned_categories = [wordlist.replace('[', '').replace(']', '').split() for wordlist in categories]
    lengths = []
    for c in cleaned_categories:
        lengths.append(len(c))
    
    participant_avg_items_per_category.append(np.array(lengths).mean())
    participant_std_items_per_category.append(np.array(lengths).std())
    participant_min_items_per_category.append(min(lengths))
    participant_max_items_per_category.append(max(lengths))


In [None]:
# computing categories / total animals
participant_categories_div_totalwords = []
for count, cat in zip(participant_num_animals_without_repetitions, participant_num_categories):
    participant_categories_div_totalwords.append(cat/count)
    
    
# calculating pivot words
participant_pivot_words = []
for count, cat in zip(participant_num_animals_without_repetitions, animal_fluency2.CATEGORIES):
    try:
        categories = cat.split('] [')
    except:
        participant_pivot_words.append(0)
        continue
    cleaned_categories = [wordlist.replace('[', '').replace(']', '').split() for wordlist in categories]
    lengths = 0
    for c in cleaned_categories:
        lengths += len(c)
    participant_pivot_words.append(lengths-count)

    
# adding troyer features to dataframe
animal_fluency['participant_num_categories_troyer'] = participant_num_categories
animal_fluency['participant_avg_items_per_category_troyer'] = participant_avg_items_per_category
animal_fluency['participant_std_items_per_category_troyer'] = participant_std_items_per_category
animal_fluency['participant_min_items_per_category_troyer'] = participant_min_items_per_category
animal_fluency['participant_max_items_per_category_troyer'] = participant_max_items_per_category
animal_fluency['participant_categories_div_totalwords_troyer'] = participant_categories_div_totalwords
animal_fluency['participant_pivot_words_troyer'] = participant_pivot_words


In [None]:
## LSA features (animal vector length, cosines, etc) were computed with direct access to the http://lsa.colorado.edu/ service

In [None]:
animal_fluency.to_excel('animal_fluency_features.xlsx')

# Computing the f statistics for individual features

In [None]:
for column in animal_fluency.columns:
    if column != 'FILE' and column != 'ANIMALS' and column != 'GROUP' and 'cosines' not in column:
        
        ones = animal_fluency[animal_fluency.GROUP==1]
        twos = animal_fluency[animal_fluency.GROUP==2]
        threes = animal_fluency[animal_fluency.GROUP==3]
        twosthrees = animal_fluency[(animal_fluency.GROUP==2)|(animal_fluency.GROUP==3)]
        
        f1, p1 = stats.f_oneway(ones[column].dropna(), twos[column].dropna(), threes[column].dropna())
        f2, p2 = stats.f_oneway(ones[column].dropna(), twos[column].dropna())
        f3, p3 = stats.f_oneway(ones[column].dropna(), threes[column].dropna())
        f4, p4 = stats.f_oneway(twos[column].dropna(), threes[column].dropna())
        f5, p5 = stats.f_oneway(ones[column].dropna(), twosthrees[column].dropna())

        print(column, 'overall', (f1,p1), '1vs2', (f2,p2),'1vs3', (f3,p3),'2vs3', (f4,p4), '1vs23', (f5,p5))