In [1]:
import logging

# 로거 설정
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s [%(levelname)s] - %(message)s',
                    filename='dataset_log.log')  # 로그를 저장할 파일명 지정
logger = logging.getLogger()

In [2]:
args_text = '--base-model sentence-transformers/paraphrase-MiniLM-L6-v2 ' + \
            '--dataset all --n-word 30000 --epochs-1 100 --epochs-2 50 ' + \
            '--bsz 32 --stage-2-lr 2e-2 --stage-2-repeat 5 --coeff-1-dist 50 ' + \
            '--n-cluster 20 ' + \
            '--stage-1-ckpt trained_model/news_model_paraphrase-MiniLM-L6-v2_stage1_20t_2000w_99e.ckpt ' + \
            '--palmetto-dir /home/minseo/jupyter_dir/PTM'

In [3]:
import re
import os
import sys
import time
import copy
import math
import argparse
import string
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtools.optim import RangerLars
import gensim.downloader
import itertools

from scipy.stats import ortho_group
from scipy.optimize import linear_sum_assignment as linear_assignment
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer

import numpy as np
from tqdm import tqdm_notebook
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.datasets import fetch_20newsgroups
from nltk.corpus import stopwords

from sklearn.feature_extraction.text import CountVectorizer
from utils import AverageMeter
from collections import OrderedDict

import pandas as pd
from sklearn.preprocessing import normalize
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import mutual_info_score
from gensim.corpora.dictionary import Dictionary
from pytorch_transformers import *
from sklearn.mixture import GaussianMixture
import scipy.stats
from sklearn.decomposition import PCA
from sklearn.cluster import OPTICS
from nltk.corpus import stopwords

from gensim.models.coherencemodel import CoherenceModel
from tqdm import tqdm
import scipy.sparse as sp
import nltk
from nltk.corpus import stopwords

from datetime import datetime
from itertools import combinations
import gensim.downloader
from scipy.linalg import qr
from data import *
from data import TwitterDataset, RedditDataset, YoutubeDataset, BertDataset
from model import ContBertTopicExtractorAE
from evaluation import get_topic_qualities

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from collections import OrderedDict
from torch.utils.data import ConcatDataset
import warnings
warnings.filterwarnings("ignore")

%load_ext autoreload
%autoreload 2

[nltk_data] Downloading package punkt to /home/minseo/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /home/minseo/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /home/minseo/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [4]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "0,1" 

In [5]:
def _parse_args():
    parser = argparse.ArgumentParser(description='Contrastive topic modeling')
    # 각 stage에서의 epochs수 
    parser.add_argument('--epochs-1', default=50, type=int,
                        help='Number of training epochs for Stage 1')   
    parser.add_argument('--epochs-2', default=10, type=int,
                        help='Number of training epochs for Stage 2')
    #각 stage에서의 batch size
    parser.add_argument('--bsz', type=int, default=64,
                        help='Batch size')
    #data set정의 
    parser.add_argument('--dataset', default='twitter', type=str,
                        choices=['twitter', 'reddit', 'youtube','all'],
                        help='Name of the dataset')
    # 클러스터 수와 topic의 수는 20 (k==20)
    parser.add_argument('--n-cluster', default=20, type=int,
                        help='Number of clusters')
    parser.add_argument('--n-topic', type=int,
                        help='Number of topics. If not specified, use same value as --n-cluster')
    # 단어vocabulary는 2000로 setting
    parser.add_argument('--n-word', default=30000, type=int,
                        help='Number of words in vocabulary')
    
    parser.add_argument('--base-model', type=str,
                        help='Name of base model in huggingface library.')
    
    parser.add_argument('--gpus', default=[0,1,2,3], type=int, nargs='+',
                        help='List of GPU numbers to use. Use 0 by default')
   
    parser.add_argument('--coeff-1-sim', default=1.0, type=float,
                        help='Coefficient for NN dot product similarity loss (Phase 1)')
    parser.add_argument('--coeff-1-dist', default=1.0, type=float,
                        help='Coefficient for NN SWD distribution loss (Phase 1)')
    parser.add_argument('--dirichlet-alpha-1', type=float,
                        help='Parameter for Dirichlet distribution (Phase 1). Use 1/n_topic by default.')
    
    parser.add_argument('--stage-1-ckpt', type=str,
                        help='Name of torch checkpoint file Stage 1. If this argument is given, skip Stage 1.')
 
    parser.add_argument('--coeff-2-recon', default=1.0, type=float,
                        help='Coefficient for VAE reconstruction loss (Phase 2)')
    parser.add_argument('--coeff-2-regul', default=1.0, type=float,
                        help='Coefficient for VAE KLD regularization loss (Phase 2)')
    parser.add_argument('--coeff-2-cons', default=1.0, type=float,
                        help='Coefficient for CL consistency loss (Phase 2)')
    parser.add_argument('--coeff-2-dist', default=1.0, type=float,
                        help='Coefficient for CL SWD distribution matching loss (Phase 2)')
    parser.add_argument('--dirichlet-alpha-2', type=float,
                        help='Parameter for Dirichlet distribution (Phase 2). Use same value as dirichlet-alpha-1 by default.')
    parser.add_argument('--stage-2-lr', default=2e-1, type=float,
                        help='Learning rate of phase 2')
    
    parser.add_argument('--stage-2-repeat', default=5, type=int,
                        help='Repetition count of phase 2')
    
    parser.add_argument('--result-file', type=str,
                        help='File name for result summary')
    parser.add_argument('--palmetto-dir', type=str,
                        help='Directory where palmetto JAR and the Wikipedia index are. For evaluation')
    
    
    # Check if the code is run in Jupyter notebook
    is_in_jupyter = False
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            is_in_jupyter = True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            is_in_jupyter = False  # Terminal running IPython
        else:
            is_in_jupyter = False  # Other type (?)
    except NameError:
        is_in_jupyter = False
    
    if is_in_jupyter:
        return parser.parse_args(args=args_text.split())
    else:
        return parser.parse_args()

#데이터들을 로드함(텍스트중에 none값을 제거/dataset_name에 따라 textData에 들어가는 내용이 달라짐)
def data_load(dataset_name, sample_size=10000):
    should_measure_hungarian = False
    textData = []
    
    if dataset_name == 'twitter':
        dataset = TwitterDataset(sample_size=sample_size)
        textData = [(text, "Twitter") for text in dataset.texts if not (isinstance(text, float) and math.isnan(text))]
    elif dataset_name == 'reddit':
        dataset = RedditDataset(sample_size=sample_size)
        textData = [(text, "Reddit") for text in dataset.texts if not (isinstance(text, float) and math.isnan(text))]
    elif dataset_name == 'youtube':
        dataset = YoutubeDataset(sample_size=sample_size)
        textData = [(text, "YouTube") for text in dataset.texts if not (isinstance(text, float) and math.isnan(text))]
    elif dataset_name == 'all':
        twitter_dataset = TwitterDataset(sample_size=sample_size)
        reddit_dataset = RedditDataset(sample_size=sample_size)
        youtube_dataset = YoutubeDataset(sample_size=sample_size)
        
        # filtering NaN values
        textData += [(text, "Twitter") for text in twitter_dataset.texts if not (isinstance(text, float) and math.isnan(text))]
        textData += [(text, "Reddit") for text in reddit_dataset.texts if not (isinstance(text, float) and math.isnan(text))]
        textData += [(text, "YouTube") for text in youtube_dataset.texts if not (isinstance(text, float) and math.isnan(text))]
    else:
        raise ValueError("Invalid dataset name!")
    
    return textData, should_measure_hungarian


In [6]:
args = _parse_args()
bsz = args.bsz
epochs_1 = args.epochs_1
epochs_2 = args.epochs_2

n_cluster = args.n_cluster
n_topic = args.n_topic if (args.n_topic is not None) else n_cluster
args.n_topic = n_topic

textData, should_measure_hungarian = data_load(args.dataset)

ema_alpha = 0.99
n_word = args.n_word
if args.dirichlet_alpha_1 is None:
    dirichlet_alpha_1 = 1 / n_cluster
else:
    dirichlet_alpha_1 = args.dirichlet_alpha_1
if args.dirichlet_alpha_2 is None:
    dirichlet_alpha_2 = dirichlet_alpha_1
else:
    dirichlet_alpha_2 = args.dirichlet_alpha_2
    
bert_name = args.base_model
bert_name_short = bert_name.split('/')[-1]
gpu_ids = args.gpus

skip_stage_1 = (args.stage_1_ckpt is not None)

In [7]:
jargons = set(['bigdata', 'RT', 'announces', 'rival', 'MSFT', 'launches', 'SEO', 'tweets', 'WSJ', 'unveils', 'MBA', 'machinelearning', 'HN', 'artificalintelligence', 'DALLE', 'NFT', 'NLP', 'edtech', 'cybersecurity', 'malware', 'CEO', 'ML', 'JPEG', 'GOOGL', 'UX', 'artificialintelligence', 'technews', 'fintech', 'tweet', 'CHATGPT', 'viral', 'blockchain', 'BARD', 'reportedly', 'BREAKING', 'deeplearning', 'educators', 'datascience', 'GOOG', 'competitor', 'FREE', 'startups', 'AGIX', 'CHAT', 'CNET', 'integrating', 'BB', 'maker', 'passes', 'NYT', 'yall', 'rescue', 'heres', 'NYC', 'LLM', 'airdrop', 'powered', 'podcast', 'Q', 'PM', 'newsletter', 'startup', 'TSLA', 'WIRED', 'digitalmarketing', 'CTO', 'changer', 'announced', 'database', 'launched', 'warns', 'popularity', 'AWS', 'nocode', 'trending', 'metaverse', 'cc', 'PR', 'RSS', 'MWC', 'USMLE', 'copywriting', 'marketers', 'tweeting', 'amid', 'AGI', 'socialmedia', 'webinar', 'agrees', 'invests', 'launch', 'killer', 'GM', 'bullish', 'edchat', 'RLHF', 'integration', 'fastestgrowing', 'CNN', 'exam',
                  'deleted', 'gif', 'giphy', 'dm', 'removed', 'remindme', 'yup', 'sydney', 'yep', 'patched', 'nope', 'giphydownsized', 'vpn', 'ascii', 'ah', 'chadgpt', 'nerfed', 'jesus', 'xd', 'wtf', 'upvote', 'nah', 'op', 'mods', 'hahaha', 'nsfw', 'huh', 'holy', 'iq', 'jailbreak', 'blah', 'bruh', 'yea', 'agi', 'porn', 'waitlist', 'nerf', 'downvoted', 'refresh', 'omg', 'sus', 'characterai', 'meth', 'chinese', 'sub', 'rick', 'american', 'elon', 'sam', 'quack', 'youchat', 'uk', 'chad', 'archived', 'youcom', 'screenshot', 'llm', 'hitler', 'lmao', 'playground', 'rpg', 'delete', 'tldr', 'davinci', 'trump', 'hangman', 'haha', 'tay', 'karma', 'john', 'chatgtp', 'url', 'wokegpt', 'offended', 'fucked', 'redditor', 'ceo', 'agreed', 'emojis', 'cheers', 'ais', 'tag', 'wow', 'lmfao', 'p', 'rip', 'chats', 'hmm', 'bypass', 'llms', 'temperature', 'login', 'cgpt', 'windows', 'novelai', 'biden', 'donald', 'christmas', 'ms', 'cringe',
                   'ZRONX', 'rook', 'thumbnail', 'vid', 'bhai', 'bishop', 'circle', 'subscribed', 'quot', 'bless', 'tutorial', 'XD', 'sir', 'GEMX', 'profitable', 'earning', 'quotquot', 'enjoyed', 'ur', 'bra', 'JIM', 'broker', 'levy', 'vids', 'stare', 'tutorials', 'subscribers', 'sponsor', 'hai', 'lifechanging', 'curve', 'shorts', 'earn', 'trader', 'PC', 'folders', 'informative', 'br', 'chess', 'jontron', 'brother', 'T', 'YT', 'upload', 'O', 'subscriber', 'intro', 'DAN', 'aint', 'download', 'LOL', 'shes', 'moves', 'telegram', 'shortlisted', 'liked', 'websiteapp', 'watched', 'grant', 'plz', 'KINGDOM', 'YOU', 'MESSIAH', 'mate', 'ki', 'subs', 'pawn', 'hes', 'U', 'HACKBANZER', 'ka', 'brbr', 'affiliate', 'clip', 'beast', 'trade', 'ive', 'ho', 'approved', 'bhi', 'gotta', 'profits', 'wanna', 'subscribe', 'funds', 'labels', 'recommended', 'audio', 'uploaded', 'appreciated', 'UBI', 'pls', 'upto', 'alot', 'twist', 'GTP', 'accent', 'monetized', 'S', 'btw'
                  ])

In [8]:
from sklearn.model_selection import train_test_split

# 각 데이터셋 초기화(data.py에서 확인 가능)
twitter_ds = TwitterDataset()
reddit_ds = RedditDataset()
youtube_ds = YoutubeDataset()

# train_twitter_texts, test_twitter_texts, train_twitter_labels, test_twitter_labels = train_test_split(twitter_ds.texts, twitter_ds.labels, train_size = 0.7, test_size = 0.3)
# train_reddit_texts, test_reddit_texts, train_reddit_labels, test_reddit_labels = train_test_split(reddit_ds.texts, reddit_ds.labels, train_size = 0.7, test_size = 0.3)
# train_youtube_texts, test_youtube_texts, train_youtube_labels, test_youtube_labels = train_test_split(youtube_ds.texts, youtube_ds.labels, train_size = 0.7, test_size = 0.3)

train_twitter_texts, test_twitter_texts = train_test_split(twitter_ds.texts, train_size = 0.7, test_size = 0.3)
train_reddit_texts, test_reddit_texts = train_test_split(reddit_ds.texts, train_size = 0.7, test_size = 0.3)
train_youtube_texts, test_youtube_texts = train_test_split(youtube_ds.texts, train_size = 0.7, test_size = 0.3)

# train_total_gpt_label = train_twitter_labels + train_reddit_labels + train_youtube_labels
train_total_text_list = train_twitter_texts + train_reddit_texts + train_youtube_texts

# test_total_gpt_label = test_twitter_labels + test_reddit_labels + test_youtube_labels
test_total_text_list = test_twitter_texts + test_reddit_texts + test_youtube_texts

In [9]:
class BertDataset(Dataset):
    def __init__(self, bert, text_list, N_word, vectorizer=None, lemmatize=False):
        self.lemmatize = lemmatize
        self.nonempty_text = [text for text in text_list if len(text) > 0]
        
        # Remove new lines
        self.nonempty_text = [re.sub("\n"," ", sent) for sent in self.nonempty_text]
                
        # Remove Emails
        self.nonempty_text = [re.sub('\S*@\S*\s?', '', sent) for sent in self.nonempty_text]
        
        # Remove new line characters
        self.nonempty_text = [re.sub('\s+', ' ', sent) for sent in self.nonempty_text]
        
        # Remove distracting single quotes
        self.nonempty_text = [re.sub("\'", "", sent) for sent in self.nonempty_text]
        
        self.jargons = set(['bigdata', 'RT', 'announces', 'rival', 'MSFT', 'launches', 'SEO', 'tweets', 'WSJ', 'unveils', 'MBA', 'machinelearning', 'HN', 'artificalintelligence', 'DALLE', 'NFT', 'NLP', 'edtech', 'cybersecurity', 'malware', 'CEO', 'ML', 'JPEG', 'GOOGL', 'UX', 'artificialintelligence', 'technews', 'fintech', 'tweet', 'CHATGPT', 'viral', 'blockchain', 'BARD', 'reportedly', 'BREAKING', 'deeplearning', 'educators', 'datascience', 'GOOG', 'competitor', 'FREE', 'startups', 'AGIX', 'CHAT', 'CNET', 'integrating', 'BB', 'maker', 'passes', 'NYT', 'yall', 'rescue', 'heres', 'NYC', 'LLM', 'airdrop', 'powered', 'podcast', 'Q', 'PM', 'newsletter', 'startup', 'TSLA', 'WIRED', 'digitalmarketing', 'CTO', 'changer', 'announced', 'database', 'launched', 'warns', 'popularity', 'AWS', 'nocode', 'trending', 'metaverse', 'cc', 'PR', 'RSS', 'MWC', 'USMLE', 'copywriting', 'marketers', 'tweeting', 'amid', 'AGI', 'socialmedia', 'webinar', 'agrees', 'invests', 'launch', 'killer', 'GM', 'bullish', 'edchat', 'RLHF', 'integration', 'fastestgrowing', 'CNN', 'exam',
                  'deleted', 'gif', 'giphy', 'dm', 'removed', 'remindme', 'yup', 'sydney', 'yep', 'patched', 'nope', 'giphydownsized', 'vpn', 'ascii', 'ah', 'chadgpt', 'nerfed', 'jesus', 'xd', 'wtf', 'upvote', 'nah', 'op', 'mods', 'hahaha', 'nsfw', 'huh', 'holy', 'iq', 'jailbreak', 'blah', 'bruh', 'yea', 'agi', 'porn', 'waitlist', 'nerf', 'downvoted', 'refresh', 'omg', 'sus', 'characterai', 'meth', 'chinese', 'sub', 'rick', 'american', 'elon', 'sam', 'quack', 'youchat', 'uk', 'chad', 'archived', 'youcom', 'screenshot', 'llm', 'hitler', 'lmao', 'playground', 'rpg', 'delete', 'tldr', 'davinci', 'trump', 'hangman', 'haha', 'tay', 'karma', 'john', 'chatgtp', 'url', 'wokegpt', 'offended', 'fucked', 'redditor', 'ceo', 'agreed', 'emojis', 'cheers', 'ais', 'tag', 'wow', 'lmfao', 'p', 'rip', 'chats', 'hmm', 'bypass', 'llms', 'temperature', 'login', 'cgpt', 'windows', 'novelai', 'biden', 'donald', 'christmas', 'ms', 'cringe',
                   'ZRONX', 'rook', 'thumbnail', 'vid', 'bhai', 'bishop', 'circle', 'subscribed', 'quot', 'bless', 'tutorial', 'XD', 'sir', 'GEMX', 'profitable', 'earning', 'quotquot', 'enjoyed', 'ur', 'bra', 'JIM', 'broker', 'levy', 'vids', 'stare', 'tutorials', 'subscribers', 'sponsor', 'hai', 'lifechanging', 'curve', 'shorts', 'earn', 'trader', 'PC', 'folders', 'informative', 'br', 'chess', 'jontron', 'brother', 'T', 'YT', 'upload', 'O', 'subscriber', 'intro', 'DAN', 'aint', 'download', 'LOL', 'shes', 'moves', 'telegram', 'shortlisted', 'liked', 'websiteapp', 'watched', 'grant', 'plz', 'KINGDOM', 'YOU', 'MESSIAH', 'mate', 'ki', 'subs', 'pawn', 'hes', 'U', 'HACKBANZER', 'ka', 'brbr', 'affiliate', 'clip', 'beast', 'trade', 'ive', 'ho', 'approved', 'bhi', 'gotta', 'profits', 'wanna', 'subscribe', 'funds', 'labels', 'recommended', 'audio', 'uploaded', 'appreciated', 'UBI', 'pls', 'upto', 'alot', 'twist', 'GTP', 'accent', 'monetized', 'S', 'btw'
                  ])
        
        self.tokenizer = AutoTokenizer.from_pretrained(bert)
        self.stopwords_list = set(TfidfVectorizer(stop_words="english").get_stop_words()).union(self.jargons)
        self.N_word = N_word
        
        if vectorizer == None:
            self.vectorizer = TfidfVectorizer(stop_words=None, max_features=self.N_word, token_pattern=r'\b[a-zA-Z]{2,}\b')
            self.vectorizer.fit(self.preprocess_ctm(self.nonempty_text))
        else:
            self.vectorizer = vectorizer
            
        self.org_list = []
        self.bow_list = []
        for sent in tqdm(self.nonempty_text):
            org_input = self.tokenizer(sent, padding='max_length', truncation=True, max_length=512, return_tensors='pt')
            org_input['input_ids'] = torch.squeeze(org_input['input_ids'])
            org_input['attention_mask'] = torch.squeeze(org_input['attention_mask'])
            self.org_list.append(org_input)
            self.bow_list.append(self.vectorize(sent))
            
        
    def vectorize(self, text):
        text = self.preprocess_ctm([text])
        vectorized_input = self.vectorizer.transform(text)
        vectorized_input = vectorized_input.toarray()
        vectorized_input = vectorized_input.astype(np.float64)

        # Get word distribution from BoW
        vectorized_input += 1e-8
        vectorized_input = vectorized_input / vectorized_input.sum(axis=1, keepdims=True)
        assert abs(vectorized_input.sum() - vectorized_input.shape[0]) < 0.01
        vectorized_label = torch.tensor(vectorized_input, dtype=torch.float)
        return vectorized_label[0]
        
        
    def preprocess_ctm(self, documents):
        preprocessed_docs_tmp = documents
        preprocessed_docs_tmp = [doc.lower() for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [doc.translate(
            str.maketrans(string.punctuation, ' ' * len(string.punctuation))) for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [' '.join([w for w in doc.split() if len(w) > 0 and w not in self.stopwords_list])
                             for doc in preprocessed_docs_tmp]
        if self.lemmatize:
            lemmatizer = WordNetLemmatizer()
            preprocessed_docs_tmp = [' '.join([lemmatizer.lemmatize(w) for w in doc.split()])
                                     for doc in preprocessed_docs_tmp]
        return preprocessed_docs_tmp        
    
    
    def __len__(self):
        return len(self.nonempty_text)
    

    def __getitem__(self, idx):
        return self.org_list[idx], self.bow_list[idx]

In [10]:
trainds = BertDataset(bert=bert_name, text_list=train_total_text_list, N_word=n_word, vectorizer=None, lemmatize=True)

testds = BertDataset(bert=bert_name, text_list=test_total_text_list, N_word=n_word, vectorizer=None, lemmatize=True)

100%|██████████| 105000/105000 [03:30<00:00, 499.48it/s]
100%|██████████| 45000/45000 [01:24<00:00, 532.75it/s]


In [11]:
# Transform all_texts to create a BoW matrix
total_bow_matrix = trainds.vectorizer.transform(train_total_text_list).toarray()

test_bow_matrix = testds.vectorizer.transform(test_total_text_list).toarray()

In [12]:
#확인용 코드
total_bow_matrix.shape

(105000, 30000)

In [13]:
print("배열의 크기:", total_bow_matrix.shape)
print("배열의 메모리 사용량:", total_bow_matrix.nbytes, "bytes")

배열의 크기: (105000, 30000)
배열의 메모리 사용량: 25200000000 bytes


In [14]:
n_word = total_bow_matrix.shape[1]

# Re_fornulate the bow

In [15]:
def dist_match_loss(hiddens, alpha=1.0):
    device = hiddens.device
    hidden_dim = hiddens.shape[-1]
    H = np.random.randn(hidden_dim, hidden_dim)
    Q, R = qr(H) 
    rand_w = torch.Tensor(Q).to(device)
    loss_dist_match = get_swd_loss(hiddens, rand_w, alpha)
    return loss_dist_match


def js_div_loss(hidden1, hidden2):
    m = 0.5 * (hidden1 + hidden2)
    return kldiv(m.log(), hidden1) + kldiv(m.log(), hidden2)


def get_swd_loss(states, rand_w, alpha=1.0):
    device = states.device
    states_shape = states.shape
    states = torch.matmul(states, rand_w)
    states_t, _ = torch.sort(states.t(), dim=1)

    # Random vector with length from normal distribution
    states_prior = torch.Tensor(np.random.dirichlet([alpha]*states_shape[1], states_shape[0])).to(device) # (bsz, dim)
    states_prior = torch.matmul(states_prior, rand_w) # (dim, dim)
    states_prior_t, _ = torch.sort(states_prior.t(), dim=1) # (dim, bsz)
    return torch.mean(torch.sum((states_prior_t - states_t)**2, axis=0))

In [16]:
from sklearn.metrics.pairwise import cosine_similarity

In [17]:
def compute_max_cosine_similarity_indices(total_bow_matrix, batch_size=500):
    n_rows = total_bow_matrix.shape[0]
    max_similarity_indices = np.zeros(n_rows, dtype=np.int64)

    for start_idx in range(0, n_rows, batch_size):
        end_idx = min(start_idx + batch_size, n_rows)
        batch_data = total_bow_matrix[start_idx:end_idx]

        batch_similarity = cosine_similarity(batch_data, total_bow_matrix)

        # 자기 자신과의 유사도를 -1로 설정
        for i, original_idx in enumerate(range(start_idx, end_idx)):
            batch_similarity[i, original_idx] = -1

        # 각 행에서 가장 큰 값을 가진 인덱스 찾기
        max_indices = np.argmax(batch_similarity, axis=1)
        max_similarity_indices[start_idx:end_idx] = max_indices

        logger.info(f"{end_idx}/{n_rows} 데이터 처리 완료")

    return max_similarity_indices

In [None]:
# 코사인 유사도 매트릭스를 계산합니다.
similarity_matrix = compute_max_cosine_similarity_indices(total_bow_matrix)

test_similarity_matrix = compute_max_cosine_similarity_indices(test_bow_matrix)

In [None]:
similarity_matrix.shape

In [None]:
class Stage2Dataset(Dataset):
    def __init__(self, encoder, ds, similarity_matrix, N_word, k=1, lemmatize=False):
        self.lemmatize = lemmatize
        self.ds = ds
        self.org_list = self.ds.org_list
        self.nonempty_text = self.ds.nonempty_text
        self.N_word = N_word
        
        self.jargons = set(['bigdata', 'RT', 'announces', 'rival', 'MSFT', 'launches', 'SEO', 'tweets', 'WSJ', 'unveils', 'MBA', 'machinelearning', 'HN', 'artificalintelligence', 'DALLE', 'NFT', 'NLP', 'edtech', 'cybersecurity', 'malware', 'CEO', 'ML', 'JPEG', 'GOOGL', 'UX', 'artificialintelligence', 'technews', 'fintech', 'tweet', 'CHATGPT', 'viral', 'blockchain', 'BARD', 'reportedly', 'BREAKING', 'deeplearning', 'educators', 'datascience', 'GOOG', 'competitor', 'FREE', 'startups', 'AGIX', 'CHAT', 'CNET', 'integrating', 'BB', 'maker', 'passes', 'NYT', 'yall', 'rescue', 'heres', 'NYC', 'LLM', 'airdrop', 'powered', 'podcast', 'Q', 'PM', 'newsletter', 'startup', 'TSLA', 'WIRED', 'digitalmarketing', 'CTO', 'changer', 'announced', 'database', 'launched', 'warns', 'popularity', 'AWS', 'nocode', 'trending', 'metaverse', 'cc', 'PR', 'RSS', 'MWC', 'USMLE', 'copywriting', 'marketers', 'tweeting', 'amid', 'AGI', 'socialmedia', 'webinar', 'agrees', 'invests', 'launch', 'killer', 'GM', 'bullish', 'edchat', 'RLHF', 'integration', 'fastestgrowing', 'CNN', 'exam',
                  'deleted', 'gif', 'giphy', 'dm', 'removed', 'remindme', 'yup', 'sydney', 'yep', 'patched', 'nope', 'giphydownsized', 'vpn', 'ascii', 'ah', 'chadgpt', 'nerfed', 'jesus', 'xd', 'wtf', 'upvote', 'nah', 'op', 'mods', 'hahaha', 'nsfw', 'huh', 'holy', 'iq', 'jailbreak', 'blah', 'bruh', 'yea', 'agi', 'porn', 'waitlist', 'nerf', 'downvoted', 'refresh', 'omg', 'sus', 'characterai', 'meth', 'chinese', 'sub', 'rick', 'american', 'elon', 'sam', 'quack', 'youchat', 'uk', 'chad', 'archived', 'youcom', 'screenshot', 'llm', 'hitler', 'lmao', 'playground', 'rpg', 'delete', 'tldr', 'davinci', 'trump', 'hangman', 'haha', 'tay', 'karma', 'john', 'chatgtp', 'url', 'wokegpt', 'offended', 'fucked', 'redditor', 'ceo', 'agreed', 'emojis', 'cheers', 'ais', 'tag', 'wow', 'lmfao', 'p', 'rip', 'chats', 'hmm', 'bypass', 'llms', 'temperature', 'login', 'cgpt', 'windows', 'novelai', 'biden', 'donald', 'christmas', 'ms', 'cringe',
                   'ZRONX', 'rook', 'thumbnail', 'vid', 'bhai', 'bishop', 'circle', 'subscribed', 'quot', 'bless', 'tutorial', 'XD', 'sir', 'GEMX', 'profitable', 'earning', 'quotquot', 'enjoyed', 'ur', 'bra', 'JIM', 'broker', 'levy', 'vids', 'stare', 'tutorials', 'subscribers', 'sponsor', 'hai', 'lifechanging', 'curve', 'shorts', 'earn', 'trader', 'PC', 'folders', 'informative', 'br', 'chess', 'jontron', 'brother', 'T', 'YT', 'upload', 'O', 'subscriber', 'intro', 'DAN', 'aint', 'download', 'LOL', 'shes', 'moves', 'telegram', 'shortlisted', 'liked', 'websiteapp', 'watched', 'grant', 'plz', 'KINGDOM', 'YOU', 'MESSIAH', 'mate', 'ki', 'subs', 'pawn', 'hes', 'U', 'HACKBANZER', 'ka', 'brbr', 'affiliate', 'clip', 'beast', 'trade', 'ive', 'ho', 'approved', 'bhi', 'gotta', 'profits', 'wanna', 'subscribe', 'funds', 'labels', 'recommended', 'audio', 'uploaded', 'appreciated', 'UBI', 'pls', 'upto', 'alot', 'twist', 'GTP', 'accent', 'monetized', 'S', 'btw'
                  ])
        
#         english_stopwords = nltk.corpus.stopwords.words('english')
#         self.stopwords_list = set(english_stopwords)
        self.stopwords_list = set(TfidfVectorizer(stop_words="english").get_stop_words()).union(self.jargons)
        
        self.vectorizer = TfidfVectorizer(stop_words=None, max_features=self.N_word, token_pattern=r'\b[a-zA-Z]{2,}\b')
        self.vectorizer.fit(self.preprocess_ctm(self.nonempty_text)) 
        self.bow_list = []
        for sent in tqdm(self.nonempty_text):
            self.bow_list.append(self.vectorize(sent))
            
#         self.similarity_matrix = torch.tensor(similarity_matrix)  # NumPy 배열을 PyTorch 텐서로 변환
#         sim_weight, sim_indices = self.similarity_matrix.topk(k=k, dim=-1)
#         zip_iterator = zip(np.arange(len(sim_weight)), sim_indices.squeeze().data.numpy())
#         self.pos_dict = dict(zip_iterator)
        self.pos_dict = similarity_matrix
        
        self.embedding_list = []
        encoder_device = next(encoder.parameters()).device
        for org_input in tqdm(self.org_list):
            org_input_ids = org_input['input_ids'].to(encoder_device).reshape(1, -1)
            org_attention_mask = org_input['attention_mask'].to(encoder_device).reshape(1, -1)
            embedding = encoder(input_ids = org_input_ids, attention_mask = org_attention_mask)
            self.embedding_list.append(embedding['pooler_output'].squeeze().detach().cpu())
            
    
    def __len__(self):
        return len(self.org_list)
        
    def preprocess_ctm(self, documents):
        preprocessed_docs_tmp = documents
        preprocessed_docs_tmp = [doc.lower() for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [doc.translate(
            str.maketrans(string.punctuation, ' ' * len(string.punctuation))) for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [' '.join([w for w in doc.split() if len(w) > 0 and w not in self.stopwords_list])
                                 for doc in preprocessed_docs_tmp]
        if self.lemmatize:
            lemmatizer = WordNetLemmatizer()
            preprocessed_docs_tmp = [' '.join([lemmatizer.lemmatize(w) for w in doc.split()])
                                     for doc in preprocessed_docs_tmp]
        return preprocessed_docs_tmp
        
    def vectorize(self, text):
        text = self.preprocess_ctm([text])
        vectorized_input = self.vectorizer.transform(text)
        vectorized_input = vectorized_input.toarray().astype(np.float64)
#         vectorized_input = (vectorized_input != 0).astype(np.float64)

        # Get word distribution from BoW
        if vectorized_input.sum() == 0:
            vectorized_input += 1e-8
        vectorized_input = vectorized_input / vectorized_input.sum(axis=1, keepdims=True)
        assert abs(vectorized_input.sum() - vectorized_input.shape[0]) < 0.01
        
        vectorized_label = torch.tensor(vectorized_input, dtype=torch.float)
        return vectorized_label[0]
        
        
    def __getitem__(self, idx):
        pos_idx = self.pos_dict[idx]
        return idx, self.embedding_list[idx], self.embedding_list[pos_idx], self.bow_list[idx], self.bow_list[pos_idx]


In [None]:
model =  ContBertTopicExtractorAE(N_topic=n_topic, N_word=args.n_word, bert=bert_name, bert_dim=768)

In [None]:
torch.cuda.empty_cache()

In [None]:
finetuneds = Stage2Dataset(model.encoder, trainds, similarity_matrix, n_word, lemmatize=True)  

testfinetuneds = Stage2Dataset(model.encoder, testds, test_similarity_matrix, n_word, lemmatize=True) 

kldiv = torch.nn.KLDivLoss(reduction='batchmean')
vocab_dict = finetuneds.vectorizer.vocabulary_
vocab_dict_reverse = {i:v for v, i in vocab_dict.items()}
print(n_word)

In [None]:
# len(finetuneds.embedding_list)
len(finetuneds.bow_list)

# Stage 3

In [None]:
def measure_hungarian_score(topic_dist, train_target):
    dist = topic_dist
    train_target_filtered = train_target
    flat_predict = torch.tensor(np.argmax(dist, axis=1))
    flat_target = torch.tensor(train_target_filtered).to(flat_predict.device)
    num_samples = flat_predict.shape[0]
    num_classes = dist.shape[1]
    match = _hungarian_match(flat_predict, flat_target, num_samples, num_classes)    
    reordered_preds = torch.zeros(num_samples).to(flat_predict.device)
    for pred_i, target_i in match:
        reordered_preds[flat_predict == pred_i] = int(target_i)
    acc = int((reordered_preds == flat_target.float()).sum()) / float(num_samples)
    return acc

In [None]:
should_measure_hungarian = True

# Main

In [None]:
from sklearn.metrics import mutual_info_score
from scipy.stats import chi2_contingency 
from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora.dictionary import Dictionary

args.stage_2_repeat = 1
results_list = []

for i in range(args.stage_2_repeat):
    model = ContBertTopicExtractorAE(N_topic=n_topic, N_word=args.n_word, bert=bert_name, bert_dim=768)
    model.beta = nn.Parameter(torch.Tensor(model.N_topic, n_word))
    nn.init.xavier_uniform_(model.beta)
    model.beta_batchnorm = nn.Sequential()
    model.cuda(gpu_ids[0])
    
    losses = AverageMeter()
    dlosses = AverageMeter() 
    rlosses = AverageMeter()
    closses = AverageMeter()
    distlosses = AverageMeter()
    ##수정
    trainloader = DataLoader(finetuneds, batch_size=bsz, shuffle=True, num_workers=0)
    memoryloader = DataLoader(finetuneds, batch_size=bsz * 2, shuffle=False, num_workers=0)
    ##
    optimizer = torch.optim.Adam(model.parameters(), lr=args.stage_2_lr)

    memory_queue = F.softmax(torch.randn(512, n_topic).cuda(gpu_ids[0]), dim=1)
    print("Coeff   / regul: {:.5f} - recon: {:.5f} - c: {:.5f} - dist: {:.5f} ".format(args.coeff_2_regul, 
                                                                                        args.coeff_2_recon,
                                                                                        args.coeff_2_cons,
                                                                                        args.coeff_2_dist))
    for epoch in range(50):
        model.train()
        model.encoder.eval()
        for batch_idx, batch in enumerate(trainloader):  
            _,org_input, pos_input, org_bow, pos_bow = batch
            org_input = org_input.cuda(gpu_ids[0])
            org_bow = org_bow.cuda(gpu_ids[0])
            pos_input = pos_input.cuda(gpu_ids[0])
            pos_bow = pos_bow.cuda(gpu_ids[0])

            batch_size = org_input.size(0) #org_input_ids.size(0)

            org_dists, org_topic_logit = model.decode(org_input)
            pos_dists, pos_topic_logit = model.decode(pos_input)

            org_topic = F.softmax(org_topic_logit, dim=1)
            pos_topic = F.softmax(pos_topic_logit, dim=1)

            # reconstruction loss
            # batchmean
#             org_target = torch.matmul(org_topic.detach(), weight_cands)
#             pos_target = torch.matmul(pos_topic.detach(), weight_cands)
            
#             _, org_target = torch.max(org_topic.detach(), 1)
#             _, pos_target = torch.max(pos_topic.detach(), 1)
            
            recons_loss = torch.mean(-torch.sum(torch.log(org_dists + 1E-10) * (org_bow), axis=1), axis=0)
            recons_loss += torch.mean(-torch.sum(torch.log((1-org_dists) + 1E-10) * (1-org_bow), axis=1), axis=0)
            recons_loss += torch.mean(-torch.sum(torch.log(pos_dists + 1E-10) * (pos_bow), axis=1), axis=0)
            recons_loss += torch.mean(-torch.sum(torch.log((1-pos_dists) + 1E-10) * (1-pos_bow), axis=1), axis=0)
            recons_loss *= 0.5

            # consistency loss
            pos_sim = torch.sum(org_topic * pos_topic, dim=-1)
            cons_loss = -pos_sim.mean()

            # distribution loss
            # batchmean
#             distmatch_loss = dist_match_loss(torch.cat((org_topic), dim=0), dirichlet_alpha_2)
            distmatch_loss = dist_match_loss(torch.cat((org_topic,), dim=0), dirichlet_alpha_2)


            loss = args.coeff_2_recon * recons_loss + \
                   args.coeff_2_cons * cons_loss + \
                   args.coeff_2_dist * distmatch_loss 
            
            losses.update(loss.item(), bsz)
            closses.update(cons_loss.item(), bsz)
            rlosses.update(recons_loss.item(), bsz)
            distlosses.update(distmatch_loss.item(), bsz)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        print("Epoch-{} / recon: {:.5f} - dist: {:.5f} - cons: {:.5f}".format(epoch, rlosses.avg, distlosses.avg, closses.avg))

    print("------- Evaluation results -------")
    #각 토픽당 가지는 워드셋
    all_list = {}
    for e, i in enumerate(model.beta.cpu().topk(10, dim=1).indices):
        word_list = []
        for j in i:
            word_list.append(vocab_dict_reverse[j.item()])
        all_list[e] = word_list
        print("topic-{}".format(e), word_list)

    topic_words_list = list(all_list.values())
#     now = datetime.now().strftime('%y%m%d_%H%M%S')
    
    reference_corpus=[doc.split() for doc in testds.preprocess_ctm(testds.nonempty_text)]
    reference_dictionary = Dictionary(reference_corpus)
    reference_dictionary.add_documents(topic_words_list)
    
    # c_v 코히어런스 모델 생성 및 점수 계산
    cv_model = CoherenceModel(topics=topic_words_list, texts=reference_corpus, dictionary=reference_dictionary, coherence='c_v', topn = 10)
    cv_score = cv_model.get_coherence()

    # npmi 코히어런스 모델 생성 및 점수 계산
    npmi_model = CoherenceModel(topics=topic_words_list, texts=reference_corpus, dictionary=reference_dictionary, coherence='c_npmi',topn = 10)
    npmi_score = npmi_model.get_coherence()

    # umass 코히어런스 모델 생성 및 점수 계산
    umass_model = CoherenceModel(topics=topic_words_list, texts=reference_corpus, dictionary=reference_dictionary, coherence='u_mass',topn = 10)
    umass_score = umass_model.get_coherence()

    # uci 코히어런스 모델 생성 및 점수 계산
    uci_model = CoherenceModel(topics=topic_words_list, texts=reference_corpus, dictionary=reference_dictionary, coherence='c_uci',topn = 10)
    uci_score = uci_model.get_coherence()

    # 각 코히어런스 점수 출력
    print(f"c_v Score: {cv_score}")
    print(f"NPMI Score: {npmi_score}")
    print(f"UMass Score: {umass_score}")
    print(f"UCI Score: {uci_score}")
    
#     results = get_topic_qualities(topic_words_list, 
#             reference_corpus, 
#             dictionary=dictionary,
#             filename=f'results/{now}.txt')

#     print(results)
#     print()
#     results_list.append(results)

In [None]:
    # umass 코히어런스 모델 생성 및 점수 계산
    umass_model = CoherenceModel(topics=topic_words_list, texts=reference_corpus, dictionary=reference_dictionary, coherence='u_mass', topn=10)
    umass_score = umass_model.get_coherence()

    print(f"UMass Score: {umass_score}")