In [4]:
words = ['w1', 'w2', 'w3']
docs = ['d1', 'd2']
topics = ['z1', 'z2']
n_wd = {'d1': {'w1': 2, 'w2': 1, 'w3': 0}, 'd2': {'w1': 0, 'w2': 1, 'w3': 2}}

P_z_d = {'d1': {'z1': 0.65, 'z2': 0.35}, 'd2': {'z1': 0.8, 'z2': 0.2}}
P_w_z = {'z1': {'w1': 0.45, 'w2': 0.3, 'w3': 0.25}, 'z2': {'w1': 0.7, 'w2': 0.2, 'w3': 0.1}}

def calc_responsibility(P_z_d, P_w_z):
    resp = {}
    for d in docs:
        resp[d] = {}
        for w in words:
            denom = sum(P_z_d[d][z] * P_w_z[z][w] for z in topics)
            resp[d][w] = {}
            for z in topics:
                num = P_z_d[d][z] * P_w_z[z][w]
                resp[d][w][z] = num / denom if denom != 0 else 0
    return resp

def flatten_responsibility(resp):
    flat = []
    for d in docs:
        for w in words:
            for z in topics:
                flat.append(resp[d][w][z])
    return flat

def flatten_params(P_w_z, P_z_d):
    flat_pwz = []
    for z in topics:
        for w in words:
            flat_pwz.append(P_w_z[z][w])
    flat_pzd = []
    for d in docs:
        for z in topics:
            flat_pzd.append(P_z_d[d][z])
    return flat_pwz, flat_pzd

def l1_diff_list(a, b):
    return [abs(x - y) for x, y in zip(a, b)]

tolerance = 0.1
max_iter = 50

prev_resp = None
prev_pwz = None
prev_pzd = None

for iteration in range(1, max_iter+1):
    print("\n====================")
    print(f"Iteration {iteration}")
    print("====================\n")
    
    # E-step
    resp = calc_responsibility(P_z_d, P_w_z)
    print("E-step: Calculate P(z|d,w) for every (d,w,z):")
    for d in docs:
        for w in words:
            denom = sum(P_z_d[d][z] * P_w_z[z][w] for z in topics)
            for z in topics:
                num = P_z_d[d][z] * P_w_z[z][w]
                print(f"P({z}|{d},{w}) = P({z}|{d}) * P({w}|{z}) / sum_z' P(z'|{d}) * P({w}|z') = {P_z_d[d][z]:.4f} * {P_w_z[z][w]:.4f} / {denom:.4f} = {resp[d][w][z]:.4f}")
    print()

    # M-step: Update P(w|z)
    print("M-step: Update P(w|z):")
    new_P_w_z = {z: {} for z in topics}
    for z in topics:
        numerator_sum = 0
        numerators = {}
        for w in words:
            numerators[w] = sum(n_wd[d][w] * resp[d][w][z] for d in docs)
            numerator_sum += numerators[w]
        for w in words:
            result = numerators[w] / numerator_sum if numerator_sum else 0
            new_P_w_z[z][w] = result
            print(f"P({w}|{z}) numerator = Sum_d n({w},d)*P({z}|d,{w}) = {numerators[w]:.4f}, denominator = {numerator_sum:.4f}, result = {result:.4f}")
    print()
    
    # M-step: Update P(z|d)
    print("M-step: Update P(z|d):")
    new_P_z_d = {d: {} for d in docs}
    for d in docs:
        numerator_sum = 0
        numerators = {}
        for z in topics:
            numerators[z] = sum(n_wd[d][w] * resp[d][w][z] for w in words)
            numerator_sum += numerators[z]
        for z in topics:
            result = numerators[z] / numerator_sum if numerator_sum else 0
            new_P_z_d[d][z] = result
            print(f"P({z}|{d}) numerator = Sum_w n({w},d)*P({z}|d,{w}) = {numerators[z]:.4f}, denominator = {numerator_sum:.4f}, result = {result:.4f}")
    print()

    # Compute differences for convergence
    flat_resp = flatten_responsibility(resp)
    flat_pwz_new, flat_pzd_new = flatten_params(new_P_w_z, new_P_z_d)

    if prev_resp is not None:
        resp_diffs = l1_diff_list(prev_resp, flat_resp)
        pwz_diffs = l1_diff_list(prev_pwz, flat_pwz_new)
        pzd_diffs = l1_diff_list(prev_pzd, flat_pzd_new)

        print("Differences for E-step responsibilities:")
        for i, v in enumerate(resp_diffs, 1):
            print(f"  Responsibility {i}: {v:.6f}")
        print("Differences for M-step P(w|z):")
        for i, v in enumerate(pwz_diffs, 1):
            print(f"  P(w|z) {i}: {v:.6f}")
        print("Differences for M-step P(z|d):")
        for i, v in enumerate(pzd_diffs, 1):
            print(f"  P(z|d) {i}: {v:.6f}")

        if all(v < tolerance for v in resp_diffs) and all(v < tolerance for v in pwz_diffs) and all(v < tolerance for v in pzd_diffs):
            print(f"\nConverged at iteration {iteration}")
            break
        else:
            print("\nNo convergence yet\n")

    prev_resp = flat_resp
    prev_pwz = flat_pwz_new
    prev_pzd = flat_pzd_new

    P_w_z = new_P_w_z
    P_z_d = new_P_z_d

if iteration == max_iter:
    print(f"Stopped after maximum iterations ({max_iter}) without full convergence.")



Iteration 1

E-step: Calculate P(z|d,w) for every (d,w,z):
P(z1|d1,w1) = P(z1|d1) * P(w1|z1) / sum_z' P(z'|d1) * P(w1|z') = 0.6500 * 0.4500 / 0.5375 = 0.5442
P(z2|d1,w1) = P(z2|d1) * P(w1|z2) / sum_z' P(z'|d1) * P(w1|z') = 0.3500 * 0.7000 / 0.5375 = 0.4558
P(z1|d1,w2) = P(z1|d1) * P(w2|z1) / sum_z' P(z'|d1) * P(w2|z') = 0.6500 * 0.3000 / 0.2650 = 0.7358
P(z2|d1,w2) = P(z2|d1) * P(w2|z2) / sum_z' P(z'|d1) * P(w2|z') = 0.3500 * 0.2000 / 0.2650 = 0.2642
P(z1|d1,w3) = P(z1|d1) * P(w3|z1) / sum_z' P(z'|d1) * P(w3|z') = 0.6500 * 0.2500 / 0.1975 = 0.8228
P(z2|d1,w3) = P(z2|d1) * P(w3|z2) / sum_z' P(z'|d1) * P(w3|z') = 0.3500 * 0.1000 / 0.1975 = 0.1772
P(z1|d2,w1) = P(z1|d2) * P(w1|z1) / sum_z' P(z'|d2) * P(w1|z') = 0.8000 * 0.4500 / 0.5000 = 0.7200
P(z2|d2,w1) = P(z2|d2) * P(w1|z2) / sum_z' P(z'|d2) * P(w1|z') = 0.2000 * 0.7000 / 0.5000 = 0.2800
P(z1|d2,w2) = P(z1|d2) * P(w2|z1) / sum_z' P(z'|d2) * P(w2|z') = 0.8000 * 0.3000 / 0.2800 = 0.8571
P(z2|d2,w2) = P(z2|d2) * P(w2|z2) / sum_z' P(z'|d