In [1]:
import pandas as pd
import numpy as np
from PIL import Image
import requests
import torch
import re
import ast
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, spearmanr

In [2]:
df4 = pd.read_pickle("/home/student/HallucinationsLLM/data/team4_df.pkl")
df5 = pd.read_excel("/home/student/HallucinationsLLM/data/team5_clean_dataset.xlsx", index_col=0)

In [3]:
image_overlap = list(set(df4.index).intersection(df5['image_link']))
image_overlap

['https://cdn.pixabay.com/photo/2016/11/29/05/26/beach-1867524_1280.jpg',
 'https://cdn.pixabay.com/photo/2016/11/18/13/23/action-1834465_1280.jpg',
 'https://cdn.pixabay.com/photo/2019/11/29/08/34/space-4660847_1280.jpg']

In [4]:
def get_max_key_value(d):
    if d:
        max_key = max(d, key=d.get)
        return max_key, d[max_key]
    else:
        return None, None

def process_df4(df):
    df['hallucinations'].fillna(df['hillucination_text'], inplace=True)
    df = df.drop('hillucination_text', axis=1)
    df['image_link'] = df.index
    df = df.reset_index(drop=True)
    for i in range(1,5):
        df[[f'pred_{i}', f'pred_{i}_prob']] = df[f'pred_{i}'].apply(lambda x: pd.Series(get_max_key_value(x)))
    
    df.rename(columns={'text': 'description', 'generated_logits': 'logits'}, inplace=True)
    df['temperature'] = 0.7
    return df

In [None]:
df4 = process_df4(df4)
df = pd.concat([df4, df5]).reset_index(drop=True)

In [4]:
df = df5

In [17]:
def validate_brackets(string):
    counter = 0
    for c in string: 
        if c == "[":
            counter += 1 
        elif c == "]":
            counter -= 1
        if counter < 0 or counter > 1:
            return False 
    return True

def validate_spaces(s):
    if "[" not in s:
        return True
    matches = re.finditer(r'\[.*?\]', s)
    for match in matches:
        start, end = match.span()
        # Check character before the match
        if start > 0 and s[start - 1].isalpha():
            return False
        # Check character after the match
        if end < len(s) and s[end].isalpha():
            return False
    return True

def check(logits):
    tokens = []
    for token, _ in logits: 
        if "<" in token or ">" in token:
            tokens.append(token)
    return list(set(tokens))


def clean_text(text):
    # Remove specific substrings
    text = re.sub(r'<0x0A>', ' ', text)
    text = re.sub(r'</s>', '', text)
    # Remove any other unwanted patterns (adjust the regex if needed)
    text = re.sub(r'<[^>]*>', '', text)
    return text.strip()

def count_words(text):
    words = text.split()
    return len(words)

def extract_brackets_len(text):
    lens = []
    cleaned_text = clean_text(text)
    words = cleaned_text.split()
    for i, word in enumerate(words):
        if "[" in word:
            counter = 0
            for j in range(i, len(words)):
                counter += 1
                if "]" in words[j]:
                    break
            lens.append(counter)
    return lens

def dot_in_hal(text):
    lens = []
    cleaned_text = clean_text(text)
    words = cleaned_text.split()
    for i, word in enumerate(words):
        if "[" in word:
            counter = 0
            for j in range(i, len(words)):
                counter += 1
                if "." in words[j]:
                    print(words[j])
                if "]" in words[j]:
                    break
            lens.append(counter)
    return lens


def count_brackets(string):
    counter = 0
    for c in string: 
        if c == "[":
            counter += 1 
    return counter


def columns_compatibility(description, hallucinations, hedges, context1, context2, context3, context4):
    text_set = set()
    text_set.add(description.replace(" ", ""))
    text_set.add(hallucinations.replace("[", "").replace("]", "").replace(" ", ""))
    # text_set.add(hedges.replace("[", "").replace("]", "").replace(" ", ""))
    # text_set.add(context1.replace("[", "").replace("]", "").replace(" ", ""))
    # text_set.add(context2.replace("[", "").replace("]", "").replace(" ", ""))
    # text_set.add(context3.replace("[", "").replace("]", "").replace(" ", ""))
    # text_set.add(context4.replace("[", "").replace("]", "").replace(" ", ""))
    return len(text_set)

In [18]:
rows_compatibilty = df.apply(lambda row: columns_compatibility(row['description'], row['hallucinations'], row['hedges'], row['context_1'], row['context_2'], row['context_3'], row['context_4']), axis=1)
rows_compatibilty[rows_compatibilty != 1]

1      2
2      2
4      2
5      2
6      2
7      2
8      2
9      2
10     2
11     2
12     2
13     2
14     2
15     2
16     2
17     2
18     2
30     2
31     2
32     2
34     2
35     2
36     2
37     2
38     2
39     2
40     2
41     2
42     2
43     2
44     2
45     2
46     2
47     2
48     2
49     2
100    2
103    2
126    2
136    2
137    2
dtype: int64

In [6]:
a = df['hallucinations'].apply(count_brackets)
a[a == 0].index

Index([7, 18, 24, 56, 126], dtype='int64')

In [7]:
cols_with_brackets = ['hallucinations', 'hedges', 'context_1', 'context_2', 'context_3', 'context_4']
for col in cols_with_brackets:
    validation_result = df[col].apply(lambda x: validate_brackets(x))
    passed_num = validation_result.astype(int).sum() 
    valid_spaces = df[col].apply(validate_spaces)
    valid_space_num = valid_spaces.astype(int).sum() 
    if 'context' in col:
        brackets_count = df[col].apply(count_brackets)
        zero_brackets = brackets_count[brackets_count == 0]
        if len(zero_brackets) > 0:
            print(f"{col} has zero brackets: {zero_brackets.index.tolist()}")
    if passed_num != len(df):
        print(f"{col} validation test failed: {passed_num}")
        print("indexes", validation_result[validation_result == False].index.tolist())
    if valid_space_num != len(df):
        print(f"{col} space test failed: {valid_space_num}")
        print("indexes", valid_spaces[valid_spaces == False].index.tolist())

In [8]:
df['logits'] = df['logits'].apply(lambda x: ast.literal_eval(x))
lst = df['logits'].apply(check).tolist()
flattened_list = [item for sublist in lst for item in sublist]
set(flattened_list)

{' <', '.<', '</s>', '>', '><'}

In [8]:
_ = df['hallucinations'].apply(dot_in_hal)

[group].
dogs].
[back].
[frisbee].
[frisbee].
[arms].
[seated].
[water].
ball].
[rain].
conditions].
bowl].
[nose].
positions].
[backpack].
musician].
[dishes].
[rice].
[bench].
riders].
[runway].
[paintbrush].
[airplane].
air].
fruits].
open].
cups].
space].
[kneeling].
[boys].
table].
[backpacks].
witch].
light].
[road].
edge].
furniture].
[upwards].
[ball].
windows].
[wallpaper].
[turn].
controller].
bowls].
fish].
wrist].
[flags].
dishes].
[pedestrians].
[wall].
road].
phone].
turned].


In [9]:
df['hal_lens'] = df['hallucinations'].apply(extract_brackets_len)
hal_lens = []
for len_list in df['hal_lens'].values:
    hal_lens.extend(len_list)

print(pd.Series(hal_lens).value_counts())

def is_max_greater_than_2(lst):
    if len(lst) == 0:
        return False
    return max(lst) >= 2

df[df['hal_lens'].apply(is_max_greater_than_2)].index

1     348
2      60
3      22
4      15
5       7
9       4
8       2
7       2
13      1
6       1
Name: count, dtype: int64


Index([  1,   3,  13,  14,  32,  33,  36,  41,  43,  44,  45,  50,  51,  53,
        54,  58,  59,  60,  61,  64,  65,  66,  68,  69,  70,  72,  74,  75,
        77,  78,  79,  81,  82,  84,  85,  87,  88,  89,  90,  91,  92,  93,
        94,  97,  99, 103, 104, 105, 107, 108, 110, 111, 112, 113, 114, 117,
       119, 120, 123, 124, 129, 130, 131, 132, 133, 134, 135, 136, 138, 139,
       140, 143, 145, 147],
      dtype='int64')