In [207]:
# imports
import ast
import pandas as pd
from pathlib import Path

import sys
sys.path.append('..')

from modeling.bertopic_models import get_texts, EMBEDDING_MODELS

from bertopic import BERTopic

In [208]:
# set global vars
CORPUS_DIR = Path('../corpora/UN General Debate Corpus/TXT')
MODELS_DIR = Path('../models')
PATH_METADATA = Path('../metadata/enhanced_metadata.csv')

In [209]:
# load models to dict
TOPIC_MODELS = {}

for name, embedding_model in EMBEDDING_MODELS.items():
    saved_model_path = Path(MODELS_DIR) / name
    if not saved_model_path.exists():
        print(f"Skipping model '{name}' because it does not exist")
        continue
    else:
        print(f"Loading model {name}")
    TOPIC_MODELS[name] = BERTopic.load(saved_model_path, embedding_model=name)

TOPIC_MODELS.keys()



Loading model all-MiniLM-L6-v2




Loading model all-mpnet-base-v2




Loading model all-MiniLM-L12-v2


dict_keys(['all-MiniLM-L6-v2', 'all-mpnet-base-v2', 'all-MiniLM-L12-v2'])

In [210]:
# set more global vars
MODEL_NAME = 'all-MiniLM-L12-v2'
TOP_N = 20

In [211]:
# pick specific model
model = TOPIC_MODELS[MODEL_NAME]

In [212]:
# get txts and class info
all_texts = get_texts(CORPUS_DIR)
texts = [text[0] for text in all_texts]

In [213]:
metadata = pd.read_csv(PATH_METADATA)

In [214]:
metadata.columns

Index(['Year', 'Session', 'ISO Code', 'Country', 'Name of Person Speaking',
       'Post', 'Population', 'TFR', 'HDI', 'GDP', 'Unemployment Rate', 'Gini',
       'CO2', 'Democracy Index', 'Region Name', 'Sub-region Name',
       'text_path'],
      dtype='object')

In [215]:
def get_topics_per_class(class_name):
    txt_to_class = []
    for _, idx in all_texts:
        iso_code = idx['country']
        session = idx['session']
        year = idx['year']
        info = metadata[
            (metadata['ISO Code'] == iso_code) &
            (metadata['Session'] == int(session)) & 
            (metadata['Year'] == int(year))]
        try:
            info = info[class_name].values[0]
            txt_to_class.append(info)
        except:
            txt_to_class.append('none')
    
    # get topics per class
    topics_per_class = model.topics_per_class(texts, classes = txt_to_class)
    
    # visualize topics per class
    return topics_per_class
    

In [216]:
topics_per_class = get_topics_per_class('Region Name')
model.visualize_topics_per_class(topics_per_class, top_n_topics = TOP_N)

6it [00:17,  2.86s/it]


ValueError: Mime type rendering requires nbformat>=4.2.0 but it is not installed

In [217]:
topics_per_class = get_topics_per_class('Sub-region Name')
model.visualize_topics_per_class(topics_per_class, top_n_topics = TOP_N)

18it [00:17,  1.02it/s]


ValueError: Mime type rendering requires nbformat>=4.2.0 but it is not installed