In [None]:
import os

In [None]:
from utils import add_personAB, llama_generate, chat_generate, gpt4_generate, call_api
import json
import pickle

In [None]:
def is_complete_demographic(dialog, pool):
    prompts = []
    for factor in ['age', 'country of residence', 'occupation', 'level of education']:
        prompt = 'Given this conversation:\n\n'
        prompt += add_personAB(dialog, orders = ['Person B: ', 'Person A: '])
        prompt += "\n\nCan you check if person B has revealed his/her " + factor + "? Please give a Yes/No answer without any explanation."
        prompts.append(prompt)
    
    res = pool.map(llama_generate, prompts)
    if any(c for c in res if 'no' in c.lower()):
        return False
    return True

In [None]:
from multiprocessing import Pool
pool = Pool(4)

#### CREATE INIT DIALOGS

In [None]:
from response_generator import ResponseGenerator
model = ResponseGenerator(path = 'mistralai/Mistral-7B-Instruct-v0.2', device = 'cuda:0')
# model = ResponseGenerator(path = 'checkpoints/mistral_peft/check_1000/', device = 'cuda:0')
# model = ResponseGenerator(path = 'checkpoints/mistral_otm/check_1000/', device = 'cuda:0')
# model = ResponseGenerator(path = 'checkpoints/mistral_oto/check_4000/', device = 'cuda:0')

#METHOD = ['peft','oto','otm','base']
METHOD = 'base'

In [None]:
def generate_response(model, method, context):
    if method == 'base':
        res = model.generate([context], mode = 'sampling', temperature = 1.0, options = [None], tuning = False)[0][0]
    elif method == 'oto':
        res = model.generate([context], mode = 'sampling', temperature = 1.0, options = [None], tuning = True)[0][0]
    elif method == 'otm':
        res = model.generate([context], mode = 'sampling', temperature = 1.0, options = [None], tuning = True)[0][0]
    elif method == 'peft':
        res = model.generate([context], mode = 'sampling', temperature = 1.0, options = [randint(1,9)], tuning = True)[0][0]
    return res

In [None]:
context = '\n'.join(["Hi there. How are you?", "I'm good. Where country are you from originally?",])
generate_response(model, METHOD, context)

In [None]:
from random import randint
prompt_asker = '''You are an expert conversationalist acting as Person A. Your goal is to guide a conversation to gather Person B's demographic details: country of residence, age, occupation, level of education. Ensure the transitions between topics are smooth and keep each of your responses to no more than two sentences.
Conversation:
'''
prompt_answer  = '''Imagine you are person B and act as if you were a real individual who willing to disclose everything. Please compose the a short response for person B in no more than two sentences.'''

def generate_conversation(model, seed_dialog):
    dialog = seed_dialog.copy()
    while len(dialog) < 30:

        #persona seeker
        prompt = prompt_asker + add_personAB(dialog, orders = ['Person B: ', 'Person A: ']) + '\n'            
        resp = chat_generate([prompt, 'Person A:'], temp = 0.5)
        resp = resp.split('\n')[0].replace('Person A:','').strip()
        dialog.append(resp)
        
        #persona revealer
        res = generate_response(model, METHOD, '\n'.join(dialog))
        dialog.append(res)
        
        if len(dialog) in [20, 26]:
            if is_complete_demographic(dialog, pool):
                break
    
    dialog.append('This is slightly off-topic, but could you please let me know your preferred gender?')
    res = generate_response(model, METHOD, '\n'.join(dialog))
    dialog.append(res)
    return dialog

import re
def remove_punct(input_string, hygen = False, sub_char = ''):
    if hygen:
        punctuation_pattern = r'[^\w\s]'
    else:
        punctuation_pattern = r'[^\w\s-]'
    cleaned_string = re.sub(punctuation_pattern, sub_char , input_string) 
    return cleaned_string

def extract_number(input_string):
    pattern = r'\d+'
    match = re.search(pattern, input_string)
    if match:
        return int(match.group())    
    return None

import re
def remove_parentheses_content(input_string):
    result = re.sub(r'\(.*?\)', '', input_string)
    result = result.strip()
    return result

def extract_profile(profile, lower = True):    
    profile = profile.replace(':\n-',':').replace('\n-',',').replace('\n \n','\n\n')
    profile = profile[profile.index('Age:'):]
    lines = profile.split('\n')
    lines = [line.lower().strip() for line in lines if line != '']
    lines = [line.split(':') for line in lines if ':' in line]
    lines = [(line[0].replace('-','').strip(), line[1].strip()) for line in lines]
    lines = [(line[0],line[1]) if (line[1] != '' and 'none' not in line[1]) else (line[0],'none') for line in lines]
    lines = dict(lines)
    for key in lines:
        lines[key] = remove_parentheses_content(lines[key]).strip()
    return lines

def extract_persona(dialog):
    prompt = 'Given this conversation:\n\n' + add_personAB(dialog, orders = ['Person B: ', 'Person A: ']) + '\n\n'
    extract_prompt = '''Please extract/infer information about Person B from the conversation and complete the following details. For any missing information, please fill in 'None'
Age:
Gender:
Nationality:
Place of birth (country):
Ethnicity:
Highest education:
Current country of residence:
Occupation:
Occupation sector:
Job title:
'''
    prompt = prompt + extract_prompt
    resp = gpt4_generate(prompt)
    return resp
    

#### CONVERSATION GENERATION

In [None]:
#generate_conversation(model, seeds[0])

In [None]:
from tqdm import tqdm
seeds = pickle.load(open('data/seed_dialogs.pkl','rb'))[:4]

results = []
dialogs = []
for seed in tqdm(seeds):
    gen_dialog = generate_conversation(model, seed)
    dialogs.append(gen_dialog)
    persona = extract_persona(gen_dialog)
    profile = extract_profile(persona)
    results.append({'raw': gen_dialog, 'extract': persona, 'info': profile})

In [None]:
results[0]

In [None]:
# with open('profiles/sampling_base.pkl','wb') as f:
#     pickle.dump(results, f)

In [None]:
#profiles = results
profiles = pickle.load(open('../profiles/sampling_base.pkl','rb'))

### MAPPING extracted attribute values to pre-defined values

In [None]:
genders = open("profiles/genders.txt").read().splitlines()

In [None]:
import re

def get_age(age_string):
    if age_string == 'none':
        return age_string
    
    age_map = {'0': '0-10',
           '1': '10-20',
           '2': '20-30',
           '3': '30-40',
           '4': '40-50',
           '5': '50-60',
           '6': '60-70',
           '7': '70+',}
    
    age = extract_number(age_string)
    if age is None:
        
        map_age = pickle.load(open('profiles/age_map.pkl','rb'))
        if age_string in map_age:
            age = map_age[age_string]
        else:
            prompt  = 'Age: ' +  age_string + '\n\n'
            prompt += '''To which group does the above age belong? Give your answer without any explanation
0-10 years old
10-20 years old
20-30 years old
30-40 years old
40-50 years old
50-60 years old
60-70 years old
70+ years old
'''
            age = call_api([prompt], max_tokens = 20, temperature = 0.5)
            map_age[age_string] = age
            pickle.dump(map_age, open('profiles/age_map.pkl', 'wb'))
        
        age = extract_number(age)
    
    if age is None:
        return 'none'
    else:
        age = min(70, age)
        age = int(age / 10)
        age = age_map[str(age)]
        return age
    
cou_nat = open('profiles/nationality.txt').read().lower().splitlines()
nat_cou = dict([list(reversed(line.split(' - '))) for line in cou_nat])
cou_nat = dict([line.split(' - ') for line in cou_nat])

def extract_location(loc):
    locations = []
    if loc is not None and loc != 'none':    
        loc_norm = remove_punct(loc.lower())
        for cand in cou_nat:
            if cand in loc_norm:
                locations.append(cand)

    if len(locations) > 0:
        return locations
    
    map_count = pickle.load(open('profiles/count_map.pkl','rb'))
    if loc in map_count:
        loc = map_count[loc]
    else:
        prompt  = 'Given this location: ' +  loc + '\n\n'
        prompt += 'Which country is this location associated with? Please provide your answer without any explanation.'

        temp = call_api([prompt], max_tokens = 30, temperature = 0.5)
        map_count[loc] = temp
        pickle.dump(map_count, open('profiles/count_map.pkl', 'wb'))
        loc = temp
    
    loc_norm = remove_punct(loc.lower())
    for cand in cou_nat:
        if cand in loc_norm:
            locations.append(cand)
    
    return locations

def get_location(por, pob):
    locations = extract_location(por)
    if len(locations) == 0:
        locations = extract_location(pob)
    
    return locations

genders = open("profiles/genders.txt").read().splitlines()
genders = sorted(genders, key = lambda x : len(x), reverse = True)

def get_gender(gender):    
    gender = gender.lower().strip().replace('-', ' ')
    
    if gender == 'none':
        return gender
    
    male = ['man', 'male', 'cisgender male', 'cisgender man']
    female = ['woman', 'female', 'cisgender female', 'cisgender woman']
    
    if gender in male:
        return 'male'
    
    if gender in female:
        return 'female'
    
    for cand in genders:
        if cand in gender:
            return cand
    
    map_gender = pickle.load(open('profiles/gender_map.pkl','rb'))
    
    if gender in map_gender:
        pred = map_gender[gender]
    else:
        prompt  = 'Given gender description: ' +  gender + '\n\n'
        prompt += 'To which of the below categories does the above gender belong?. Please provide your answer without any explanation. Return "others" if it does not fit into any specific category listed.'

        prompt += '\n'.join(genders)
        pred = call_api([prompt], max_tokens = 30, temperature = 0.5)
        
        map_gender[gender] = pred
        pickle.dump(map_gender, open('profiles/gender_map.pkl', 'wb'))
    
    pred = pred.lower().strip().replace('-', ' ')
    
    if pred in genders:
        return pred
    else:
        for cand in genders:
            if cand in pred:
                return cand
    
    return 'none'

def get_education(edu):
    
    if edu == 'none':
        return edu
    
    edu_norm = remove_punct(edu.lower().strip(), hygen = True)
    edu_norm.replace('ms', 'master')
    titles = ['bachelor','master','phd', 'associate', 'doctor', 'high school', 'diploma', 'certificat']
    for cand in titles:
        if cand in edu_norm:
            return cand
    
    map_edu = pickle.load(open('profiles/edu_map.pkl','rb'))
    
    if edu in map_edu:
        edu_norm = map_edu[edu]
    else:
        prompt  = 'Given this education background: ' +  edu + '\n\n'
        prompt += '''To which group does the above education belong? Give your answer without any explanation. Return "others" if it does not fit into any specific category listed.
Primary school
Secondary school
High school
Bachelor
Master
PhD
Doctorate Degree
Associate Degree
Diploma
Certificate programs 
Juris Doctor
Medical Doctor
No formal education
'''
        edu_norm = call_api([prompt], max_tokens = 30, temperature = 0.5)
        map_edu[edu] = edu_norm
        pickle.dump(map_edu, open('profiles/edu_map.pkl', 'wb'))
    
    edu_norm = remove_punct(edu_norm.lower().strip(), hygen = True)
    
    for cand in titles:
        if cand in edu_norm:
            return cand
    
    if edu_norm in ['no formal education', 'others']:
        return edu_norm
    
    return 'none'

def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def get_job(jobs):
    job_map = pickle.load(open('profiles/job_map.pkl','rb'))
    job_map['none'] = 'none'
    norm_jobs   = [remove_punct(job.lower().strip(), hygen = True) for job in jobs]
    search_jobs = [job for job in norm_jobs if job not in job_map]
    
    search_jobs = list(chunks(search_jobs, 30))
    for chunk in search_jobs:
        print('search job!', chunk)
        chunk_job = [str(i + 3) + ". " + chunk[i] for i in range(0, len(chunk))]
        prompt  = 'Given this list of jobs:\n'
        prompt += '1. software engineer\n2. runs a digital marketing agency\n'
        prompt += '\n'.join(chunk_job) + '\n\n'
        
        prompt += '''Please assign each job from the list above to one of the following categories:
Accountancy, banking and finance
Business, consulting and management
Charity and voluntary work
Creative arts and design
Energy and utilities
Engineering and manufacturing
Environment and agriculture
Healthcare
Hospitality and events management
Information technology
Law
Law enforcement and security
Leisure, sport and tourism
Marketing, advertising and PR
Media and internet
Property and construction
Public services and administration
Recruitment and HR
Retail
Sales
Science and pharmaceuticals
Social care
Teacher training and education
Transport and logistics
Student
Retired
Unemployed
Others

'''
        prompt += 'Kindly structure your response as follows: `<#number>. <job name> - <job category>`'
        sytems = '1. software engineer - Information technology\n2. runs a digital marketing agency - Marketing, advertising and PR'
        
        sector_map = call_api([prompt, sytems], max_tokens = 3000, temperature = 0.5)
        sector_map = [line.strip() for line in sector_map.split('\n') if len(line) > 5 and ' - ' in line]
        
        for i in range(0,len(sector_map)):
            for j in range(50,0, -1):
                sector_map[i] = sector_map[i].replace(str(j) + ". ", '')
                
        sector_map = dict([line.split(' - ') for line in sector_map if line.count(' - ') == 1])
        
        job_map = {**job_map, **sector_map}
    
    pickle.dump(job_map, open('profiles/job_map.pkl', 'wb'))
    
    results = []
    for job in norm_jobs:
        if job in job_map:
            results.append(job_map[job])
        else:
            results.append('none')
    
    return results


In [None]:
import pickle
req_att = ['age','gender','current country of residence','place of birth (country)','highest education','occupation']

for i in range(0,len(profiles)):
    if any(att for att in req_att if att not in profiles[i]['info']):
        profiles[i] = None
        continue

profiles = [p for p in profiles if p != None]

valid_sectors = [line.lower() for line in open('profiles/sectors.txt').read().splitlines()]
valid_genders = open("profiles/genders.txt").read().splitlines() + ['others']

from collections import Counter

jobs = get_job([p['info']['occupation'] for p in profiles if p['info']['occupation'] != 'none'])
jobs = [job for job in jobs if job != 'none']
jobs = [job.lower() if job.lower() in valid_sectors else 'others' for job in jobs]

gens = [get_gender(line) for line in [p['info']['gender'] for p in profiles if p['info']['gender'] != 'none']]
gens = [gen for gen in gens if gen != 'none']
gens = [gen.lower() if gen.lower() in valid_genders else 'others' for gen in gens]

ages = [get_age(line) for line in [p['info']['age'] for p in profiles if p['info']['age'] != 'none']]
ages = [age.lower() for age in ages if age != 'none']

nats = [(p['info']['current country of residence'], p['info']['place of birth (country)']) for p in profiles]
nats = [get_location(line[0], line[1]) for line in nats]
nats = [j for i in nats for j in i]
nats = [nat.lower() for nat in nats if nat != 'none']

edus = [get_education(line) for line in [p['info']['highest education'] for p in profiles if p['info']['highest education'] != 'none']]
edus = [edu.lower() for edu in edus if edu != 'none']

## Calculating entropy

In [None]:
stats = {'Age group': ages, 'Gender': gens, 'Location': nats, 'Highest education': edus, 'Occupation sector': jobs}

In [None]:
import numpy as np
import math

def shannon(probabilities):
    shannon_entropy = -np.sum(probabilities * np.log2(probabilities))
    return round(shannon_entropy,2)

shannon_scores = []
for key in stats:
    values = np.array(list(Counter(stats[key]).values()))
    probs = np.sort(values / sum(values))[::-1]
    print(key, shannon(probs))

## Visualizing

In [None]:
map_occ = {'accountancy, banking and finance' : 'accountancy/finance',
'business, consulting and management': 'business/management',
'charity and voluntary work': 'charity/voluntary',
'creative arts and design': 'arts/design',
'energy and utilities': 'energy/utilities',
'engineering and manufacturing': 'engineering/manufacturing',
'environment and agriculture': 'environment/agriculture',
'healthcare': 'healthcare',
'hospitality and events management': 'hospitality/events manage',
'information technology': 'information technology',
'law': 'law',
'law enforcement and security': 'law enforcement/security',
'leisure, sport and tourism': 'leisure/sport/tourism',
'marketing, advertising and pr': 'marketing/advert/pr',
'media and internet': 'media/internet',
'property and construction': 'property/construction',
'public services and administration': 'public services/admin',
'recruitment and hr': 'recruitment/hr',
'retail': 'retail',
'sales': 'sales',
'science and pharmaceuticals': 'science/pharmaceuticals',
'social care': 'social care',
'teacher training and education': 'education',
'transport and logistics': 'transport/logistics',
'student': 'student',
'unemployed': 'unemployed',
'retired': 'retired',
'others': 'others',
'': '',
}
1

In [None]:
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

def my_autopct(pct):
    return '{:.1f}%'.format(pct) if pct >= 2 else ''
    
# Creating subplots
fig, axs = plt.subplots(2, 5, figsize=(33, 10))

z = 0
for i in range(0,1):
    j = 0
    for key in stats:
        items = list(Counter(stats[key]).items())
        items = sorted(items, key = lambda x : x[1], reverse = True)
        names, probs = zip(*items)
        names = list(names)
        probs = list(probs)
        sum_p = sum(probs)
        probs = [float(v) / sum_p for v in probs]
        
        if key == 'Occupation sector':
            names = [map_occ[v] for v in names]
        
        #only top-10 with percangtage more than x will be show with reduced text
        for k in range(10, len(names)):
            names[k] = ''

        V = 4
        
        for k in range(V, min(len(names), 10)):
            if probs[k] < 0.05:
                names[k] = ''
        
        axs[z][j].pie(probs, labels=names, autopct=my_autopct, startangle=90, textprops={'fontsize': 18})
        if i == 0:
            axs[z][j].set_title(key, fontsize=18, fontweight='bold')        
        j += 1
    z += 1

plt.tight_layout()

# Display the pie charts
plt.show()