In [1]:
import numpy as np
import pandas as pd
from pprint import pprint

In [2]:
# Function defs
def tokenize(corpus : str) -> list:
    tokens = []
    for sentence in corpus:
        tokens.append(sentence.split())
    return tokens

In [3]:
def generate_center_context_pair(tokens, window: int) -> dict:
    pairs = dict()
    for row in tokens:
        for idx, center_word in enumerate(row):
            pairs.setdefault(center_word, [])
            for i in range(idx - window, idx + window + 1):
                if (i >= 0 and i != idx and i < len(row)):
                    pairs[center_word].append(row[i])
    return pairs

In [4]:
def generate_jdt(cc_pair: dict) -> list:
    jdt = []
    for center in cc_pair.keys():
        for context in cc_pair[center]:
            jdt.append([center, context])
    return jdt

In [5]:
def all_p_of_context_given_center(joint_distrib_table: pd.DataFrame):
    counts = joint_distrib_table.groupby(['center', 'context']).size()
    counts = counts.to_dict()

    # Denominator for the probability
    total = joint_distrib_table.groupby('center').size()
    total = total.to_dict()

    for center in total.keys():
        for k in list(counts.keys()):
            if k[0] is center:
                counts[k] = [counts[k]]
                counts[k].append(total[center])

    return counts

In [6]:
corpus = [
        "he is a king",
        "she is a queen",
        "he is a man",
        "she is a woman",
        "warsaw is poland capital",
        "berlin is germany capital",
        "paris is france capital",
        # "Sxi este juna kaj bela",
]

In [7]:
def main():
    pprint(corpus)

    tokens = tokenize(corpus)
    cc_pair = generate_center_context_pair(tokens, 2)

    # pprint(cc_pair)

    global jdt
    jdt = np.asarray(generate_jdt(cc_pair))
    jdt = pd.DataFrame({'center': jdt[:, 0], 'context': jdt[:, 1]})
    print("Joint Distribution Table")
    print(jdt)

    cc_pair_counts = all_p_of_context_given_center(jdt)
    pprint(cc_pair_counts)

if __name__ == "__main__":
    main()


['he is a king',
 'she is a queen',
 'he is a man',
 'she is a woman',
 'warsaw is poland capital',
 'berlin is germany capital',
 'paris is france capital']
Joint Distribution Table
     center  context
0        he       is
1        he        a
2        he       is
3        he        a
4        is       he
5        is        a
6        is     king
7        is      she
8        is        a
9        is    queen
10       is       he
11       is        a
12       is      man
13       is      she
14       is        a
15       is    woman
16       is   warsaw
17       is   poland
18       is  capital
19       is   berlin
20       is  germany
21       is  capital
22       is    paris
23       is   france
24       is  capital
25        a       he
26        a       is
27        a     king
28        a      she
29        a       is
..      ...      ...
40      she        a
41      she       is
42      she        a
43    queen       is
44    queen        a
45      man       is
46      man        