# Setting(dataset, parameter)

In [1]:
args_text = '--base-model sentence-transformers/paraphrase-MiniLM-L6-v2 ' + \
            '--dataset all --n-word 30000 ' + \
            '--bsz 32 --stage-2-lr 2e-2 --stage-2-repeat 5 ' + \
            '--n-cluster 20 '

In [2]:
import re
import os
import time
import argparse
import string
import torch
import torch.nn as nn
import torch.nn.functional as F
import gensim.downloader
import itertools

from sentence_transformers import SentenceTransformer

import numpy as np
from tqdm import tqdm_notebook
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from sklearn.datasets import fetch_20newsgroups
from nltk.corpus import stopwords

from utils import AverageMeter

import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import mutual_info_score
from gensim.corpora.dictionary import Dictionary
from pytorch_transformers import *
import scipy.stats

from gensim.models.coherencemodel import CoherenceModel
from tqdm import tqdm
import nltk

from datetime import datetime
import gensim.downloader
from scipy.linalg import qr
from data import *
from data import TwitterDataset, RedditDataset, YoutubeDataset, BertDataset
from model import ContBertTopicExtractorAE

from data import BertDataset, Stage2Dataset
import random
from sklearn.metrics.pairwise import cosine_similarity

import warnings
warnings.filterwarnings("ignore")

%load_ext autoreload
%autoreload 2

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd
[nltk_data] Downloading package punkt to /home/minseo/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /home/minseo/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /home/minseo/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [3]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "0,1" 

In [4]:
def _parse_args():
    parser = argparse.ArgumentParser(description='Contrastive topic modeling')

    #각 stage에서의 batch size
    parser.add_argument('--bsz', type=int, default=64,
                        help='Batch size')
    #data set정의 
    parser.add_argument('--dataset', default='twitter', type=str,
                        choices=['twitter', 'reddit', 'youtube','all'],
                        help='Name of the dataset')
    # 클러스터 수와 topic의 수는 20 (k==20)
    parser.add_argument('--n-cluster', default=20, type=int,
                        help='Number of clusters')
    parser.add_argument('--n-topic', type=int,
                        help='Number of topics. If not specified, use same value as --n-cluster')
    # 단어vocabulary는 2000로 setting
    parser.add_argument('--n-word', default=30000, type=int,
                        help='Number of words in vocabulary')
    
    parser.add_argument('--base-model', type=str,
                        help='Name of base model in huggingface library.')
    
    parser.add_argument('--gpus', default=[0,1,2,3], type=int, nargs='+',
                        help='List of GPU numbers to use. Use 0 by default')
   
    parser.add_argument('--dirichlet-alpha-1', type=float,
                        help='Parameter for Dirichlet distribution (Phase 1). Use 1/n_topic by default.')
 
    parser.add_argument('--coeff-2-recon', default=1.0, type=float,
                        help='Coefficient for VAE reconstruction loss (Phase 2)')
    parser.add_argument('--coeff-2-regul', default=1.0, type=float,
                        help='Coefficient for VAE KLD regularization loss (Phase 2)')
    parser.add_argument('--coeff-2-cons', default=1.0, type=float,
                        help='Coefficient for CL consistency loss (Phase 2)')
    parser.add_argument('--coeff-2-dist', default=1.0, type=float,
                        help='Coefficient for CL SWD distribution matching loss (Phase 2)')
    parser.add_argument('--dirichlet-alpha-2', type=float,
                        help='Parameter for Dirichlet distribution (Phase 2). Use same value as dirichlet-alpha-1 by default.')
    parser.add_argument('--stage-2-lr', default=2e-1, type=float,
                        help='Learning rate of phase 2')
    
    parser.add_argument('--stage-2-repeat', default=5, type=int,
                        help='Repetition count of phase 2')
    
    parser.add_argument('--result-file', type=str,
                        help='File name for result summary')
    
    
    # Check if the code is run in Jupyter notebook
    is_in_jupyter = False
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            is_in_jupyter = True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            is_in_jupyter = False  # Terminal running IPython
        else:
            is_in_jupyter = False  # Other type (?)
    except NameError:
        is_in_jupyter = False
    
    if is_in_jupyter:
        return parser.parse_args(args=args_text.split())
    else:
        return parser.parse_args()

In [5]:
args = _parse_args()
bsz = args.bsz

n_cluster = args.n_cluster
n_topic = args.n_topic if (args.n_topic is not None) else n_cluster
args.n_topic = n_topic

# textData, should_measure_hungarian = data_load(args.dataset)

ema_alpha = 0.99
n_word = args.n_word
if args.dirichlet_alpha_1 is None:
    dirichlet_alpha_1 = 1 / n_cluster
else:
    dirichlet_alpha_1 = args.dirichlet_alpha_1
if args.dirichlet_alpha_2 is None:
    dirichlet_alpha_2 = dirichlet_alpha_1
else:
    dirichlet_alpha_2 = args.dirichlet_alpha_2
    
bert_name = args.base_model
bert_name_short = bert_name.split('/')[-1]
gpu_ids = args.gpus

In [6]:
# 시드 고정 함수
def set_seed(seed_value):
    """모든 랜덤 요소에 대한 시드를 고정합니다."""
    random.seed(seed_value)  
    np.random.seed(seed_value)  
    torch.manual_seed(seed_value)  
    if torch.cuda.is_available():  
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

In [7]:
set_seed(41)

In [8]:
from sklearn.model_selection import train_test_split

# 각 데이터셋 초기화(data.py에서 확인 가능)
twitter_ds = TwitterDataset()
reddit_ds = RedditDataset()
youtube_ds = YoutubeDataset()

from sklearn.model_selection import train_test_split

# 전체 데이터를 train과 test로 7:3으로 분할
train_twitter_texts, test_twitter_texts, train_twitter_labels, test_twitter_labels = train_test_split(
    twitter_ds.texts, twitter_ds.labels, train_size=0.7, random_state=41)

train_reddit_texts, test_reddit_texts, train_reddit_labels, test_reddit_labels = train_test_split(
    reddit_ds.texts, reddit_ds.labels, train_size=0.7, random_state=41)

train_youtube_texts, test_youtube_texts, train_youtube_labels, test_youtube_labels = train_test_split(
    youtube_ds.texts, youtube_ds.labels, train_size=0.7, random_state=41)

# train 데이터를 다시 train과 val로 9:1로 분할
train_twitter_texts, val_twitter_texts, train_twitter_labels, val_twitter_labels = train_test_split(
    train_twitter_texts, train_twitter_labels, test_size=0.1, random_state=41)

train_reddit_texts, val_reddit_texts, train_reddit_labels, val_reddit_labels = train_test_split(
    train_reddit_texts, train_reddit_labels, test_size=0.1, random_state=41)

train_youtube_texts, val_youtube_texts, train_youtube_labels, val_youtube_labels = train_test_split(
    train_youtube_texts, train_youtube_labels, test_size=0.1, random_state=41)

# 각 데이터의 플랫폼을 합침
train_total_label = train_twitter_labels + train_reddit_labels + train_youtube_labels
train_total_text_list = train_twitter_texts + train_reddit_texts + train_youtube_texts

val_total_label = val_twitter_labels + val_reddit_labels + val_youtube_labels
val_total_text_list = val_twitter_texts + val_reddit_texts + val_youtube_texts

test_total_label = test_twitter_labels + test_reddit_labels + test_youtube_labels
test_total_text_list = test_twitter_texts + test_reddit_texts + test_youtube_texts

In [13]:
trainds = BertDataset(bert=bert_name, text_list=train_total_text_list, platform_label = train_total_label, N_word=n_word, vectorizer=None, lemmatize=True)
valds = BertDataset(bert=bert_name, text_list=val_total_text_list, platform_label = val_total_label, N_word=n_word, vectorizer=None, lemmatize=True)
testds = BertDataset(bert=bert_name, text_list=test_total_text_list, platform_label = test_total_label, N_word=n_word, vectorizer=None, lemmatize=True)

[nltk_data] Downloading package punkt to /home/minseo/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /home/minseo/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /home/minseo/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
100%|██████████| 37800/37800 [00:53<00:00, 712.15it/s]
100%|██████████| 4200/4200 [00:04<00:00, 855.18it/s]
100%|██████████| 18000/18000 [00:23<00:00, 770.20it/s]


# 원본 데이터 Mean pooling 진행

# mean_pooling 함수 정의
def mean_pooling(embeddings, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
    sum_embeddings = torch.sum(embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask

In [14]:
# 모든 문장의 평균 풀링 임베딩 계산
trainds_embeddings = []
for _,_,pooled_embedding,_ in trainds:
    trainds_embeddings.append(pooled_embedding)

# 평균 풀링된 임베딩 매트릭스 생성
train_mean_pooled_embeddings = torch.stack(trainds_embeddings)

# 모든 문장의 평균 풀링 임베딩 계산
valds_embeddings = []
for _,_,pooled_embedding,_ in valds:
    valds_embeddings.append(pooled_embedding)

# 평균 풀링된 임베딩 매트릭스 생성
val_mean_pooled_embeddings = torch.stack(valds_embeddings)

# 모든 문장의 평균 풀링 임베딩 계산
testds_embeddings = []
for _,_,pooled_embedding,_ in testds:
    testds_embeddings.append(pooled_embedding)

# 평균 풀링된 임베딩 매트릭스 생성
test_mean_pooled_embeddings = torch.stack(testds_embeddings)

# Re_fornulate the bow

In [15]:
def dist_match_loss(hiddens, alpha=1.0):
    device = hiddens.device
    hidden_dim = hiddens.shape[-1]
    H = np.random.randn(hidden_dim, hidden_dim)
    Q, R = qr(H) 
    rand_w = torch.Tensor(Q).to(device)
    loss_dist_match = get_swd_loss(hiddens, rand_w, alpha)
    return loss_dist_match


def js_div_loss(hidden1, hidden2):
    m = 0.5 * (hidden1 + hidden2)
    return kldiv(m.log(), hidden1) + kldiv(m.log(), hidden2)


def get_swd_loss(states, rand_w, alpha=1.0):
    device = states.device
    states_shape = states.shape
    states = torch.matmul(states, rand_w)
    states_t, _ = torch.sort(states.t(), dim=1)

    # Random vector with length from normal distribution
    states_prior = torch.Tensor(np.random.dirichlet([alpha]*states_shape[1], states_shape[0])).to(device) # (bsz, dim)
    states_prior = torch.matmul(states_prior, rand_w) # (dim, dim)
    states_prior_t, _ = torch.sort(states_prior.t(), dim=1) # (dim, bsz)
    return torch.mean(torch.sum((states_prior_t - states_t)**2, axis=0))

# Get pos_similarity

In [16]:
def compute_max_cosine_similarity_indices(mean_pooled_embeddings, batch_size=500):
    n_rows = mean_pooled_embeddings.shape[0]
    max_similarity_indices = np.zeros(n_rows, dtype=np.int64)

    for start_idx in range(0, n_rows, batch_size):
        end_idx = min(start_idx + batch_size, n_rows)
        batch_data = mean_pooled_embeddings[start_idx:end_idx]

        batch_similarity = cosine_similarity(batch_data, mean_pooled_embeddings)

        # 자기 자신과의 유사도를 -1로 설정
        for i, original_idx in enumerate(range(start_idx, end_idx)):
            batch_similarity[i, original_idx] = -1

        # 각 행에서 가장 큰 값을 가진 인덱스 찾기
        max_indices = np.argmax(batch_similarity, axis=1)
        max_similarity_indices[start_idx:end_idx] = max_indices

        #logger.info(f"{end_idx}/{n_rows} 데이터 처리 완료")

    return max_similarity_indices

In [17]:
# mean_pooled_embeddings를 2차원 형태로 변환
train_mean_pooled_embeddings_2d = train_mean_pooled_embeddings.squeeze()
val_mean_pooled_embeddings_2d = val_mean_pooled_embeddings.squeeze()
test_mean_pooled_embeddings_2d = test_mean_pooled_embeddings.squeeze()

In [18]:
# 코사인 유사도 매트릭스 계산
train_similarity_matrix = compute_max_cosine_similarity_indices(train_mean_pooled_embeddings_2d)
val_similarity_matrix = compute_max_cosine_similarity_indices(val_mean_pooled_embeddings_2d)
test_similarity_matrix = compute_max_cosine_similarity_indices(test_mean_pooled_embeddings_2d)  # 테스트 데이터셋에 대해서도 동일하게

In [19]:
train_similarity_matrix.shape

(37800,)

In [20]:
model =  ContBertTopicExtractorAE(N_topic=n_topic, N_word=args.n_word, bert=bert_name, bert_dim=768)

In [21]:
finetuneds = Stage2Dataset(model.encoder, trainds, train_similarity_matrix, n_word, lemmatize=True)

valfinetuneds = Stage2Dataset(model.encoder, valds, val_similarity_matrix, n_word, lemmatize=True) 

testfinetuneds = Stage2Dataset(model.encoder, testds, test_similarity_matrix, n_word, lemmatize=True) 

kldiv = torch.nn.KLDivLoss(reduction='batchmean')
vocab_dict = finetuneds.vectorizer.vocabulary_
vocab_dict_reverse = {i:v for v, i in vocab_dict.items()}
print(n_word)

100%|██████████| 37800/37800 [00:31<00:00, 1196.41it/s]
100%|██████████| 37800/37800 [19:05<00:00, 32.98it/s]
100%|██████████| 4200/4200 [00:02<00:00, 1548.81it/s]
100%|██████████| 4200/4200 [02:04<00:00, 33.85it/s]
100%|██████████| 18000/18000 [00:14<00:00, 1275.59it/s]
100%|██████████| 18000/18000 [08:57<00:00, 33.46it/s]

30000





In [22]:
# len(finetuneds.embedding_list)
len(finetuneds.bow_list)

37800

# Stage 3

In [23]:
def measure_hungarian_score(topic_dist, train_target):
    dist = topic_dist
    train_target_filtered = train_target
    flat_predict = torch.tensor(np.argmax(dist, axis=1))
    flat_target = torch.tensor(train_target_filtered).to(flat_predict.device)
    num_samples = flat_predict.shape[0]
    num_classes = dist.shape[1]
    match = _hungarian_match(flat_predict, flat_target, num_samples, num_classes)    
    reordered_preds = torch.zeros(num_samples).to(flat_predict.device)
    for pred_i, target_i in match:
        reordered_preds[flat_predict == pred_i] = int(target_i)
    acc = int((reordered_preds == flat_target.float()).sum()) / float(num_samples)
    return acc

In [24]:
should_measure_hungarian = True

In [25]:
torch.cuda.empty_cache()

# Seperate Platform dataset

In [26]:
from torch.utils.data import DataLoader, Dataset, Sampler
from collections import defaultdict

# 커스텀 샘플러 구현
class PlatformSampler(Sampler):
    def __init__(self, dataset, platform_label):
        self.indices = [i for i, label in enumerate(dataset.platform_label_list) if label == platform_label]
    
    def __iter__(self):
        return iter(self.indices)
    
    def __len__(self):
        return len(self.indices)

In [32]:
# 플랫폼별 데이터로더 생성 함수
def create_platform_dataloader(dataset, platform_label, batch_size=32):
    sampler = PlatformSampler(dataset, platform_label)
    dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
    return dataloader

def create_platform_dataloader(dataset, platform_label, batch_size=32, num_workers=0):
    # PlatformSampler는 플랫폼 라벨에 따라 데이터셋에서 샘플링하는 커스텀 샘플러입니다.
    # 이 샘플러는 여기서 정의하거나, 필요에 따라 다른 샘플링 로직을 구현할 수 있습니다.
    sampler = PlatformSampler(dataset, platform_label)
    dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers)
    return dataloader

# Main

In [34]:
from sklearn.metrics import mutual_info_score
from scipy.stats import chi2_contingency 
from gensim.models.coherencemodel import CoherenceModel
from gensim.corpora.dictionary import Dictionary
from coherence import get_topic_coherence

args.stage_2_repeat = 1
results_list = []

for i in range(args.stage_2_repeat):
    model = ContBertTopicExtractorAE(N_topic=n_topic, N_word=args.n_word, bert=bert_name, bert_dim=768)
    model.beta = nn.Parameter(torch.Tensor(model.N_topic, n_word))
    nn.init.xavier_uniform_(model.beta)
    model.beta_batchnorm = nn.Sequential()
    model.cuda(gpu_ids[0])
    
    losses = AverageMeter()
    dlosses = AverageMeter() 
    rlosses = AverageMeter()
    closses = AverageMeter()
    distlosses = AverageMeter()
    ##수정
    twitter_trainloader = create_platform_dataloader(finetuneds, 'twitter', batch_size=bsz, num_workers=0)
    reddit_trainloader = create_platform_dataloader(finetuneds, 'reddit', batch_size=bsz, num_workers=0)
    youtube_trainloader = create_platform_dataloader(finetuneds, 'youtube', batch_size=bsz, num_workers=0)
    memoryloader = DataLoader(finetuneds, batch_size=bsz * 2, shuffle=False, num_workers=0)
    ##
    optimizer = torch.optim.Adam(model.parameters(), lr=args.stage_2_lr)

    memory_queue = F.softmax(torch.randn(512, n_topic).cuda(gpu_ids[0]), dim=1)
    print("Coeff   / regul: {:.5f} - recon: {:.5f} - c: {:.5f} - dist: {:.5f} ".format(args.coeff_2_regul, 
                                                                                        args.coeff_2_recon,
                                                                                        args.coeff_2_cons,
                                                                                        args.coeff_2_dist))
    # 최적 epoch 추적을 위한 변수 초기화
    best_npmi = -1
    best_epoch = 0
    best_model_state = None  # 모델 상태를 저장하기 위한 변수
    
    for epoch in range(100):
        model.train()
        model.encoder.eval()
        # 각 플랫폼별 DataLoader를 순회
        for platform_loader in [twitter_trainloader, reddit_trainloader, youtube_trainloader]:
            for batch_idx, batch in enumerate(platform_loader): 
                _, org_input, pos_input, org_bow, pos_bow, _ = batch
                org_input = org_input.cuda(gpu_ids[0])
                org_bow = org_bow.cuda(gpu_ids[0])
                pos_input = pos_input.cuda(gpu_ids[0])
                pos_bow = pos_bow.cuda(gpu_ids[0])

                batch_size = org_input.size(0) #org_input_ids.size(0)

                org_dists, org_topic_logit = model.decode(org_input)
                pos_dists, pos_topic_logit = model.decode(pos_input)

                org_topic = F.softmax(org_topic_logit, dim=1)
                pos_topic = F.softmax(pos_topic_logit, dim=1)

                # reconstruction loss
                # batchmean
    #             org_target = torch.matmul(org_topic.detach(), weight_cands)
    #             pos_target = torch.matmul(pos_topic.detach(), weight_cands)

    #             _, org_target = torch.max(org_topic.detach(), 1)
    #             _, pos_target = torch.max(pos_topic.detach(), 1)

                # 텐서 크기 맞춰줌
                org_dists = org_dists[:, :org_bow.size(1)]
                pos_dists = pos_dists[:, :pos_bow.size(1)]

                recons_loss = torch.mean(-torch.sum(torch.log(org_dists + 1E-10) * (org_bow), axis=1), axis=0)
                recons_loss += torch.mean(-torch.sum(torch.log((1-org_dists) + 1E-10) * (1-org_bow), axis=1), axis=0)
                recons_loss += torch.mean(-torch.sum(torch.log(pos_dists + 1E-10) * (pos_bow), axis=1), axis=0)
                recons_loss += torch.mean(-torch.sum(torch.log((1-pos_dists) + 1E-10) * (1-pos_bow), axis=1), axis=0)
                recons_loss *= 0.5

                # consistency loss
                pos_sim = torch.sum(org_topic * pos_topic, dim=-1)
                cons_loss = -pos_sim.mean()

                # distribution loss
                # batchmean
    #             distmatch_loss = dist_match_loss(torch.cat((org_topic), dim=0), dirichlet_alpha_2)
                distmatch_loss = dist_match_loss(torch.cat((org_topic,), dim=0), dirichlet_alpha_2)


                loss = args.coeff_2_recon * recons_loss + \
                       args.coeff_2_cons * cons_loss + \
                       args.coeff_2_dist * distmatch_loss 

                losses.update(loss.item(), bsz)
                closses.update(cons_loss.item(), bsz)
                rlosses.update(recons_loss.item(), bsz)
                distlosses.update(distmatch_loss.item(), bsz)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            
        print("Epoch-{} / recon: {:.5f} - dist: {:.5f} - cons: {:.5f}".format(epoch, rlosses.avg, distlosses.avg, closses.avg))

        # Epoch 마다 실행
        model.eval()

        # 각 토픽에 대한 상위 10개 단어 추출
        top_words_per_topic = {}
        for topic_idx in range(model.N_topic):
            top_words_indices = model.beta[topic_idx].topk(10).indices
            top_words = [vocab_dict_reverse[idx.item()] for idx in top_words_indices]
            top_words_per_topic[topic_idx] = top_words
            
        reference_corpus=[doc.split() for doc in valds.preprocess_ctm(valds.nonempty_text)]
        topic_words_list = list(top_words_per_topic.values())
        result = get_topic_coherence(topic_words_list, reference_corpus)
        avg_npmi = result['NPMI']

        # 최적의 NPMI와 epoch 추적
        if avg_npmi > best_npmi:
            best_npmi = avg_npmi
            best_epoch = epoch
            best_model_state = model.state_dict()  # 현재 모델 상태 저장

    print(f"Best Epoch: {best_epoch} with NPMI: {best_npmi}")
    # 훈련 완료 후, 최적 모델 상태 저장
    torch.save(best_model_state, 'our_best_model_state.pth')
    model.load_state_dict(torch.load('our_best_model_state.pth'))
    
    print("------- Evaluation results -------")
    #각 토픽당 가지는 워드셋
    all_list = {}
    for e, i in enumerate(model.beta.cpu().topk(10, dim=1).indices):
        word_list = []
        for j in i:
            word_list.append(vocab_dict_reverse[j.item()])
        all_list[e] = word_list
        print("topic-{}".format(e), word_list)

Coeff   / regul: 1.00000 - recon: 1.00000 - c: 1.00000 - dist: 1.00000 
Epoch-0 / recon: 9.18218 - dist: 0.15830 - cons: -0.09501
Epoch-1 / recon: 8.90246 - dist: 0.14250 - cons: -0.11963
Epoch-2 / recon: 8.75436 - dist: 0.13385 - cons: -0.13652
Epoch-3 / recon: 8.65146 - dist: 0.12876 - cons: -0.14787
Epoch-4 / recon: 8.57034 - dist: 0.12491 - cons: -0.15582
Epoch-5 / recon: 8.50191 - dist: 0.12191 - cons: -0.16149
Epoch-6 / recon: 8.44209 - dist: 0.11951 - cons: -0.16552
Epoch-7 / recon: 8.38847 - dist: 0.11777 - cons: -0.16833
Epoch-8 / recon: 8.33908 - dist: 0.11635 - cons: -0.17019
Epoch-9 / recon: 8.29337 - dist: 0.11545 - cons: -0.17123
Epoch-10 / recon: 8.25057 - dist: 0.11504 - cons: -0.17156
Epoch-11 / recon: 8.21015 - dist: 0.11517 - cons: -0.17135
Epoch-12 / recon: 8.17176 - dist: 0.11559 - cons: -0.17083
Epoch-13 / recon: 8.13523 - dist: 0.11630 - cons: -0.16991
Epoch-14 / recon: 8.10010 - dist: 0.11723 - cons: -0.16874
Epoch-15 / recon: 8.06629 - dist: 0.11842 - cons: -0.

In [35]:
topic_words_list = list(all_list.values())
reference_corpus=[doc.split() for doc in testds.preprocess_ctm(testds.nonempty_text)]

topics=topic_words_list
texts=reference_corpus
print(get_topic_coherence(topics, texts))

{'NPMI': 0.3756525631944724, 'UCI': 2.836169828792919, 'UMASS': -3.391847991779266, 'CV': 0.7104940248302872, 'Topic_Diversity': 0.975}


# MI Calulate

In [38]:
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.metrics import mutual_info_score

# 모델을 평가 모드로 설정
model.eval()

# 테스트 데이터 로더 설정
testloader = DataLoader(testfinetuneds, batch_size=bsz * 2, shuffle=False, num_workers=0)

# 토픽 라벨과 플랫폼 라벨 추출
ourmodel_test_topic_labels = []
test_platform_labels = []

for batch in testloader:
    _, org_embedding, _, org_bow, _, platform_labels = batch
    org_embedding = org_embedding.to(gpu_ids[0])  # gpu_ids[1]는 사용 환경에 따라 다를 수 있습니다.
    _, topic_logit = model.decode(org_embedding)
    topic_label = torch.argmax(F.softmax(topic_logit, dim=1), dim=1)
    ourmodel_test_topic_labels.extend(topic_label.cpu().numpy())
    test_platform_labels.extend(platform_labels)  # 수정된 부분

# 플랫폼별 토픽 분포 계산
topic_dist_df_test = pd.crosstab(pd.Series(ourmodel_test_topic_labels, name='Topic'),
                                 pd.Series(test_platform_labels, name='Platform'), normalize='index')

# 플랫폼별 및 전체에 대한 토픽 분포를 계산
platform_counts = pd.Series(test_platform_labels).value_counts()
platform_probabilities = platform_counts / platform_counts.sum()

# 전체 데이터셋에 대한 토픽 분포의 엔트로피 계산 (H(Y))
topic_probabilities = pd.Series(ourmodel_test_topic_labels).value_counts(normalize=True)
H_Y = -np.sum(topic_probabilities * np.log2(topic_probabilities + 1e-10))

# 각 플랫폼별 조건부 엔트로피 계산 및 H(Y|X) 계산
H_Y_given_X_total = 0
for platform in platform_probabilities.index:
    # 해당 플랫폼에 대한 토픽 라벨 필터링
    platform_indices = [i for i, x in enumerate(test_platform_labels) if x == platform]
    platform_topic_labels = [ourmodel_test_topic_labels[i] for i in platform_indices]
    platform_topic_prob = pd.Series(platform_topic_labels).value_counts(normalize=True)
    
    # 플랫폼별 조건부 엔트로피 계산
    H_Y_given_X = -np.sum(platform_topic_prob * np.log(platform_topic_prob + 1e-10))
    H_Y_given_X_total += platform_probabilities[platform] * H_Y_given_X

# Mutual Information (MI) 계산
mi = H_Y - H_Y_given_X_total

# 각 토픽에서 가장 높은 플랫폼 비율 추출
max_platform_distribution = topic_dist_df_test.max(axis=1)

# 평균 purity 계산
average_purity = max_platform_distribution.mean()

print("각 토픽에서 가장 높은 플랫폼의 비율 평균 (purity):", average_purity)
print('H(Y):', H_Y)
print('H(Y|X):', H_Y_given_X)
print('Mutual Information (MI):', mi)

# sklearn의 mutual_info_score를 이용한 MI 계산으로 검증
mi_score = mutual_info_score(ourmodel_test_topic_labels, test_platform_labels)
print("Original Mutual Information Score:", mi_score)

H(Y): 2.854358622943483
H(Y|X): 2.960724069021441
H(X|Y): 1.0635194901934535
Mutual Information (MI): -0.10636544607795795
Original Mutual Information Score: 0.03084546522123911


In [39]:
topic_dist_df_test.T

Topic,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
Platform,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
reddit,0.351759,0.458071,0.258687,0.240572,0.631579,0.269311,0.296909,0.341987,0.35376,0.36128,0.395793,0.172078,0.179688,0.286325,0.328261,0.401471,0.306859,0.345266,0.364625,0.33543
twitter,0.319095,0.209644,0.383687,0.53446,0.31467,0.411273,0.318985,0.290564,0.29805,0.23628,0.217973,0.318182,0.433594,0.399573,0.297826,0.267647,0.358604,0.322171,0.298419,0.323899
youtube,0.329146,0.332285,0.357625,0.224967,0.053751,0.319415,0.384106,0.367449,0.348189,0.402439,0.386233,0.50974,0.386719,0.314103,0.373913,0.330882,0.334537,0.332564,0.336957,0.340671


# Seperate platform

In [40]:
# reference_corpus의 길이 확인
total_length = len(reference_corpus)

# 각 플랫폼별로 6,000개씩 데이터가 충분한지 확인
if total_length >= 18000:
    # 트위터 데이터 분할
    twitter_texts = reference_corpus[:6000]

    # 레딧 데이터 분할
    reddit_texts = reference_corpus[6000:12000]

    # 유튜브 데이터 분할
    youtube_texts = reference_corpus[12000:18000]
else:
    print("데이터가 충분하지 않습니다.")

twitter_dictionary = Dictionary(twitter_texts)
twitter_dictionary.add_documents(topic_words_list)

reddit_dictionary = Dictionary(reddit_texts)
reddit_dictionary.add_documents(topic_words_list)

youtube_dictionary = Dictionary(youtube_texts)
youtube_dictionary.add_documents(topic_words_list)

## Twitter

In [41]:
print(get_topic_coherence(topics, twitter_texts))

{'NPMI': 0.17281190004854755, 'UCI': 1.4734833541766035, 'UMASS': -4.897262862010767, 'CV': 0.7754231691410781, 'Topic_Diversity': 0.975}


## Reddit

In [42]:
print(get_topic_coherence(topics, reddit_texts))

{'NPMI': 0.46981132807081494, 'UCI': 3.092224774259657, 'UMASS': -2.5446757068500663, 'CV': 0.7607069509740174, 'Topic_Diversity': 0.975}


## Youtube

In [43]:
print(get_topic_coherence(topics, youtube_texts))

{'NPMI': 0.251779756176881, 'UCI': 2.1096697750568794, 'UMASS': -3.874286438876698, 'CV': 0.7634598786378521, 'Topic_Diversity': 0.975}
