In [1]:
import os
from dev.constants import gdrive_path

import pandas as pd

from bertopic import BERTopic
from bertopic.representation import MaximalMarginalRelevance
from sklearn.feature_extraction import text
from sklearn.feature_extraction.text import CountVectorizer

In [6]:
def get_context(id, claims):
    subset = claims.loc[claims["statementID"] == id]
    assert subset["context"].nunique() == 1
    context = subset["context"].unique().item()
    return context

path = f"{gdrive_path}/climatex_full/claims"
files = os.listdir(path)

documents, ids = [], []
for file in files:
    claims = pd.read_json(f"{path}/{file}", orient="records", lines=True)
    if len(claims) == 0: continue
    report_ids = claims["statementID"].tolist()
    report_documents = [get_context(id, claims) for id in report_ids]
    documents.extend(report_documents)
    ids.extend(report_ids)

In [8]:
cv = CountVectorizer(stop_words="english")
mmr = MaximalMarginalRelevance(diversity=0.2)
tm = BERTopic(vectorizer_model=cv, representation_model=mmr, n_gram_range=(1, 3))
topics, p = tm.fit_transform(documents)

In [9]:
topic_info = tm.get_topic_info()
topic_words = {topic: tm.get_topic(topic) for topic in topic_info.Topic.unique() if topic > -1}
print(len(topic_words))
print(topic_words)

208


In [11]:
tm.reduce_topics(documents, nr_topics=50)
topics, p = tm.topics_, tm.probabilities_

topic_info = tm.get_topic_info()
topic_words = {topic: tm.get_topic(topic) for topic in topic_info.Topic.unique() if topic > -1}

In [12]:
topic_words

{0: [('adaptation', 0.030036396538321457),
  ('development', 0.022980679395816633),
  ('climate', 0.020231112037490503),
  ('indigenous', 0.017248037240933056),
  ('governance', 0.01471697312323599),
  ('resilient', 0.013746690220535446),
  ('risks', 0.012700728575514469),
  ('decisionmaking', 0.011564294655694094),
  ('poverty', 0.009668365709042052),
  ('impacts', 0.009187383046448137)],
 1: [('coral', 0.021556517561663782),
  ('acidification', 0.0197949257485088),
  ('reefs', 0.014211322099875358),
  ('warming', 0.012933631899691482),
  ('organisms', 0.01264316471944909),
  ('aquaculture', 0.011544484159819082),
  ('reef', 0.010999185478923673),
  ('phytoplankton', 0.010870240020419137),
  ('fisheries', 0.010687153231031189),
  ('ecosystems', 0.009233799162519098)],
 2: [('precipitation', 0.03706850275940039),
  ('africa', 0.025636392456185673),
  ('monsoon', 0.024091112577177352),
  ('regions', 0.018390734983817533),
  ('warming', 0.015827399928873186),
  ('models', 0.0142181409068

In [13]:
hand_labels = {
    "other": [0, 8, 12, 18, 20, 22, 29, 31, 36, 37, 40, 41],
    "coral_reefs": [1],
    "precipitation": [2, 10, 46],
    "finance": [3, 42],
    "coasts": [4, 7, 39],
    "emissions": [5],
    "sea_level": [6],
    "forests": [9, 13, 26],
    "cryosphere": [11],
    "oceana": [14],
    "cyclones": [15],
    "energy": [16, 47],
    "farming": [17, 19],
    "wetlands": [21],
    "oceans": [23],
    "migration": [24],
    "disease": [25],
    "cdr": [27],
    "cities": [28, 44],
    "flooding": [30],
    "technology": [32],
    "tropical_islands": [33],
    "solar": [34],
    "climate_discourse": [35],
    "transport": [38],
    "dryland": [43],
    "lakes": [45],
    "hydropower": [48]
}

def get_label(ix):
    for label in hand_labels.keys():
        if ix in hand_labels[label]: return label

id2topic = {ids[i]: get_label(topics[i]) for i in range(len(ids))}

In [18]:
for file in files:
    claims = pd.read_json(f"{path}/{file}", orient="records", lines=True)
    if len(claims) == 0: continue
    claims["topic"] = claims["statementID"].apply(lambda id: id2topic[id])
    claims.to_json(f"{gdrive_path}/climatex_full/topics/{file.split('_')[0]}_topics.jsonl", orient="records", lines=True)