In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from hmm import HMM

In [2]:
anime = pd.read_csv('anime.csv')
rating = pd.read_csv('rating.csv')

In [3]:
anime = anime.sort_values(by='anime_id')
anime = anime[anime['members'] > 10000]

In [4]:
# only consider slice of life anime
anime = anime[anime['genre'].str.contains('Slice of Life') == True]

In [5]:
# aggregrate people's votes -> counts of which users voted more
counts = rating.groupby(by='user_id').agg('count')
counts = counts.rename(index=str, columns={'anime_id': 'counts'})
del counts['rating']
counts = counts.reset_index()

In [6]:
# only consider anime rated >= 7
rating = rating[rating['rating'] >= 7]

# people who rated more than n anime
top_counts = counts[(counts['counts'] >= 80)]  # ~20k users

In [7]:
# filter only ratings from people who satisfy the above criteria
rating_ = rating[rating['user_id'].isin(top_counts['user_id'])]  # filtered rating database - include top users only
rating_.head()

Unnamed: 0,user_id,anime_id,rating
47,1,8074,10
81,1,11617,10
83,1,11757,10
101,1,15451,10
156,3,20,8


In [8]:
len(rating_['anime_id'])  # 2653526

TIME_STEP = 5000  # anime_ids between time steps. min = 17; max = 34525.

# need to group by anime_id
rating_ = rating_.assign(time=rating_['anime_id']/TIME_STEP)
rating_['time'] = rating_['time'].apply(lambda x: int(x))

In [9]:
# filtered rating database - include only non-filtered anime
rating_ = rating_[rating_['anime_id'].isin(anime['anime_id'])]

# randomly sample 1000 users
user_ids = rating_['user_id'].unique()
user_ids = np.random.choice(user_ids, 200, replace=False)
animes = anime['anime_id'].unique()
anime_id_to_index = {}
for i in range(len(animes)):
    anime = animes[i]
    anime_id_to_index[anime] = i
    
# build an "observation sequence"
max_time = rating_['time'].max()
observation_seqs = np.zeros(shape=(len(user_ids), max_time), dtype=object)
for u in range(len(user_ids)):
    for t in range(max_time):  
        observation_seqs[u][t] = []
for u in tqdm(range(len(user_ids))):
    rating_per_user = rating_[rating_['user_id'] == user_ids[u]]
    for t in range(max_time):
        observation_seqs[u][t] = \
            rating_per_user[rating_per_user['time'] == t]['anime_id'].map(anime_id_to_index).tolist()
        if len(observation_seqs[u][t]) == 0:
             # fill in blanks with random anime -> TODO according to NBD
            observation_seqs[u][t] = [anime_id_to_index[np.random.choice(animes)]] 

100%|██████████| 200/200 [00:01<00:00, 116.52it/s]


In [10]:
# hold-out time t anime
time_t_anime = []
for u in range(len(observation_seqs)):
    time_t_anime.append(observation_seqs[u][-1])

In [11]:
observation_seqs = np.resize(observation_seqs, (len(user_ids), max_time-1))

In [12]:
# average anime watched per time-step
total = 0
for seq in observation_seqs:
    total += sum([len(x) for x in seq])/len(seq)
total /= len(observation_seqs)
total

4.634999999999999

In [13]:
n_states = 15
n_items = len(animes)
hmm = HMM(n_items, n_states, n_items + 5)

hmm.baum_welch(observation_seqs)

  5%|▌         | 1/20 [00:03<01:15,  3.96s/it]

132.002911746 68.6619166658
0.0367978081533 0.0131067707743
6.87767157709 1.90277790811
1.70039742385e-07 4.32709107634e-08
562.02186765 267.301742639
23.8853533974 7.27324350564
2209.49461367 2453.54575799
318.028147995 160.923692443
3.9630731154e-07 1.13815940797e-07
1.00980116937 0.308452816164
7.64881206685e-06 2.14101316584e-06
0.0147235389331 0.00416962488947
0.00787210845089 0.00245191122866
97.3787906173 36.0417944044
1038.43759267 627.317523738


 10%|█         | 2/20 [00:07<01:07,  3.76s/it]

87.629065382 49.6243967614
2.48600097859 1.4359753271
27.4857723208 8.61542539698
6.52800398038e-19 1.61898625112e-19
641.96170123 340.931026342
22.010119519 7.26682330052
1204.22225622 1714.62973337
47.7778708409 28.2873662906
2.1879619985 0.62304368694
15.1911756571 5.20952709003
2.87529478278 0.822825813373
1.27032285067e-08 3.30436653187e-09
2.75063254546 0.792563739253
7.96950574635 3.39660436588
405.557412775 385.281885317


 15%|█▌        | 3/20 [00:11<01:04,  3.79s/it]

35.6108199665 32.6543128043
36.4471403329 27.8169947319
241.244610311 87.0641530338
1.13867860675e-06 2.82359284297e-07
82.7461830747 61.8627552096
23.1882707625 8.11490115874
745.646148318 1155.12229398
48.4074539618 33.0800912739
1.60185228545e-09 4.38684485531e-10
15.3391855947 5.71065900267
1.04906415394e-08 2.88257144273e-09
2.78721127529e-09 6.85574720477e-10
3.62821417034e-09 1.00459425154e-09
13.4060435234 6.67865462797
382.915830678 447.938432215


 20%|██        | 4/20 [00:15<01:01,  3.87s/it]

37.5979633303 50.576222277
27.9562080458 48.3224882043
377.83904256 142.795990683
0.101233689959 0.0252878125754
114.459993707 170.516057424
22.7793730794 8.00069999606
512.931608312 830.749745935
42.7454869315 32.7270308536
1.17480315669 0.308927540668
16.4918783186 6.73263865009
1.23863786409 0.326513936414
0.0350949188584 0.00870189819884
1.22454529742 0.32318314394
17.5881110368 10.7395643597
456.229233919 598.411996415


 25%|██▌       | 5/20 [00:19<00:58,  3.87s/it]

33.8810888315 64.2949612983
57.7750357461 89.763028113
79.5193337455 29.575573316
1.58578078286e-05 3.927428574e-06
138.402007176 273.066469144
29.7024966 10.5943814222
427.621784494 702.614566845
51.9658039297 43.9907554897
5.5942336609 1.41814795002
22.4551247682 11.8103828585
12.2430973514 3.10898927271
9.01731258566e-09 2.22237531796e-09
6.95634805523 1.76690449641
24.3139841596 20.0245338502
361.024480279 509.502968733


Process ForkPoolWorker-4:
Traceback (most recent call last):
Process ForkPoolWorker-8:
Process ForkPoolWorker-3:
Process ForkPoolWorker-5:
Process ForkPoolWorker-6:
  File "/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
Process ForkPoolWorker-1:
  File "/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/anaconda3/lib/python3.6/multiprocessing/pool.py", line 119, in worker
    result = (True, func(*args, **kwds))
  File "/anaconda3/lib/python3.6/multiprocessing/pool.py", line 119, in worker
    result = (True, func(*args, **kwds))
Traceback (most recent call last):
Traceback (most recent call last):
  File "/anaco

  File "/anaconda3/lib/python3.6/site-packages/scipy/stats/_multivariate.py", line 2878, in logpmf
    x, xcond = self._process_quantiles(x, n, p)
  File "/anaconda3/lib/python3.6/site-packages/scipy/stats/_multivariate.py", line 2911, in pmf
    return np.exp(self.logpmf(x, n, p))
  File "/anaconda3/lib/python3.6/site-packages/scipy/stats/_multivariate.py", line 2832, in _process_quantiles
    if xx.size != 0 and not xx.shape[-1] == p.shape[-1]:
  File "/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
KeyboardInterrupt
KeyboardInterrupt
  File "/anaconda3/lib/python3.6/site-packages/scipy/special/basic.py", line 2165, in comb
    k, N = asarray(k), asarray(N)
  File "/anaconda3/lib/python3.6/multiprocessing/pool.py", line 119, in worker
    result = (True, func(*args, **kwds))
  File "/anaconda3/lib/python3.6/multiprocessing/pool.py", line 44, in mapstar
    return list(map(*args))
  File "/anaconda3/lib/python3.6/site

KeyboardInterrupt: 

In [None]:
hmm.A[2].sum()