In [2]:
import polars as pl
from nltk import word_tokenize
import numpy as np

In [13]:
df = pl.read_csv('train.csv')
df.head()

id,text,author
str,str,str
"""id26305""","""This process, ...","""EAP"""
"""id17569""","""It never once ...","""HPL"""
"""id11008""","""In his left ha...","""EAP"""
"""id27763""","""How lovely is ...","""MWS"""
"""id12958""","""Finding nothin...","""HPL"""


In [35]:
k = 10
alpha =  min(0.1,50/k)
beta = 0.01

In [60]:
W = df.with_columns([
    pl.col('text').apply(lambda x: word_tokenize(x.lower())).alias('words')
]).explode('words').select([
    pl.col('id').alias('doc_id'),
    pl.col('words').alias('token'),
    pl.col('author')
])
W = W.with_columns([pl.Series(name='topic', values=np.random.choice(np.arange(10), len(W),  p=[1/k]*k))])
W

doc_id,token,author,topic
str,str,str,i32
"""id26305""","""this""","""EAP""",4
"""id26305""","""process""","""EAP""",8
"""id26305""",""",""","""EAP""",4
"""id26305""","""however""","""EAP""",5
"""id26305""",""",""","""EAP""",7
"""id26305""","""afforded""","""EAP""",4
"""id26305""","""me""","""EAP""",7
"""id26305""","""no""","""EAP""",8
"""id26305""","""means""","""EAP""",1
"""id26305""","""of""","""EAP""",2


In [64]:
V = W.groupby(['token', 'topic']).count().with_columns([
    pl.col('topic').apply(lambda x: f'k{x}')
]).pivot(values='count', index='token', columns='topic').fill_null(0)
Vs = V.select([
    pl.col([f'k{i}' for i in range(k)]).sum()
])
Vlen = len(V)
print(Vlen)
Vs

25372


k0,k1,k2,k3,k4,k5,k6,k7,k8,k9
u32,u32,u32,u32,u32,u32,u32,u32,u32,u32
59818,59347,59464,59559,59851,59348,59005,59328,59508,59694


In [97]:
D = W.groupby(['doc_id', 'topic']).count().with_columns([
    pl.col('topic').apply(lambda x: f'k{x}')
]).pivot(values='count', index='doc_id', columns='topic').fill_null(0)
Ds = D.select([
    pl.col('doc_id'),
    pl.sum(pl.col([f'k{i}' for i in range(k)]))
]).with_columns([pl.lit(0).alias('index')]).pivot(values='sum', index="index", columns='doc_id').drop('index')
Ds

id09068,id06978,id04229,id13790,id13382,id11560,id22552,id18346,id12382,id05574,id05225,id19269,id10118,id00297,id13308,id15925,id16907,id04343,id06290,id12915,id21223,id17271,id05749,id07607,id05128,id17674,id20490,id02820,id12889,id14340,id17428,id22094,id08836,id02088,id03208,id27942,id10765,...,id18983,id13931,id06277,id19407,id22178,id05508,id20648,id10741,id19410,id25635,id25311,id25776,id15555,id16694,id22388,id27232,id10362,id17585,id00908,id04068,id22183,id01048,id27732,id26963,id12085,id26189,id20233,id05517,id06845,id21595,id01037,id08334,id03150,id05111,id17846,id26673,id16673
u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,...,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32
29,27,42,49,19,30,37,44,44,6,58,7,14,32,16,65,25,34,11,20,9,71,40,25,16,32,29,53,28,32,36,18,30,36,14,38,32,...,19,6,8,8,12,13,8,9,6,4,5,6,7,14,6,7,8,5,13,4,13,5,7,7,6,8,6,5,10,6,7,7,6,5,4,4,5


In [134]:
def calc_prob(row):
    p = []
    for i in range(k):
        cdk = D.filter(
            pl.col('doc_id') == row["doc_id"]
        )[f'k{int(i)}'][0]
        ckv = V.filter(
            pl.col('token') == row["token"]
        )[f'k{int(i)}'][0]
        cd = Ds[0, row["doc_id"][0]]
        ck = Vs[0, f'k{int(row["topic"][0])}']
        p.append((alpha + cdk)/(k*alpha + cd) * (beta + ckv)/(Vlen*beta + ck))
    p = np.array(p)
    p = p/p.sum()
    return p




for it in range(3):
    for i in range(len(W)):
        if i % 5000 == 0:
            print('iteration:', it, 'in, token:', i, 'of', len(W))
        row = W[i, :]
        V.filter(
            pl.col('token') == row["token"]
        )[f'k{int(row["topic"][0])}'][0] -= 1
        D.filter(
            pl.col('doc_id') == row["doc_id"]
        )[f'k{int(row["topic"][0])}'][0] -= 1
        Ds[0, row["doc_id"][0]] -= 1
        Vs[0, f'k{int(row["topic"][0])}'] -= 1
        p = calc_prob(row)
        new_topic = int(np.random.choice(np.arange(10),  p=[1/k]*k))
        W[i, "topic"] = new_topic
        V.filter(
            pl.col('token') == row["token"]
        )[f'k{int(new_topic)}'][0] += 1
        D.filter(
            pl.col('doc_id') == row["doc_id"]
        )[f'k{int(new_topic)}'][0] += 1
        Ds[0, row["doc_id"][0]] += 1
        Vs[0, f'k{int(new_topic)}'] += 1
    print('iteration comlete', it)


iteration: 0 in, token: 0 of 594922
iteration: 0 in, token: 5000 of 594922


KeyboardInterrupt: 

In [118]:
df = pl.DataFrame({
    'a': range(5),
    'b': range(5)
})
df[3, 'b'] = 99

print(df)

Ds[0, 'id09068'] = 29
Ds

shape: (5, 2)
┌─────┬─────┐
│ a   ┆ b   │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 0   ┆ 0   │
│ 1   ┆ 1   │
│ 2   ┆ 2   │
│ 3   ┆ 99  │
│ 4   ┆ 4   │
└─────┴─────┘


id09068,id06978,id04229,id13790,id13382,id11560,id22552,id18346,id12382,id05574,id05225,id19269,id10118,id00297,id13308,id15925,id16907,id04343,id06290,id12915,id21223,id17271,id05749,id07607,id05128,id17674,id20490,id02820,id12889,id14340,id17428,id22094,id08836,id02088,id03208,id27942,id10765,...,id18983,id13931,id06277,id19407,id22178,id05508,id20648,id10741,id19410,id25635,id25311,id25776,id15555,id16694,id22388,id27232,id10362,id17585,id00908,id04068,id22183,id01048,id27732,id26963,id12085,id26189,id20233,id05517,id06845,id21595,id01037,id08334,id03150,id05111,id17846,id26673,id16673
u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,...,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32,u32
29,27,42,49,19,30,37,44,44,6,58,7,14,32,16,65,25,34,11,20,9,71,40,25,16,32,29,53,28,32,36,18,30,36,14,38,32,...,19,6,8,8,12,13,8,9,6,4,5,6,7,14,6,7,8,5,13,4,13,5,7,7,6,8,6,5,10,6,7,7,6,5,4,4,5
