# Setting(dataset, parameter)

In [1]:
args_text = '--base-model sentence-transformers/paraphrase-MiniLM-L6-v2 ' + \
            '--dataset all --n-word 30000 ' + \
            '--bsz 32 --stage-2-lr 2e-2 --stage-2-repeat 5 ' + \
            '--n-cluster 20 '

In [2]:
import re
import os
import time
import argparse
import string
import torch
import torch.nn as nn
import torch.nn.functional as F
import gensim.downloader
import itertools

from sentence_transformers import SentenceTransformer

import numpy as np
from tqdm import tqdm_notebook
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from sklearn.datasets import fetch_20newsgroups
from nltk.corpus import stopwords

from utils import AverageMeter

import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import mutual_info_score
from sklearn.model_selection import train_test_split
from gensim.corpora.dictionary import Dictionary
from pytorch_transformers import *
import scipy.stats

from gensim.models.coherencemodel import CoherenceModel
from tqdm import tqdm
import nltk

from datetime import datetime
import gensim.downloader
from scipy.linalg import qr
from data import *
from data import TwitterDataset, RedditDataset, YoutubeDataset, BertDataset
from model import ContBertTopicExtractorAE

from data import BertDataset, Stage2Dataset
import random
from sklearn.metrics.pairwise import cosine_similarity

import warnings
warnings.filterwarnings("ignore")

%load_ext autoreload
%autoreload 2

[nltk_data] Downloading package punkt to /home/don12/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /home/don12/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /home/don12/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [3]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0

In [4]:
def _parse_args():
    parser = argparse.ArgumentParser(description='Contrastive topic modeling')
    parser.add_argument('--bsz', type=int, default=64,
                        help='Batch size')
    parser.add_argument('--dataset', default='twitter', type=str,
                        choices=['twitter', 'reddit', 'youtube','all'],
                        help='Name of the dataset')
    parser.add_argument('--n-cluster', default=20, type=int,
                        help='Number of clusters')
    parser.add_argument('--n-topic', type=int,
                        help='Number of topics. If not specified, use same value as --n-cluster')
    parser.add_argument('--n-word', default=30000, type=int,
                        help='Number of words in vocabulary')
    
    parser.add_argument('--base-model', type=str,
                        help='Name of base model in huggingface library.')
    
    parser.add_argument('--dirichlet-alpha-1', type=float,
                        help='Parameter for Dirichlet distribution (Phase 1). Use 1/n_topic by default.')
 
    parser.add_argument('--coeff-2-recon', default=1.0, type=float,
                        help='Coefficient for VAE reconstruction loss (Phase 2)')
    parser.add_argument('--coeff-2-regul', default=1.0, type=float,
                        help='Coefficient for VAE KLD regularization loss (Phase 2)')
    parser.add_argument('--coeff-2-cons', default=1.0, type=float,
                        help='Coefficient for CL consistency loss (Phase 2)')
    parser.add_argument('--coeff-2-dist', default=1.0, type=float,
                        help='Coefficient for CL SWD distribution matching loss (Phase 2)')
    parser.add_argument('--dirichlet-alpha-2', type=float,
                        help='Parameter for Dirichlet distribution (Phase 2). Use same value as dirichlet-alpha-1 by default.')
    parser.add_argument('--stage-2-lr', default=2e-1, type=float,
                        help='Learning rate of phase 2')
    
    parser.add_argument('--stage-2-repeat', default=5, type=int,
                        help='Repetition count of phase 2')
    
    parser.add_argument('--result-file', type=str,
                        help='File name for result summary')
    
    
    # Check if the code is run in Jupyter notebook
    is_in_jupyter = False
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            is_in_jupyter = True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            is_in_jupyter = False  # Terminal running IPython
        else:
            is_in_jupyter = False  # Other type (?)
    except NameError:
        is_in_jupyter = False
    
    if is_in_jupyter:
        return parser.parse_args(args=args_text.split())
    else:
        return parser.parse_args()

In [5]:
args = _parse_args()
bsz = args.bsz

n_cluster = args.n_cluster
n_topic = args.n_topic if (args.n_topic is not None) else n_cluster
args.n_topic = n_topic

# textData, should_measure_hungarian = data_load(args.dataset)

ema_alpha = 0.99
n_word = args.n_word
if args.dirichlet_alpha_1 is None:
    dirichlet_alpha_1 = 1 / n_cluster
else:
    dirichlet_alpha_1 = args.dirichlet_alpha_1
if args.dirichlet_alpha_2 is None:
    dirichlet_alpha_2 = dirichlet_alpha_1
else:
    dirichlet_alpha_2 = args.dirichlet_alpha_2
    
bert_name = args.base_model
bert_name_short = bert_name.split('/')[-1]

In [6]:
twitter_ds = TwitterDataset()
reddit_ds = RedditDataset()
youtube_ds = YoutubeDataset()

train_twitter_texts, test_twitter_texts, train_twitter_labels, test_twitter_labels = train_test_split(
    twitter_ds.texts, twitter_ds.labels, train_size=0.7, random_state=42)

train_reddit_texts, test_reddit_texts, train_reddit_labels, test_reddit_labels = train_test_split(
    reddit_ds.texts, reddit_ds.labels, train_size=0.7, random_state=42)

train_youtube_texts, test_youtube_texts, train_youtube_labels, test_youtube_labels = train_test_split(
    youtube_ds.texts, youtube_ds.labels, train_size=0.7, random_state=42)

# Split train data 9:1 between train and val
train_twitter_texts, val_twitter_texts, train_twitter_labels, val_twitter_labels = train_test_split(
    train_twitter_texts, train_twitter_labels, test_size=0.1, random_state=42)

train_reddit_texts, val_reddit_texts, train_reddit_labels, val_reddit_labels = train_test_split(
    train_reddit_texts, train_reddit_labels, test_size=0.1, random_state=42)

train_youtube_texts, val_youtube_texts, train_youtube_labels, val_youtube_labels = train_test_split(
    train_youtube_texts, train_youtube_labels, test_size=0.1, random_state=42)


train_total_label = train_twitter_labels + train_reddit_labels + train_youtube_labels
train_total_text_list = train_twitter_texts + train_reddit_texts + train_youtube_texts

val_total_label = val_twitter_labels + val_reddit_labels + val_youtube_labels
val_total_text_list = val_twitter_texts + val_reddit_texts + val_youtube_texts

test_total_label = test_twitter_labels + test_reddit_labels + test_youtube_labels
test_total_text_list = test_twitter_texts + test_reddit_texts + test_youtube_texts

In [7]:
device = torch.device("cuda:2" if torch.cuda.is_available() else 'cpu')
print(device)

cuda:2


In [8]:
class BertDataset(Dataset):
    def __init__(self, bert, text_list, platform_label, N_word, vectorizer=None, lemmatize=False):
        self.lemmatize = lemmatize
        self.nonempty_text = [text for text in text_list if len(text) > 0]
        
        # Remove new lines
        self.nonempty_text = [re.sub("\n"," ", sent) for sent in self.nonempty_text]
                
        # Remove Emails
        self.nonempty_text = [re.sub('\S*@\S*\s?', '', sent) for sent in self.nonempty_text]
        
        # Remove new line characters
        self.nonempty_text = [re.sub('\s+', ' ', sent) for sent in self.nonempty_text]
        
        # Remove distracting single quotes
        self.nonempty_text = [re.sub("\'", "", sent) for sent in self.nonempty_text]
        
        # jargons list via c-TF-IDF 
        self.jargons =  set(['coinex', 'announces', 'seos', 'chatgptpowered', 'launches', 'bigdata', 'unveils', 'openaichatgpt', 'tags', 'hn', 'stablediffusion', 'chatgptstyle', 'reportedly', 'mba', 'marketers', 'baidu', 'technews', 'fintech', 'chatgptlike', 'elonmusk', 'notion', 'goog', 'googleai', 'digitalmarketing', 'artificalintelligence', 'rt', 'googl', 'bardai', 'edtech', 'malware', 'wharton', 'agix', 'chatgptplus', 'datascience', 'deeplearning', 'msft', 'weirdness', 'tweets', 'amid', 'aitools', 'cybersecurity', 'airdrop', 'cc', 'valentines', 'startups', 'snapchat', 'generativeai', 'buzzfeed', 'fastestgrowing', 'anthropic', 'maker', 'rival', 'techcrunch', 'aiart', 'nocode', 'invests', 'cybercriminals', 'abstracts', 'nyc', 'webinar', 'retweet', 'educators', 'brilliance', 'rescue', 'daysofcode', 'gm', 'rtechnology', 'linkedin', 'licensing', 'copywriting', 'copywriters', 'contentmarketing', 'revolutionizing', 'technologynews', 'warns', 'metaverse', 'cofounder', 'trending', 'founders', 'aipowered', 'openaichat', 'releases', 'microsofts', 'chinas', 'infosec', 'launching', 'jasper', 'nfts', 'newsletter', 'chatgptgod', 'futureofwork', 'digitaltransformation', 'founder', 'feb', 'buzz', 'rn', 'ux', 'courtesy', 'nick', 'claude',
'remindme', 'giphy', 'gif', 'deleted', 'giphydownsized', 'chadgpt', 'removed', 'patched', 'nerfed', 'yup', 'waitlist', 'refresh', 'sydney', 'mods', 'nsfw', 'characterai', 'screenshot', 'downvoted', 'youcom', 'meth', 'ascii', 'karma', 'hahaha', 'hangman', 'chatopenaicom', 'emojis', 'porn', 'redditor', 'vpn', 'upvotes', 'blah', 'upvote', 'violated', 'yep', 'joking', 'nope', 'offended', 'mod', 'bruh', 'roleplay', 'ops', 'bob', 'dans', 'redditors', 'nerf', 'firefox', 'trolling', 'sarcastic', 'huh', 'turbo', 'troll', 'patch', 'tag', 'url', 'sus', 'erotica', 'chad', 'gotcha', 'basilisk', 'login', 'lmfao', 'temperature', 'poll', 'emoji', 'rick', 'dm', 'jailbreak', 'orange', 'sub', 'quack', 'davinci', 'uh', 'flagged', 'op', 'markdown', 'flair', 'cares', 'refreshing', 'hitler', 'cookies', 'hmm', 'yikes', 'erotic', 'gti', 'paywall', 'elaborate', 'yea', 'ah', 'uncensored', 'rude', 'colour', 'bitch', 'therapy', 'neutered', 'deny', 'chats', 'jailbroken', 'cake', 'dungeon', 'dang',
'zronx', 'tuce', 'jontron', 'levy', 'bishop', 'rook', 'thumbnail', 'quotquot', 'jon', 'linus', 'hrefaboutinvalidzcsafeza', 'beluga', 'vid', 'bhai', 'gemx', 'raid', 'ohio', 'circle', 'subscribed', 'anna', 'stare', 'canva', 'napster', 'shapiro', 'sponsor', 'broker', 'websiteapp', 'manoj', 'subscriber', 'bluewillow', 'alex', 'vids', 'legends', 'ryan', 'shes', 'hackbanzer', 'quotoquot', 'pictory', 'youtuber', 'profitable', 'pawn', 'joma', 'folders', 'lifechanging', 'thomas', 'ur', 'plz', 'mike', 'scott', 'casey', 'adrian', 'enjoyed', 'stockfish', 'invideo', 'shortlisted', 'hikaru', 'bless', 'corpsb', 'chatgbt', 'bfuture', 'curve', 'accent', 'amc', 'tutorials', 'gotham', 'mrs', 'earning', 'bra', 'elo', 'oliver', 'youtubers', 'quotcontinuequot', 'membership', 'labels', 'dagogo', 'eonr', 'hai', 'quotai', 'affiliate', 'congratulationsbryou', 'subscribers', 'thumbnails', 'azn', 'beast', 'tom', 'trader', 'garetz', 'quot', 'subbed', 'pls', 'quotchatgpt', 'gtp', 'machina', 'quoti', 'bret', 'terminator', 'watchingbrdm', 'quothow', 'nowi', 'mint'])

        
        self.tokenizer = AutoTokenizer.from_pretrained(bert)
        self.model = AutoModel.from_pretrained(bert).to(device)
        self.stopwords_list = set(TfidfVectorizer(stop_words="english").get_stop_words()).union(self.jargons)
        self.N_word = N_word
        
        if vectorizer == None:
            self.vectorizer = TfidfVectorizer(stop_words=None, max_features=self.N_word, token_pattern=r'\b[a-zA-Z]{2,}\b')
            self.vectorizer.fit(self.preprocess_ctm(self.nonempty_text))
        else:
            self.vectorizer = vectorizer
            
        self.org_list = []
        self.bow_list = []
        self.platform_label_list = platform_label
        
        for sent in tqdm(self.nonempty_text):
            org_input = self.tokenizer(sent, padding='max_length', truncation=True, max_length=512, return_tensors='pt')
            org_input['input_ids'] = torch.squeeze(org_input['input_ids'])
            org_input['attention_mask'] = torch.squeeze(org_input['attention_mask'])
            self.org_list.append(org_input)
            self.bow_list.append(self.vectorize(sent))
            
    def vectorize(self, text):
        text = self.preprocess_ctm([text])
        vectorized_input = self.vectorizer.transform(text)
        vectorized_input = vectorized_input.toarray()
        vectorized_input = vectorized_input.astype(np.float64)

        # Get word distribution from BoW
        vectorized_input += 1e-8
        vectorized_input = vectorized_input / vectorized_input.sum(axis=1, keepdims=True)
        assert abs(vectorized_input.sum() - vectorized_input.shape[0]) < 0.01
        vectorized_label = torch.tensor(vectorized_input, dtype=torch.float)
        return vectorized_label[0]
        
        
    def preprocess_ctm(self, documents):
        preprocessed_docs_tmp = documents
        preprocessed_docs_tmp = [doc.lower() for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [doc.translate(
            str.maketrans(string.punctuation, ' ' * len(string.punctuation))) for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [' '.join([w for w in doc.split() if len(w) > 0 and w not in self.stopwords_list])
                             for doc in preprocessed_docs_tmp]
        if self.lemmatize:
            lemmatizer = WordNetLemmatizer()
            preprocessed_docs_tmp = [' '.join([lemmatizer.lemmatize(w) for w in doc.split()])
                                     for doc in preprocessed_docs_tmp]
        return preprocessed_docs_tmp    
    
    # mean_pooling 함수 정의
    def mean_pooling(self, model_output, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(model_output.size()).float()
        sum_embeddings = torch.sum(model_output * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask
    
    
    def __len__(self):
        return len(self.nonempty_text)
    

    def __getitem__(self, idx):
        sentence = self.nonempty_text[idx]
        encoded_input = self.tokenizer(sentence, padding=True, truncation=True, max_length=512, return_tensors='pt')
        with torch.no_grad():
            model_output = self.model(**encoded_input.to(device))
        pooled_embedding = self.mean_pooling(model_output.last_hidden_state, encoded_input['attention_mask'])
        return self.org_list[idx], self.bow_list[idx], pooled_embedding, self.platform_label_list[idx]


In [9]:
trainds = BertDataset(bert=bert_name, text_list=train_total_text_list, platform_label = train_total_label, N_word=n_word, vectorizer=None, lemmatize=True)
valds = BertDataset(bert=bert_name, text_list=val_total_text_list, platform_label = val_total_label, N_word=n_word, vectorizer=None, lemmatize=True)
testds = BertDataset(bert=bert_name, text_list=test_total_text_list, platform_label = test_total_label, N_word=n_word, vectorizer=None, lemmatize=True)

2024-03-20 19:30:03.310093: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-03-20 19:30:03.505009: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-20 19:30:04.424355: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-11.7/lib64:
2024-03-20 19:30:04.424478: W tensorflow/compiler/xla/strea

#  Mean pooling progress

In [10]:
# train data set
trainds_embeddings = []
for _,_,pooled_embedding,_ in trainds:
    trainds_embeddings.append(pooled_embedding.to(device))  
train_mean_pooled_embeddings = torch.stack(trainds_embeddings).to(device) 

# validation data set
valds_embeddings = []
for _,_,pooled_embedding,_ in valds:
    valds_embeddings.append(pooled_embedding.to(device))
val_mean_pooled_embeddings = torch.stack(valds_embeddings).to(device)

# test data set
testds_embeddings = []
for _,_,pooled_embedding,_ in testds:
    testds_embeddings.append(pooled_embedding.to(device))
test_mean_pooled_embeddings = torch.stack(testds_embeddings).to(device)

# Re_fornulate the bow

In [11]:
def dist_match_loss(hiddens, alpha=1.0):
    device = hiddens.device
    hidden_dim = hiddens.shape[-1]
    H = np.random.randn(hidden_dim, hidden_dim)
    Q, R = qr(H) 
    rand_w = torch.Tensor(Q).to(device)
    loss_dist_match = get_swd_loss(hiddens, rand_w, alpha)
    return loss_dist_match


def js_div_loss(hidden1, hidden2):
    m = 0.5 * (hidden1 + hidden2)
    return kldiv(m.log(), hidden1) + kldiv(m.log(), hidden2)


def get_swd_loss(states, rand_w, alpha=1.0):
    device = states.device
    states_shape = states.shape
    states = torch.matmul(states, rand_w)
    states_t, _ = torch.sort(states.t(), dim=1)

    # Random vector with length from normal distribution
    states_prior = torch.Tensor(np.random.dirichlet([alpha]*states_shape[1], states_shape[0])).to(device) # (bsz, dim)
    states_prior = torch.matmul(states_prior, rand_w) # (dim, dim)
    states_prior_t, _ = torch.sort(states_prior.t(), dim=1) # (dim, bsz)
    return torch.mean(torch.sum((states_prior_t - states_t)**2, axis=0))

# Get pos_similarity

In [12]:
def compute_max_cosine_similarity_indices(mean_pooled_embeddings, batch_size=500):
    n_rows = mean_pooled_embeddings.size(0)
    max_similarity_indices = torch.zeros(n_rows, dtype=torch.int64, device=mean_pooled_embeddings.device)

    for start_idx in range(0, n_rows, batch_size):
        end_idx = min(start_idx + batch_size, n_rows)
        batch_data = mean_pooled_embeddings[start_idx:end_idx]

        batch_data_norm = torch.nn.functional.normalize(batch_data, p=2, dim=1)
        mean_pooled_embeddings_norm = torch.nn.functional.normalize(mean_pooled_embeddings, p=2, dim=1)
        
        # cosine similarty
        batch_similarity = torch.mm(batch_data_norm, mean_pooled_embeddings_norm.transpose(0, 1))
        batch_similarity[torch.arange(end_idx-start_idx), torch.arange(start_idx, end_idx)] = -1

        max_indices = torch.argmax(batch_similarity, dim=1)
        max_similarity_indices[start_idx:end_idx] = max_indices

    return max_similarity_indices

In [13]:
# Convert mean_pooled_embeddings to two-dimensional form
train_mean_pooled_embeddings_2d = train_mean_pooled_embeddings.to(device).squeeze()
val_mean_pooled_embeddings_2d = val_mean_pooled_embeddings.to(device).squeeze()
test_mean_pooled_embeddings_2d = test_mean_pooled_embeddings.to(device).squeeze()

In [14]:
# Calculate the cosine similarity matrix
train_similarity_matrix = compute_max_cosine_similarity_indices(train_mean_pooled_embeddings_2d)
val_similarity_matrix = compute_max_cosine_similarity_indices(val_mean_pooled_embeddings_2d)
test_similarity_matrix = compute_max_cosine_similarity_indices(test_mean_pooled_embeddings_2d)  

In [15]:
model =  ContBertTopicExtractorAE(N_topic=n_topic, N_word=args.n_word, bert=bert_name, bert_dim=768)
model.to(device)  

ContBertTopicExtractorAE(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 384, padding_idx=0)
      (position_embeddings): Embedding(512, 384)
      (token_type_embeddings): Embedding(2, 384)
      (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=384, out_features=384, bias=True)
              (key): Linear(in_features=384, out_features=384, bias=True)
              (value): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=384, out_features=384, bias=True)
              (LayerNorm): LayerNorm((384,), eps=1e-12, ele

In [16]:
class Stage2Dataset(Dataset):
    def __init__(self, encoder, ds, similarity_matrix, N_word, k=1, lemmatize=False):
        self.lemmatize = lemmatize
        self.ds = ds
        self.org_list = self.ds.org_list
        self.nonempty_text = self.ds.nonempty_text
        self.N_word = N_word
            
      
        self.jargons =  set(['coinex', 'announces', 'seos', 'chatgptpowered', 'launches', 'bigdata', 'unveils', 'openaichatgpt', 'tags', 'hn', 'stablediffusion', 'chatgptstyle', 'reportedly', 'mba', 'marketers', 'baidu', 'technews', 'fintech', 'chatgptlike', 'elonmusk', 'notion', 'goog', 'googleai', 'digitalmarketing', 'artificalintelligence', 'rt', 'googl', 'bardai', 'edtech', 'malware', 'wharton', 'agix', 'chatgptplus', 'datascience', 'deeplearning', 'msft', 'weirdness', 'tweets', 'amid', 'aitools', 'cybersecurity', 'airdrop', 'cc', 'valentines', 'startups', 'snapchat', 'generativeai', 'buzzfeed', 'fastestgrowing', 'anthropic', 'maker', 'rival', 'techcrunch', 'aiart', 'nocode', 'invests', 'cybercriminals', 'abstracts', 'nyc', 'webinar', 'retweet', 'educators', 'brilliance', 'rescue', 'daysofcode', 'gm', 'rtechnology', 'linkedin', 'licensing', 'copywriting', 'copywriters', 'contentmarketing', 'revolutionizing', 'technologynews', 'warns', 'metaverse', 'cofounder', 'trending', 'founders', 'aipowered', 'openaichat', 'releases', 'microsofts', 'chinas', 'infosec', 'launching', 'jasper', 'nfts', 'newsletter', 'chatgptgod', 'futureofwork', 'digitaltransformation', 'founder', 'feb', 'buzz', 'rn', 'ux', 'courtesy', 'nick', 'claude',
'remindme', 'giphy', 'gif', 'deleted', 'giphydownsized', 'chadgpt', 'removed', 'patched', 'nerfed', 'yup', 'waitlist', 'refresh', 'sydney', 'mods', 'nsfw', 'characterai', 'screenshot', 'downvoted', 'youcom', 'meth', 'ascii', 'karma', 'hahaha', 'hangman', 'chatopenaicom', 'emojis', 'porn', 'redditor', 'vpn', 'upvotes', 'blah', 'upvote', 'violated', 'yep', 'joking', 'nope', 'offended', 'mod', 'bruh', 'roleplay', 'ops', 'bob', 'dans', 'redditors', 'nerf', 'firefox', 'trolling', 'sarcastic', 'huh', 'turbo', 'troll', 'patch', 'tag', 'url', 'sus', 'erotica', 'chad', 'gotcha', 'basilisk', 'login', 'lmfao', 'temperature', 'poll', 'emoji', 'rick', 'dm', 'jailbreak', 'orange', 'sub', 'quack', 'davinci', 'uh', 'flagged', 'op', 'markdown', 'flair', 'cares', 'refreshing', 'hitler', 'cookies', 'hmm', 'yikes', 'erotic', 'gti', 'paywall', 'elaborate', 'yea', 'ah', 'uncensored', 'rude', 'colour', 'bitch', 'therapy', 'neutered', 'deny', 'chats', 'jailbroken', 'cake', 'dungeon', 'dang',
'zronx', 'tuce', 'jontron', 'levy', 'bishop', 'rook', 'thumbnail', 'quotquot', 'jon', 'linus', 'hrefaboutinvalidzcsafeza', 'beluga', 'vid', 'bhai', 'gemx', 'raid', 'ohio', 'circle', 'subscribed', 'anna', 'stare', 'canva', 'napster', 'shapiro', 'sponsor', 'broker', 'websiteapp', 'manoj', 'subscriber', 'bluewillow', 'alex', 'vids', 'legends', 'ryan', 'shes', 'hackbanzer', 'quotoquot', 'pictory', 'youtuber', 'profitable', 'pawn', 'joma', 'folders', 'lifechanging', 'thomas', 'ur', 'plz', 'mike', 'scott', 'casey', 'adrian', 'enjoyed', 'stockfish', 'invideo', 'shortlisted', 'hikaru', 'bless', 'corpsb', 'chatgbt', 'bfuture', 'curve', 'accent', 'amc', 'tutorials', 'gotham', 'mrs', 'earning', 'bra', 'elo', 'oliver', 'youtubers', 'quotcontinuequot', 'membership', 'labels', 'dagogo', 'eonr', 'hai', 'quotai', 'affiliate', 'congratulationsbryou', 'subscribers', 'thumbnails', 'azn', 'beast', 'tom', 'trader', 'garetz', 'quot', 'subbed', 'pls', 'quotchatgpt', 'gtp', 'machina', 'quoti', 'bret', 'terminator', 'watchingbrdm', 'quothow', 'nowi', 'mint'])


        self.stopwords_list = set(TfidfVectorizer(stop_words="english").get_stop_words()).union(self.jargons)
        
        self.vectorizer = TfidfVectorizer(stop_words=None, max_features=self.N_word, token_pattern=r'\b[a-zA-Z]{2,}\b')
        self.vectorizer.fit(self.preprocess_ctm(self.nonempty_text)) 
        self.bow_list = []
        for sent in tqdm(self.nonempty_text):
            self.bow_list.append(self.vectorize(sent))
        self.pos_dict = similarity_matrix
        
        self.embedding_list = []
        self.platform_label_list = self.ds.platform_label_list
        encoder_device = next(encoder.parameters()).device
        for org_input in tqdm(self.org_list):
            org_input_ids = org_input['input_ids'].to(encoder_device).reshape(1, -1)
            org_attention_mask = org_input['attention_mask'].to(encoder_device).reshape(1, -1)
            embedding = encoder(input_ids = org_input_ids, attention_mask = org_attention_mask)
            self.embedding_list.append(embedding['pooler_output'].squeeze().detach().cpu())
            
        
            
    
    def __len__(self):
        return len(self.org_list)
        
    def preprocess_ctm(self, documents):
        preprocessed_docs_tmp = documents
        preprocessed_docs_tmp = [doc.lower() for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [doc.translate(
            str.maketrans(string.punctuation, ' ' * len(string.punctuation))) for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [' '.join([w for w in doc.split() if len(w) > 0 and w not in self.stopwords_list])
                                 for doc in preprocessed_docs_tmp]
        if self.lemmatize:
            lemmatizer = WordNetLemmatizer()
            preprocessed_docs_tmp = [' '.join([lemmatizer.lemmatize(w) for w in doc.split()])
                                     for doc in preprocessed_docs_tmp]
        return preprocessed_docs_tmp
        
    def vectorize(self, text):
        text = self.preprocess_ctm([text])
        vectorized_input = self.vectorizer.transform(text)
        vectorized_input = vectorized_input.toarray().astype(np.float64)

        # Get word distribution from BoW
        if vectorized_input.sum() == 0:
            vectorized_input += 1e-8
        vectorized_input = vectorized_input / vectorized_input.sum(axis=1, keepdims=True)
        assert abs(vectorized_input.sum() - vectorized_input.shape[0]) < 0.01
        vectorized_label = torch.tensor(vectorized_input, dtype=torch.float).to(device) 
        return vectorized_label[0]
        
        
    def __getitem__(self, idx):
        pos_idx = self.pos_dict[idx]
        return idx, self.embedding_list[idx], self.embedding_list[pos_idx], self.bow_list[idx], self.bow_list[pos_idx],self.platform_label_list[idx]


In [17]:
finetuneds = Stage2Dataset(model.encoder, trainds, train_similarity_matrix, n_word, lemmatize=True)
valfinetuneds = Stage2Dataset(model.encoder, valds, val_similarity_matrix, n_word, lemmatize=True) 
testfinetuneds = Stage2Dataset(model.encoder, testds, test_similarity_matrix, n_word, lemmatize=True) 

kldiv = torch.nn.KLDivLoss(reduction='batchmean')
vocab_dict = finetuneds.vectorizer.vocabulary_
vocab_dict_reverse = {i:v for v, i in vocab_dict.items()}
print(n_word)

100%|██████████| 37800/37800 [01:00<00:00, 623.94it/s]
100%|██████████| 37800/37800 [04:47<00:00, 131.62it/s]
100%|██████████| 4200/4200 [00:06<00:00, 697.67it/s]
100%|██████████| 4200/4200 [00:31<00:00, 132.96it/s]
100%|██████████| 18000/18000 [00:27<00:00, 664.31it/s]
100%|██████████| 18000/18000 [02:17<00:00, 130.66it/s]

30000





# Stage 3

In [18]:
def measure_hungarian_score(topic_dist, train_target):
    dist = topic_dist
    train_target_filtered = train_target
    flat_predict = torch.tensor(np.argmax(dist, axis=1))
    flat_target = torch.tensor(train_target_filtered).to(flat_predict.device)
    num_samples = flat_predict.shape[0]
    num_classes = dist.shape[1]
    match = _hungarian_match(flat_predict, flat_target, num_samples, num_classes)    
    reordered_preds = torch.zeros(num_samples).to(flat_predict.device)
    for pred_i, target_i in match:
        reordered_preds[flat_predict == pred_i] = int(target_i)
    acc = int((reordered_preds == flat_target.float()).sum()) / float(num_samples)
    return acc

In [19]:
should_measure_hungarian = True

In [20]:
torch.cuda.empty_cache()

# Seperate Platform dataset

In [21]:
from torch.utils.data import DataLoader, Dataset, Sampler
from collections import defaultdict

# Implementing a custom sampler
class PlatformSampler(Sampler):
    def __init__(self, dataset, platform_label):
        self.indices = [i for i, label in enumerate(dataset.platform_label_list) if label == platform_label]
    
    def __iter__(self):
        return iter(self.indices)
    
    def __len__(self):
        return len(self.indices)

In [22]:
def create_platform_dataloader(dataset, platform_label, batch_size=32, num_workers=0):
    sampler = PlatformSampler(dataset, platform_label)
    dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers)
    return dataloader

# Main

In [23]:
from sklearn.metrics import mutual_info_score
from scipy.stats import chi2_contingency 
from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora.dictionary import Dictionary
from coherence import get_topic_coherence

# Seed fixation functions
def set_seed(seed_value):
    random.seed(seed_value)  
    np.random.seed(seed_value)  
    torch.manual_seed(seed_value)  
    if torch.cuda.is_available():  
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(41)

args.stage_2_repeat = 1
results_list = []

for i in range(args.stage_2_repeat):
    model = ContBertTopicExtractorAE(N_topic=n_topic, N_word=args.n_word, bert=bert_name, bert_dim=768)
    model.beta = nn.Parameter(torch.Tensor(model.N_topic, n_word))
    nn.init.xavier_uniform_(model.beta)
    model.beta_batchnorm = nn.Sequential()
    model.cuda(device)
    
    losses = AverageMeter()
    dlosses = AverageMeter() 
    rlosses = AverageMeter()
    closses = AverageMeter()
    distlosses = AverageMeter()
    ##수정
    twitter_trainloader = create_platform_dataloader(finetuneds, 'twitter', batch_size=bsz, num_workers=0)
    reddit_trainloader = create_platform_dataloader(finetuneds, 'reddit', batch_size=bsz, num_workers=0)
    youtube_trainloader = create_platform_dataloader(finetuneds, 'youtube', batch_size=bsz, num_workers=0)
    memoryloader = DataLoader(finetuneds, batch_size=bsz * 2, shuffle=False, num_workers=0)
    ##
    optimizer = torch.optim.Adam(model.parameters(), lr=args.stage_2_lr)

    memory_queue = F.softmax(torch.randn(512, n_topic).cuda(device), dim=1)
    print("Coeff   / regul: {:.5f} - recon: {:.5f} - c: {:.5f} - dist: {:.5f} ".format(args.coeff_2_regul, 
                                                                                        args.coeff_2_recon,
                                                                                        args.coeff_2_cons,
                                                                                        args.coeff_2_dist))

    best_npmi = -1
    best_epoch = 0
    best_model_state = None  
    
    # Create an iterator for each platform-specific DataLoader
    twitter_iter = iter(twitter_trainloader)
    reddit_iter = iter(reddit_trainloader)
    youtube_iter = iter(youtube_trainloader)

    # Calculate the length of the longest DataLoader to scope the training loop
    max_length = max(len(twitter_trainloader), len(reddit_trainloader), len(youtube_trainloader))

    for epoch in range(100):
        model.train()
        model.encoder.eval()

        for _ in range(max_length):
            # Sequentially fetching batches from platform-specific DataLoaders
            try:
                twitter_batch = next(twitter_iter)
            except StopIteration:
                # When Twitter DataLoader reaches the end, restart the iterator
                twitter_iter = iter(twitter_trainloader)
                twitter_batch = next(twitter_iter)

            try:
                reddit_batch = next(reddit_iter)
            except StopIteration:
                reddit_iter = iter(reddit_trainloader)
                reddit_batch = next(reddit_iter)

            try:
                youtube_batch = next(youtube_iter)
            except StopIteration:
                youtube_iter = iter(youtube_trainloader)
                youtube_batch = next(youtube_iter)

            # Implement learning logic for each batch
            for batch in [twitter_batch, reddit_batch, youtube_batch]:
                _, org_input, pos_input, org_bow, pos_bow, _ = batch
                org_input = org_input.cuda(device)
                org_bow = org_bow.cuda(device)
                pos_input = pos_input.cuda(device)
                pos_bow = pos_bow.cuda(device)

                batch_size = org_input.size(0) #org_input_ids.size(0)

                org_dists, org_topic_logit = model.decode(org_input)
                pos_dists, pos_topic_logit = model.decode(pos_input)

                org_topic = F.softmax(org_topic_logit, dim=1)
                pos_topic = F.softmax(pos_topic_logit, dim=1)

                org_dists = org_dists[:, :org_bow.size(1)]
                pos_dists = pos_dists[:, :pos_bow.size(1)]

                recons_loss = torch.mean(-torch.sum(torch.log(org_dists + 1E-10) * (org_bow), axis=1), axis=0)
                recons_loss += torch.mean(-torch.sum(torch.log((1-org_dists) + 1E-10) * (1-org_bow), axis=1), axis=0)
                recons_loss += torch.mean(-torch.sum(torch.log(pos_dists + 1E-10) * (pos_bow), axis=1), axis=0)
                recons_loss += torch.mean(-torch.sum(torch.log((1-pos_dists) + 1E-10) * (1-pos_bow), axis=1), axis=0)
                recons_loss *= 0.5

                # consistency loss
                pos_sim = torch.sum(org_topic * pos_topic, dim=-1)
                cons_loss = -pos_sim.mean()

                # distribution loss
                # batchmean
                distmatch_loss = dist_match_loss(torch.cat((org_topic,), dim=0), dirichlet_alpha_2)


                loss = args.coeff_2_recon * recons_loss + \
                       args.coeff_2_cons * cons_loss + \
                       args.coeff_2_dist * distmatch_loss
            
            

                losses.update(loss.item(), bsz)
                closses.update(cons_loss.item(), bsz)
                rlosses.update(recons_loss.item(), bsz)
                distlosses.update(distmatch_loss.item(), bsz)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            
        print("Epoch-{} / recon: {:.5f} - dist: {:.5f} - cons: {:.5f}".format(epoch, rlosses.avg, distlosses.avg, closses.avg))

    
        model.eval()

        # Extract the top 10 words for each topic
        top_words_per_topic = {}
        for topic_idx in range(model.N_topic):
            top_words_indices = model.beta[topic_idx].topk(10).indices
            top_words = [vocab_dict_reverse[idx.item()] for idx in top_words_indices]
            top_words_per_topic[topic_idx] = top_words
            
        reference_corpus=[doc.split() for doc in valds.preprocess_ctm(valds.nonempty_text)]
        topic_words_list = list(top_words_per_topic.values())
        result = get_topic_coherence(topic_words_list, reference_corpus)
        avg_npmi = result['NPMI']

        # Optimal NPMI and epoch tracking
        if avg_npmi > best_npmi:
            best_npmi = avg_npmi
            best_epoch = epoch
            best_model_state = model.state_dict()  # 현재 모델 상태 저장

    print(f"Best Epoch: {best_epoch} with NPMI: {best_npmi}")
    torch.save(best_model_state, 'our_best_model_state.pth')
    model.load_state_dict(torch.load('our_best_model_state.pth'))
    
    print("------- Evaluation results -------")
    # Each topic has its own wordset
    all_list = {}
    for e, i in enumerate(model.beta.cpu().topk(10, dim=1).indices):
        word_list = []
        for j in i:
            word_list.append(vocab_dict_reverse[j.item()])
        all_list[e] = word_list
        print("topic-{}".format(e), word_list)

Coeff   / regul: 1.00000 - recon: 1.00000 - c: 1.00000 - dist: 1.00000 
Epoch-0 / recon: 9.21985 - dist: 0.16057 - cons: -0.07186
Epoch-1 / recon: 8.99693 - dist: 0.14935 - cons: -0.10571
Epoch-2 / recon: 8.86256 - dist: 0.14490 - cons: -0.12476
Epoch-3 / recon: 8.75588 - dist: 0.14007 - cons: -0.13809
Epoch-4 / recon: 8.66498 - dist: 0.13428 - cons: -0.14827
Epoch-5 / recon: 8.58487 - dist: 0.12931 - cons: -0.15606
Epoch-6 / recon: 8.51427 - dist: 0.12522 - cons: -0.16180
Epoch-7 / recon: 8.45157 - dist: 0.12224 - cons: -0.16571
Epoch-8 / recon: 8.39531 - dist: 0.12025 - cons: -0.16844
Epoch-9 / recon: 8.34415 - dist: 0.11897 - cons: -0.17017
Epoch-10 / recon: 8.29713 - dist: 0.11828 - cons: -0.17118
Epoch-11 / recon: 8.25348 - dist: 0.11795 - cons: -0.17158
Epoch-12 / recon: 8.21254 - dist: 0.11803 - cons: -0.17151
Epoch-13 / recon: 8.17392 - dist: 0.11846 - cons: -0.17108
Epoch-14 / recon: 8.13720 - dist: 0.11914 - cons: -0.17041
Epoch-15 / recon: 8.10210 - dist: 0.11998 - cons: -0.

In [24]:
topic_words_list = list(all_list.values())
reference_corpus=[doc.split() for doc in testds.preprocess_ctm(testds.nonempty_text)]

topics=topic_words_list
texts=reference_corpus
print(get_topic_coherence(topics, texts))

{'NPMI': 0.7122771917868125, 'UCI': 5.188708587945945, 'UMASS': -1.1483764321230514, 'CV': 0.7297783087113847, 'Topic_Diversity': 0.955}


# MI Calulate

In [26]:
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader


model.eval()
testloader = DataLoader(testfinetuneds, batch_size=bsz * 2, shuffle=False, num_workers=0)
    

ourmodel_test_topic_labels = []
test_platform_labels = []

for batch in testloader:
    _, org_embedding, _, org_bow, _, platform_labels = batch
    org_embedding = org_embedding.to(device)
    _, topic_logit = model.decode(org_embedding)
    topic_label = torch.argmax(F.softmax(topic_logit, dim=1), dim=1)
    ourmodel_test_topic_labels.extend(topic_label.cpu().numpy())
    test_platform_labels.extend(platform_labels)

# Calculate topic distribution by platform
topic_dist_df_test = pd.crosstab(pd.Series(ourmodel_test_topic_labels, name='Topic'),
                            pd.Series(test_platform_labels, name='Platform'), normalize='index')

# Calculate topic distribution by platform and overall
platform_counts = pd.Series(test_platform_labels).value_counts()
platform_probabilities = platform_counts / platform_counts.sum()

# Calculate the entropy of the topic distribution over the entire dataset (H(Y))
topic_probabilities = pd.Series(ourmodel_test_topic_labels).value_counts(normalize=True)
H_Y = -np.sum(topic_probabilities * np.log(topic_probabilities + 1e-10))

# Compute conditional entropy and H(Y|X) for each platform
H_Y_given_X_total = 0
for platform in platform_probabilities.index:
    platform_indices = [i for i, x in enumerate(test_platform_labels) if x == platform]
    platform_topic_labels = [ourmodel_test_topic_labels[i] for i in platform_indices]
    platform_topic_prob = pd.Series(platform_topic_labels).value_counts(normalize=True)
    
    H_Y_given_X = -np.sum(platform_topic_prob * np.log(platform_topic_prob + 1e-10))
    H_Y_given_X_total += platform_probabilities[platform] * H_Y_given_X

mi = H_Y - H_Y_given_X_total
H_X = -np.sum(platform_probabilities * np.log(platform_probabilities + 1e-10))


print('H(Y):', H_Y)
print('H(Y|X):', H_Y_given_X_total)
print('Mutual Information (MI):', mi)

mi_score = mutual_info_score(ourmodel_test_topic_labels, test_platform_labels)
print("Original Mutual Information Score:", mi_score)

H(Y): 2.9031547807445492
H(Y|X): 2.8742561848737305
Mutual Information (MI): 0.028898595870818777
Original Mutual Information Score: 0.02889859587081954


# Seperate platform

In [28]:
# check the length of reference_corpus
total_length = len(reference_corpus)

if total_length >= 18000:
    twitter_texts = reference_corpus[:6000]
    reddit_texts = reference_corpus[6000:12000]
    youtube_texts = reference_corpus[12000:18000]
else:
    print("Not enough data.")

twitter_dictionary = Dictionary(twitter_texts)
twitter_dictionary.add_documents(topic_words_list)

reddit_dictionary = Dictionary(reddit_texts)
reddit_dictionary.add_documents(topic_words_list)

youtube_dictionary = Dictionary(youtube_texts)
youtube_dictionary.add_documents(topic_words_list)

## Twitter

In [29]:
print(get_topic_coherence(topics, twitter_texts))

{'NPMI': 0.5257036585398003, 'UCI': 4.529252108654194, 'UMASS': -1.8080285939387835, 'CV': 0.8279785814522855, 'Topic_Diversity': 0.955}


## Reddit

In [30]:
print(get_topic_coherence(topics, reddit_texts))

{'NPMI': 0.7669463298158827, 'UCI': 3.5512475447898737, 'UMASS': -0.9595442389240559, 'CV': 0.7988206610347721, 'Topic_Diversity': 0.955}


## Youtube

In [31]:
print(get_topic_coherence(topics, youtube_texts))

{'NPMI': 0.6444031679463139, 'UCI': 5.15491069547586, 'UMASS': -0.8707065182521553, 'CV': 0.8135890235745974, 'Topic_Diversity': 0.955}
