In [1]:
import os
os.environ['HF_HOME'] = '/mnt/sagemaker-nvme/cache'
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
import torch
import pandas as pd
from sklearn.cluster import KMeans
from utils import last_token_pool, get_detailed_instruct
from transformers import AutoTokenizer, AutoModel


def merge_personas(exp_name):

    file_path = f'{exp_name}.csv'
    data = pd.read_csv(file_path)
    
    def get_formatted_persona_dim(row):
        task = 'Given a persona dimension description, retrieve semantically similar persona dimension descriptions.'
        persona = f"{row['name']}: {row['description']}. Candidate values: {row['candidate_values']}"
        return get_detailed_instruct(task, persona)
    
    data['formatted'] = data.apply(get_formatted_persona_dim, axis=1)
    
    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained('Salesforce/SFR-Embedding-2_R')
    model = AutoModel.from_pretrained('Salesforce/SFR-Embedding-2_R', device_map='auto')
    
    # Get the embeddings
    max_length = 4096
    input_texts = data['formatted'].to_list()
    batch_dict = tokenizer(input_texts, max_length=max_length, padding=True, truncation=True, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model(**batch_dict)
        embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
        # print(embeddings)

    # Clustering
    num_clusters = 20
    clustering_model = KMeans(n_clusters=num_clusters)
    clustering_model.fit(embeddings)
    data['cluster'] = clustering_model.labels_
    data = data.sort_values(by='cluster')
    data.to_csv(f'{exp_name}_clustered.csv')

    for idx in range(num_clusters):
        print(idx)
        for _, row in enumerate(data[data['cluster'] == idx]['formatted']):
            print(row.split('\n')[1])
        print('\n\n')


In [2]:
merge_personas('low')

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

0
Household Composition: having children in the home and the associated concerns about gun safety. Candidate values: ['has children in the home', 'no children in the home']
Household Composition: whether the person has children living in their home, which may impact their views on gun safety practices. Candidate values: ['has children', 'no children']
Parental Status: whether the person has children living in their home or not. Candidate values: ['has children', 'no children']
Household Composition: whether the respondent has children living in their home. Candidate values: ['has children', 'no children']
Parental Status: whether the person has children under the age of 18. Candidate values: ['has children under 18', 'does not have children under 18']
Parental Status: having children in the home and the associated concerns for their safety. Candidate values: ['has children', 'does not have children']
Parental Status: whether the individual has children in their home. Candidate values: 

In [3]:
merge_personas('low_simple')

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Some parameters are on the meta device device because they were offloaded to the cpu.


0
Demographic Factors: The participant's demographic characteristics, such as age, gender, and socioeconomic status, which may influence their perspective on gun ownership and societal perceptions.. Candidate values: ['Age', 'Gender', 'Socioeconomic Status']
Demographic Factors: The survey participant's demographic characteristics, such as age, gender, and socioeconomic status, that may be associated with their reason for gun ownership.. Candidate values: ['Age', 'Gender', 'Income level', 'Education level']
Demographic Factors: The participant's demographic characteristics, such as age, gender, and socioeconomic status, which may influence their perception of crime and safety.. Candidate values: ['Age', 'Gender', 'Socioeconomic Status']
Demographic Factors: The individual's personal characteristics, such as age, gender, and socioeconomic status, which may influence their attitudes and behaviors related to firearm carrying.. Candidate values: ['Age', 'Gender', 'Income Level', 'Education

In [4]:


# Summarization
summarizer = pipeline('summarization')
def summarize_text(texts):
    combined_text = " ".join(texts)
    summary = summarizer(combined_text, max_length=50, min_length=25, do_sample=False)
    return summary[0]['summary_text']

data['summary'] = data.groupby('cluster')['candidate_values'].transform(lambda x: summarize_text(x.tolist()))

# Save the result
data.to_csv('/mnt/data/merged_user_persona_dimensions.csv', index=False)


NameError: name 'pipeline' is not defined