In [1]:
import os
from dev.constants import data_storage

import pandas as pd

from bertopic import BERTopic
from bertopic.representation import MaximalMarginalRelevance
from sklearn.feature_extraction.text import CountVectorizer

In [2]:
def get_context(id, claims):
    subset = claims.loc[claims["statementID"] == id]
    context = subset["context"].unique().item()
    return context

path = f"{data_storage}/climatex/claims"
files = os.listdir(path)

documents, ids = [], []
for file in files:
    claims = pd.read_json(f"{path}/{file}", orient="records", lines=True)
    if len(claims) == 0: continue
    report_ids = claims["statementID"].tolist()
    report_documents = [get_context(id, claims) for id in report_ids]
    documents.extend(report_documents)
    ids.extend(report_ids)

In [3]:
cv = CountVectorizer(stop_words="english")
mmr = MaximalMarginalRelevance(diversity=0.2)
tm = BERTopic(vectorizer_model=cv, representation_model=mmr, n_gram_range=(1, 3))
topics, p = tm.fit_transform(documents)

In [4]:
topic_info = tm.get_topic_info()
topic_words = {topic: tm.get_topic(topic) for topic in topic_info.Topic.unique() if topic > -1}
print(len(topic_words))

206


In [5]:
tm.reduce_topics(documents, nr_topics=50)
topics, p = tm.topics_, tm.probabilities_

topic_info = tm.get_topic_info()
topic_words = {topic: tm.get_topic(topic) for topic in topic_info.Topic.unique() if topic > -1}
topic_words

{0: [('adaptation', 0.03155865292886547),
  ('climate', 0.01775888646417554),
  ('resilient', 0.015377696585186845),
  ('sustainable', 0.014625887823410469),
  ('mitigation', 0.01313189529755759),
  ('governance', 0.012202722757808558),
  ('planning', 0.0114317131784781),
  ('risk', 0.011423441064552595),
  ('pathways', 0.010652431050440465),
  ('policies', 0.008911855494814304)],
 1: [('ocean', 0.018712375195927923),
  ('marine', 0.016837240881997777),
  ('coral', 0.016315365529268412),
  ('ecosystems', 0.014860319118587347),
  ('warming', 0.013838129130424285),
  ('fish', 0.011315973794166597),
  ('reefs', 0.010623771284285351),
  ('organisms', 0.009465064345214927),
  ('reef', 0.009356563838200313),
  ('phytoplankton', 0.008564224464569135)],
 2: [('europe', 0.0218200544626815),
  ('risks', 0.01962904452396857),
  ('warming', 0.017954217508676255),
  ('confidence', 0.017865168538423577),
  ('risk', 0.014311700514681629),
  ('mediterranean', 0.011926012354175338),
  ('precipitation',

In [7]:
hand_labels = {
    "other": [0, 2, 13, 18, 19, 22, 26, 31, 39, 41, 42, 43, 45, 46, 47, 48],
    "marine_ecosystems": [1],
    "migration": [3],
    "water": [4],
    "precipitation": [5],
    "biodiversity": [6],
    "polar_regions": [7],
    "finance": [8],
    "forests": [9],
    "disease": [10],
    "australasia": [11],
    "farming": [12, 38],
    "climate_models": [14],
    "emissions": [15, 27],
    "geoengineering": [16],
    "sea_level": [17],
    "fishing": [20],
    "tropical_islands": [21],
    "aerosols": [23, 28],
    "bioenergy": [24],
    "decision_making": [25],
    "energy": [29, 44],
    "temperature": [30],
    "technology": [32],
    "ocean_acidification": [33],
    "transport": [34],
    "peat": [35],
    "coasts": [36],
    "mountains": [37],
    "sustainable_development": [40]
}
len(hand_labels)

30

In [8]:
def get_label(ix):
    for label in hand_labels.keys():
        if ix in hand_labels[label]: return label

id2topic = {ids[i]: get_label(topics[i]) for i in range(len(ids))}

for file in files:
    claims = pd.read_json(f"{path}/{file}", orient="records", lines=True)
    if len(claims) == 0: continue
    claims["topic"] = claims["statementID"].apply(lambda id: id2topic[id])
    claims.to_json(f"{data_storage}/climatex/topics/{file.split('_')[0]}_topics.jsonl", orient="records", lines=True)