# Setting(dataset, parameter)

In [1]:
import os

# 원하는 저장 경로를 변수에 저장 
desired_cache_path = "/mnt/ssd1/don_ssd1/twitter_crawling_don/UTopic/hub" # 본인 mnt 디렉토리 주소 입력

# TRANSFORMERS_CACHE 환경변수 설정
os.environ["TRANSFORMERS_CACHE"] = desired_cache_path

In [2]:
args_text = '--base-model sentence-transformers/paraphrase-MiniLM-L6-v2 ' + \
            '--dataset all --n-word 30000 ' + \
            '--bsz 32 --stage-2-lr 2e-2 --stage-2-repeat 5 ' + \
            '--n-cluster 3 '

In [3]:
import re
import os
import time
import argparse
import string
import torch
import torch.nn as nn
import torch.nn.functional as F
import gensim.downloader
import itertools

from sentence_transformers import SentenceTransformer

import numpy as np
from tqdm import tqdm_notebook
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from sklearn.datasets import fetch_20newsgroups
from nltk.corpus import stopwords

from utils import AverageMeter

import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import mutual_info_score
from gensim.corpora.dictionary import Dictionary
from pytorch_transformers import *
import scipy.stats

from gensim.models.coherencemodel import CoherenceModel
from tqdm import tqdm
import nltk

from datetime import datetime
import gensim.downloader
from scipy.linalg import qr
from data import *
from data import TwitterDataset, RedditDataset, YoutubeDataset, BertDataset
from model import ContBertTopicExtractorAE

from data import BertDataset, Stage2Dataset
import random
from sklearn.metrics.pairwise import cosine_similarity

import warnings
warnings.filterwarnings("ignore")

%load_ext autoreload
%autoreload 2

[nltk_data] Downloading package punkt to /home/don12/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /home/don12/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /home/don12/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [4]:
!nvidia-smi

Thu Mar 14 23:06:22 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.105.01   Driver Version: 515.105.01   CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A6000    Off  | 00000000:18:00.0 Off |                  Off |
| 39%   61C    P2   154W / 300W |  45715MiB / 49140MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000    Off  | 00000000:3B:00.0 Off |                  Off |
| 40%   68C    P2   221W / 300W |  25373MiB / 49140MiB |     55%      Default |
|       

In [4]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "0,1" 

In [5]:
def _parse_args():
    parser = argparse.ArgumentParser(description='Contrastive topic modeling')

    #각 stage에서의 batch size
    parser.add_argument('--bsz', type=int, default=64,
                        help='Batch size')
    #data set정의 
    parser.add_argument('--dataset', default='twitter', type=str,
                        choices=['twitter', 'reddit', 'youtube','all'],
                        help='Name of the dataset')
    # 클러스터 수와 topic의 수는 20 (k==20)
    parser.add_argument('--n-cluster', default=3, type=int,
                        help='Number of clusters')
    parser.add_argument('--n-topic', type=int,
                        help='Number of topics. If not specified, use same value as --n-cluster')
    # 단어vocabulary는 2000로 setting
    parser.add_argument('--n-word', default=30000, type=int,
                        help='Number of words in vocabulary')
    
    parser.add_argument('--base-model', type=str,
                        help='Name of base model in huggingface library.')
    
    parser.add_argument('--gpus', default=[0,1,2,3], type=int, nargs='+',
                        help='List of GPU numbers to use. Use 0 by default')
   
    parser.add_argument('--dirichlet-alpha-1', type=float,
                        help='Parameter for Dirichlet distribution (Phase 1). Use 1/n_topic by default.')
 
    parser.add_argument('--coeff-2-recon', default=1.0, type=float,
                        help='Coefficient for VAE reconstruction loss (Phase 2)')
    parser.add_argument('--coeff-2-regul', default=1.0, type=float,
                        help='Coefficient for VAE KLD regularization loss (Phase 2)')
    parser.add_argument('--coeff-2-cons', default=1.0, type=float,
                        help='Coefficient for CL consistency loss (Phase 2)')
    parser.add_argument('--coeff-2-dist', default=1.0, type=float,
                        help='Coefficient for CL SWD distribution matching loss (Phase 2)')
    parser.add_argument('--dirichlet-alpha-2', type=float,
                        help='Parameter for Dirichlet distribution (Phase 2). Use same value as dirichlet-alpha-1 by default.')
    parser.add_argument('--stage-2-lr', default=2e-1, type=float,
                        help='Learning rate of phase 2')
    
    parser.add_argument('--stage-2-repeat', default=5, type=int,
                        help='Repetition count of phase 2')
    
    parser.add_argument('--result-file', type=str,
                        help='File name for result summary')
    
    
    # Check if the code is run in Jupyter notebook
    is_in_jupyter = False
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            is_in_jupyter = True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            is_in_jupyter = False  # Terminal running IPython
        else:
            is_in_jupyter = False  # Other type (?)
    except NameError:
        is_in_jupyter = False
    
    if is_in_jupyter:
        return parser.parse_args(args=args_text.split())
    else:
        return parser.parse_args()

In [6]:
from sklearn.datasets import fetch_20newsgroups
import numpy as np

# 20newsgroups 데이터셋 불러오기
dataset = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))


documents = dataset.data
targets = dataset.target
target_names = dataset.target_names

# 각 주제별로 한 개의 문서를 선택하여 출력
for i in range(len(target_names)):
    # 현재 주제와 일치하는 문서들의 인덱스 찾기
    doc_indices = np.where(targets == i)[0]
    
    # 첫 번째 문서의 인덱스 선택
    doc_idx = doc_indices[0]
    
    # 해당 문서와 주제 이름 출력
    print(f"{target_names[i]}")
    #print(f"문서 예시:\n{documents[doc_idx][:500]}")  # 첫 500자만 출력
    print("-" * 80)
# 20newsgroups 데이터셋의 주제들을 'politics', 'sport', 'tech'로 매핑하는 함수를 작성합니다.

def map_topics_to_labels(targets, target_names):
    # 주제를 라벨로 매핑하는 딕셔너리
    topic_to_label = {
        'talk.politics.guns': 'politics',
        'talk.politics.mideast': 'politics',
        'talk.politics.misc': 'politics',
        'rec.sport.baseball': 'sport',
        'rec.sport.hockey': 'sport',
        'comp.os.ms-windows.misc': 'tech',
        'comp.sys.ibm.pc.hardware': 'tech',
        'comp.sys.mac.hardware': 'tech',
        'comp.windows.x': 'tech'
    }

    # 타겟(주제) ID를 라벨로 변환
    labels = np.array([topic_to_label.get(target_names[target], "other") for target in targets])
    return labels

# 20newsgroups 데이터셋의 타겟을 'politics', 'sport', 'tech' 라벨로 매핑
mapped_labels = map_topics_to_labels(targets, target_names)

# 결과 확인을 위해 첫 10개의 매핑된 라벨을 출력
mapped_labels[:10]
# 'politics', 'sport', 'tech' 라벨을 가진 문서만 필터링하는 함수

def filter_documents(documents, labels, valid_labels):
    # 유효한 라벨을 가진 문서의 인덱스를 찾습니다.
    valid_indices = [i for i, label in enumerate(labels) if label in valid_labels]

    # 해당 인덱스의 문서만 필터링합니다.
    filtered_documents = [documents[i] for i in valid_indices]
    filtered_labels = [labels[i] for i in valid_indices]

    return filtered_documents, filtered_labels

# 'politics', 'sport', 'tech' 라벨을 가진 문서 필터링
valid_labels = ['politics', 'sport', 'tech']
filtered_documents, filtered_labels = filter_documents(documents, mapped_labels, valid_labels)

alt.atheism
--------------------------------------------------------------------------------
comp.graphics
--------------------------------------------------------------------------------
comp.os.ms-windows.misc
--------------------------------------------------------------------------------
comp.sys.ibm.pc.hardware
--------------------------------------------------------------------------------
comp.sys.mac.hardware
--------------------------------------------------------------------------------
comp.windows.x
--------------------------------------------------------------------------------
misc.forsale
--------------------------------------------------------------------------------
rec.autos
--------------------------------------------------------------------------------
rec.motorcycles
--------------------------------------------------------------------------------
rec.sport.baseball
--------------------------------------------------------------------------------
rec.sport.hockey
---

In [8]:
nyt_df = pd.read_csv('nyt-articles-2020.csv')

newsdesk_to_label = {
    'Science': 'tech',
    'Technology': 'tech',
    'Business': 'tech',
    'Sports': 'sport',
    'OpEd': 'politics',
    'U.S.': 'politics',
    'New York': 'politics',
    'Politics': 'politics'
}

nyt_df['label'] = nyt_df['newsdesk'].map(newsdesk_to_label)
filtered_nyt_df = nyt_df[nyt_df['label'].isin(['politics', 'sport', 'tech'])]

In [9]:
import os
import pandas as pd

data_folder_path = '/mnt/ssd1/don_ssd1/twitter_crawling_don/UTopic/data/bbc'
labels = ['politics', 'sport', 'tech']

data = []
for label in labels:
    folder_path = os.path.join(data_folder_path, label)
    for filename in os.listdir(folder_path):
        file_path = os.path.join(folder_path, filename)
        with open(file_path, 'r', encoding='utf-8') as file:
            content = file.read()
            data.append({'label': label, 'document': content})

df = pd.DataFrame(data)

In [10]:
newsgroups_data = list(zip(filtered_documents, filtered_labels))
nyt_data = list(zip(filtered_nyt_df['abstract'].tolist(), filtered_nyt_df['label'].tolist()))
bbc_data = list(zip(df['document'].tolist(), df['label'].tolist()))

In [11]:
import random

# 샘플링 사이즈와 라벨 설정
total_sample_size = 3600
labels = ['politics', 'sport', 'tech']
platforms = ['20newsgroups', 'nyt', 'bbc']
samples_per_label = total_sample_size // len(labels)

# 데이터와 라벨, 플랫폼을 라벨별로 분리하여 저장하는 함수
def sample_by_label(data, label, num_samples, platform):
    filtered_data = [(text, lbl, platform) for text, lbl in data if lbl == label]
    return random.sample(filtered_data, min(num_samples, len(filtered_data)))

# 데이터를 훈련과 테스트로 분할하는 함수
def split_train_test(data, train_ratio=0.8):
    train_size = int(len(data) * train_ratio)
    return data[:train_size], data[train_size:]

# 훈련 데이터에서 검증 데이터 분리하는 함수
def split_train_valid(data, valid_ratio_from_train=0.1):
    valid_size = int(len(data) * valid_ratio_from_train)
    train_size = len(data) - valid_size
    return data[:train_size], data[train_size:]

train_total_text_list = []
train_total_label_list = []
valid_total_text_list = []
valid_total_label_list = []
test_total_text_list = []
test_total_label_list = []
train_total_platform_list = []
valid_total_platform_list = []
test_total_platform_list = []

# 각 데이터셋에 대해 라벨별로 샘플링
data_sources = [(newsgroups_data, '20newsgroups'), (nyt_data, 'nyt'), (bbc_data, 'bbc')]

for label in labels:
    for data, platform in data_sources:
        sampled_data = sample_by_label(data, label, samples_per_label // len(labels), platform)
        # 훈련 및 테스트 데이터 분할
        train_data, test_data = split_train_test(sampled_data)
        # 훈련 데이터에서 검증 데이터 분리
        train_data, valid_data = split_train_valid(train_data)

        train_total_text_list.extend([text for text, lbl, plt in train_data])
        train_total_label_list.extend([lbl for text, lbl, plt in train_data])
        train_total_platform_list.extend([plt for text, lbl, plt in train_data])
        valid_total_text_list.extend([text for text, lbl, plt in valid_data])
        valid_total_label_list.extend([lbl for text, lbl, plt in valid_data])
        valid_total_platform_list.extend([plt for text, lbl, plt in valid_data])
        test_total_text_list.extend([text for text, lbl, plt in test_data])
        test_total_label_list.extend([lbl for text, lbl, plt in test_data])
        test_total_platform_list.extend([plt for text, lbl, plt in test_data])


In [12]:
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


len(set(['coinex', 'announces', 'seos', 'chatgptpowered', 'launches', 'bigdata', 'unveils', 'openaichatgpt', 'tags', 'hn', 'stablediffusion', 'chatgptstyle', 'reportedly', 'mba', 'marketers', 'baidu', 'technews', 'fintech', 'chatgptlike', 'elonmusk', 'notion', 'goog', 'googleai', 'digitalmarketing', 'artificalintelligence', 'rt', 'googl', 'bardai', 'edtech', 'malware', 'wharton', 'agix', 'chatgptplus', 'datascience', 'deeplearning', 'msft', 'weirdness', 'tweets', 'amid', 'aitools', 'cybersecurity', 'airdrop', 'cc', 'valentines', 'startups', 'snapchat', 'generativeai', 'buzzfeed', 'fastestgrowing', 'anthropic', 'maker', 'rival', 'techcrunch', 'aiart', 'nocode', 'invests', 'cybercriminals', 'abstracts', 'nyc', 'webinar', 'retweet', 'educators', 'brilliance', 'rescue', 'daysofcode', 'gm', 'rtechnology', 'linkedin', 'licensing', 'copywriting', 'copywriters', 'contentmarketing', 'revolutionizing', 'technologynews', 'warns', 'metaverse', 'cofounder', 'trending', 'founders', 'aipowered', 'openaichat', 'releases', 'microsofts', 'chinas', 'infosec', 'launching', 'jasper', 'nfts', 'newsletter', 'chatgptgod', 'futureofwork', 'digitaltransformation', 'founder', 'feb', 'buzz', 'rn', 'ux', 'courtesy', 'nick', 'claude',
        'remindme', 'giphy', 'gif', 'deleted', 'giphydownsized', 'chadgpt', 'removed', 'patched', 'nerfed', 'yup', 'waitlist', 'refresh', 'sydney', 'mods', 'nsfw', 'characterai', 'screenshot', 'downvoted', 'youcom', 'meth', 'ascii', 'karma', 'hahaha', 'hangman', 'chatopenaicom', 'emojis', 'porn', 'redditor', 'vpn', 'upvotes', 'blah', 'upvote', 'violated', 'yep', 'joking', 'nope', 'offended', 'mod', 'bruh', 'roleplay', 'ops', 'bob', 'dans', 'redditors', 'nerf', 'firefox', 'trolling', 'sarcastic', 'huh', 'turbo', 'troll', 'patch', 'tag', 'url', 'sus', 'erotica', 'chad', 'gotcha', 'basilisk', 'login', 'lmfao', 'temperature', 'poll', 'emoji', 'rick', 'dm', 'jailbreak', 'orange', 'sub', 'quack', 'davinci', 'uh', 'flagged', 'op', 'markdown', 'flair', 'cares', 'refreshing', 'hitler', 'cookies', 'hmm', 'yikes', 'erotic', 'gti', 'paywall', 'elaborate', 'yea', 'ah', 'uncensored', 'rude', 'colour', 'bitch', 'therapy', 'neutered', 'deny', 'chats', 'jailbroken', 'cake', 'dungeon', 'dang',
        'zronx', 'tuce', 'jontron', 'levy', 'bishop', 'rook', 'thumbnail', 'quotquot', 'jon', 'linus', 'hrefaboutinvalidzcsafeza', 'beluga', 'vid', 'bhai', 'gemx', 'raid', 'ohio', 'circle', 'subscribed', 'anna', 'stare', 'canva', 'napster', 'shapiro', 'sponsor', 'broker', 'websiteapp', 'manoj', 'subscriber', 'bluewillow', 'alex', 'vids', 'legends', 'ryan', 'shes', 'hackbanzer', 'quotoquot', 'pictory', 'youtuber', 'profitable', 'pawn', 'joma', 'folders', 'lifechanging', 'thomas', 'ur', 'plz', 'mike', 'scott', 'casey', 'adrian', 'enjoyed', 'stockfish', 'invideo', 'shortlisted', 'hikaru', 'bless', 'corpsb', 'chatgbt', 'bfuture', 'curve', 'accent', 'amc', 'tutorials', 'gotham', 'mrs', 'earning', 'bra', 'elo', 'oliver', 'youtubers', 'quotcontinuequot', 'membership', 'labels', 'dagogo', 'eonr', 'hai', 'quotai', 'affiliate', 'congratulationsbryou', 'subscribers', 'thumbnails', 'azn', 'beast', 'tom', 'trader', 'garetz', 'quot', 'subbed', 'pls', 'quotchatgpt', 'gtp', 'machina', 'quoti', 'bret', 'terminator', 'watchingbrdm', 'quothow', 'nowi', 'mint']))

len(set(['chatgptpowered', 'announces', 'seos', 'launches', 'unveils', 'bigdata', 'stablediffusion', 'coinex', 'chatgptstyle', 'tags', 'mba', 'baidu', 'openaichatgpt', 'edtech', 'technews', 'reportedly', 'notion', 'wharton', 'chatgptlike', 'hn', 'fintech', 'artificalintelligence', 'googleai', 'marketers', 'bardai', 'elonmusk', 'goog', 'googl', 'msft', 'abstracts', 'chatgptplus', 'rescue', 'rt', 'tweets', 'deeplearning', 'datascience', 'malware', 'generativeai', 'buzzfeed', 'digitalmarketing', 'agix', 'invests', 'amid', 'airdrop', 'anthropic', 'retweet', 'rn', 'cybercriminals', 'coauthor', 'rival', 'educators', 'forbes', 'startups', 'chatgptgod', 'snapchat', 'infosec', 'aiart', 'cybersecurity', 'valentines', 'technologynews', 'copywriting', 'courtesy', 'chinas', 'newsletter', 'gm', 'maker', 'aipowered', 'gamechanger', 'cc', 'daysofcode', 'nfts', 'nocode', 'techcrunch', 'ux', 'linkedin', 'metaverse', 'weirdness', 'webinar', 'futureofwork', 'blockchain', 'bullish', 'socialmedia', 'edchat', 'ernie', 'azure', 'fastestgrowing', 'launching', 'cofounder', 'tweeting', 'hackers', 'trending', 'bloomberg', 'ftw', 'licensing', 'competitor', 'chatgptai', 'classrooms', 'integrating', 'gtgt', 'openaichat', 'brilliance', 'chatgptdown', 'microsofts', 'warns', 'jpeg', 'chatsonic', 'viral', 'aitools', 'frenzy', 'claude', 'alibaba', 'wsj', 'bings', 'fet', 'founders', 'aiwritten', 'layoffs', 'contentmarketing', 'opera', 'iot', 'digitaltransformation', 'jasper', 'ethereum', 'mindblowing', 'classroom', 'everyones', 'cnet', 'phishing', 'haiku', 'chatgptmaker', 'yall', 'revolutionize', 'disrupt', 'kenyan', 'googlebard', 'releases', 'feb', 'buzz', 'passes', 'yearold', 'march', 'storm', 'bb', 'dey', 'introduces', 'firms', 'nft', 'startup', 'atlantic', 'valuation', 'chatgptgenerated', 'qampa', 'mem', 'rtechnology', 'insider', 'jpmorgan', 'craze', 'revolutionizing', 'bans', 'aichatgpt', 'whatsapp', 'digitalhealth', 'blurry', 'ens', 'davos', 'officially', 'analytics', 'york', 'nyc', 'weekend', 'aichatbot', 'obsessed', 'bingchatgpt', 'trends', 'leveraging', 'tipping', 'boost', 'killer', 'november', 'pros', 'unroll', 'introducing', 'integrates', 'nlp', 'tweet', 'gainer', 'launched', 'imgnai', 'cave', 'changer', 'highered', 'copywriters', 'agrees', 'multibilliondollar', 'nick', 'fastest', 'contentcreation', 'ht', 'founder', 'fluent',
'deleted', 'giphy', 'remindme', 'gif', 'removed', 'giphydownsized', 'patched', 'downvoted', 'nerfed', 'sydney', 'nsfw', 'youcom', 'characterai', 'blah', 'chadgpt', 'yup', 'refresh', 'ascii', 'waitlist', 'hangman', 'redditor', 'meth', 'huh', 'vpn', 'porn', 'nope', 'screenshot', 'quack', 'upvote', 'emojis', 'nerf', 'mods', 'uncensored', 'karma', 'firefox', 'hahaha', 'archived', 'fucked', 'temperature', 'cringe', 'yep', 'chatopenaicom', 'davinci', 'violated', 'rick', 'cookies', 'upvotes', 'hitler', 'offended', 'roleplay', 'cows', 'sus', 'woah', 'trolling', 'chad', 'erotic', 'bud', 'ampd', 'patch', 'sub', 'troll', 'lobotomized', 'bruh', 'sarcastic', 'turbo', 'jailbreaking', 'moderation', 'deny', 'poll', 'orange', 'jailbreak', 'complaining', 'flagged', 'smut', 'bitch', 'op', 'url', 'logged', 'neutered', 'tag', 'ah', 'cake', 'chats', 'wokegpt', 'lmfao', 'asshole', 'erotica', 'delete', 'ops', 'downvote', 'cache', 'gay', 'yea', 'joking', 'lame', 'dm', 'unrestricted', 'login', 'gti', 'rude', 'markdown', 'therapist', 'mod', 'redditors', 'inspect', 'refreshing', 'bypass', 'colour', 'rip', 'workaround', 'retarded', 'yikes', 'sarcasm', 'lasted', 'moron', 'forgets', 'username', 'regenerate', 'basilisk', 'triggered', 'uh', 'hourly', 'jailbroken', 'youchat', 'playground', 'bob', 'flair', 'oof', 'vram', 'repost', 'counting', 'idiot', 'fart', 'eh', 'emoji', 'edgy', 'censoring', 'nah', 'dick', 'morty', 'tos', 'cgpt', 'screenshots', 'tldr', 'vietnam', 'wtf', 'guessing', 'automoderator', 'chatgptx', 'formatting', 'haikusbot', 'hmm', 'angry', 'clinicalillusionist', 'poop', 'nazi', 'insult', 'svg', 'censored', 'filters', 'nerfing', 'cutoff', 'nevermind', 'dungeon', 'automod', 'goddamn', 'shitty', 'polite', 'logging', 'wdym', 'filtered', 'iq', 'censor', 'ip', 'chicken', 'novelai', 'dans', 'sex', 'milk', 'gotcha', 'reload', 'sassy', 'dnd', 'blocked', 'retry', 'odd', 'batman', 'locally', 'af', 'donald', 'satire', 'gtit', 'kanye', 'log', 'drugs', 'upvoted', 'closedai', 'annoyed', 'restricted', 'umm',
'tuce', 'zronx', 'quotquot', 'jontron', 'levy', 'jon', 'rook', 'bishop', 'beluga', 'thumbnail', 'ohio', 'linus', 'gemx', 'vid', 'hrefaboutinvalidzcsafeza', 'bhai', 'raid', 'stare', 'napster', 'pictory', 'subscribed', 'anna', 'circle', 'ur', 'pawn', 'stockfish', 'websiteapp', 'shapiro', 'ryan', 'gotham', 'manoj', 'subscriber', 'broker', 'folders', 'sponsor', 'youtubers', 'hikaru', 'bluewillow', 'ltlets', 'canva', 'joma', 'shorts', 'legends', 'lifechanging', 'hackbanzer', 'labels', 'vids', 'membership', 'profitable', 'scott', 'mrs', 'shes', 'adrian', 'bless', 'earning', 'maher', 'quothow', 'chatgbt', 'affiliate', 'oliver', 'thomas', 'shortlisted', 'subscribers', 'elo', 'alex', 'quotoquot', 'plz', 'jim', 'invideo', 'corpsb', 'bfuture', 'hai', 'enjoyed', 'mike', 'terminator', 'thx', 'trader', 'quot', 'gtp', 'quotchatgpt', 'amc', 'youtuber', 'tom', 'quoti', 'quotai', 'greg', 'accent', 'antichrist', 'subs', 'yt', 'gbt', 'curve', 'brother', 'tutorials', 'ben', 'shadow', 'nowi', 'quotcontinuequot', 'congratulationsbryou', 'ka', 'watchingbrdm', 'dagogo', 'pls', 'fx', 'garetz', 'bret', 'azn', 'uploaded', 'funds', 'silver', 'ring', 'intro', 'anlt', 'telegram', 'mint', 'ambulance', 'terrifying', 'casey', 'ke', 'brthanks', 'bra', 'bhi', 'machina', 'thankyou', 'vanoss', 'aa', 'kya', 'dread', 'harry', 'mate', 'idk', 'portfolio', 'leila', 'upto', 'legend', 'subbed', 'magnus', 'beast', 'earned', 'mosh', 'checkmate', 'quotit', 'quotim', 'brbr', 'mittens', 'madan', 'quotthe', 'ho', 'ltt', 'sigmoid', 'quotdont', 'helpdesk', 'clip', 'eonr', 'quotyou', 'monique', 'youquot', 'brrepent', 'tutorial', 'rooks', 'grant', 'upload', 'moves', 'laughed', 'shouldve', 'reynolds', 'delirious', 'informative', 'funniest', 'subscribe', 'ki', 'awzx', 'imperative', 'brthank', 'tho', 'quotwhat', 'bri', 'bhaiya', 'xd', 'screwed', 'logo', 'profits', 'quotthis', 'renders', 'channels', 'jarvis', 'hexagon', 'earn', 'heavens', 'lucid', 'kar', 'quotits', 'roblox', 'giveaway', 'knight', 'brit', 'hamish', 'skynet', 'bye', 'couldve'])
)

In [14]:
class BertDataset(Dataset):
    def __init__(self, bert, text_list, platform_label, label_list, N_word, vectorizer=None, lemmatize=False):
        self.lemmatize = lemmatize
        self.nonempty_text = [text for text in text_list if len(text) > 0]
        
        # Remove new lines
        self.nonempty_text = [re.sub("\n"," ", sent) for sent in self.nonempty_text]
                
        # Remove Emails
        self.nonempty_text = [re.sub('\S*@\S*\s?', '', sent) for sent in self.nonempty_text]
        
        # Remove new line characters
        self.nonempty_text = [re.sub('\s+', ' ', sent) for sent in self.nonempty_text]
        
        # Remove distracting single quotes
        self.nonempty_text = [re.sub("\'", "", sent) for sent in self.nonempty_text]
        
        self.jargons = set(['homosexual', 'simms', 'scsi', 'dma', 'ulf', 'color', 'armenian', 'widget', 'ide', 'centris', 'braves', 'leafs', 'faq', 'simm', 'gm', 'ethernet', 'stats', 'font', 'espn', 'meg', 'rgb', 'vram', 'buffalo', 'batf', 'printer', 'cpu', 'com', 'bruins', 'toronto', 'armenians', 'gant', 'lebanese', 'pitching', 'koresh', 'icon', 'dos', 'baerga', 'arabs', 'sabres', 'mb', 'dx', 'ottawa', 'azerbaijan', 'mailing', 'bitmap', 'norton', 'jews', 'gif', 'slave', 'pocklington', 'mhz', 'rectum', 'don', 'cdrom', 'ihr', 'irq', 'gld', 'cleveland', 'quadra', 'cursor', 'flames', 'bios', 'runs', 'wings', 'dl', 'penguins', 'motherboard', 'hp', 'gay', 'alomar', 'davidians', 'dale', 'colormap', 'israel', 'shrill', 'mattingly', 'pitcher', 'xr', 'chipset', 'diamond', 'cubs', 'detroit', 'fuhr', 'ordonly', 'floppy', 'gritz', 'sandberg', 'yankees', 'arab', 'clinton', 'controller', 'oh', 'vga', 'xman', 'hey', 'blasphemy', 'farrs', 'bosnia', 'wondering', 'libxmulibxmua', 'trumps', 'covid', 'biden', 'trump', 'vaccine', 'republicans', 'stock', 'outbreak', 'quarterback', 'chiefs', 'nfl', 'bidens', 'quarantine', 'lockdown', 'twitter', 'patriots', 'bowl', 'cancellation', 'facebook', 'ravens', 'steelers', 'canceled', 'soccer', 'woods', 'progressive', 'immune', 'pandemic', 'nba', 'colts', 'infection', 'gauff', 'vicepresidential', 'hospitalized', 'caucus', 'lawmaker', 'buttigieg', 'antibody', 'basketball', 'postseason', 'viral', 'blessing', 'marshmallow', 'ncaa', 'genetic', 'bernie', 'cheating', 'confront', 'djokovic', 'kentucky', 'bloombergs', 'pelicans', 'mate', 'distancing', 'floyds', 'patient', 'seahawks', 'pga', 'nomination', 'mds', 'tokyo', 'upended', 'biotech', 'lakers', 'immunity', 'shutdown', 'clinical', 'presidentelect', 'obama', 'scientist', 'readers', 'presidential', 'mlb', 'columnist', 'mask', 'furlough', 'republican', 'obsessing', 'donald', 'sanders', 'narrative', 'indistinguishable', 'astros', 'gym', 'coronavirus', 'apex', 'uncertainty', 'normalcy', 'miami', 'tennessee', 'rays', 'jacksonville', 'nominee', 'insect', 'alarmist', 'bee', 'safely', 'instagram', 'michigan', 'expanded', 'amazon', 'ukip', 'blog', 'patent', 'kenteris', 'mourinho', 'henman', 'moya', 'kilroysilk', 'gerrard', 'mirza', 'spam', 'thanou', 'wenger', 'souness', 'commodore', 'uwb', 'holmes', 'benitez', 'conte', 'bittorrent', 'blunkett', 'mock', 'spyware', 'ds', 'liverpool', 'robot', 'hague', 'iaaf', 'domain', 'parry', 'roddick', 'eu', 'bluray', 'agassi', 'nintendo', 'bt', 'aragones', 'screensaver', 'asylum', 'lycos', 'argonaut', 'hunting', 'hewitt', 'capriati', 'donation', 'firefox', 'sec', 'cabir', 'bnp', 'oleary', 'wifi', 'ferguson', 'safin', 'anelka', 'davenport', 'bellamy', 'blackpool', 'hiphop', 'robben', 'rfid', 'uefa', 'federer', 'drinking', 'skype', 'psp', 'radcliffe', 'cunningham', 'lions', 'hingis', 'straw', 'directive', 'dvd', 'ogara', 'ruddock', 'seafarer', 'hunt', 'browser', 'livingstone', 'seed', 'southampton', 'mcletchie', 'award', 'jaynes', 'poster', 'balco', 'referendum', 'edinburgh', 'hdtv', 'pension', 'indoor', 'mini', 'simonetti', 'grid', 'chepkemei', 'wilkinson', 'robinson', 'marathon', 'ink', 'dvds', 'chelsea'])
        
        self.tokenizer = AutoTokenizer.from_pretrained(bert)
        self.model = AutoModel.from_pretrained(bert)
        self.stopwords_list = set(TfidfVectorizer(stop_words="english").get_stop_words()).union(self.jargons)
        self.N_word = N_word
        
        if vectorizer == None:
            self.vectorizer = TfidfVectorizer(stop_words=None, max_features=self.N_word, token_pattern=r'\b[a-zA-Z]{2,}\b')
            self.vectorizer.fit(self.preprocess_ctm(self.nonempty_text))
        else:
            self.vectorizer = vectorizer
            
        self.org_list = []
        self.bow_list = []
        for sent in tqdm(self.nonempty_text):
            org_input = self.tokenizer(sent, padding='max_length', truncation=True, max_length=512, return_tensors='pt')
            org_input['input_ids'] = torch.squeeze(org_input['input_ids'])
            org_input['attention_mask'] = torch.squeeze(org_input['attention_mask'])
            self.org_list.append(org_input)
            self.bow_list.append(self.vectorize(sent))
        ## platform label    
        self.platform_label_list = platform_label
        self.label_list = label_list
            
        
    def vectorize(self, text):
        text = self.preprocess_ctm([text])
        vectorized_input = self.vectorizer.transform(text)
        vectorized_input = vectorized_input.toarray()
        vectorized_input = vectorized_input.astype(np.float64)

        # Get word distribution from BoW
        vectorized_input += 1e-8
        vectorized_input = vectorized_input / vectorized_input.sum(axis=1, keepdims=True)
        assert abs(vectorized_input.sum() - vectorized_input.shape[0]) < 0.01
        vectorized_label = torch.tensor(vectorized_input, dtype=torch.float)
        return vectorized_label[0]
        
        
    def preprocess_ctm(self, documents):
        preprocessed_docs_tmp = documents
        preprocessed_docs_tmp = [doc.lower() for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [doc.translate(
            str.maketrans(string.punctuation, ' ' * len(string.punctuation))) for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [' '.join([w for w in doc.split() if len(w) > 0 and w not in self.stopwords_list])
                             for doc in preprocessed_docs_tmp]
        if self.lemmatize:
            lemmatizer = WordNetLemmatizer()
            preprocessed_docs_tmp = [' '.join([lemmatizer.lemmatize(w) for w in doc.split()])
                                     for doc in preprocessed_docs_tmp]
        return preprocessed_docs_tmp    
    
    # mean_pooling 함수 정의
    def mean_pooling(self, model_output, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(model_output.size()).float()
        sum_embeddings = torch.sum(model_output * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask
    
    
    def __len__(self):
        return len(self.nonempty_text)
    

    def __getitem__(self, idx):
        sentence = self.nonempty_text[idx]
        encoded_input = self.tokenizer(sentence, padding=True, truncation=True, max_length=512, return_tensors='pt')
        
        with torch.no_grad():
            model_output = self.model(**encoded_input)
        
        # mean_pooling 함수를 사용하여 문장 임베딩의 평균을 계산
        pooled_embedding = self.mean_pooling(model_output.last_hidden_state, encoded_input['attention_mask'])

        return self.org_list[idx], self.bow_list[idx], pooled_embedding, self.platform_label_list[idx], self.label_list[idx]


In [15]:
args = _parse_args()
bsz = args.bsz

n_cluster = args.n_cluster
n_topic = args.n_topic if (args.n_topic is not None) else n_cluster
args.n_topic = n_topic

# textData, should_measure_hungarian = data_load(args.dataset)

ema_alpha = 0.99
n_word = args.n_word
if args.dirichlet_alpha_1 is None:
    dirichlet_alpha_1 = 1 / n_cluster
else:
    dirichlet_alpha_1 = args.dirichlet_alpha_1
if args.dirichlet_alpha_2 is None:
    dirichlet_alpha_2 = dirichlet_alpha_1
else:
    dirichlet_alpha_2 = args.dirichlet_alpha_2
    
bert_name = args.base_model
bert_name_short = bert_name.split('/')[-1]
gpu_ids = args.gpus

In [16]:
# 시드 고정 함수
def set_seed(seed_value):
    """모든 랜덤 요소에 대한 시드를 고정합니다."""
    random.seed(seed_value)  
    np.random.seed(seed_value)  
    torch.manual_seed(seed_value)  
    if torch.cuda.is_available():  
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

In [17]:
set_seed(41)

In [18]:
trainds = BertDataset(bert=bert_name, text_list=train_total_text_list, platform_label=train_total_platform_list, label_list=train_total_label_list, N_word=n_word, vectorizer=None, lemmatize=True)
valds = BertDataset(bert=bert_name, text_list=valid_total_text_list, platform_label=valid_total_platform_list, label_list=valid_total_label_list, N_word=n_word, vectorizer=None, lemmatize=True)
testds = BertDataset(bert=bert_name, text_list=test_total_text_list, platform_label=test_total_platform_list, label_list=test_total_label_list, N_word=n_word, vectorizer=None, lemmatize=True)

2024-03-15 08:00:43.216871: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-03-15 08:00:43.377873: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-15 08:00:44.051537: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda-11.7/lib64:
2024-03-15 08:00:44.051627: W tensorflow/compiler/xla/strea

# 원본 데이터 Mean pooling 진행

# mean_pooling 함수 정의
def mean_pooling(embeddings, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
    sum_embeddings = torch.sum(embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

In [19]:
# trainds
trainds_embeddings = []
for item in trainds:
    _, _, pooled_embedding, _, _ = item 
    trainds_embeddings.append(pooled_embedding)

train_mean_pooled_embeddings = torch.stack(trainds_embeddings)

# valds
valds_embeddings = []
for item in valds:
    _, _, pooled_embedding, _, _ = item 
    valds_embeddings.append(pooled_embedding)

val_mean_pooled_embeddings = torch.stack(valds_embeddings)

# testds
testds_embeddings = []
for item in testds:
    _, _, pooled_embedding, _, _ = item  
    testds_embeddings.append(pooled_embedding)

test_mean_pooled_embeddings = torch.stack(testds_embeddings)

# Re_fornulate the bow

In [20]:
def dist_match_loss(hiddens, alpha=1.0):
    device = hiddens.device
    hidden_dim = hiddens.shape[-1]
    H = np.random.randn(hidden_dim, hidden_dim)
    Q, R = qr(H) 
    rand_w = torch.Tensor(Q).to(device)
    loss_dist_match = get_swd_loss(hiddens, rand_w, alpha)
    return loss_dist_match


def js_div_loss(hidden1, hidden2):
    m = 0.5 * (hidden1 + hidden2)
    return kldiv(m.log(), hidden1) + kldiv(m.log(), hidden2)


def get_swd_loss(states, rand_w, alpha=1.0):
    device = states.device
    states_shape = states.shape
    states = torch.matmul(states, rand_w)
    states_t, _ = torch.sort(states.t(), dim=1)

    # Random vector with length from normal distribution
    states_prior = torch.Tensor(np.random.dirichlet([alpha]*states_shape[1], states_shape[0])).to(device) # (bsz, dim)
    states_prior = torch.matmul(states_prior, rand_w) # (dim, dim)
    states_prior_t, _ = torch.sort(states_prior.t(), dim=1) # (dim, bsz)
    return torch.mean(torch.sum((states_prior_t - states_t)**2, axis=0))

# Get pos_similarity

In [21]:
def compute_max_cosine_similarity_indices(mean_pooled_embeddings, batch_size=500):
    n_rows = mean_pooled_embeddings.shape[0]
    max_similarity_indices = np.zeros(n_rows, dtype=np.int64)

    for start_idx in range(0, n_rows, batch_size):
        end_idx = min(start_idx + batch_size, n_rows)
        batch_data = mean_pooled_embeddings[start_idx:end_idx]

        batch_similarity = cosine_similarity(batch_data, mean_pooled_embeddings)

        # 자기 자신과의 유사도를 -1로 설정
        for i, original_idx in enumerate(range(start_idx, end_idx)):
            batch_similarity[i, original_idx] = -1

        # 각 행에서 가장 큰 값을 가진 인덱스 찾기
        max_indices = np.argmax(batch_similarity, axis=1)
        max_similarity_indices[start_idx:end_idx] = max_indices

    return max_similarity_indices

In [22]:
# mean_pooled_embeddings를 2차원 형태로 변환
train_mean_pooled_embeddings_2d = train_mean_pooled_embeddings.squeeze()
val_mean_pooled_embeddings_2d = val_mean_pooled_embeddings.squeeze()
test_mean_pooled_embeddings_2d = test_mean_pooled_embeddings.squeeze()

In [23]:
# 코사인 유사도 매트릭스 계산
train_similarity_matrix = compute_max_cosine_similarity_indices(train_mean_pooled_embeddings_2d)
val_similarity_matrix = compute_max_cosine_similarity_indices(val_mean_pooled_embeddings_2d)
test_similarity_matrix = compute_max_cosine_similarity_indices(test_mean_pooled_embeddings_2d)  # 테스트 데이터셋에 대해서도 동일하게

In [26]:
model =  ContBertTopicExtractorAE(N_topic=n_topic, N_word=args.n_word, bert=bert_name, bert_dim=768)

In [27]:
class Stage2Dataset(Dataset):
    def __init__(self, encoder, ds, similarity_matrix, N_word, k=1, lemmatize=False):
        self.lemmatize = lemmatize
        self.ds = ds
        self.org_list = self.ds.org_list
        self.nonempty_text = self.ds.nonempty_text
        self.N_word = N_word
        self.jargons =  set(['homosexual', 'simms', 'scsi', 'dma', 'ulf', 'color', 'armenian', 'widget', 'ide', 'centris', 'braves', 'leafs', 'faq', 'simm', 'gm', 'ethernet', 'stats', 'font', 'espn', 'meg', 'rgb', 'vram', 'buffalo', 'batf', 'printer', 'cpu', 'com', 'bruins', 'toronto', 'armenians', 'gant', 'lebanese', 'pitching', 'koresh', 'icon', 'dos', 'baerga', 'arabs', 'sabres', 'mb', 'dx', 'ottawa', 'azerbaijan', 'mailing', 'bitmap', 'norton', 'jews', 'gif', 'slave', 'pocklington', 'mhz', 'rectum', 'don', 'cdrom', 'ihr', 'irq', 'gld', 'cleveland', 'quadra', 'cursor', 'flames', 'bios', 'runs', 'wings', 'dl', 'penguins', 'motherboard', 'hp', 'gay', 'alomar', 'davidians', 'dale', 'colormap', 'israel', 'shrill', 'mattingly', 'pitcher', 'xr', 'chipset', 'diamond', 'cubs', 'detroit', 'fuhr', 'ordonly', 'floppy', 'gritz', 'sandberg', 'yankees', 'arab', 'clinton', 'controller', 'oh', 'vga', 'xman', 'hey', 'blasphemy', 'farrs', 'bosnia', 'wondering', 'libxmulibxmua', 'trumps', 'covid', 'biden', 'trump', 'vaccine', 'republicans', 'stock', 'outbreak', 'quarterback', 'chiefs', 'nfl', 'bidens', 'quarantine', 'lockdown', 'twitter', 'patriots', 'bowl', 'cancellation', 'facebook', 'ravens', 'steelers', 'canceled', 'soccer', 'woods', 'progressive', 'immune', 'pandemic', 'nba', 'colts', 'infection', 'gauff', 'vicepresidential', 'hospitalized', 'caucus', 'lawmaker', 'buttigieg', 'antibody', 'basketball', 'postseason', 'viral', 'blessing', 'marshmallow', 'ncaa', 'genetic', 'bernie', 'cheating', 'confront', 'djokovic', 'kentucky', 'bloombergs', 'pelicans', 'mate', 'distancing', 'floyds', 'patient', 'seahawks', 'pga', 'nomination', 'mds', 'tokyo', 'upended', 'biotech', 'lakers', 'immunity', 'shutdown', 'clinical', 'presidentelect', 'obama', 'scientist', 'readers', 'presidential', 'mlb', 'columnist', 'mask', 'furlough', 'republican', 'obsessing', 'donald', 'sanders', 'narrative', 'indistinguishable', 'astros', 'gym', 'coronavirus', 'apex', 'uncertainty', 'normalcy', 'miami', 'tennessee', 'rays', 'jacksonville', 'nominee', 'insect', 'alarmist', 'bee', 'safely', 'instagram', 'michigan', 'expanded', 'amazon', 'ukip', 'blog', 'patent', 'kenteris', 'mourinho', 'henman', 'moya', 'kilroysilk', 'gerrard', 'mirza', 'spam', 'thanou', 'wenger', 'souness', 'commodore', 'uwb', 'holmes', 'benitez', 'conte', 'bittorrent', 'blunkett', 'mock', 'spyware', 'ds', 'liverpool', 'robot', 'hague', 'iaaf', 'domain', 'parry', 'roddick', 'eu', 'bluray', 'agassi', 'nintendo', 'bt', 'aragones', 'screensaver', 'asylum', 'lycos', 'argonaut', 'hunting', 'hewitt', 'capriati', 'donation', 'firefox', 'sec', 'cabir', 'bnp', 'oleary', 'wifi', 'ferguson', 'safin', 'anelka', 'davenport', 'bellamy', 'blackpool', 'hiphop', 'robben', 'rfid', 'uefa', 'federer', 'drinking', 'skype', 'psp', 'radcliffe', 'cunningham', 'lions', 'hingis', 'straw', 'directive', 'dvd', 'ogara', 'ruddock', 'seafarer', 'hunt', 'browser', 'livingstone', 'seed', 'southampton', 'mcletchie', 'award', 'jaynes', 'poster', 'balco', 'referendum', 'edinburgh', 'hdtv', 'pension', 'indoor', 'mini', 'simonetti', 'grid', 'chepkemei', 'wilkinson', 'robinson', 'marathon', 'ink', 'dvds', 'chelsea'])
        self.stopwords_list = set(TfidfVectorizer(stop_words="english").get_stop_words()).union(self.jargons)
        
        self.vectorizer = TfidfVectorizer(stop_words=None, max_features=self.N_word, token_pattern=r'\b[a-zA-Z]{2,}\b')
        self.vectorizer.fit(self.preprocess_ctm(self.nonempty_text)) 
        self.bow_list = []
        for sent in tqdm(self.nonempty_text):
            self.bow_list.append(self.vectorize(sent))
            

        self.pos_dict = similarity_matrix
        
        self.embedding_list = []
        encoder_device = next(encoder.parameters()).device
        for org_input in tqdm(self.org_list):
            org_input_ids = org_input['input_ids'].to(encoder_device).reshape(1, -1)
            org_attention_mask = org_input['attention_mask'].to(encoder_device).reshape(1, -1)
            embedding = encoder(input_ids = org_input_ids, attention_mask = org_attention_mask)
            self.embedding_list.append(embedding['pooler_output'].squeeze().detach().cpu())
            
        self.platform_label_list = self.ds.platform_label_list
        self.label_list = self.ds.label_list
            
    
    def __len__(self):
        return len(self.org_list)
        
    def preprocess_ctm(self, documents):
        preprocessed_docs_tmp = documents
        preprocessed_docs_tmp = [doc.lower() for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [doc.translate(
            str.maketrans(string.punctuation, ' ' * len(string.punctuation))) for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [' '.join([w for w in doc.split() if len(w) > 0 and w not in self.stopwords_list])
                                 for doc in preprocessed_docs_tmp]
        if self.lemmatize:
            lemmatizer = WordNetLemmatizer()
            preprocessed_docs_tmp = [' '.join([lemmatizer.lemmatize(w) for w in doc.split()])
                                     for doc in preprocessed_docs_tmp]
        return preprocessed_docs_tmp
        
    def vectorize(self, text):
        text = self.preprocess_ctm([text])
        vectorized_input = self.vectorizer.transform(text)
        vectorized_input = vectorized_input.toarray().astype(np.float64)
#         vectorized_input = (vectorized_input != 0).astype(np.float64)

        # Get word distribution from BoW
        if vectorized_input.sum() == 0:
            vectorized_input += 1e-8
        vectorized_input = vectorized_input / vectorized_input.sum(axis=1, keepdims=True)
        assert abs(vectorized_input.sum() - vectorized_input.shape[0]) < 0.01
        
        vectorized_label = torch.tensor(vectorized_input, dtype=torch.float)
        return vectorized_label[0]
        
        
    def __getitem__(self, idx):
        # 데이터셋의 크기를 초과하지 않도록 idx 값을 조정
        if idx >= len(self.org_list):
            idx = len(self.org_list) - 1  
        pos_idx = self.pos_dict[idx]

        return idx, self.embedding_list[idx], self.embedding_list[pos_idx], self.bow_list[idx], self.bow_list[pos_idx], self.platform_label_list[idx], self.label_list[idx]


In [28]:
finetuneds = Stage2Dataset(model.encoder, trainds, train_similarity_matrix, n_word, lemmatize=True)
valfinetuneds = Stage2Dataset(model.encoder, valds, val_similarity_matrix, n_word, lemmatize=True) 
testfinetuneds = Stage2Dataset(model.encoder, testds, test_similarity_matrix, n_word, lemmatize=True) 

kldiv = torch.nn.KLDivLoss(reduction='batchmean')
vocab_dict = finetuneds.vectorizer.vocabulary_
vocab_dict_reverse = {i:v for v, i in vocab_dict.items()}
print(n_word)

100%|██████████| 2583/2583 [00:04<00:00, 546.59it/s]
100%|██████████| 2583/2583 [45:15<00:00,  1.05s/it] 
100%|██████████| 287/287 [00:00<00:00, 579.77it/s]
100%|██████████| 287/287 [01:10<00:00,  4.09it/s]
100%|██████████| 713/713 [00:01<00:00, 613.78it/s]
100%|██████████| 713/713 [06:56<00:00,  1.71it/s]

30000





In [25]:
train_similarity_matrix.shape

(2583,)

In [29]:
# len(finetuneds.embedding_list)
len(finetuneds.bow_list)

2583

# Stage 3

In [30]:
def measure_hungarian_score(topic_dist, train_target):
    dist = topic_dist
    train_target_filtered = train_target
    flat_predict = torch.tensor(np.argmax(dist, axis=1))
    flat_target = torch.tensor(train_target_filtered).to(flat_predict.device)
    num_samples = flat_predict.shape[0]
    num_classes = dist.shape[1]
    match = _hungarian_match(flat_predict, flat_target, num_samples, num_classes)    
    reordered_preds = torch.zeros(num_samples).to(flat_predict.device)
    for pred_i, target_i in match:
        reordered_preds[flat_predict == pred_i] = int(target_i)
    acc = int((reordered_preds == flat_target.float()).sum()) / float(num_samples)
    return acc

In [31]:
should_measure_hungarian = True

In [32]:
torch.cuda.empty_cache()

# Seperate Platform dataset

In [33]:
from torch.utils.data import DataLoader, Dataset, Sampler
from collections import defaultdict

class PlatformSampler(Sampler):
    def __init__(self, dataset, platform_label):
        self.indices = [i for i, label in enumerate(dataset.platform_label_list) if label == platform_label]
    
    def __iter__(self):
        return iter(self.indices)
    
    def __len__(self):
        return len(self.indices)

In [34]:
def create_platform_dataloader(dataset, platform_label, batch_size=32, num_workers=0):
    sampler = PlatformSampler(dataset, platform_label)
    dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers)
    return dataloader

# Main

In [49]:
from sklearn.metrics import mutual_info_score
from scipy.stats import chi2_contingency 
from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora.dictionary import Dictionary
from coherence import get_topic_coherence

args.stage_2_repeat = 1
results_list = []

for i in range(args.stage_2_repeat):
    model = ContBertTopicExtractorAE(N_topic=n_topic, N_word=args.n_word, bert=bert_name, bert_dim=768)
    model.beta = nn.Parameter(torch.Tensor(model.N_topic, n_word))
    nn.init.xavier_uniform_(model.beta)
    model.beta_batchnorm = nn.Sequential()
    model.cuda(gpu_ids[0])
    
    losses = AverageMeter()
    dlosses = AverageMeter() 
    rlosses = AverageMeter()
    closses = AverageMeter()
    distlosses = AverageMeter()

    trainloader = DataLoader(finetuneds, batch_size=bsz, shuffle=True, num_workers=0)
    testloader = DataLoader(testfinetuneds, batch_size=bsz, shuffle=True, num_workers=0)
    newsgroups_data_trainloader = create_platform_dataloader(finetuneds, '20newsgroups', batch_size=bsz, num_workers=0)
    nyt_data_trainloader = create_platform_dataloader(finetuneds, 'nyt', batch_size=bsz, num_workers=0)
    bbc_trainloader = create_platform_dataloader(finetuneds, 'bbc', batch_size=bsz, num_workers=0)
    memoryloader = DataLoader(finetuneds, batch_size=bsz * 2, shuffle=False, num_workers=0)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.stage_2_lr)

    memory_queue = F.softmax(torch.randn(512, n_topic).cuda(gpu_ids[0]), dim=1)
    print("Coeff   / regul: {:.5f} - recon: {:.5f} - c: {:.5f} - dist: {:.5f} ".format(args.coeff_2_regul, 
                                                                                        args.coeff_2_recon,
                                                                                        args.coeff_2_cons,
                                                                                        args.coeff_2_dist))
    best_npmi = -1
    best_epoch = 0
    best_model_state = None  # 모델 상태를 저장하기 위한 변수
    
    # 각 플랫폼별 DataLoader의 이터레이터 생성
    newsgroups_iter = iter(newsgroups_data_trainloader)
    nyt_iter = iter(nyt_data_trainloader)
    bbc_iter = iter(bbc_trainloader)

    max_length = max(len(newsgroups_data_trainloader), len(nyt_data_trainloader), len(bbc_trainloader))

    for epoch in range(100):
        model.train()
        model.encoder.eval()

        for _ in range(max_length):
            # 플랫폼별 DataLoader에서 배치를 순차적으로 가져오기
            try:
                newsgroups_batch = next(newsgroups_iter)
            except StopIteration:
                newsgroups_iter = iter(newsgroups_data_trainloader)
                newsgroups_batch = next(newsgroups_iter)

            try:
                nyt_batch = next(nyt_iter)
            except StopIteration:
                nyt_iter = iter(nyt_data_trainloader)
                nyt_batch = next(nyt_iter)

            # 기존 코드의 수정 사항
            try:
                bbc_batch = next(bbc_iter)
            except StopIteration:
                bbc_iter = iter(bbc_trainloader)  # 새로운 이터레이터 생성
                bbc_batch = next(bbc_iter)  # 새로운 이터레이터에서 첫 번째 배치를 가져옴

            # 각 배치에 대한 학습 로직 구현
            for batch in [newsgroups_batch, nyt_batch, bbc_batch]:
                _, org_input, pos_input, org_bow, pos_bow, _ ,_= batch
                org_input = org_input.cuda(gpu_ids[0])
                org_bow = org_bow.cuda(gpu_ids[0])
                pos_input = pos_input.cuda(gpu_ids[0])
                pos_bow = pos_bow.cuda(gpu_ids[0])

                batch_size = org_input.size(0) #org_input_ids.size(0)

                org_dists, org_topic_logit = model.decode(org_input)
                pos_dists, pos_topic_logit = model.decode(pos_input)

                org_topic = F.softmax(org_topic_logit, dim=1)
                pos_topic = F.softmax(pos_topic_logit, dim=1)

                # 텐서 크기 맞춰줌
                org_dists = org_dists[:, :org_bow.size(1)]
                pos_dists = pos_dists[:, :pos_bow.size(1)]

                recons_loss = torch.mean(-torch.sum(torch.log(org_dists + 1E-10) * (org_bow), axis=1), axis=0)
                recons_loss += torch.mean(-torch.sum(torch.log((1-org_dists) + 1E-10) * (1-org_bow), axis=1), axis=0)
                recons_loss += torch.mean(-torch.sum(torch.log(pos_dists + 1E-10) * (pos_bow), axis=1), axis=0)
                recons_loss += torch.mean(-torch.sum(torch.log((1-pos_dists) + 1E-10) * (1-pos_bow), axis=1), axis=0)
                recons_loss *= 0.5

                # consistency loss
                pos_sim = torch.sum(org_topic * pos_topic, dim=-1)
                cons_loss = -pos_sim.mean()

                distmatch_loss = dist_match_loss(torch.cat((org_topic,), dim=0), dirichlet_alpha_2)


                loss = args.coeff_2_recon * recons_loss + \
                       args.coeff_2_cons * cons_loss + \
                       args.coeff_2_dist * distmatch_loss 

                losses.update(loss.item(), bsz)
                closses.update(cons_loss.item(), bsz)
                rlosses.update(recons_loss.item(), bsz)
                distlosses.update(distmatch_loss.item(), bsz)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            
        print("Epoch-{} / recon: {:.5f} - dist: {:.5f} - cons: {:.5f}".format(epoch, rlosses.avg, distlosses.avg, closses.avg))

        # Epoch 마다 실행
        model.eval()

        # 각 토픽에 대한 상위 10개 단어 추출
        top_words_per_topic = {}
        for topic_idx in range(model.N_topic):
            top_words_indices = model.beta[topic_idx].topk(10).indices
            top_words = [vocab_dict_reverse[idx.item()] for idx in top_words_indices]
            top_words_per_topic[topic_idx] = top_words
            
        reference_corpus=[doc.split() for doc in valds.preprocess_ctm(valds.nonempty_text)]
        topic_words_list = list(top_words_per_topic.values())
        result = get_topic_coherence(topic_words_list, reference_corpus)
        avg_npmi = result['NPMI']

        # 최적의 NPMI와 epoch 추적
        if avg_npmi > best_npmi:
            best_npmi = avg_npmi
            best_epoch = epoch
            best_model_state = model.state_dict()  # 현재 모델 상태 저장

    print(f"Best Epoch: {best_epoch} with NPMI: {best_npmi}")
    # 훈련 완료 후, 최적 모델 상태 저장
    torch.save(best_model_state, 'our_best_model_state.pth')
    model.load_state_dict(torch.load('our_best_model_state.pth'))
    
    print("------- Evaluation results -------")
    #각 토픽당 가지는 워드셋
    all_list = {}
    for e, i in enumerate(model.beta.cpu().topk(10, dim=1).indices):
        word_list = []
        for j in i:
            word_list.append(vocab_dict_reverse[j.item()])
        all_list[e] = word_list
        print("topic-{}".format(e), word_list)

Coeff   / regul: 1.00000 - recon: 1.00000 - c: 1.00000 - dist: 1.00000 
Epoch-0 / recon: 10.68118 - dist: 0.04441 - cons: -0.33590
Epoch-1 / recon: 10.38936 - dist: 0.04398 - cons: -0.33562
Epoch-2 / recon: 10.22574 - dist: 0.04602 - cons: -0.34295
Epoch-3 / recon: 10.12088 - dist: 0.04785 - cons: -0.36033
Epoch-4 / recon: 10.04564 - dist: 0.04752 - cons: -0.37663
Epoch-5 / recon: 9.98726 - dist: 0.04712 - cons: -0.39320
Epoch-6 / recon: 9.93895 - dist: 0.04642 - cons: -0.40729
Epoch-7 / recon: 9.89825 - dist: 0.04492 - cons: -0.41888
Epoch-8 / recon: 9.86333 - dist: 0.04387 - cons: -0.42810
Epoch-9 / recon: 9.83305 - dist: 0.04332 - cons: -0.43593
Epoch-10 / recon: 9.80656 - dist: 0.04257 - cons: -0.44235
Epoch-11 / recon: 9.78316 - dist: 0.04182 - cons: -0.44817
Epoch-12 / recon: 9.76237 - dist: 0.04155 - cons: -0.45309
Epoch-13 / recon: 9.74369 - dist: 0.04114 - cons: -0.45738
Epoch-14 / recon: 9.72688 - dist: 0.04090 - cons: -0.46125
Epoch-15 / recon: 9.71166 - dist: 0.04071 - cons

# Hungarian matching & Label purity

- 헝가리안 매칭은 topic 수 3개로 설정 (acc 재기 위함)
- label purity는 topic 수 20개

In [37]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from scipy.optimize import linear_sum_assignment
from sklearn.metrics import mutual_info_score
from scipy.stats import chi2_contingency
from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora.dictionary import Dictionary
from coherence import get_topic_coherence

# 헝가리안 매칭을 위한 함수 정의
def measure_hungarian_score(topic_dist, train_target):
    dist = topic_dist
    train_target_filtered = train_target
    flat_predict = torch.tensor(np.argmax(dist, axis=1))
    flat_target = torch.tensor(train_target_filtered).to(flat_predict.device)
    num_samples = flat_predict.shape[0]
    num_classes = dist.shape[1]
    match = _hungarian_match(flat_predict, flat_target, num_samples, num_classes)    
    reordered_preds = torch.zeros(num_samples).to(flat_predict.device)
    for pred_i, target_i in match:
        reordered_preds[flat_predict == pred_i] = int(target_i)
    acc = int((reordered_preds == flat_target.float()).sum()) / float(num_samples)
    return acc

def _hungarian_match(predicted, target, num_samples, num_classes):
    matrix = np.zeros((num_classes, num_classes), dtype=np.uint64)
    for i in range(num_samples):
        matrix[predicted[i], target[i]] += 1
    row_ind, col_ind = linear_sum_assignment(matrix.max() - matrix)
    return list(zip(row_ind, col_ind))

# 테스트 데이터 로더 및 모델 준비
test_loader = DataLoader(testfinetuneds, batch_size=bsz, shuffle=False, num_workers=0)

# 테스트 데이터셋에 대한 라벨 할당 및 정확도 계산
topic_labels = [[] for _ in range(model.N_topic)]
for batch in test_loader:
    _, org_input, _, _, _, platform_labels, actual_labels = batch
    org_input = org_input.to(device)
    with torch.no_grad():
        org_topic_logit = model.decode(org_input)[1]
        org_topic = F.softmax(org_topic_logit, dim=1)
    dominant_topics = torch.argmax(org_topic, dim=1).cpu().numpy()
    for topic_idx, label in zip(dominant_topics, actual_labels):
        topic_labels[topic_idx].append(label)


In [50]:
# 라벨('politics', 'sport', 'tech')

topic_distributions = []
for batch in testloader:
    _, org_input, _, _, _, platform_labels, actual_labels = batch
    org_input = org_input.to(device)
    with torch.no_grad():
        _, org_topic_logit = model.decode(org_input)
        org_topic = F.softmax(org_topic_logit, dim=1)
    topic_distributions.append(org_topic.cpu().numpy())

topic_dist_test = np.concatenate(topic_distributions, axis=0)
test_target = np.array(test_total_label_list)  

test_target_adjusted = test_target[:len(topic_dist_test)]

label_to_int = {label: idx for idx, label in enumerate(sorted(set(test_total_label_list)))}
int_test_target = np.array([label_to_int[label] for label in test_total_label_list])

int_test_target_adjusted = int_test_target[:len(topic_dist_test)]

# 테스트 데이터셋에 대한 정확도 계산
accuracy_test = measure_hungarian_score(topic_dist_test, int_test_target_adjusted)
print("Test Accuracy:", accuracy_test)

('sport', 'politics', 'tech', 'sport', 'tech', 'politics', 'tech', 'sport', 'tech', 'tech', 'politics', 'sport', 'sport', 'sport', 'politics', 'politics', 'sport', 'politics', 'sport', 'tech', 'politics', 'tech', 'sport', 'politics', 'politics', 'politics', 'tech', 'sport', 'sport', 'politics', 'sport', 'tech')
('sport', 'politics', 'sport', 'sport', 'tech', 'politics', 'sport', 'politics', 'tech', 'politics', 'politics', 'sport', 'sport', 'sport', 'sport', 'politics', 'tech', 'tech', 'tech', 'tech', 'tech', 'tech', 'tech', 'politics', 'sport', 'tech', 'sport', 'tech', 'politics', 'tech', 'tech', 'politics')
('sport', 'politics', 'politics', 'sport', 'politics', 'sport', 'sport', 'sport', 'tech', 'politics', 'tech', 'tech', 'sport', 'sport', 'politics', 'politics', 'tech', 'sport', 'sport', 'tech', 'sport', 'politics', 'politics', 'sport', 'politics', 'sport', 'sport', 'tech', 'sport', 'sport', 'tech', 'sport')
('tech', 'sport', 'politics', 'politics', 'sport', 'politics', 'sport', 'sp

In [56]:
# 플랫폼 
topic_distributions = []
for batch in testloader:
    _, org_input, _, _, _, platform_labels, actual_labels = batch
    org_input = org_input.to(device)
    with torch.no_grad():
        _, org_topic_logit = model.decode(org_input)
        org_topic = F.softmax(org_topic_logit, dim=1)
    topic_distributions.append(org_topic.cpu().numpy())

# 배열 형태로 변환
topic_dist_test = np.concatenate(topic_distributions, axis=0)
test_target = np.array(test_total_platform_list)  # 실제 라벨을 배열 형태로 변환

test_target_adjusted = test_target[:len(topic_dist_test)]
label_to_int = {label: idx for idx, label in enumerate(sorted(set(test_total_platform_list)))}
int_test_target = np.array([label_to_int[label] for label in test_total_platform_list])
int_test_target_adjusted = int_test_target[:len(topic_dist_test)]

# 테스트 데이터셋에 대한 정확도 계산
accuracy_test = measure_hungarian_score(topic_dist_test, int_test_target_adjusted)
print("Test Accuracy:", accuracy_test)

Test Accuracy: 0.3380084151472651


In [44]:
# 라벨 퓨리티 계산
def calculate_label_purity(topic_labels):
    purity_scores = []
    for labels in topic_labels:
        if not labels:
            continue
        most_common_label = max(set(labels), key=labels.count)
        purity = labels.count(most_common_label) / len(labels)
        purity_scores.append(purity)
    average_purity = sum(purity_scores) / len(purity_scores) if purity_scores else 0
    return purity_scores, average_purity

# 토픽 별 라벨 할당
topic_labels = [[] for _ in range(model.N_topic)] 

for batch in trainloader:
    _, org_input, _, _, _, platform_labels, actual_labels = batch
    print(actual_labels)
    org_input = org_input.to(device)
    with torch.no_grad():  
        org_topic_logit = model.decode(org_input)[1] 
        org_topic = F.softmax(org_topic_logit, dim=1)  

    dominant_topics = torch.argmax(org_topic, dim=1).cpu().numpy()  # 가장 높은 확률의 토픽 인덱스 추출

    for topic_idx, label in zip(dominant_topics, actual_labels):
        topic_labels[topic_idx].append(label)


purity_scores, average_purity = calculate_label_purity(topic_labels)

print("각 토픽의 라벨 퓨리티:", purity_scores)
print("평균 라벨 퓨리티:", average_purity)

('tech', 'tech', 'politics', 'tech', 'tech', 'sport', 'sport', 'politics', 'politics', 'tech', 'sport', 'tech', 'tech', 'sport', 'sport', 'sport', 'politics', 'tech', 'tech', 'sport', 'sport', 'sport', 'sport', 'tech', 'politics', 'sport', 'sport', 'politics', 'politics', 'sport', 'politics', 'sport')
('politics', 'tech', 'tech', 'tech', 'politics', 'tech', 'politics', 'tech', 'sport', 'tech', 'tech', 'politics', 'tech', 'tech', 'politics', 'politics', 'politics', 'tech', 'sport', 'tech', 'sport', 'tech', 'politics', 'tech', 'tech', 'sport', 'sport', 'sport', 'tech', 'sport', 'sport', 'tech')
('sport', 'sport', 'politics', 'politics', 'sport', 'politics', 'tech', 'politics', 'tech', 'sport', 'sport', 'tech', 'politics', 'tech', 'politics', 'sport', 'sport', 'politics', 'sport', 'sport', 'tech', 'tech', 'tech', 'sport', 'sport', 'politics', 'tech', 'tech', 'politics', 'politics', 'tech', 'politics')
('sport', 'tech', 'politics', 'tech', 'tech', 'tech', 'politics', 'politics', 'tech', 'p

# MI

In [182]:
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

# 기존 코드의 나머지 부분...

# 모델을 평가 모드로 설정
model.eval()

# 테스트 데이터 로더 설정
testloader = DataLoader(testfinetuneds, batch_size=bsz * 2, shuffle=False, num_workers=0)
    
# 토픽 라벨과 플랫폼 라벨 추출
ourmodel_test_topic_labels = []
test_platform_labels = []

for batch in testloader:
#     _, org_embedding, _, org_bow, _, platform_labels = batch
    _, org_embedding, _, org_bow, _, platform_labels, actual_labels = batch
    org_embedding = org_embedding.to(gpu_ids[0])
    _, topic_logit = model.decode(org_embedding)
    topic_label = torch.argmax(F.softmax(topic_logit, dim=1), dim=1)
    ourmodel_test_topic_labels.extend(topic_label.cpu().numpy())
    test_platform_labels.extend(platform_labels)

# 플랫폼별 토픽 분포 계산
topic_dist_df_test = pd.crosstab(pd.Series(ourmodel_test_topic_labels, name='Topic'),
                            pd.Series(test_platform_labels, name='Platform'), normalize='index')

# 플랫폼별 및 전체에 대한 토픽 분포를 계산
platform_counts = pd.Series(test_platform_labels).value_counts()
platform_probabilities = platform_counts / platform_counts.sum()

# 전체 데이터셋에 대한 토픽 분포의 엔트로피 계산 (H(Y))
topic_probabilities = pd.Series(ourmodel_test_topic_labels).value_counts(normalize=True)
H_Y = -np.sum(topic_probabilities * np.log(topic_probabilities + 1e-10))

# 각 플랫폼별 조건부 엔트로피 계산 및 H(Y|X) 계산
H_Y_given_X_total = 0
for platform in platform_probabilities.index:
    # 해당 플랫폼에 대한 토픽 라벨 필터링
    platform_indices = [i for i, x in enumerate(test_platform_labels) if x == platform]
    platform_topic_labels = [ourmodel_test_topic_labels[i] for i in platform_indices]
    platform_topic_prob = pd.Series(platform_topic_labels).value_counts(normalize=True)
    
    # 플랫폼별 조건부 엔트로피 계산
    H_Y_given_X = -np.sum(platform_topic_prob * np.log(platform_topic_prob + 1e-10))
    H_Y_given_X_total += platform_probabilities[platform] * H_Y_given_X

# 상호정보량(MI) 계산
mi = H_Y - H_Y_given_X_total
# 플랫폼 분포의 엔트로피 H(X) 계산
H_X = -np.sum(platform_probabilities * np.log(platform_probabilities + 1e-10))

# 상호정보량(MI)을 이미 계산했다고 가정
# MI = H_Y - H_Y_given_X_total

# 정규화된 상호정보량(NMI) 계산
NMI = 2 * mi / (H_Y + H_X)

print("Normalized Mutual Information (NMI):", NMI)

print('H(Y):', H_Y)
print('H(Y|X):', H_Y_given_X_total)
print('Mutual Information (MI):', mi)

mi_score = mutual_info_score(ourmodel_test_topic_labels, test_platform_labels)
print("Original Mutual Information Score:", mi_score)

Normalized Mutual Information (NMI): 0.03602773971758664
H(Y): 2.8874862389659492
H(Y|X): 2.815682920913974
Mutual Information (MI): 0.07180331805197504
Original Mutual Information Score: 0.0718033180846549


# Coherence & Topic Diversity

In [172]:
# 주제 다양성 계산 함수
def calculate_topic_diversity(topic_words_list):
    all_words = set()
    for words in topic_words_list:
        all_words.update(words)
    unique_words_count = len(all_words)
    total_words_count = sum(len(words) for words in topic_words_list)
    return unique_words_count / total_words_count

# 주제 다양성 계산
topic_diversity = calculate_topic_diversity(topic_words_list)
print("Topic Diversity:",topic_diversity)

Topic Diversity: 0.975


In [173]:
topic_words_list = list(all_list.values())
reference_corpus=[doc.split() for doc in testds.preprocess_ctm(testds.nonempty_text)]

topics=topic_words_list
texts=reference_corpus
print(get_topic_coherence(topics, texts))

{'NPMI': 0.6073848301599067, 'UCI': 3.7812460006494004, 'UMASS': -1.091080646445589, 'CV': 0.7076743326717905, 'Topic_Diversity': 0.975}


In [235]:
reference_corpus = [doc.split() for doc in testds.preprocess_ctm(testds.nonempty_text)]
platform_labels = testds.platform_label_list  

platform_corpora = {'bbc': [], 'nyt': [], '20newsgroups': []}

for doc, platform in zip(reference_corpus, platform_labels):
    if platform in platform_corpora:
        platform_corpora[platform].append(doc)

bbc_texts = platform_corpora['bbc']
nyt_texts = platform_corpora['nyt']
news_texts = platform_corpora['20newsgroups']

# 각 플랫폼 별 문서 수 확인
print(f"'bbc' 문서 수: {len(bbc_texts)}")
print(f"'nyt' 문서 수: {len(nyt_texts)}")
print(f"'20newsgroups' 문서 수: {len(news_texts)}")


'bbc' 문서 수: 233
'nyt' 문서 수: 240
'20newsgroups' 문서 수: 240


In [236]:
print(get_topic_coherence(topics, bbc_texts))

{'NPMI': 0.547141676456582, 'UCI': 2.9629632092033424, 'UMASS': -1.1028965338813588, 'CV': 0.7743565001071542, 'Topic_Diversity': 0.975}


In [237]:
print(get_topic_coherence(topics, nyt_texts))

{'NPMI': 0.6110145593992476, 'UCI': 2.798549250650759, 'UMASS': -0.6914497119293894, 'CV': 0.8624470661183867, 'Topic_Diversity': 0.975}


In [238]:
print(get_topic_coherence(topics, news_texts))

{'NPMI': 0.7354736894690324, 'UCI': 4.03086575972686, 'UMASS': -0.44793986730701374, 'CV': 0.81696670820916, 'Topic_Diversity': 0.975}
