In [1]:
args_text = '--base-model sentence-transformers/paraphrase-MiniLM-L6-v2 ' + \
            '--dataset all --n-word 2000 --epochs-1 100 --epochs-2 50 ' + \
            '--bsz 32 --stage-2-lr 2e-2 --stage-2-repeat 5 --coeff-1-dist 50 ' + \
            '--n-cluster 20 ' + \
            '--stage-1-ckpt trained_model/news_model_paraphrase-MiniLM-L6-v2_stage1_20t_2000w_99e.ckpt ' + \
            '--palmetto-dir /home/don12/twitter_crawling_don/UTopic'

In [2]:
import re
import os
import sys
import time
import copy
import math
import argparse
import string
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtools.optim import RangerLars
import gensim.downloader
import itertools

from scipy.stats import ortho_group
from scipy.optimize import linear_sum_assignment as linear_assignment
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer

import numpy as np
from tqdm import tqdm_notebook
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.datasets import fetch_20newsgroups
from nltk.corpus import stopwords

from sklearn.feature_extraction.text import CountVectorizer
from utils import AverageMeter
from collections import OrderedDict

import pandas as pd
from sklearn.preprocessing import normalize
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from gensim.corpora.dictionary import Dictionary
from pytorch_transformers import *
from sklearn.mixture import GaussianMixture
import scipy.stats
from sklearn.decomposition import PCA
from sklearn.cluster import OPTICS
from nltk.corpus import stopwords

from gensim.models.coherencemodel import CoherenceModel
from tqdm import tqdm
import scipy.sparse as sp
import nltk
from nltk.corpus import stopwords

from datetime import datetime
from itertools import combinations
import gensim.downloader
from scipy.linalg import qr
from data import *
from data import TwitterDataset, RedditDataset, YoutubeDataset, BertDataset
from model import ContBertTopicExtractorAE
from evaluation import get_topic_qualities

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from collections import OrderedDict
from torch.utils.data import ConcatDataset
import warnings
warnings.filterwarnings("ignore")

%load_ext autoreload
%autoreload 2

[nltk_data] Downloading package punkt to /home/don12/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /home/don12/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /home/don12/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [3]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "2,3" 

In [4]:
def _parse_args():
    parser = argparse.ArgumentParser(description='Contrastive topic modeling')
    # 각 stage에서의 epochs수 
    parser.add_argument('--epochs-1', default=50, type=int,
                        help='Number of training epochs for Stage 1')   
    parser.add_argument('--epochs-2', default=10, type=int,
                        help='Number of training epochs for Stage 2')
    #각 stage에서의 batch size
    parser.add_argument('--bsz', type=int, default=64,
                        help='Batch size')
    #data set정의 
    parser.add_argument('--dataset', default='twitter', type=str,
                        choices=['twitter', 'reddit', 'youtube','all'],
                        help='Name of the dataset')
    # 클러스터 수와 topic의 수는 20 (k==20)
    parser.add_argument('--n-cluster', default=20, type=int,
                        help='Number of clusters')
    parser.add_argument('--n-topic', type=int,
                        help='Number of topics. If not specified, use same value as --n-cluster')
    # 단어vocabulary는 2000로 setting
    parser.add_argument('--n-word', default=2000, type=int,
                        help='Number of words in vocabulary')
    
    parser.add_argument('--base-model', type=str,
                        help='Name of base model in huggingface library.')
    
    parser.add_argument('--gpus', default=[0,1,2,3], type=int, nargs='+',
                        help='List of GPU numbers to use. Use 0 by default')
   
    parser.add_argument('--coeff-1-sim', default=1.0, type=float,
                        help='Coefficient for NN dot product similarity loss (Phase 1)')
    parser.add_argument('--coeff-1-dist', default=1.0, type=float,
                        help='Coefficient for NN SWD distribution loss (Phase 1)')
    parser.add_argument('--dirichlet-alpha-1', type=float,
                        help='Parameter for Dirichlet distribution (Phase 1). Use 1/n_topic by default.')
    
    parser.add_argument('--stage-1-ckpt', type=str,
                        help='Name of torch checkpoint file Stage 1. If this argument is given, skip Stage 1.')
 
    parser.add_argument('--coeff-2-recon', default=1.0, type=float,
                        help='Coefficient for VAE reconstruction loss (Phase 2)')
    parser.add_argument('--coeff-2-regul', default=1.0, type=float,
                        help='Coefficient for VAE KLD regularization loss (Phase 2)')
    parser.add_argument('--coeff-2-cons', default=1.0, type=float,
                        help='Coefficient for CL consistency loss (Phase 2)')
    parser.add_argument('--coeff-2-dist', default=1.0, type=float,
                        help='Coefficient for CL SWD distribution matching loss (Phase 2)')
    parser.add_argument('--dirichlet-alpha-2', type=float,
                        help='Parameter for Dirichlet distribution (Phase 2). Use same value as dirichlet-alpha-1 by default.')
    parser.add_argument('--stage-2-lr', default=2e-1, type=float,
                        help='Learning rate of phase 2')
    
    parser.add_argument('--stage-2-repeat', default=5, type=int,
                        help='Repetition count of phase 2')
    
    parser.add_argument('--result-file', type=str,
                        help='File name for result summary')
    parser.add_argument('--palmetto-dir', type=str,
                        help='Directory where palmetto JAR and the Wikipedia index are. For evaluation')
    
    
    # Check if the code is run in Jupyter notebook
    is_in_jupyter = False
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            is_in_jupyter = True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            is_in_jupyter = False  # Terminal running IPython
        else:
            is_in_jupyter = False  # Other type (?)
    except NameError:
        is_in_jupyter = False
    
    if is_in_jupyter:
        return parser.parse_args(args=args_text.split())
    else:
        return parser.parse_args()

#데이터들을 로드함(텍스트중에 none값을 제거/dataset_name에 따라 textData에 들어가는 내용이 달라짐)
def data_load(dataset_name, sample_size=1000):
    should_measure_hungarian = False
    textData = []
    
    if dataset_name == 'twitter':
        dataset = TwitterDataset(sample_size=sample_size)
        textData = [(text, "Twitter") for text in dataset.texts if not (isinstance(text, float) and math.isnan(text))]
    elif dataset_name == 'reddit':
        dataset = RedditDataset(sample_size=sample_size)
        textData = [(text, "Reddit") for text in dataset.texts if not (isinstance(text, float) and math.isnan(text))]
    elif dataset_name == 'youtube':
        dataset = YoutubeDataset(sample_size=sample_size)
        textData = [(text, "YouTube") for text in dataset.texts if not (isinstance(text, float) and math.isnan(text))]
    elif dataset_name == 'all':
        twitter_dataset = TwitterDataset(sample_size=sample_size)
        reddit_dataset = RedditDataset(sample_size=sample_size)
        youtube_dataset = YoutubeDataset(sample_size=sample_size)
        
        # filtering NaN values
        textData += [(text, "Twitter") for text in twitter_dataset.texts if not (isinstance(text, float) and math.isnan(text))]
        textData += [(text, "Reddit") for text in reddit_dataset.texts if not (isinstance(text, float) and math.isnan(text))]
        textData += [(text, "YouTube") for text in youtube_dataset.texts if not (isinstance(text, float) and math.isnan(text))]
    else:
        raise ValueError("Invalid dataset name!")
    
    return textData, should_measure_hungarian


In [5]:
args = _parse_args()
bsz = args.bsz
epochs_1 = args.epochs_1
epochs_2 = args.epochs_2

n_cluster = args.n_cluster
n_topic = args.n_topic if (args.n_topic is not None) else n_cluster
args.n_topic = n_topic

textData, should_measure_hungarian = data_load(args.dataset)

ema_alpha = 0.99
n_word = args.n_word
if args.dirichlet_alpha_1 is None:
    dirichlet_alpha_1 = 1 / n_cluster
else:
    dirichlet_alpha_1 = args.dirichlet_alpha_1
if args.dirichlet_alpha_2 is None:
    dirichlet_alpha_2 = dirichlet_alpha_1
else:
    dirichlet_alpha_2 = args.dirichlet_alpha_2
    
bert_name = args.base_model
bert_name_short = bert_name.split('/')[-1]
gpu_ids = args.gpus

skip_stage_1 = (args.stage_1_ckpt is not None)

In [6]:
jargons = set(['bigdata', 'RT', 'announces', 'rival', 'MSFT', 'launches', 'SEO', 'tweets', 'WSJ', 'unveils', 'MBA', 'machinelearning', 'HN', 'artificalintelligence', 'DALLE', 'NFT', 'NLP', 'edtech', 'cybersecurity', 'malware', 'CEO', 'ML', 'JPEG', 'GOOGL', 'UX', 'artificialintelligence', 'technews', 'fintech', 'tweet', 'CHATGPT', 'viral', 'blockchain', 'BARD', 'reportedly', 'BREAKING', 'deeplearning', 'educators', 'datascience', 'GOOG', 'competitor', 'FREE', 'startups', 'AGIX', 'CHAT', 'CNET', 'integrating', 'BB', 'maker', 'passes', 'NYT', 'yall', 'rescue', 'heres', 'NYC', 'LLM', 'airdrop', 'powered', 'podcast', 'Q', 'PM', 'newsletter', 'startup', 'TSLA', 'WIRED', 'digitalmarketing', 'CTO', 'changer', 'announced', 'database', 'launched', 'warns', 'popularity', 'AWS', 'nocode', 'trending', 'metaverse', 'cc', 'PR', 'RSS', 'MWC', 'USMLE', 'copywriting', 'marketers', 'tweeting', 'amid', 'AGI', 'socialmedia', 'webinar', 'agrees', 'invests', 'launch', 'killer', 'GM', 'bullish', 'edchat', 'RLHF', 'integration', 'fastestgrowing', 'CNN', 'exam',
                  'deleted', 'gif', 'giphy', 'dm', 'removed', 'remindme', 'yup', 'sydney', 'yep', 'patched', 'nope', 'giphydownsized', 'vpn', 'ascii', 'ah', 'chadgpt', 'nerfed', 'jesus', 'xd', 'wtf', 'upvote', 'nah', 'op', 'mods', 'hahaha', 'nsfw', 'huh', 'holy', 'iq', 'jailbreak', 'blah', 'bruh', 'yea', 'agi', 'porn', 'waitlist', 'nerf', 'downvoted', 'refresh', 'omg', 'sus', 'characterai', 'meth', 'chinese', 'sub', 'rick', 'american', 'elon', 'sam', 'quack', 'youchat', 'uk', 'chad', 'archived', 'youcom', 'screenshot', 'llm', 'hitler', 'lmao', 'playground', 'rpg', 'delete', 'tldr', 'davinci', 'trump', 'hangman', 'haha', 'tay', 'karma', 'john', 'chatgtp', 'url', 'wokegpt', 'offended', 'fucked', 'redditor', 'ceo', 'agreed', 'emojis', 'cheers', 'ais', 'tag', 'wow', 'lmfao', 'p', 'rip', 'chats', 'hmm', 'bypass', 'llms', 'temperature', 'login', 'cgpt', 'windows', 'novelai', 'biden', 'donald', 'christmas', 'ms', 'cringe',
                   'ZRONX', 'rook', 'thumbnail', 'vid', 'bhai', 'bishop', 'circle', 'subscribed', 'quot', 'bless', 'tutorial', 'XD', 'sir', 'GEMX', 'profitable', 'earning', 'quotquot', 'enjoyed', 'ur', 'bra', 'JIM', 'broker', 'levy', 'vids', 'stare', 'tutorials', 'subscribers', 'sponsor', 'hai', 'lifechanging', 'curve', 'shorts', 'earn', 'trader', 'PC', 'folders', 'informative', 'br', 'chess', 'jontron', 'brother', 'T', 'YT', 'upload', 'O', 'subscriber', 'intro', 'DAN', 'aint', 'download', 'LOL', 'shes', 'moves', 'telegram', 'shortlisted', 'liked', 'websiteapp', 'watched', 'grant', 'plz', 'KINGDOM', 'YOU', 'MESSIAH', 'mate', 'ki', 'subs', 'pawn', 'hes', 'U', 'HACKBANZER', 'ka', 'brbr', 'affiliate', 'clip', 'beast', 'trade', 'ive', 'ho', 'approved', 'bhi', 'gotta', 'profits', 'wanna', 'subscribe', 'funds', 'labels', 'recommended', 'audio', 'uploaded', 'appreciated', 'UBI', 'pls', 'upto', 'alot', 'twist', 'GTP', 'accent', 'monetized', 'S', 'btw'
                  ])

In [7]:
twitter_ds = TwitterDataset()
reddit_ds = RedditDataset()
youtube_ds = YoutubeDataset()

# 각 데이터셋에서 가져온 데이터를 합칩니다.
total_text_list = []
for ds in [twitter_ds, reddit_ds, youtube_ds]:
    for idx in range(len(ds)):
        text, target, platform_label = ds[idx]
        total_text_list.append({"text": text, "target": target, "platform": platform_label})

In [8]:
total_text_list

[{'text': 'ChatGPT is remarkable The ability to repurpose and condense information into coherent sentences is impressive It might make you believe it is human though its extreme literalmindedness leaves it a long way from convincing you it is British\n\n',
  'target': tensor(0),
  'platform': 'Twitter'},
 {'text': ' Exclusive OpenAI Used Kenyan Workers on Less Than  Per Hour to Make ChatGPT Less Toxic',
  'target': tensor(1),
  'platform': 'Twitter'},
 {'text': 'Drop us some more interesting tech in the comments below which you think can be a game changer \n\nbonuz  G ai metaverse chatGPT',
  'target': tensor(2),
  'platform': 'Twitter'},
 {'text': ' Ask ChatGPT to do it for you',
  'target': tensor(3),
  'platform': 'Twitter'},
 {'text': ' I will feed these parameters to ChatGPT and let it write the email',
  'target': tensor(4),
  'platform': 'Twitter'},
 {'text': ' students using the  hats methods to fully understand chatgpt artificialintelligence edchatie ',
  'target': tensor(5),


In [9]:
len(total_text_list)

3000

In [10]:
class BertDataset(Dataset):
    def __init__(self, bert, text_list, N_word, vectorizer=None, lemmatize=False):
        self.lemmatize = lemmatize
        self.nonempty_text = [entry["text"] for entry in text_list if len(entry["text"]) > 0]
        self.platforms = [entry["platform"] for entry in text_list if len(entry["text"]) > 0]
        
        self.platform_to_index = {"Twitter": 0, "Reddit": 1, "Youtube": 2}
        self.platform_numbers = [self.platform_to_index[platform] for platform in self.platforms]
        
        # Remove new lines
        self.nonempty_text = [re.sub("\n"," ", sent) for sent in self.nonempty_text]
                
        # Remove Emails
        self.nonempty_text = [re.sub('\S*@\S*\s?', '', sent) for sent in self.nonempty_text]
        
        # Remove new line characters
        self.nonempty_text = [re.sub('\s+', ' ', sent) for sent in self.nonempty_text]
        
        # Remove distracting single quotes
        self.nonempty_text = [re.sub("\'", "", sent) for sent in self.nonempty_text]
        
        self.jargons = set(['bigdata', 'RT', 'announces', 'rival', 'MSFT', 'launches', 'SEO', 'tweets', 'WSJ', 'unveils', 'MBA', 'machinelearning', 'HN', 'artificalintelligence', 'DALLE', 'NFT', 'NLP', 'edtech', 'cybersecurity', 'malware', 'CEO', 'ML', 'JPEG', 'GOOGL', 'UX', 'artificialintelligence', 'technews', 'fintech', 'tweet', 'CHATGPT', 'viral', 'blockchain', 'BARD', 'reportedly', 'BREAKING', 'deeplearning', 'educators', 'datascience', 'GOOG', 'competitor', 'FREE', 'startups', 'AGIX', 'CHAT', 'CNET', 'integrating', 'BB', 'maker', 'passes', 'NYT', 'yall', 'rescue', 'heres', 'NYC', 'LLM', 'airdrop', 'powered', 'podcast', 'Q', 'PM', 'newsletter', 'startup', 'TSLA', 'WIRED', 'digitalmarketing', 'CTO', 'changer', 'announced', 'database', 'launched', 'warns', 'popularity', 'AWS', 'nocode', 'trending', 'metaverse', 'cc', 'PR', 'RSS', 'MWC', 'USMLE', 'copywriting', 'marketers', 'tweeting', 'amid', 'AGI', 'socialmedia', 'webinar', 'agrees', 'invests', 'launch', 'killer', 'GM', 'bullish', 'edchat', 'RLHF', 'integration', 'fastestgrowing', 'CNN', 'exam',
                  'deleted', 'gif', 'giphy', 'dm', 'removed', 'remindme', 'yup', 'sydney', 'yep', 'patched', 'nope', 'giphydownsized', 'vpn', 'ascii', 'ah', 'chadgpt', 'nerfed', 'jesus', 'xd', 'wtf', 'upvote', 'nah', 'op', 'mods', 'hahaha', 'nsfw', 'huh', 'holy', 'iq', 'jailbreak', 'blah', 'bruh', 'yea', 'agi', 'porn', 'waitlist', 'nerf', 'downvoted', 'refresh', 'omg', 'sus', 'characterai', 'meth', 'chinese', 'sub', 'rick', 'american', 'elon', 'sam', 'quack', 'youchat', 'uk', 'chad', 'archived', 'youcom', 'screenshot', 'llm', 'hitler', 'lmao', 'playground', 'rpg', 'delete', 'tldr', 'davinci', 'trump', 'hangman', 'haha', 'tay', 'karma', 'john', 'chatgtp', 'url', 'wokegpt', 'offended', 'fucked', 'redditor', 'ceo', 'agreed', 'emojis', 'cheers', 'ais', 'tag', 'wow', 'lmfao', 'p', 'rip', 'chats', 'hmm', 'bypass', 'llms', 'temperature', 'login', 'cgpt', 'windows', 'novelai', 'biden', 'donald', 'christmas', 'ms', 'cringe',
                   'ZRONX', 'rook', 'thumbnail', 'vid', 'bhai', 'bishop', 'circle', 'subscribed', 'quot', 'bless', 'tutorial', 'XD', 'sir', 'GEMX', 'profitable', 'earning', 'quotquot', 'enjoyed', 'ur', 'bra', 'JIM', 'broker', 'levy', 'vids', 'stare', 'tutorials', 'subscribers', 'sponsor', 'hai', 'lifechanging', 'curve', 'shorts', 'earn', 'trader', 'PC', 'folders', 'informative', 'br', 'chess', 'jontron', 'brother', 'T', 'YT', 'upload', 'O', 'subscriber', 'intro', 'DAN', 'aint', 'download', 'LOL', 'shes', 'moves', 'telegram', 'shortlisted', 'liked', 'websiteapp', 'watched', 'grant', 'plz', 'KINGDOM', 'YOU', 'MESSIAH', 'mate', 'ki', 'subs', 'pawn', 'hes', 'U', 'HACKBANZER', 'ka', 'brbr', 'affiliate', 'clip', 'beast', 'trade', 'ive', 'ho', 'approved', 'bhi', 'gotta', 'profits', 'wanna', 'subscribe', 'funds', 'labels', 'recommended', 'audio', 'uploaded', 'appreciated', 'UBI', 'pls', 'upto', 'alot', 'twist', 'GTP', 'accent', 'monetized', 'S', 'btw'
                  ])
        
        self.tokenizer = AutoTokenizer.from_pretrained(bert)
        self.stopwords_list = set(TfidfVectorizer(stop_words="english").get_stop_words()).union(self.jargons)
        self.N_word = N_word
        
        if vectorizer == None:
            self.vectorizer = TfidfVectorizer(stop_words=None, max_features=self.N_word, token_pattern=r'\b[a-zA-Z]{2,}\b')
            self.vectorizer.fit(self.preprocess_ctm(self.nonempty_text))
        else:
            self.vectorizer = vectorizer
            
        self.org_list = []
        self.bow_list = []
        for sent in tqdm(self.nonempty_text):
            org_input = self.tokenizer(sent, padding='max_length', truncation=True, max_length=512, return_tensors='pt')
            org_input['input_ids'] = torch.squeeze(org_input['input_ids'])
            org_input['attention_mask'] = torch.squeeze(org_input['attention_mask'])
            self.org_list.append(org_input)
            self.bow_list.append(self.vectorize(sent))
            
        
    def vectorize(self, text):
        text = self.preprocess_ctm([text])
        vectorized_input = self.vectorizer.transform(text)
        vectorized_input = vectorized_input.toarray()
        vectorized_input = vectorized_input.astype(np.float64)

        # Get word distribution from BoW
        vectorized_input += 1e-8
        vectorized_input = vectorized_input / vectorized_input.sum(axis=1, keepdims=True)
        assert abs(vectorized_input.sum() - vectorized_input.shape[0]) < 0.01
        vectorized_label = torch.tensor(vectorized_input, dtype=torch.float)
        return vectorized_label[0]
        
        
    def preprocess_ctm(self, documents):
        preprocessed_docs_tmp = documents
        preprocessed_docs_tmp = [doc.lower() for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [doc.translate(
            str.maketrans(string.punctuation, ' ' * len(string.punctuation))) for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [' '.join([w for w in doc.split() if len(w) > 0 and w not in self.stopwords_list])
                             for doc in preprocessed_docs_tmp]
        if self.lemmatize:
            lemmatizer = WordNetLemmatizer()
            preprocessed_docs_tmp = [' '.join([lemmatizer.lemmatize(w) for w in doc.split()])
                                     for doc in preprocessed_docs_tmp]
        return preprocessed_docs_tmp        
    
    
    def __len__(self):
        return len(self.nonempty_text)
    

    def __getitem__(self, idx):
        return self.org_list[idx], self.bow_list[idx], self.platform_numbers[idx]

In [11]:
#trainds = BertDataset(bert=bert_name, text_list=total_text_list, N_word=n_word, vectorizer=None, lemmatize=True)
# BertDataset 클래스를 초기화합니다.
trainds = BertDataset(bert=bert_name, text_list=total_text_list, N_word=n_word, vectorizer=None, lemmatize=True)

100%|██████████| 3000/3000 [00:08<00:00, 336.84it/s]


In [12]:
len(trainds.platform_numbers)

3000

In [13]:
# total_text_list를 문자열 리스트로 변경합니다.
total_text_list_str = [entry['text'] for entry in total_text_list]
total_bow_matrix = trainds.vectorizer.transform(total_text_list_str).toarray()

In [14]:
#확인용 코드
total_bow_matrix.shape

(3000, 2000)

In [15]:
n_word = total_bow_matrix.shape[1]

# Re_fornulate the bow

In [16]:
def dist_match_loss(hiddens, alpha=1.0):
    device = hiddens.device
    hidden_dim = hiddens.shape[-1]
    H = np.random.randn(hidden_dim, hidden_dim)
    Q, R = qr(H) 
    rand_w = torch.Tensor(Q).to(device)
    loss_dist_match = get_swd_loss(hiddens, rand_w, alpha)
    return loss_dist_match


def js_div_loss(hidden1, hidden2):
    m = 0.5 * (hidden1 + hidden2)
    return kldiv(m.log(), hidden1) + kldiv(m.log(), hidden2)


def get_swd_loss(states, rand_w, alpha=1.0):
    device = states.device
    states_shape = states.shape
    states = torch.matmul(states, rand_w)
    states_t, _ = torch.sort(states.t(), dim=1)

    # Random vector with length from normal distribution
    states_prior = torch.Tensor(np.random.dirichlet([alpha]*states_shape[1], states_shape[0])).to(device) # (bsz, dim)
    states_prior = torch.matmul(states_prior, rand_w) # (dim, dim)
    states_prior_t, _ = torch.sort(states_prior.t(), dim=1) # (dim, bsz)
    return torch.mean(torch.sum((states_prior_t - states_t)**2, axis=0))

In [17]:
from sklearn.metrics.pairwise import cosine_similarity

In [18]:
def compute_cosine_similarity_matrix(total_bow_matrix, batch_size=500):
    n_rows = total_bow_matrix.shape[0]
    similarity_matrix = None  # 초기화
    
    for i in range(0, n_rows, batch_size):
        start_idx = i
        end_idx = min(i + batch_size, n_rows)
        batch_data = total_bow_matrix[start_idx:end_idx]
        
        batch_similarity = cosine_similarity(batch_data, total_bow_matrix)
        
        if similarity_matrix is None:
            similarity_matrix = np.zeros((n_rows, n_rows))
        
        similarity_matrix[start_idx:end_idx, :] = batch_similarity
        
        # 대각선 원소를 0으로 설정
        np.fill_diagonal(similarity_matrix, -1)
    
    return similarity_matrix

In [19]:
# 코사인 유사도 매트릭스를 계산합니다.
similarity_matrix = compute_cosine_similarity_matrix(total_bow_matrix)

In [20]:
similarity_matrix.shape

(3000, 3000)

In [21]:
class Stage2Dataset(Dataset):
    def __init__(self, encoder, ds, similarity_matrix, N_word, k=1, lemmatize=False):
        self.lemmatize = lemmatize
        self.ds = ds
        self.org_list = self.ds.org_list
        self.nonempty_text = self.ds.nonempty_text
        self.N_word = N_word
        
        self.platforms = self.ds.platforms
        self.platform_to_index = {"Twitter": 0, "Reddit": 1, "Youtube": 2}
        self.platform_idx = [self.platform_to_index[platform] for platform in self.platforms]
        
        self.jargons = set(['bigdata', 'RT', 'announces', 'rival', 'MSFT', 'launches', 'SEO', 'tweets', 'WSJ', 'unveils', 'MBA', 'machinelearning', 'HN', 'artificalintelligence', 'DALLE', 'NFT', 'NLP', 'edtech', 'cybersecurity', 'malware', 'CEO', 'ML', 'JPEG', 'GOOGL', 'UX', 'artificialintelligence', 'technews', 'fintech', 'tweet', 'CHATGPT', 'viral', 'blockchain', 'BARD', 'reportedly', 'BREAKING', 'deeplearning', 'educators', 'datascience', 'GOOG', 'competitor', 'FREE', 'startups', 'AGIX', 'CHAT', 'CNET', 'integrating', 'BB', 'maker', 'passes', 'NYT', 'yall', 'rescue', 'heres', 'NYC', 'LLM', 'airdrop', 'powered', 'podcast', 'Q', 'PM', 'newsletter', 'startup', 'TSLA', 'WIRED', 'digitalmarketing', 'CTO', 'changer', 'announced', 'database', 'launched', 'warns', 'popularity', 'AWS', 'nocode', 'trending', 'metaverse', 'cc', 'PR', 'RSS', 'MWC', 'USMLE', 'copywriting', 'marketers', 'tweeting', 'amid', 'AGI', 'socialmedia', 'webinar', 'agrees', 'invests', 'launch', 'killer', 'GM', 'bullish', 'edchat', 'RLHF', 'integration', 'fastestgrowing', 'CNN', 'exam',
                  'deleted', 'gif', 'giphy', 'dm', 'removed', 'remindme', 'yup', 'sydney', 'yep', 'patched', 'nope', 'giphydownsized', 'vpn', 'ascii', 'ah', 'chadgpt', 'nerfed', 'jesus', 'xd', 'wtf', 'upvote', 'nah', 'op', 'mods', 'hahaha', 'nsfw', 'huh', 'holy', 'iq', 'jailbreak', 'blah', 'bruh', 'yea', 'agi', 'porn', 'waitlist', 'nerf', 'downvoted', 'refresh', 'omg', 'sus', 'characterai', 'meth', 'chinese', 'sub', 'rick', 'american', 'elon', 'sam', 'quack', 'youchat', 'uk', 'chad', 'archived', 'youcom', 'screenshot', 'llm', 'hitler', 'lmao', 'playground', 'rpg', 'delete', 'tldr', 'davinci', 'trump', 'hangman', 'haha', 'tay', 'karma', 'john', 'chatgtp', 'url', 'wokegpt', 'offended', 'fucked', 'redditor', 'ceo', 'agreed', 'emojis', 'cheers', 'ais', 'tag', 'wow', 'lmfao', 'p', 'rip', 'chats', 'hmm', 'bypass', 'llms', 'temperature', 'login', 'cgpt', 'windows', 'novelai', 'biden', 'donald', 'christmas', 'ms', 'cringe',
                   'ZRONX', 'rook', 'thumbnail', 'vid', 'bhai', 'bishop', 'circle', 'subscribed', 'quot', 'bless', 'tutorial', 'XD', 'sir', 'GEMX', 'profitable', 'earning', 'quotquot', 'enjoyed', 'ur', 'bra', 'JIM', 'broker', 'levy', 'vids', 'stare', 'tutorials', 'subscribers', 'sponsor', 'hai', 'lifechanging', 'curve', 'shorts', 'earn', 'trader', 'PC', 'folders', 'informative', 'br', 'chess', 'jontron', 'brother', 'T', 'YT', 'upload', 'O', 'subscriber', 'intro', 'DAN', 'aint', 'download', 'LOL', 'shes', 'moves', 'telegram', 'shortlisted', 'liked', 'websiteapp', 'watched', 'grant', 'plz', 'KINGDOM', 'YOU', 'MESSIAH', 'mate', 'ki', 'subs', 'pawn', 'hes', 'U', 'HACKBANZER', 'ka', 'brbr', 'affiliate', 'clip', 'beast', 'trade', 'ive', 'ho', 'approved', 'bhi', 'gotta', 'profits', 'wanna', 'subscribe', 'funds', 'labels', 'recommended', 'audio', 'uploaded', 'appreciated', 'UBI', 'pls', 'upto', 'alot', 'twist', 'GTP', 'accent', 'monetized', 'S', 'btw'
                  ])
        
#         english_stopwords = nltk.corpus.stopwords.words('english')
#         self.stopwords_list = set(english_stopwords)
        self.stopwords_list = set(TfidfVectorizer(stop_words="english").get_stop_words()).union(self.jargons)
        
        self.vectorizer = TfidfVectorizer(stop_words=None, max_features=self.N_word, token_pattern=r'\b[a-zA-Z]{2,}\b')
        self.vectorizer.fit(self.preprocess_ctm(self.nonempty_text)) 
        self.bow_list = []
        for sent in tqdm(self.nonempty_text):
            self.bow_list.append(self.vectorize(sent))
            
        self.similarity_matrix = torch.tensor(similarity_matrix)  # NumPy 배열을 PyTorch 텐서로 변환
        sim_weight, sim_indices = self.similarity_matrix.topk(k=k, dim=-1)
        zip_iterator = zip(np.arange(len(sim_weight)), sim_indices.squeeze().data.numpy())
        self.pos_dict = dict(zip_iterator)
        
        self.embedding_list = []
        encoder_device = next(encoder.parameters()).device
        for org_input in tqdm(self.org_list):
            org_input_ids = org_input['input_ids'].to(encoder_device).reshape(1, -1)
            org_attention_mask = org_input['attention_mask'].to(encoder_device).reshape(1, -1)
            embedding = encoder(input_ids = org_input_ids, attention_mask = org_attention_mask)
            self.embedding_list.append(embedding['pooler_output'].squeeze().detach().cpu())
            
    
    def __len__(self):
        return len(self.org_list)
        
    def preprocess_ctm(self, documents):
        preprocessed_docs_tmp = documents
        preprocessed_docs_tmp = [doc.lower() for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [doc.translate(
            str.maketrans(string.punctuation, ' ' * len(string.punctuation))) for doc in preprocessed_docs_tmp]
        preprocessed_docs_tmp = [' '.join([w for w in doc.split() if len(w) > 0 and w not in self.stopwords_list])
                                 for doc in preprocessed_docs_tmp]
        if self.lemmatize:
            lemmatizer = WordNetLemmatizer()
            preprocessed_docs_tmp = [' '.join([lemmatizer.lemmatize(w) for w in doc.split()])
                                     for doc in preprocessed_docs_tmp]
        return preprocessed_docs_tmp
        
    def vectorize(self, text):
        text = self.preprocess_ctm([text])
        vectorized_input = self.vectorizer.transform(text)
        vectorized_input = vectorized_input.toarray().astype(np.float64)
#         vectorized_input = (vectorized_input != 0).astype(np.float64)

        # Get word distribution from BoW
        if vectorized_input.sum() == 0:
            vectorized_input += 1e-8
        vectorized_input = vectorized_input / vectorized_input.sum(axis=1, keepdims=True)
        assert abs(vectorized_input.sum() - vectorized_input.shape[0]) < 0.01
        
        vectorized_label = torch.tensor(vectorized_input, dtype=torch.float)
        return vectorized_label[0]
        
        
#     def __getitem__(self, idx):
#         pos_idx = self.pos_dict[idx]
#         return self.embedding_list[idx], self.embedding_list[pos_idx], self.bow_list[idx], self.bow_list[pos_idx]
    def __getitem__(self, idx):
#         print("Requested idx:", idx)  # idx 확인
        pos_idx = self.pos_dict[idx]
#         print("Requested pos_idx:", pos_idx)  # pos_idx 확인

#         if idx >= len(self.embedding_list) or pos_idx >= len(self.embedding_list):
#             print(f"Invalid index. idx: {idx}, pos_idx: {pos_idx}")  # 인덱스 유효성 확인

        return self.embedding_list[idx], self.embedding_list[pos_idx], self.bow_list[idx], self.bow_list[pos_idx], self.platform_idx 


In [22]:
model =  ContBertTopicExtractorAE(N_topic=n_topic, N_word=args.n_word, bert=bert_name, bert_dim=768)

In [23]:
torch.cuda.empty_cache()

In [24]:
finetuneds = Stage2Dataset(model.encoder, trainds, similarity_matrix, n_word, lemmatize=True)    

kldiv = torch.nn.KLDivLoss(reduction='batchmean')
vocab_dict = finetuneds.vectorizer.vocabulary_
vocab_dict_reverse = {i:v for v, i in vocab_dict.items()}
print(n_word)

100%|██████████| 3000/3000 [00:04<00:00, 680.33it/s]
100%|██████████| 3000/3000 [08:00<00:00,  6.25it/s] 

2000





In [25]:
# len(finetuneds.embedding_list)
finetuneds.platform_idx[999:1000]

[0]

# Stage 3

In [26]:
def measure_hungarian_score(topic_dist, train_target):
    dist = topic_dist
    train_target_filtered = train_target
    flat_predict = torch.tensor(np.argmax(dist, axis=1))
    flat_target = torch.tensor(train_target_filtered).to(flat_predict.device)
    num_samples = flat_predict.shape[0]
    num_classes = dist.shape[1]
    match = _hungarian_match(flat_predict, flat_target, num_samples, num_classes)    
    reordered_preds = torch.zeros(num_samples).to(flat_predict.device)
    for pred_i, target_i in match:
        reordered_preds[flat_predict == pred_i] = int(target_i)
    acc = int((reordered_preds == flat_target.float()).sum()) / float(num_samples)
    return acc

In [27]:
from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora.dictionary import Dictionary
from gensim import models
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

# Main

In [35]:
from sklearn.datasets import fetch_20newsgroups
from gensim.models import LdaModel
from gensim.corpora import Dictionary
from gensim.models.coherencemodel import CoherenceModel
import gensim.downloader

# 20Newsgroups 데이터 로드
newsgroups = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))

# 간단한 전처리: 토큰화 및 불용어 제거
processed_data = [text.split() for text in newsgroups.data]


In [34]:
results_list = []

for i in range(args.stage_2_repeat):
    model = ContBertTopicExtractorAE(N_topic=n_topic, N_word=args.n_word, bert=bert_name, bert_dim=768)
#     model.load_state_dict(torch.load(model_stage1_name), strict=True)
    model.beta = nn.Parameter(torch.Tensor(model.N_topic, n_word))
    nn.init.xavier_uniform_(model.beta)
    model.beta_batchnorm = nn.Sequential()
    model.cuda(gpu_ids[0])
    
    losses = AverageMeter()
    dlosses = AverageMeter() 
    rlosses = AverageMeter()
    closses = AverageMeter()
    distlosses = AverageMeter()
    trainloader = DataLoader(finetuneds, batch_size=bsz, shuffle=True, num_workers=0, )
    memoryloader = DataLoader(finetuneds, batch_size=bsz * 2, shuffle=False, num_workers=0)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.stage_2_lr)

    memory_queue = F.softmax(torch.randn(512, n_topic).cuda(gpu_ids[0]), dim=1)
    print("Coeff   / regul: {:.5f} - recon: {:.5f} - c: {:.5f} - dist: {:.5f} ".format(args.coeff_2_regul, 
                                                                                        args.coeff_2_recon,
                                                                                        args.coeff_2_cons,
                                                                                        args.coeff_2_dist))
    for epoch in range(50):
        model.train()
        model.encoder.eval()
        for batch_idx, batch in enumerate(trainloader):  
            org_input, pos_input, org_bow, pos_bow, platform_idx = batch
            org_input = org_input.cuda(gpu_ids[0])
            org_bow = org_bow.cuda(gpu_ids[0])
            pos_input = pos_input.cuda(gpu_ids[0])
            pos_bow = pos_bow.cuda(gpu_ids[0])
            #platform_idx = platform_idx.cuda(gpu_ids[1])

            batch_size = org_input.size(0) #org_input_ids.size(0)

            org_dists, org_topic_logit = model.decode(org_input)
            pos_dists, pos_topic_logit = model.decode(pos_input)

            org_topic = F.softmax(org_topic_logit, dim=1)
            pos_topic = F.softmax(pos_topic_logit, dim=1)

            # reconstruction loss
            # batchmean
#             org_target = torch.matmul(org_topic.detach(), weight_cands)
#             pos_target = torch.matmul(pos_topic.detach(), weight_cands)
            
#             _, org_target = torch.max(org_topic.detach(), 1)
#             _, pos_target = torch.max(pos_topic.detach(), 1)
            
            recons_loss = torch.mean(-torch.sum(torch.log(org_dists + 1E-10) * (org_bow), axis=1), axis=0)
            recons_loss += torch.mean(-torch.sum(torch.log((1-org_dists) + 1E-10) * (1-org_bow), axis=1), axis=0)
            recons_loss += torch.mean(-torch.sum(torch.log(pos_dists + 1E-10) * (pos_bow), axis=1), axis=0)
            recons_loss += torch.mean(-torch.sum(torch.log((1-pos_dists) + 1E-10) * (1-pos_bow), axis=1), axis=0)
            recons_loss *= 0.5

            # consistency loss
            pos_sim = torch.sum(org_topic * pos_topic, dim=-1)
            cons_loss = -pos_sim.mean()

            # distribution loss
            # batchmean
#             distmatch_loss = dist_match_loss(torch.cat((org_topic), dim=0), dirichlet_alpha_2)
            distmatch_loss = dist_match_loss(torch.cat((org_topic,), dim=0), dirichlet_alpha_2)


            loss = args.coeff_2_recon * recons_loss + \
                   args.coeff_2_cons * cons_loss + \
                   args.coeff_2_dist * distmatch_loss 
            
            losses.update(loss.item(), bsz)
            closses.update(cons_loss.item(), bsz)
            rlosses.update(recons_loss.item(), bsz)
            distlosses.update(distmatch_loss.item(), bsz)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        print("Epoch-{} / recon: {:.5f} - dist: {:.5f} - cons: {:.5f}".format(epoch, rlosses.avg, distlosses.avg, closses.avg))

    print("------- Evaluation results -------")
    all_list = {}
    for e, i in enumerate(model.beta.cpu().topk(10, dim=1).indices):
        word_list = []
        for j in i:
            word_list.append(vocab_dict_reverse[j.item()])
        all_list[e] = word_list
        print("topic-{}".format(e), word_list)

    topic_words_list = list(all_list.values())
    now = datetime.now().strftime('%y%m%d_%H%M%S')

    # 사전 생성
    reference_corpus=[doc.split() for doc in trainds.preprocess_ctm(trainds.nonempty_text)]
    dictionary = Dictionary(reference_corpus)
    
    results = get_topic_qualities(topic_words_list, 
            reference_corpus, 
            dictionary=dictionary,
            filename=f'results/{now}.txt')
        
    print(results)
    print()
    results_list.append(results)

Coeff   / regul: 1.00000 - recon: 1.00000 - c: 1.00000 - dist: 1.00000 
Epoch-0 / recon: 8.19787 - dist: 0.19313 - cons: -0.05066
Epoch-1 / recon: 7.97000 - dist: 0.18286 - cons: -0.05138
Epoch-2 / recon: 7.82764 - dist: 0.17776 - cons: -0.05238
Epoch-3 / recon: 7.73250 - dist: 0.17430 - cons: -0.05315
Epoch-4 / recon: 7.66503 - dist: 0.17179 - cons: -0.05520
Epoch-5 / recon: 7.61366 - dist: 0.16945 - cons: -0.05768
Epoch-6 / recon: 7.57116 - dist: 0.16626 - cons: -0.06012
Epoch-7 / recon: 7.53304 - dist: 0.16253 - cons: -0.06271
Epoch-8 / recon: 7.49382 - dist: 0.15904 - cons: -0.06682
Epoch-9 / recon: 7.45578 - dist: 0.15572 - cons: -0.07042
Epoch-10 / recon: 7.41938 - dist: 0.15214 - cons: -0.07376
Epoch-11 / recon: 7.38605 - dist: 0.14928 - cons: -0.07678
Epoch-12 / recon: 7.35467 - dist: 0.14672 - cons: -0.07918
Epoch-13 / recon: 7.32466 - dist: 0.14444 - cons: -0.08150
Epoch-14 / recon: 7.29607 - dist: 0.14266 - cons: -0.08358
Epoch-15 / recon: 7.26834 - dist: 0.14094 - cons: -0.

In [36]:
results_list

[{'topic_N': 20,
  'umass': -11.423192198477533,
  'c_v': 0.4755596031551386,
  'c_npmi': -0.33745228907922276,
  'c_uci': -10.206727676021092,
  'sim_w2v': 0.13471363696883312,
  'diversity': 0.83,
  'filename': 'results/230924_215308.txt'},
 {'topic_N': 20,
  'umass': -12.18752079625053,
  'c_v': 0.46308604655611435,
  'c_npmi': -0.3395225331797592,
  'c_uci': -10.334344919772427,
  'sim_w2v': 0.122508331523936,
  'diversity': 0.875,
  'filename': 'results/230924_215642.txt'},
 {'topic_N': 20,
  'umass': -11.286020078889836,
  'c_v': 0.46648086032474956,
  'c_npmi': -0.31613546040513274,
  'c_uci': -9.713479010053522,
  'sim_w2v': 0.12041743544328608,
  'diversity': 0.78,
  'filename': 'results/230924_220119.txt'},
 {'topic_N': 20,
  'umass': -11.009834960307064,
  'c_v': nan,
  'c_npmi': inf,
  'c_uci': inf,
  'sim_w2v': 0.12845896413922314,
  'diversity': 0.835,
  'filename': 'results/230924_220452.txt'},
 {'topic_N': 20,
  'umass': -11.799854424197113,
  'c_v': 0.48566321536481116

The history saving thread hit an unexpected error (OperationalError('database or disk is full')).History will not be written to the database.


In [37]:
results_df = pd.DataFrame(results_list)
print(results_df)
print('mean')
print(results_df.mean())
print('std')
print(results_df.std())

   topic_N      umass       c_v    c_npmi      c_uci   sim_w2v  diversity  \
0       20 -11.423192  0.475560 -0.337452 -10.206728  0.134714      0.830   
1       20 -12.187521  0.463086 -0.339523 -10.334345  0.122508      0.875   
2       20 -11.286020  0.466481 -0.316135  -9.713479  0.120417      0.780   
3       20 -11.009835       NaN       inf        inf  0.128459      0.835   
4       20 -11.799854  0.485663 -0.326492  -9.983765  0.136775      0.885   

                    filename  
0  results/230924_215308.txt  
1  results/230924_215642.txt  
2  results/230924_220119.txt  
3  results/230924_220452.txt  
4  results/230924_220827.txt  
mean
topic_N      20.000000
umass       -11.541284
c_v           0.472697
c_npmi             inf
c_uci              inf
sim_w2v       0.128575
diversity     0.841000
dtype: float64
std
topic_N      0.000000
umass        0.459898
c_v          0.010121
c_npmi            NaN
c_uci             NaN
sim_w2v      0.007216
diversity    0.041743
dtype: float

In [38]:
results_df

Unnamed: 0,topic_N,umass,c_v,c_npmi,c_uci,sim_w2v,diversity,filename
0,20,-11.423192,0.47556,-0.337452,-10.206728,0.134714,0.83,results/230924_215308.txt
1,20,-12.187521,0.463086,-0.339523,-10.334345,0.122508,0.875,results/230924_215642.txt
2,20,-11.28602,0.466481,-0.316135,-9.713479,0.120417,0.78,results/230924_220119.txt
3,20,-11.009835,,inf,inf,0.128459,0.835,results/230924_220452.txt
4,20,-11.799854,0.485663,-0.326492,-9.983765,0.136775,0.885,results/230924_220827.txt


In [165]:
from collections import defaultdict
import math

def calculate_mi(total_text_list, topic_words_list):
    platform_topic_counts = defaultdict(lambda: defaultdict(int))  # platform_topic_counts[platform][topic_id]
    platform_counts = defaultdict(int)
    topic_counts = defaultdict(int)

    for entry in total_text_list:
        platform = entry['platform']
        content = entry['text']
        platform_counts[platform] += 1
        for topic_id, words in enumerate(topic_words_list):
            if any(word in content for word in words):
                platform_topic_counts[platform][topic_id] += 1
                topic_counts[topic_id] += 1

    # Calculate p(x,y), p(x), p(y)
    total_docs = len(total_text_list)
    p_x_y = {}  # Joint probability for platform and topic
    p_x = {}    # Marginal probability for platform
    p_y = {}    # Marginal probability for topic

    for platform, count in platform_counts.items():
        p_x[platform] = count / total_docs
        for topic_id in range(len(topic_words_list)):
            p_x_y[(platform, topic_id)] = platform_topic_counts[platform][topic_id] / total_docs

    for topic_id, count in topic_counts.items():
        p_y[topic_id] = count / total_docs

    # Calculate MI
    mi_values = {}
    for platform in platform_counts.keys():
        for topic_id in range(len(topic_words_list)):
            px = p_x[platform]
            py = p_y[topic_id]
            pxy = p_x_y[(platform, topic_id)]
            if pxy == 0:  # If joint probability is 0, then MI contribution is also 0
                mi = 0
            else:
                mi = pxy * math.log2(pxy / (px * py))
            mi_values[(platform, topic_id)] = mi

    return mi_values


In [169]:
mi_results = calculate_mi(total_text_list, topic_words_list)
print(mi_results)

{('Twitter', 0): -0.015491122468771365, ('Twitter', 1): -0.023671540911905938, ('Twitter', 2): -0.024455653850044366, ('Twitter', 3): -0.012732353200770711, ('Twitter', 4): -0.02528162770568063, ('Twitter', 5): -0.005900108728396244, ('Twitter', 6): -0.013600697830971035, ('Twitter', 7): -0.01489129276141035, ('Twitter', 8): -0.02061820899528352, ('Twitter', 9): -0.01875973730552711, ('Twitter', 10): -0.033921837687139175, ('Twitter', 11): -0.0361598701472769, ('Twitter', 12): -0.03428194133637023, ('Twitter', 13): -0.008152881569010137, ('Twitter', 14): -0.018129468325203704, ('Twitter', 15): -0.01366762714618823, ('Twitter', 16): -0.022098183629628808, ('Twitter', 17): -0.016705730585034478, ('Twitter', 18): -0.021062521096866915, ('Twitter', 19): -0.015419855740593067, ('Reddit', 0): 0.0016192085576820288, ('Reddit', 1): 0.011640058540141057, ('Reddit', 2): 0.012394265366869134, ('Reddit', 3): 0.04017390235314121, ('Reddit', 4): -0.005920981888741308, ('Reddit', 5): 0.00397266792070

In [32]:
results_list = []

for i in range(args.stage_2_repeat):
    model = ContBertTopicExtractorAE(N_topic=n_topic, N_word=args.n_word, bert=bert_name, bert_dim=768)
    model.beta = nn.Parameter(torch.Tensor(model.N_topic, n_word))
    nn.init.xavier_uniform_(model.beta)
    model.beta_batchnorm = nn.Sequential()
    model.cuda(gpu_ids[0])
    
    losses = AverageMeter()
    dlosses = AverageMeter() 
    rlosses = AverageMeter()
    closses = AverageMeter()
    distlosses = AverageMeter()
    trainloader = DataLoader(finetuneds, batch_size=bsz, shuffle=True, num_workers=0, )
    memoryloader = DataLoader(finetuneds, batch_size=bsz * 2, shuffle=False, num_workers=0)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.stage_2_lr)

    memory_queue = F.softmax(torch.randn(512, n_topic).cuda(gpu_ids[0]), dim=1)
    print("Coeff   / regul: {:.5f} - recon: {:.5f} - c: {:.5f} - dist: {:.5f} ".format(args.coeff_2_regul, 
                                                                                        args.coeff_2_recon,
                                                                                        args.coeff_2_cons,
                                                                                        args.coeff_2_dist))
    topic_counts = np.zeros((3, 20)) 
    
    for epoch in range(50):
        model.train()
        model.encoder.eval()
        for batch_idx, batch in enumerate(trainloader):  
            org_input, pos_input, org_bow, pos_bow, platform_idx = batch
            org_input = org_input.cuda(gpu_ids[0])
            org_bow = org_bow.cuda(gpu_ids[0])
            pos_input = pos_input.cuda(gpu_ids[0])
            pos_bow = pos_bow.cuda(gpu_ids[0])
            platform_idx_tensor = torch.stack(platform_idx).cuda(gpu_ids[1])

            batch_size = org_input.size(0) #org_input_ids.size(0)

            org_dists, org_topic_logit = model.decode(org_input)
            pos_dists, pos_topic_logit = model.decode(pos_input)

            org_topic = F.softmax(org_topic_logit, dim=1)
            pos_topic = F.softmax(pos_topic_logit, dim=1)
            
            # 2. 토픽 카운트 업데이트
            _, org_topic_idx = torch.max(org_topic, 1)
            for p, t in zip(platform_idx_tensor.cpu().numpy(), org_topic_idx.cpu().numpy()):  # GPU 텐서를 CPU로 옮기고 numpy 배열로 변환
                topic_counts[p][t] += 1
            

            # reconstruction loss
            # batchmean
#             org_target = torch.matmul(org_topic.detach(), weight_cands)
#             pos_target = torch.matmul(pos_topic.detach(), weight_cands)
            
#             _, org_target = torch.max(org_topic.detach(), 1)
#             _, pos_target = torch.max(pos_topic.detach(), 1)
            
            recons_loss = torch.mean(-torch.sum(torch.log(org_dists + 1E-10) * (org_bow), axis=1), axis=0)
            recons_loss += torch.mean(-torch.sum(torch.log((1-org_dists) + 1E-10) * (1-org_bow), axis=1), axis=0)
            recons_loss += torch.mean(-torch.sum(torch.log(pos_dists + 1E-10) * (pos_bow), axis=1), axis=0)
            recons_loss += torch.mean(-torch.sum(torch.log((1-pos_dists) + 1E-10) * (1-pos_bow), axis=1), axis=0)
            recons_loss *= 0.5

            # consistency loss
            pos_sim = torch.sum(org_topic * pos_topic, dim=-1)
            cons_loss = -pos_sim.mean()

            # distribution loss
            # batchmean
#             distmatch_loss = dist_match_loss(torch.cat((org_topic), dim=0), dirichlet_alpha_2)
            distmatch_loss = dist_match_loss(torch.cat((org_topic,), dim=0), dirichlet_alpha_2)


            loss = args.coeff_2_recon * recons_loss + \
                   args.coeff_2_cons * cons_loss + \
                   args.coeff_2_dist * distmatch_loss 
            
            losses.update(loss.item(), bsz)
            closses.update(cons_loss.item(), bsz)
            rlosses.update(recons_loss.item(), bsz)
            distlosses.update(distmatch_loss.item(), bsz)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        topic_probabilities = topic_counts / topic_counts.sum(axis=1, keepdims=True)
        df = pd.DataFrame(topic_probabilities, columns=[f"topic-{i}" for i in range(20)])

        # 4. MI를 계산합니다.
        P_X_Y = topic_probabilities / topic_probabilities.sum() # joint probability
        P_X = df.sum(axis=0).values
        P_Y = df.sum(axis=1).values

        MI = 0
        for i in range(3):
            for j in range(20):
                if P_X_Y[i][j] > 0:
                    MI += P_X_Y[i][j] * log2(P_X_Y[i][j] / (P_X[j] * P_Y[i]))

        print(f"Mutual Information: {MI}")

        print("Epoch-{} / recon: {:.5f} - dist: {:.5f} - cons: {:.5f}".format(epoch, rlosses.avg, distlosses.avg, closses.avg))

    print("------- Evaluation results -------")
    all_list = {}
    for e, i in enumerate(model.beta.cpu().topk(10, dim=1).indices):
        word_list = []
        for j in i:
            word_list.append(vocab_dict_reverse[j.item()])
        all_list[e] = word_list
        print("topic-{}".format(e), word_list)

    topic_words_list = list(all_list.values())
    now = datetime.now().strftime('%y%m%d_%H%M%S')

    # 사전 생성
    reference_corpus=[doc.split() for doc in trainds.preprocess_ctm(trainds.nonempty_text)]
    dictionary = Dictionary(reference_corpus)

    results = get_topic_qualities(topic_words_list, 
                                  reference_corpus=[doc.split() for doc in trainds.preprocess_ctm(trainds.nonempty_text)], 
                                  dictionary=dictionary,
                                  filename=f'results/{now}.txt')

    print(results)
    print()
    results_list.append(results)

Coeff   / regul: 1.00000 - recon: 1.00000 - c: 1.00000 - dist: 1.00000 
Mutual Information: 0
Epoch-0 / recon: 8.19721 - dist: 0.19125 - cons: -0.05046
Mutual Information: 0
Epoch-1 / recon: 7.97339 - dist: 0.18278 - cons: -0.05071


KeyboardInterrupt: 

In [161]:
print(P_X_Y)
print(P_X)
print(P_Y)
print(topic_counts)

[[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
  nan nan]
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
  nan nan]
 [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
  nan nan]]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0.]
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]


In [151]:
for batch_idx, batch in enumerate(trainloader):  
    print(len(batch))
    break


5


In [None]:
results_list = []
mi_values = []
num_platforms = 3000
num_topics = 20

for i in range(args.stage_2_repeat):
    model = ContBertTopicExtractorAE(N_topic=n_topic, N_word=args.n_word, bert=bert_name, bert_dim=768)
#     model.load_state_dict(torch.load(model_stage1_name), strict=True)
    model.beta = nn.Parameter(torch.Tensor(model.N_topic, n_word))
    nn.init.xavier_uniform_(model.beta)
    model.beta_batchnorm = nn.Sequential()
    model.cuda(gpu_ids[0])
    
    losses = AverageMeter()
    dlosses = AverageMeter() 
    rlosses = AverageMeter()
    closses = AverageMeter()
    distlosses = AverageMeter()
    trainloader = DataLoader(finetuneds, batch_size=bsz, shuffle=True, num_workers=0)
    memoryloader = DataLoader(finetuneds, batch_size=bsz * 2, shuffle=False, num_workers=0)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.stage_2_lr)

    memory_queue = F.softmax(torch.randn(512, n_topic).cuda(gpu_ids[0]), dim=1)
    print("Coeff   / regul: {:.5f} - recon: {:.5f} - c: {:.5f} - dist: {:.5f} ".format(args.coeff_2_regul, 
                                                                                        args.coeff_2_recon,
                                                                                        args.coeff_2_cons,
                                                                                        args.coeff_2_dist))
    
    topic_platform_count = {i: {platform: 0 for platform in trainds.platform_numbers} for i in range(n_topic)}

    for epoch in range(50):
        model.train()
        model.encoder.eval()
 
        for batch_idx, batch in enumerate(trainloader):  
            org_input, pos_input, org_bow, pos_bow, platform_info = batch
            org_input = org_input.cuda(gpu_ids[0])
            org_bow = org_bow.cuda(gpu_ids[0])
            pos_input = pos_input.cuda(gpu_ids[0])
            pos_bow = pos_bow.cuda(gpu_ids[0])
            platform_info = platform_info.cuda(gpu_ids[0])  # 만약 platform_info도 tensor라면 GPU로 이동

            batch_size = org_input.size(0) #org_input_ids.size(0)

            org_dists, org_topic_logit = model.decode(org_input)
            pos_dists, pos_topic_logit = model.decode(pos_input)

            org_topic = F.softmax(org_topic_logit, dim=1)
            pos_topic = F.softmax(pos_topic_logit, dim=1)
            
            for topic_num, topic_dist in enumerate(org_topic):
                most_likely_topic = torch.argmax(topic_dist).item()
                # platform_info를 사용하여 topic_platform_count 업데이트
                topic_platform_count[most_likely_topic][platform_info[topic_num].item()] += 1
           
            # reconstruction loss
            # batchmean
#             org_target = torch.matmul(org_topic.detach(), weight_cands)
#             pos_target = torch.matmul(pos_topic.detach(), weight_cands)
            
#             _, org_target = torch.max(org_topic.detach(), 1)
#             _, pos_target = torch.max(pos_topic.detach(), 1)
            
            recons_loss = torch.mean(-torch.sum(torch.log(org_dists + 1E-10) * (org_bow), axis=1), axis=0)
            recons_loss += torch.mean(-torch.sum(torch.log((1-org_dists) + 1E-10) * (1-org_bow), axis=1), axis=0)
            recons_loss += torch.mean(-torch.sum(torch.log(pos_dists + 1E-10) * (pos_bow), axis=1), axis=0)
            recons_loss += torch.mean(-torch.sum(torch.log((1-pos_dists) + 1E-10) * (1-pos_bow), axis=1), axis=0)
            recons_loss *= 0.5

            # consistency loss
            pos_sim = torch.sum(org_topic * pos_topic, dim=-1)
            cons_loss = -pos_sim.mean()

            # distribution loss
            # batchmean
#             distmatch_loss = dist_match_loss(torch.cat((org_topic), dim=0), dirichlet_alpha_2)
            distmatch_loss = dist_match_loss(torch.cat((org_topic,), dim=0), dirichlet_alpha_2)


            loss = args.coeff_2_recon * recons_loss + \
                   args.coeff_2_cons * cons_loss + \
                   args.coeff_2_dist * distmatch_loss 
            
            losses.update(loss.item(), bsz)
            closses.update(cons_loss.item(), bsz)
            rlosses.update(recons_loss.item(), bsz)
            distlosses.update(distmatch_loss.item(), bsz)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        p_topic = org_topic.mean(dim=0)
        p_platform = torch.tensor([platform_info.eq(i).float().mean() for i in range(num_platforms)], device=gpu_ids[0])

        p_joint = torch.zeros(num_topics, num_platforms, device=gpu_ids[0])
        total_count = sum(sum(platform_count.values()) for platform_count in topic_platform_count.values())

        for i in range(num_topics):
            for j in range(num_platforms):
                p_joint[i][j] = topic_platform_count[i][j] / total_count

        mi = torch.sum(p_joint * (torch.log(p_joint / (p_topic[:, None] * p_platform[None, :]) + 1e-10)))
        mi_values.append(mi.item())
        
        print("Epoch-{} / recon: {:.5f} - dist: {:.5f} - cons: {:.5f}".format(epoch, rlosses.avg, distlosses.avg, closses.avg))

    print("------- Evaluation results -------")
#     topic_probs = {}
#     for topic in topic_platform_count.keys():
#         topic_probs[topic] = sum(topic_platform_count[topic].values()) / 3000.0

    all_list = {}
    for e, i in enumerate(model.beta.cpu().topk(10, dim=1).indices):
        word_list = []
        for j in i:
            word_list.append(vocab_dict_reverse[j.item()])
        all_list[e] = word_list
        print("topic-{}".format(e), word_list)
        print(f"Topic-{e} platform probabilities: {topic_probs[e]}")
        
    topic_words_list = list(all_list.values())
    now = datetime.now().strftime('%y%m%d_%H%M%S')

    # 사전 생성
    dictionary = Dictionary(reference_corpus)

    results = get_topic_qualities(topic_words_list, 
                                  reference_corpus=[doc.split() for doc in trainds.preprocess_ctm(trainds.nonempty_text)], 
                                  dictionary=dictionary,
                                  filename=f'results/{now}.txt')
    for epoch, mi_value in enumerate(mi_values):
        print(f"Epoch-{epoch} MI: {mi_value:.5f}")

    print(results)
    print()
    results_list.append(results)