In [1]:
import pandas as pd
import numpy as np
import umap
import hdbscan
from sklearn.feature_extraction.text import CountVectorizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data = pd.read_pickle("sample_data_no_images.pkl")

In [3]:
data["questions"] = [q[11:-10] for q in data["questions"].tolist()]

data.rename(columns={
    "questions": "question",
    "answers": "answer",
}, inplace=True)

In [4]:
data

Unnamed: 0,question,answer,image_classes,hidden_states
0,what ethnicity of cuisine does this restaurant...,A Chinese Cuisine restaurant.,"[Traffic sign, Stop sign]","[17.110271453857422, 13.300082206726074, 13.05..."
1,what does the license plate read?,A vehicle with a license plate that reads VE 8...,"[Person, Human body, Human leg, Human hair, Ma...","[13.568666458129883, 12.785804748535156, 13.19..."
2,what numbers are on the tail of the helicopter?,AAU.,"[Toy, Vehicle, Helicopter, Aircraft]","[12.305842399597168, 13.563484191894531, 14.84..."
3,what is written on the button with a blue outl...,A Kern EMB 220-1 is on the table.,"[Vehicle registration plate, Saucer, Plate, Pl...","[12.50804328918457, 15.923273086547852, 13.584..."
4,what is the 2nd letter of the word on the cont...,'A',"[Drum, Drink]","[13.862640380859375, 12.851885795593262, 14.51..."
...,...,...,...,...
315,what three letters are above the protein plus?,Arealkalita.,"[Drink, Food]","[12.743967056274414, 19.589651107788086, 14.17..."
316,what is the product shown?,A fruit shoot fruit shoot.,[Poster],"[13.869406700134277, 21.49329948425293, 13.056..."
317,what are the first two words on that piece of ...,'please note',"[Furniture, Cabinetry, Bed, Chest of drawers, ...","[11.29504108428955, 19.662620544433594, 13.571..."
318,what flavor is on the cup?,ABYSS,"[Beer, Dessert, Drink, Dairy, Mug, Coffee cup,...","[14.276371955871582, 15.000089645385742, 12.89..."


# The Topic Model

### Reduce the Dimensionality of the Embeddings With UMAP

In [5]:
# Use 10d for plotting (scatter plot and parallel coords)
embeddings_10d = umap.UMAP(
    n_neighbors=8,
    n_components=10,
    min_dist=0.0,
    metric="cosine",
    random_state=42
).fit_transform(np.array(data["hidden_states"].tolist()))

# Use 72d for clustering (better results)
embeddings_72d = umap.UMAP(
    n_neighbors=8,
    n_components=72,
    min_dist=0.0,
    metric="cosine",
    random_state=42
).fit_transform(np.array(data["hidden_states"].tolist()))


embeddings_10d

  warn(f"n_jobs value {self.n_jobs} overridden to 1 by setting random_state. Use no seed for parallelism.")
  warn(f"n_jobs value {self.n_jobs} overridden to 1 by setting random_state. Use no seed for parallelism.")


array([[ 4.2421017,  8.802336 ,  7.2136083, ...,  5.379241 ,  2.1277182,
         5.257646 ],
       [ 7.3908844,  8.068256 ,  6.378817 , ...,  5.313352 ,  2.6506248,
         4.7870746],
       [ 8.004892 ,  6.6811657,  5.6447587, ...,  5.188428 ,  3.1786869,
         4.902034 ],
       ...,
       [ 8.011806 ,  7.072491 ,  6.2453227, ...,  4.952343 ,  3.2267723,
         4.3398314],
       [ 7.9592156,  6.408976 ,  5.5930977, ...,  5.1610017,  3.303684 ,
         4.8432846],
       [ 3.4523036, 10.258271 ,  8.231936 , ...,  5.407035 ,  1.3668879,
         5.4077477]], dtype=float32)

### Cluster With HDBSCAN

In [6]:
clusters = hdbscan.HDBSCAN(
    min_cluster_size=8,
    prediction_data=True,
).fit(embeddings_72d)

clusters

In [7]:
np.unique(clusters.labels_)

array([-1,  0,  1,  2,  3,  4,  5,  6,  7])

### Augment the Dataset
Append the feature coordinates from UMAP Dim Reduction (dim=10) to each embedding, and also assign its cluster

In [8]:
for i in range(10):
    data[f"feature #{i+1}"] = embeddings_10d[:, i]

data["cluster"] = clusters.labels_

data

Unnamed: 0,question,answer,image_classes,hidden_states,feature #1,feature #2,feature #3,feature #4,feature #5,feature #6,feature #7,feature #8,feature #9,feature #10,cluster
0,what ethnicity of cuisine does this restaurant...,A Chinese Cuisine restaurant.,"[Traffic sign, Stop sign]","[17.110271453857422, 13.300082206726074, 13.05...",4.242102,8.802336,7.213608,4.882900,4.183401,4.297847,4.980998,5.379241,2.127718,5.257646,6
1,what does the license plate read?,A vehicle with a license plate that reads VE 8...,"[Person, Human body, Human leg, Human hair, Ma...","[13.568666458129883, 12.785804748535156, 13.19...",7.390884,8.068256,6.378817,6.395576,2.956427,6.172572,4.550172,5.313352,2.650625,4.787075,7
2,what numbers are on the tail of the helicopter?,AAU.,"[Toy, Vehicle, Helicopter, Aircraft]","[12.305842399597168, 13.563484191894531, 14.84...",8.004892,6.681166,5.644759,6.707337,2.781179,5.975961,4.813974,5.188428,3.178687,4.902034,1
3,what is written on the button with a blue outl...,A Kern EMB 220-1 is on the table.,"[Vehicle registration plate, Saucer, Plate, Pl...","[12.50804328918457, 15.923273086547852, 13.584...",5.752650,7.207048,6.732515,5.473076,3.470276,4.768158,5.173018,5.138033,2.906142,4.820735,3
4,what is the 2nd letter of the word on the cont...,'A',"[Drum, Drink]","[13.862640380859375, 12.851885795593262, 14.51...",7.783505,8.071488,6.367298,6.286996,2.856590,6.182775,4.447919,5.273427,2.748135,4.394160,7
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
315,what three letters are above the protein plus?,Arealkalita.,"[Drink, Food]","[12.743967056274414, 19.589651107788086, 14.17...",3.272186,9.924520,8.195949,4.719154,4.919437,4.045883,5.143302,5.552808,1.188666,5.581168,5
316,what is the product shown?,A fruit shoot fruit shoot.,[Poster],"[13.869406700134277, 21.49329948425293, 13.056...",3.385476,8.945896,7.260214,4.776760,4.452409,4.073992,4.955898,5.439129,2.179989,5.493263,4
317,what are the first two words on that piece of ...,'please note',"[Furniture, Cabinetry, Bed, Chest of drawers, ...","[11.29504108428955, 19.662620544433594, 13.571...",8.011806,7.072491,6.245323,6.142936,2.739525,5.709261,4.764238,4.952343,3.226772,4.339831,1
318,what flavor is on the cup?,ABYSS,"[Beer, Dessert, Drink, Dairy, Mug, Coffee cup,...","[14.276371955871582, 15.000089645385742, 12.89...",7.959216,6.408976,5.593098,6.595272,2.791255,5.790640,4.895515,5.161002,3.303684,4.843285,1


## Find Frequent Words in Clusters


In [9]:
n_top_words = 10

def generate_c_tf_idf(docs_per_topic, size_of_data, col):
    count = CountVectorizer(
        ngram_range=(1, 1),
        stop_words="english"
    ).fit(docs_per_topic[col].values)

    t = count.transform(docs_per_topic[col].values).toarray()
    w = t.sum(axis=1)
    tf = np.divide(t.T, w)
    sum_t = t.sum(axis=0)
    idf = np.log(np.divide(size_of_data, sum_t)).reshape(-1, 1)
    tf_idf = np.multiply(tf, idf)

    return tf_idf, count


def extract_top_n_words_per_topic(grouped, c_tf_idf, count: CountVectorizer, n=20):
    words = count.get_feature_names_out()
    labels = sorted(list(grouped["cluster"]))
    tf_idf_transposed = c_tf_idf.T
    indices = tf_idf_transposed.argsort()[:, -n:]
    top_n_words = {
        label: [(words[j], tf_idf_transposed[i][j]) for j in indices[i]][::-1] for i, label in enumerate(labels)
    }
    return top_n_words

def print_to_words(top_words):
    c_ids = list(sorted(top_words.keys()))
    if -1 in top_words.keys():
        c_ids = c_ids[1:]
        c_ids.append(-1)

    for start in range(0, len(c_ids), 2):
        current_ids = c_ids[start:(start+2)]

        print("-"*(33*len(current_ids) + 2))
        for i, c_id in enumerate(current_ids):
            idx = "|  " if i == 0 else ""
            end = "\n" if i == len(current_ids) - 1 else ""
            msg = f"Topic #{c_id}" if c_id != -1 else "Outliers (-1)"
            print(idx, msg.ljust(29), "| ", end=end)
        print("-"*(33*len(current_ids) + 2))

        for i, word_scores in enumerate(zip(*[top_words[i] for i in current_ids])):
            for j, (word, score) in enumerate(word_scores):
                end = " " if j < len(word_scores) else "\n"
                idx = str(i + 1).ljust(3) if j == 0 else ""
                print(idx, word.ljust(20), str(score.round(5)).ljust(8), "|", end=end)
            print()
        print()
        

### Class-Based Term-Frequency Inverse-Document-Frequency (in answers)
This allows us to find the frequency and relevancy of a word in a cluster relative to other clusters.

In [10]:
grouped_answers = data.groupby(['cluster'], as_index=False).agg({'answer': ' '.join})
c_tf_idf_ans, count_ans = generate_c_tf_idf(grouped_answers, len(data), "answer")

grouped_answers

Unnamed: 0,cluster,answer
0,-1,A person is holding a bottle of irit A book ti...
1,0,A car with a license plate that says 10033. A3...
2,1,AAU. A. A picture of a man with a shirt that s...
3,2,A person is holding a phone that says Run Reco...
4,3,A Kern EMB 220-1 is on the table. A gallon of ...
5,4,A4 poster advertises Oktoberfest. A sign adver...
6,5,A can of Corona Extra sits on a shelf. A bottl...
7,6,A Chinese Cuisine restaurant. A phone booth is...
8,7,A vehicle with a license plate that reads VE 8...


#### Extract top n Words From Each Topic (in answers)

In [11]:
top_words_in_answers = extract_top_n_words_per_topic(grouped_answers, c_tf_idf_ans, count_ans, n_top_words)

print_to_words(top_words_in_answers)

--------------------------------------------------------------------
|   Topic #0                      |  Topic #1                      | 
--------------------------------------------------------------------
1   boat                 0.37808  |  box                  0.55088  | 
2   written              0.18455  |  says                 0.07573  | 
3   plane                0.15935  |  book                 0.07565  | 
4   helicopter           0.15935  |  10                   0.07363  | 
5   white                0.12256  |  00                   0.07221  | 
6   united               0.10488  |  holding              0.06372  | 
7   24                   0.10488  |  person               0.05425  | 
8   knife                0.10488  |  ruler                0.05232  | 
9   rambo                0.10488  |  clock                0.04814  | 
10  ray                  0.10488  |  cup                  0.04814  | 

--------------------------------------------------------------------
|   Topic #2          

### Class-Based Term-Frequency Inverse-Document-Frequency (in questions)
This allows us to find the frequency and relevancy of a word in a cluster relative to other clusters.

In [12]:
grouped_questions = data.groupby(['cluster'], as_index=False).agg({'question': ' '.join})
c_tf_idf_ques, count_ques = generate_c_tf_idf(grouped_questions, len(data), "question")

grouped_questions

Unnamed: 0,cluster,question
0,-1,what brand name is written in black on counter...
1,0,what is the box car's serial number? what is t...
2,1,what numbers are on the tail of the helicopter...
3,2,what is the total distance in kilometers? what...
4,3,what is written on the button with a blue outl...
5,4,who presents this? where does is say to visit?...
6,5,what is lite? what s in the white carton? what...
7,6,what ethnicity of cuisine does this restaurant...
8,7,what does the license plate read? what is the ...


#### Extract top n Words From Each Topic (in questions)

In [13]:
top_words_in_questions = extract_top_n_words_per_topic(grouped_questions, c_tf_idf_ques, count_ques, n_top_words)
print_to_words(top_words_in_questions)

--------------------------------------------------------------------
|   Topic #0                      |  Topic #1                      | 
--------------------------------------------------------------------
1   boat                 0.48689  |  does                 0.13569  | 
2   plane                0.34657  |  letter               0.13293  | 
3   number               0.32421  |  cup                  0.11843  | 
4   airplane             0.28195  |  kind                 0.11277  | 
5   helicopter           0.25943  |  numbers              0.1124   | 
6   tail                 0.24345  |  bottle               0.1124   | 
7   letters              0.20494  |  say                  0.10795  | 
8   island               0.16023  |  ruler                0.08414  | 
9   serial               0.16023  |  brand                0.07975  | 
10  service              0.16023  |  plane                0.07493  | 

--------------------------------------------------------------------
|   Topic #2          

## Combine all the Data

In [14]:
data.to_csv("./out_data/embeds.csv", index=False)

In [15]:
def convert_topics_to_df(topics):
    topic_ids = []
    words = []
    values = []

    for topic_id, vals in topics.items():
        for word, value in vals:
            topic_ids.append(topic_id)
            words.append(word)
            values.append(value)

    return pd.DataFrame({
        "clusters": topic_ids,
        "word": words,
        "score": values
    })

convert_topics_to_df(top_words_in_answers).to_csv("./out_data/answer_topics.csv", index=False)
convert_topics_to_df(top_words_in_questions).to_csv("./out_data/question_topics.csv", index=False)