# Anchored CoreX on STS Benchmark dataset

### Requirements

In [1]:
import numpy as np
import scipy.sparse as ss
import pandas as pd
import pickle

from corextopic import corextopic as ct
from pathlib import Path
from sklearn.feature_extraction.text import TfidfVectorizer
from matplotlib import pyplot as plt

from sentence_similarity.data import Pipeline, PipelineConfig, STSBenchmark

data_dir = Path("data")
assert data_dir.exists(), "data_dir does not exist."
output_dir = Path("data")
output_dir.mkdir(exist_ok=True, parents=True)

## Preprocessing & Vectorizer

In [2]:
config = PipelineConfig(
    filtered_pos_tags=[],
    remove_stop_words=True,
    remove_numbers=False,
    remove_symbols=False,
    remove_punctuation=False,
)
pipeline = Pipeline(config)
config.save(data_dir)

In [3]:
# load the dataset
sts_benchmark = STSBenchmark(data_dir, partition="train")

In [4]:
# preprocess sentences
s1_preprocessed = pipeline(sts_benchmark.s1)
s2_preprocessed = pipeline(sts_benchmark.s2)

Preprocessing: 100%|██████████| 5552/5552 [00:03<00:00, 1540.68it/s]
Preprocessing: 100%|██████████| 5552/5552 [00:02<00:00, 1915.40it/s]


In [5]:
pd.concat([sts_benchmark.s1, s1_preprocessed], axis=1)

Unnamed: 0,s1,0
0,A plane is taking off.,plane take .
1,A man is playing a large flute.,man play large flute.
2,A man is spreading shreded cheese on a pizza.,man spread shred cheese pizza.
3,Three men are playing chess.,man play chess.
4,A man is playing the cello.,man play cello.
...,...,...
5547,"Palestinian hunger striker, Israel reach deal","palestinian hunger striker, Israel reach deal"
5548,Assad says Syria will comply with UN arms reso...,Assad say Syria comply UN arm resolution
5549,South Korean President Sorry For Ferry Response,south korean President sorry Ferry Response
5550,Food price hikes raise concerns in Iran,food price hike raise concern Iran


In [6]:
# fit TF-IDF vectorizer
vectorizer = TfidfVectorizer(strip_accents="ascii", binary=True, ngram_range=(1,1))
doc_word = vectorizer.fit_transform(pd.concat([s1_preprocessed, s2_preprocessed]))
doc_word = ss.csr_matrix(doc_word)

# save vectorizer
with open(output_dir / "vectorizer.bin", "wb") as f:
    pickle.dump(vectorizer, f)

print(doc_word.shape)  # n_docs x m_words

# Get words that label the columns (needed to extract readable topics and make anchoring easier)
words = list(np.asarray(vectorizer.get_feature_names_out()))
print("# digits:", len([word for word in words if word.isdigit()]))

(11104, 9393)
# digits: 360


## CoreX Topic model

In [7]:
# Train the CorEx topic model with 50 topics
topic_model = ct.Corex(n_hidden=50, words=words, max_iter=300)
topic_model.fit(doc_word, words=words)

plt.plot(topic_model.tc_history)
plt.show()

In [8]:
topic_model.save(output_dir / "corex_model.bin")

In [9]:
# Print all topics from the CorEx topic model
topics = topic_model.get_topics()
for n,topic in enumerate(topics):
    topic_words,_,_ = zip(*topic)
    print('{}: '.format(n) + ', '.join(topic_words))

0: man, play, woman, dog, guitar, white, black, ride, run, slice
1: stock, percent, index, share, cent, nasdaq, composite, point, trading, close
2: dow, jones, average, industrial, credit, dji, outfielder, hate, ninth, extradition
3: vessel, fill, wonder, jury, nullification, total, bell, microsoft, monteith, cory
4: san, suu, kyi, illness, proud, liberty, aung, richard, schedule, plutonium
5: usd, yuan, magnitude, quake, strengthen, usgs, weaken, hurricane, cenc, jolt
6: ratify, amend, aviv, tel, asylum, universe, explanation, existence, withdraw, moscow
7: criticism, king, whitey, ernst, ii, albert, wrongdoing, einstein, bulgergirlfriend, hunt
8: decker, double, bus, drive, hood, wheel, ferris, druce, fucking, volt
9: requirement, corn, guess, hi, atlantic, doping, absolutely, pacific, terral, actually
10: caucus, statue, lenin, topple, murdoch, majority, minimum, santorum, maine, kiev
11: plead, guilty, torch, ahmadinejad, temple, steer, jetblue, iranparliament, appearance, buddhist