In [3]:
import importlib
import json
import pathlib
import logging
import pandas as pd
import numpy as np
import pyLDAvis

In [4]:
with open("./static/config/modelRegistry.json", "r") as f:
    model_classes = json.load(f)
    
def load_class_from_path(class_path: str):
    module_path, class_name = class_path.rsplit(".", 1)
    module = importlib.import_module(module_path)
    return getattr(module, class_name)


MODEL_REGISTRY = {
    key: load_class_from_path(path) for key, path in model_classes.items()
}

MODEL_REGISTRY

{'tomotopyLDA': tova.topic_models.models.traditional.tomotopy_lda_tm_model.TomotopyLDATMmodel,
 'CTM': tova.topic_models.models.traditional.ctmtm_model.CTMTMmodel,
 'topicGPT': tova.topic_models.models.llm_based.topicgpt.topicgpt_tm_model.TopicGPTTMmodel,
 'OpenTopicRAGModel': tova.topic_models.models.llm_based.topicrag.open_topic_rag_model.OpenTopicRAGModel}

In [5]:
model = "topicGPT"
model_name = "test_nb"
id = "xxx"

In [6]:
data_file = "data_test/bills_sample_100.csv"
train_data = pd.read_csv(data_file).sample(1000, random_state=42)
train_data = train_data.rename(columns={"summary": "raw_text"})
train_data = train_data[["id", "raw_text"]].to_dict(orient="records")

In [None]:
model_cls = MODEL_REGISTRY.get(model)
if model_cls is None:
    raise ValueError(f"Unknown model: {model}")

tr_params = {
    #"num_topics": 50,
    #"preprocess_text": False,
  }

tm_model = model_cls(
    model_name=model_name,
    corpus_id="c_4e3634ace8f94d8e899142ef637348c0",
    id=id,
    model_path=pathlib.Path(f"data/tests/test_{model_name}"),
    load_model=False,
    logger=logging.getLogger(f"test_logger_{model_name}"),
    **tr_params
)

mo = tm_model.train_model(train_data)

In [6]:
thetas = tm_model.tm_model._thetas
alphas= tm_model.tm_model._alphas
betas = tm_model.tm_model._betas
vocab = tm_model.tm_model._vocab
betas

array([[2.5195919e-04, 1.1186427e-03, 7.5819204e-04, ..., 1.3817464e-04,
        3.4841048e-04, 1.0000000e-12],
       [1.0000000e-12, 2.7100969e-04, 3.8097534e-04, ..., 1.0000000e-12,
        1.0000000e-12, 5.4258919e-05],
       [1.0000000e-12, 4.2435946e-04, 6.1949290e-04, ..., 2.0160321e-04,
        1.8074561e-04, 5.8819231e-05],
       ...,
       [1.0000000e-12, 1.0000000e-12, 4.6570780e-04, ..., 1.0000000e-12,
        3.5667632e-04, 1.0000000e-12],
       [1.0000000e-12, 2.1579172e-04, 1.9501198e-04, ..., 5.3309137e-04,
        2.9871159e-04, 1.0000000e-12],
       [1.0000000e-12, 1.0000000e-12, 1.0000000e-12, ..., 1.0000000e-12,
        1.0000000e-12, 1.0000000e-12]], dtype=float32)

In [8]:
# check if betas sum to 1
np.sum(betas, axis=1)

array([1.        , 0.99999994, 1.        , 1.        , 0.9999999 ,
       0.99999994, 1.        , 1.        , 1.        ], dtype=float32)

In [9]:
ndocs = 10000
validDocs = np.sum(thetas.toarray(), axis=1) > 0
nValidDocs = np.sum(validDocs)
if ndocs > nValidDocs:
    ndocs = nValidDocs
perm = np.sort(np.random.permutation(nValidDocs)[:ndocs])
doc_len = ndocs * [1]
vocabfreq = np.round(ndocs*(alphas.dot(betas))).astype(int)

vis_data = pyLDAvis.prepare(
    betas,
    thetas[validDocs, ][perm, ].toarray(),
    doc_len,
    vocab,
    vocabfreq,
    lambda_step=0.05,
    sort_topics=False,
    n_jobs=-1)

# OPENRAG

In [8]:
model = "OpenTopicRAGModel"
model_name = "test_nb"
id = "xxx"
data_file = "data_test/bills_sample_100.csv"
train_data = pd.read_csv(data_file).sample(1000, random_state=42)
train_data = train_data.rename(columns={"summary": "raw_text"})
train_data = train_data[["id", "raw_text"]].to_dict(orient="records")

In [None]:
model_cls = MODEL_REGISTRY.get(model)
if model_cls is None:
    raise ValueError(f"Unknown model: {model}")

tr_params = {
    #"num_topics": 50,
    #"preprocess_text": False,
  }

tm_model = model_cls(
    model_name=model_name,
    corpus_id="c_4e3634ace8f94d8e899142ef637348c0",
    id=id,
    model_path=pathlib.Path(f"data/tests/test_{model_name}"),
    load_model=False,
    logger=logging.getLogger(f"test_logger_{model_name}"),
    **tr_params
)

mo = tm_model.train_model(train_data)

Loaded config file static/config/config.yaml and section logger.
Logs will be saved in data/logs
Loaded config file static/config/config.yaml and section llm.
tova.prompter.prompter - INFO - Using OLLAMA API with host: http://kumo01.tsc.uc3m.es:11434


`torch_dtype` is deprecated! Use `dtype` instead!
OpenTopicRAG Iterations:   0%|                              | 0/2 [00:00<?, ?it/s]      

Cache key: 1ffbc4a22df67728e35fbee67138df3e
Cache miss: computing results...
