In [1]:
import numpy as np
import pandas as pd
from pprint import pprint

In [2]:

# Function defs

def tokenize(corpus : str) -> list:
    tokens = []
    for sentence in corpus:
        tokens.append(sentence.split())
    return tokens

In [3]:
def generate_center_context_pair(tokens, window: int) -> dict:
    pairs = dict()
    for row in tokens:
        for idx, center_word in enumerate(row):
            pairs.setdefault(center_word, [])
            for i in range(idx - window, idx + window + 1):
                if (i >= 0 and i != idx and i < len(row)):
                    pairs[center_word].append(row[i])
    return pairs

In [4]:
def generate_jdt(cc_pair: dict) -> list:
    jdt = []
    for center in cc_pair.keys():
        for context in cc_pair[center]:
            jdt.append([center, context])
    return jdt

In [5]:
def all_p_of_context_given_center(joint_distrib_table: pd.DataFrame):
    counts = joint_distrib_table.groupby(['center', 'context']).size()
    counts = counts.to_dict()
    return counts

In [6]:
corpus = [
        "he is a king",
        "she is a queen",
        "he is a man",
        "she is a woman",
        "warsaw is poland capital",
        "berlin is germany capital",
        "paris is france capital",
        # "Sxi este juna kaj bela",
]

In [14]:
def main():
    #pprint(corpus)

    tokens = tokenize(corpus)
    cc_pair = generate_center_context_pair(tokens, 2)

    # pprint(cc_pair)

    global jdt
    jdt = np.asarray(generate_jdt(cc_pair))
    jdt = pd.DataFrame({'center': jdt[:, 0], 'context': jdt[:, 1]})
    print("Joint Distribution Table")
    print(jdt[:10])

    cc_pair_counts = all_p_of_context_given_center(jdt)
    pprint(cc_pair_counts)

if __name__ == "__main__":
    main()


Joint Distribution Table
  center context
0     he      is
1     he       a
2     he      is
3     he       a
4     is      he
5     is       a
6     is    king
7     is     she
8     is       a
9     is   queen
{('a', 'he'): 2,
 ('a', 'is'): 4,
 ('a', 'king'): 1,
 ('a', 'man'): 1,
 ('a', 'queen'): 1,
 ('a', 'she'): 2,
 ('a', 'woman'): 1,
 ('berlin', 'germany'): 1,
 ('berlin', 'is'): 1,
 ('capital', 'france'): 1,
 ('capital', 'germany'): 1,
 ('capital', 'is'): 3,
 ('capital', 'poland'): 1,
 ('france', 'capital'): 1,
 ('france', 'is'): 1,
 ('france', 'paris'): 1,
 ('germany', 'berlin'): 1,
 ('germany', 'capital'): 1,
 ('germany', 'is'): 1,
 ('he', 'a'): 2,
 ('he', 'is'): 2,
 ('is', 'a'): 4,
 ('is', 'berlin'): 1,
 ('is', 'capital'): 3,
 ('is', 'france'): 1,
 ('is', 'germany'): 1,
 ('is', 'he'): 2,
 ('is', 'king'): 1,
 ('is', 'man'): 1,
 ('is', 'paris'): 1,
 ('is', 'poland'): 1,
 ('is', 'queen'): 1,
 ('is', 'she'): 2,
 ('is', 'warsaw'): 1,
 ('is', 'woman'): 1,
 ('king', 'a'): 1,
 ('king',