In [1]:
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
import numpy as np

from scipy.sparse import lil_matrix

from sqlalchemy.dialects.postgresql import array_agg
from sqlalchemy.sql import functions as func

from osp_graphs.v1_db import session, Text, Field, Subfield, SubfieldDocument, Citation, Document

In [3]:
count = func.count(Citation.text_id)

texts = (session
    .query(Text.id, Text.title, Text.authors, array_agg(Citation.document_id))
    .join(Citation)
    .filter(Text.valid==True)
    .filter(Text.display==True)
    .group_by(Text.id)
    .order_by(count.desc())
    .limit(1000)
    .all())

In [10]:
text_ids = [t[0] for t in texts]
text_id_to_idx = {tid: i for i, tid in enumerate(text_ids)}

In [11]:
doc_ids = list(set([i for t in texts for i in t[-1]]))
doc_id_to_idx = {did: i for i, did in enumerate(doc_ids)}

In [13]:
len(text_ids)

1000

In [14]:
len(doc_ids)

225711

In [17]:
tdm = np.zeros((len(text_ids), len(doc_ids)))

In [18]:
for t in texts:
    for did in t[-1]:
        tidx = text_id_to_idx[t[0]]
        didx = doc_id_to_idx[did]
        tdm[tidx][didx] += 1

In [19]:
tdm

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])