# calculate phrase contribution to image selection using fisher exact text

input: 
- rank data: pairs with selection labels

    I load them from ranking_data/*.json
    
- prompt list: str, (positive or negative)

    I load them from *_data.msgpack

output: 

phrase contribution list csv with 5 fiedls:
- phrase: str
- num_select: frequency of this phrase contained in prompts of selected images
- num_unselect: frequency of this phrase contained in prompts of unselected images
- odds_ratio: phrase contribution to selection
- p_value: confidence, lower is better

Approach for Statistics and Analysis:

Data Preparation: 

- For each prompt, we have two key pieces of information – the prompt itself and the label indicating whether it was selected or not. Divide the data into two groups based on selected and unselected labels.

- Calculate Frequencies: For each phrase, calculate its frequency in the selected prompts and in the unselected prompts.

- Fisher's Exact Test: Use Fisher's exact test to assess whether there is a significant difference between the observed events in the two groups. This statistical test is suitable for analyzing 2x2 contingency table data.

Odds Ratio:

In statistics, the odds ratio compares the odds of an event occurring in two groups. In this context, the odds ratio represents the ratio of the probability of an event (e.g., the occurrence of a pharse) in the selected prompts to its probability in the unselected prompts.

An odds ratio greater than 1 indicates that the event is more likely to occur in the selected prompts. If it's less than 1, the event is more likely in the unselected prompts.

P-value:

The P-value is used to determine whether the observed differences are likely due to random factors. Under the null hypothesis, a low P-value (typically less than 0.05) provides evidence to reject the null hypothesis, suggesting a significant difference between the two groups.

A small P-value indicates statistically significant differences, providing sufficient evidence to reject the null hypothesis and suggesting a significant difference between the two groups.

In [12]:
import os
import sys
import json
import glob

from PIL import Image

import re

from scipy.stats import fisher_exact

import numpy as np
import pandas as pd

from matplotlib import pyplot

import msgpack

from tqdm.auto import tqdm

In [17]:
RANK_DIR = '../kcg-ml-image-pipeline/output/environmental/ranking_v1/'
DATA_DIR = '../kcg-ml-image-pipeline/output/dataset/data/'

# load rank data

In [4]:
paths = sorted(glob.glob(os.path.join(RANK_DIR, 'ranking_data', '*.json')))

rank_pairs = list()
for path in tqdm(paths):
    js = json.load(open(path))
    
    file_path_1 = os.path.splitext(js['image_1_metadata']['file_path'])[0].replace('datasets/', '')
    file_path_2 = os.path.splitext(js['image_2_metadata']['file_path'])[0].replace('datasets/', '')
    
    rank_pairs.append((file_path_1, file_path_2, js['selected_image_index']))

  0%|          | 0/49097 [00:00<?, ?it/s]

In [5]:
rank_pairs = pd.DataFrame(rank_pairs, columns=['image_1', 'image_2', 'selected_image_index'])
ordered_pairs = [((image_1, image_2) if selected_image_index == 0 else (image_2, image_1)) for image_1, image_2, selected_image_index in rank_pairs.itertuples(index=False, name=None)]

# stats phrase

In [13]:
def format_prompt(prompt):
    
    prompt = prompt.strip()
    
    while re.search(r'(\([\s,\.:;\|]*\))|(\<[\s,\.:;\|]*\>)|(\[[\s,\.:;\|]*\])|(\{[\s,\.:;\|]*\})', prompt):
        prompt = re.sub(r'(\([\s,\.:;\|]*\))|(\<[\s,\.:;\|]*\>)|(\[[\s,\.:;\|]*\])|(\{[\s,\.:;\|]*\})', '', prompt)
    
    prompt = re.sub(r'([\[\(\{\<])\s', r'\1', prompt)
    prompt = re.sub(r'\s([\]\)\}\>])', r'\1', prompt)
    prompt = re.sub(r'\s+', ' ', prompt)
    prompt = re.sub(r'(\s?[,;])+', r',', prompt)
    
    prompt = re.sub(r'^[\.,;\s]+', '', prompt)
    prompt = re.sub(r'[\.,;\s]+$', '', prompt)
    
    return prompt

def remove_weight(prompt):
    prompt = re.sub(':[-\s0-9,\.]*', ', ', prompt)
    prompt = re.sub(r'[\(\[\{\<\>\}\]\)]+', '', prompt)
    return prompt

def prompt_to_tags(prompt):
    
    prompt = format_prompt(prompt)
    prompt = remove_weight(prompt)
    prompt = format_prompt(prompt)
    
    for tag in prompt.split(','):
        tag = tag.strip()
        if len(tag) == 0:
            continue
        yield tag

In [14]:
def load_prompt_from_data(path):
    
    data = msgpack.load(open(path, 'rb'))
    
    return data['positive_prompt'], data['negative_prompt']

In [18]:
positive_stats = dict()
negative_stats = dict()

for image_1, image_2 in tqdm(ordered_pairs, leave=False):
    
    # selected 
    
    positive_prompt, negative_prompt = load_prompt_from_data(os.path.join(DATA_DIR, f'{image_1}_data.msgpack'))
    
    for tag in prompt_to_tags(positive_prompt):
        if tag not in positive_stats:
            positive_stats[tag] = [0, 0]
        positive_stats[tag][0] += 1
    
    for tag in prompt_to_tags(negative_prompt):
        if tag not in negative_stats:
            negative_stats[tag] = [0, 0]
        negative_stats[tag][0] += 1
    
    # unselected 
    
    positive_prompt, negative_prompt = load_prompt_from_data(os.path.join(DATA_DIR, f'{image_2}_data.msgpack'))
    
    for tag in prompt_to_tags(positive_prompt):
        if tag not in positive_stats:
            positive_stats[tag] = [0, 0]
        positive_stats[tag][1] += 1
    
    for tag in prompt_to_tags(negative_prompt):
        if tag not in negative_stats:
            negative_stats[tag] = [0, 0]
        negative_stats[tag][1] += 1

  0%|          | 0/49097 [00:00<?, ?it/s]

# get result

In [20]:
def stats_to_df(stats):
    
    df = pd.DataFrame(stats.values(), columns=['num_select', 'num_unselect'])
    df['phrase'] = stats.keys()
    df.set_index('phrase', inplace=True)
    
    n_total = len(ordered_pairs)

    os, ps = list(), list()
    for num_select, num_unselect in tqdm(df.itertuples(index=False, name=None), total=df.shape[0], leave=False):
        # if num_select + num_unselect < 100:
        #     os.append(-1)
        #     ps.append(-1)
        #     continue
        o, p = fisher_exact([[num_select, num_unselect], [n_total - num_select, n_total - num_unselect]])
        os.append(o)
        ps.append(p)    
        
    df['odds_ratio'] = os
    df['p_value'] = ps
    
    return df

In [21]:
positive_df = stats_to_df(positive_stats)
negative_df = stats_to_df(negative_stats)

  0%|          | 0/94073 [00:00<?, ?it/s]

  0%|          | 0/11993 [00:00<?, ?it/s]

In [22]:
positive_df.to_csv('data/positive_freqs.csv')
negative_df.to_csv('data/negative_freqs.csv')

In [23]:
positive_df.query('num_select > 50 and num_unselect > 50 and p_value < 5e-2').sort_values('odds_ratio').tail(50)

Unnamed: 0_level_0,num_select,num_unselect,odds_ratio,p_value
phrase,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
leotard,381,248,1.540485,1.164458e-07
palm trees,114,74,1.541799,0.004287916
official alternate costume,114,74,1.541799,0.004287916
stylized,94,61,1.542021,0.009874552
vines,112,72,1.556826,0.003884107
tile based,13770,9805,1.56201,1.8106649999999997e-193
color grading,217,139,1.563642,4.067724e-05
black armor,86,55,1.564625,0.01119885
cinematic shot,253,162,1.564638,8.83122e-06
sunbeam,94,60,1.567754,0.007581123


In [24]:
negative_df.query('num_select > 50 and num_unselect > 50 and p_value < 5e-2').sort_values('odds_ratio').tail(50)

Unnamed: 0_level_0,num_select,num_unselect,odds_ratio,p_value
phrase,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
wa_lolita,146,100,1.461372,0.003979249
missing_face,267,183,1.461526,8.41964e-05
milf,76,52,1.462254,0.04150568
big flower,95,65,1.462433,0.02146769
moon,98,67,1.463612,0.01914112
cloned face,193,132,1.463945,0.0008305208
multiple objects,126,86,1.466313,0.007188122
fused anatomy,106,72,1.473244,0.01307601
age spot,96,65,1.477857,0.0176944
poorly drawn,416,282,1.479238,4.043683e-07
