In [1]:
args_text = '--base-model sentence-transformers/paraphrase-MiniLM-L6-v2 '+\
            '--dataset all --n-word 2000 --epochs-1 100 --epochs-2 50 ' + \
            '--bsz 32 --stage-2-lr 2e-2 --stage-2-repeat 5 --coeff-1-dist 50 '+ \
            '--n-cluster 20 ' + \
            '--stage-1-ckpt trained_model/news_model_paraphrase-MiniLM-L6-v2_stage1_20t_2000w_99e.ckpt'

In [2]:
import re
import os
import sys
import time
import copy
import math
import argparse
import string
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtools.optim import RangerLars
import gensim.downloader
import itertools

from scipy.stats import ortho_group
from scipy.optimize import linear_sum_assignment as linear_assignment
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer

import numpy as np
from tqdm import tqdm_notebook
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.datasets import fetch_20newsgroups
from nltk.corpus import stopwords

from sklearn.feature_extraction.text import CountVectorizer
from utils import AverageMeter
from collections import OrderedDict

import pandas as pd
from sklearn.preprocessing import normalize
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from gensim.corpora.dictionary import Dictionary
from pytorch_transformers import *
from sklearn.mixture import GaussianMixture
import scipy.stats
from sklearn.decomposition import PCA
from sklearn.cluster import OPTICS
from nltk.corpus import stopwords

from gensim.models.coherencemodel import CoherenceModel
from tqdm import tqdm
import scipy.sparse as sp
import nltk
from nltk.corpus import stopwords

from datetime import datetime
from itertools import combinations
import gensim.downloader
from scipy.linalg import qr
from data import *
from data import TwitterDataset, RedditDataset, YoutubeDataset
from model import ContBertTopicExtractorAE
from evaluation import get_topic_qualities
import warnings
warnings.filterwarnings("ignore")

%load_ext autoreload
%autoreload 2

[nltk_data] Downloading package punkt to /home/minseo/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to /home/minseo/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /home/minseo/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [3]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "0,1" 

torch.cuda.is_available() #gpu 사용확인

# Data Loading

In [4]:
def _parse_args():
    parser = argparse.ArgumentParser(description='Contrastive topic modeling')
    # 각 stage에서의 epochs수 
    parser.add_argument('--epochs-1', default=50, type=int,
                        help='Number of training epochs for Stage 1')   
    parser.add_argument('--epochs-2', default=10, type=int,
                        help='Number of training epochs for Stage 2')
    #각 stage에서의 batch size
    parser.add_argument('--bsz', type=int, default=64,
                        help='Batch size')
    #data set정의 
    parser.add_argument('--dataset', default='twitter', type=str,
                        choices=['twitter', 'reddit', 'youtube','all'],
                        help='Name of the dataset')
    # 클러스터 수와 topic의 수는 20 (k==20)
    parser.add_argument('--n-cluster', default=20, type=int,
                        help='Number of clusters')
    parser.add_argument('--n-topic', type=int,
                        help='Number of topics. If not specified, use same value as --n-cluster')
    # 단어vocabulary는 2000로 setting
    parser.add_argument('--n-word', default=2000, type=int,
                        help='Number of words in vocabulary')
    
    parser.add_argument('--base-model', type=str,
                        help='Name of base model in huggingface library.')
    
    parser.add_argument('--gpus', default=[0,1,2,3], type=int, nargs='+',
                        help='List of GPU numbers to use. Use 0 by default')
   
    parser.add_argument('--coeff-1-sim', default=1.0, type=float,
                        help='Coefficient for NN dot product similarity loss (Phase 1)')
    parser.add_argument('--coeff-1-dist', default=1.0, type=float,
                        help='Coefficient for NN SWD distribution loss (Phase 1)')
    parser.add_argument('--dirichlet-alpha-1', type=float,
                        help='Parameter for Dirichlet distribution (Phase 1). Use 1/n_topic by default.')
    
    parser.add_argument('--stage-1-ckpt', type=str,
                        help='Name of torch checkpoint file Stage 1. If this argument is given, skip Stage 1.')
 
    parser.add_argument('--coeff-2-recon', default=1.0, type=float,
                        help='Coefficient for VAE reconstruction loss (Phase 2)')
    parser.add_argument('--coeff-2-regul', default=1.0, type=float,
                        help='Coefficient for VAE KLD regularization loss (Phase 2)')
    parser.add_argument('--coeff-2-cons', default=1.0, type=float,
                        help='Coefficient for CL consistency loss (Phase 2)')
    parser.add_argument('--coeff-2-dist', default=1.0, type=float,
                        help='Coefficient for CL SWD distribution matching loss (Phase 2)')
    parser.add_argument('--dirichlet-alpha-2', type=float,
                        help='Parameter for Dirichlet distribution (Phase 2). Use same value as dirichlet-alpha-1 by default.')
    parser.add_argument('--stage-2-lr', default=2e-1, type=float,
                        help='Learning rate of phase 2')
    
    parser.add_argument('--stage-2-repeat', default=5, type=int,
                        help='Repetition count of phase 2')
    
    parser.add_argument('--result-file', type=str,
                        help='File name for result summary')
    parser.add_argument('--palmetto-dir', type=str,
                        help='Directory where palmetto JAR and the Wikipedia index are. For evaluation')
    
    
    # Check if the code is run in Jupyter notebook
    is_in_jupyter = False
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            is_in_jupyter = True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            is_in_jupyter = False  # Terminal running IPython
        else:
            is_in_jupyter = False  # Other type (?)
    except NameError:
        is_in_jupyter = False
    
    if is_in_jupyter:
        return parser.parse_args(args=args_text.split())
    else:
        return parser.parse_args()

#데이터들을 로드함(텍스트중에 none값을 제거/dataset_name에 따라 textData에 들어가는 내용이 달라짐)
def data_load(dataset_name, sample_size=100000):
    should_measure_hungarian = False
    textData = []
    
    if dataset_name == 'twitter':
        dataset = TwitterDataset(sample_size=sample_size)
        textData = [(text, "Twitter") for text in dataset.texts if not (isinstance(text, float) and math.isnan(text))]
    elif dataset_name == 'reddit':
        dataset = RedditDataset(sample_size=sample_size)
        textData = [(text, "Reddit") for text in dataset.texts if not (isinstance(text, float) and math.isnan(text))]
    elif dataset_name == 'youtube':
        dataset = YoutubeDataset(sample_size=sample_size)
        textData = [(text, "YouTube") for text in dataset.texts if not (isinstance(text, float) and math.isnan(text))]
    elif dataset_name == 'all':
        twitter_dataset = TwitterDataset(sample_size=sample_size)
        reddit_dataset = RedditDataset(sample_size=sample_size)
        youtube_dataset = YoutubeDataset(sample_size=sample_size)
        
        # filtering NaN values
        textData += [(text, "Twitter") for text in twitter_dataset.texts if not (isinstance(text, float) and math.isnan(text))]
        textData += [(text, "Reddit") for text in reddit_dataset.texts if not (isinstance(text, float) and math.isnan(text))]
        textData += [(text, "YouTube") for text in youtube_dataset.texts if not (isinstance(text, float) and math.isnan(text))]
    else:
        raise ValueError("Invalid dataset name!")
    
    return textData, should_measure_hungarian


In [5]:
args = _parse_args()
bsz = args.bsz
epochs_1 = args.epochs_1
epochs_2 = args.epochs_2

n_cluster = args.n_cluster
n_topic = args.n_topic if (args.n_topic is not None) else n_cluster
args.n_topic = n_topic

textData, should_measure_hungarian = data_load(args.dataset)

ema_alpha = 0.99
n_word = args.n_word
if args.dirichlet_alpha_1 is None:
    dirichlet_alpha_1 = 1 / n_cluster
else:
    dirichlet_alpha_1 = args.dirichlet_alpha_1
if args.dirichlet_alpha_2 is None:
    dirichlet_alpha_2 = dirichlet_alpha_1
else:
    dirichlet_alpha_2 = args.dirichlet_alpha_2
    
bert_name = args.base_model
bert_name_short = bert_name.split('/')[-1]
gpu_ids = args.gpus

skip_stage_1 = (args.stage_1_ckpt is not None)

In [6]:
from sklearn.metrics.pairwise import cosine_similarity

In [7]:
jargons = set(['bigdata', 'RT', 'announces', 'rival', 'MSFT', 'launches', 'SEO', 'tweets', 'WSJ', 'unveils', 'MBA', 'machinelearning', 'HN', 'artificalintelligence', 'DALLE', 'NFT', 'NLP', 'edtech', 'cybersecurity', 'malware', 'CEO', 'ML', 'JPEG', 'GOOGL', 'UX', 'artificialintelligence', 'technews', 'fintech', 'tweet', 'CHATGPT', 'viral', 'blockchain', 'BARD', 'reportedly', 'BREAKING', 'deeplearning', 'educators', 'datascience', 'GOOG', 'competitor', 'FREE', 'startups', 'AGIX', 'CHAT', 'CNET', 'integrating', 'BB', 'maker', 'passes', 'NYT', 'yall', 'rescue', 'heres', 'NYC', 'LLM', 'airdrop', 'powered', 'podcast', 'Q', 'PM', 'newsletter', 'startup', 'TSLA', 'WIRED', 'digitalmarketing', 'CTO', 'changer', 'announced', 'database', 'launched', 'warns', 'popularity', 'AWS', 'nocode', 'trending', 'metaverse', 'cc', 'PR', 'RSS', 'MWC', 'USMLE', 'copywriting', 'marketers', 'tweeting', 'amid', 'AGI', 'socialmedia', 'webinar', 'agrees', 'invests', 'launch', 'killer', 'GM', 'bullish', 'edchat', 'RLHF', 'integration', 'fastestgrowing', 'CNN', 'exam',
                  'deleted', 'gif', 'giphy', 'dm', 'removed', 'remindme', 'yup', 'sydney', 'yep', 'patched', 'nope', 'giphydownsized', 'vpn', 'ascii', 'ah', 'chadgpt', 'nerfed', 'jesus', 'xd', 'wtf', 'upvote', 'nah', 'op', 'mods', 'hahaha', 'nsfw', 'huh', 'holy', 'iq', 'jailbreak', 'blah', 'bruh', 'yea', 'agi', 'porn', 'waitlist', 'nerf', 'downvoted', 'refresh', 'omg', 'sus', 'characterai', 'meth', 'chinese', 'sub', 'rick', 'american', 'elon', 'sam', 'quack', 'youchat', 'uk', 'chad', 'archived', 'youcom', 'screenshot', 'llm', 'hitler', 'lmao', 'playground', 'rpg', 'delete', 'tldr', 'davinci', 'trump', 'hangman', 'haha', 'tay', 'karma', 'john', 'chatgtp', 'url', 'wokegpt', 'offended', 'fucked', 'redditor', 'ceo', 'agreed', 'emojis', 'cheers', 'ais', 'tag', 'wow', 'lmfao', 'p', 'rip', 'chats', 'hmm', 'bypass', 'llms', 'temperature', 'login', 'cgpt', 'windows', 'novelai', 'biden', 'donald', 'christmas', 'ms', 'cringe',
                   'ZRONX', 'rook', 'thumbnail', 'vid', 'bhai', 'bishop', 'circle', 'subscribed', 'quot', 'bless', 'tutorial', 'XD', 'sir', 'GEMX', 'profitable', 'earning', 'quotquot', 'enjoyed', 'ur', 'bra', 'JIM', 'broker', 'levy', 'vids', 'stare', 'tutorials', 'subscribers', 'sponsor', 'hai', 'lifechanging', 'curve', 'shorts', 'earn', 'trader', 'PC', 'folders', 'informative', 'br', 'chess', 'jontron', 'brother', 'T', 'YT', 'upload', 'O', 'subscriber', 'intro', 'DAN', 'aint', 'download', 'LOL', 'shes', 'moves', 'telegram', 'shortlisted', 'liked', 'websiteapp', 'watched', 'grant', 'plz', 'KINGDOM', 'YOU', 'MESSIAH', 'mate', 'ki', 'subs', 'pawn', 'hes', 'U', 'HACKBANZER', 'ka', 'brbr', 'affiliate', 'clip', 'beast', 'trade', 'ive', 'ho', 'approved', 'bhi', 'gotta', 'profits', 'wanna', 'subscribe', 'funds', 'labels', 'recommended', 'audio', 'uploaded', 'appreciated', 'UBI', 'pls', 'upto', 'alot', 'twist', 'GTP', 'accent', 'monetized', 'S', 'btw'
                  ])

In [8]:
def create_dataloaders(dataset_list, batch_size=64, shuffle=True, num_workers=0):
    dataloaders = []

    for dataset in dataset_list:
        loader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers
        )
        dataloaders.append(loader)
    
    return dataloaders


def process_text_data(texts, all_texts):
    # jargons as provided
    jargons = set(['bigdata', 'RT', 'announces', 'rival', 'MSFT', 'launches', 'SEO', 'tweets', 'WSJ', 'unveils', 'MBA', 'machinelearning', 'HN', 'artificalintelligence', 'DALLE', 'NFT', 'NLP', 'edtech', 'cybersecurity', 'malware', 'CEO', 'ML', 'JPEG', 'GOOGL', 'UX', 'artificialintelligence', 'technews', 'fintech', 'tweet', 'CHATGPT', 'viral', 'blockchain', 'BARD', 'reportedly', 'BREAKING', 'deeplearning', 'educators', 'datascience', 'GOOG', 'competitor', 'FREE', 'startups', 'AGIX', 'CHAT', 'CNET', 'integrating', 'BB', 'maker', 'passes', 'NYT', 'yall', 'rescue', 'heres', 'NYC', 'LLM', 'airdrop', 'powered', 'podcast', 'Q', 'PM', 'newsletter', 'startup', 'TSLA', 'WIRED', 'digitalmarketing', 'CTO', 'changer', 'announced', 'database', 'launched', 'warns', 'popularity', 'AWS', 'nocode', 'trending', 'metaverse', 'cc', 'PR', 'RSS', 'MWC', 'USMLE', 'copywriting', 'marketers', 'tweeting', 'amid', 'AGI', 'socialmedia', 'webinar', 'agrees', 'invests', 'launch', 'killer', 'GM', 'bullish', 'edchat', 'RLHF', 'integration', 'fastestgrowing', 'CNN', 'exam',
                  'deleted', 'gif', 'giphy', 'dm', 'removed', 'remindme', 'yup', 'sydney', 'yep', 'patched', 'nope', 'giphydownsized', 'vpn', 'ascii', 'ah', 'chadgpt', 'nerfed', 'jesus', 'xd', 'wtf', 'upvote', 'nah', 'op', 'mods', 'hahaha', 'nsfw', 'huh', 'holy', 'iq', 'jailbreak', 'blah', 'bruh', 'yea', 'agi', 'porn', 'waitlist', 'nerf', 'downvoted', 'refresh', 'omg', 'sus', 'characterai', 'meth', 'chinese', 'sub', 'rick', 'american', 'elon', 'sam', 'quack', 'youchat', 'uk', 'chad', 'archived', 'youcom', 'screenshot', 'llm', 'hitler', 'lmao', 'playground', 'rpg', 'delete', 'tldr', 'davinci', 'trump', 'hangman', 'haha', 'tay', 'karma', 'john', 'chatgtp', 'url', 'wokegpt', 'offended', 'fucked', 'redditor', 'ceo', 'agreed', 'emojis', 'cheers', 'ais', 'tag', 'wow', 'lmfao', 'p', 'rip', 'chats', 'hmm', 'bypass', 'llms', 'temperature', 'login', 'cgpt', 'windows', 'novelai', 'biden', 'donald', 'christmas', 'ms', 'cringe',
                   'ZRONX', 'rook', 'thumbnail', 'vid', 'bhai', 'bishop', 'circle', 'subscribed', 'quot', 'bless', 'tutorial', 'XD', 'sir', 'GEMX', 'profitable', 'earning', 'quotquot', 'enjoyed', 'ur', 'bra', 'JIM', 'broker', 'levy', 'vids', 'stare', 'tutorials', 'subscribers', 'sponsor', 'hai', 'lifechanging', 'curve', 'shorts', 'earn', 'trader', 'PC', 'folders', 'informative', 'br', 'chess', 'jontron', 'brother', 'T', 'YT', 'upload', 'O', 'subscriber', 'intro', 'DAN', 'aint', 'download', 'LOL', 'shes', 'moves', 'telegram', 'shortlisted', 'liked', 'websiteapp', 'watched', 'grant', 'plz', 'KINGDOM', 'YOU', 'MESSIAH', 'mate', 'ki', 'subs', 'pawn', 'hes', 'U', 'HACKBANZER', 'ka', 'brbr', 'affiliate', 'clip', 'beast', 'trade', 'ive', 'ho', 'approved', 'bhi', 'gotta', 'profits', 'wanna', 'subscribe', 'funds', 'labels', 'recommended', 'audio', 'uploaded', 'appreciated', 'UBI', 'pls', 'upto', 'alot', 'twist', 'GTP', 'accent', 'monetized', 'S', 'btw'
                  ])
    
    # Standard English stop words + jargons
    exclude_words = set(TfidfVectorizer(stop_words="english").get_stop_words()).union(jargons)
    
#     new_vocab = set(bow_vectorizer.get_feature_names()) - jargons
    vectorizer = TfidfVectorizer(stop_words=list(exclude_words), 
                                 lowercase=False, 
                                 token_pattern=r'\b([A-Z]+|[a-z]+)\b')

    vectorizer.fit(all_texts) 
    bow_matrix_texts = vectorizer.transform(texts).toarray()
    
    # Normalize the matrix so that the sum is 1
    epsilon = 1e-10  # To avoid division by zero
    row_sums = bow_matrix_texts.sum(axis=1)
    normalized_bow_matrix_texts = bow_matrix_texts / (row_sums[:, np.newaxis] + epsilon)
    
    print("Shape of the normalized BoW matrix:", normalized_bow_matrix_texts.shape)

    return normalized_bow_matrix_texts

def compute_bow_batchwise(batch_texts, bow_matrix, full_texts):
    positive_samples = []
    
    batch_adjusted_bow_matrix = process_text_data(batch_texts, full_texts)
    
    # batch_adjusted_bow_matrix와 bow_matrix 간의 코사인 유사도 계산
    cosine_sim_matrix = cosine_similarity(batch_adjusted_bow_matrix, bow_matrix)
    
    # batch_texts에 있는 각 텍스트에 대해 full_texts에서 인덱스를 찾음
    for i, batch_text in enumerate(batch_texts):
        # full_texts에서 batch_text의 인덱스 찾기
        self_index = full_texts.index(batch_text) if batch_text in full_texts else None
        
        # 찾은 인덱스의 유사도를 -1로 설정
        if self_index is not None:
            cosine_sim_matrix[i][self_index] = -1
    
#     # 자기 자신에 대한 유사도를 -1로 설정
#     for i, batch_text_vector in enumerate(batch_adjusted_bow_matrix):
#         self_index = np.where((batch_text_vector == bow_matrix).all(axis=1))[0]
#         if len(self_index) > 0:
#             cosine_sim_matrix[i][self_index] = -1

    # 각 batch_text에 대해 가장 유사도가 높은 full_text의 인덱스를 찾음
    for i in range(len(batch_adjusted_bow_matrix)):
        positive_sample_index = np.argmax(cosine_sim_matrix[i])
        positive_samples.append(positive_sample_index)

    return positive_samples

In [9]:
##상수를 따로 뺌
platform_names = ["Twitter", "Reddit", "YouTube"]

def main():
    # 각 데이터셋 초기화(data.py에서 확인 가능)
    twitter_ds = TwitterDataset()
    reddit_ds = RedditDataset()
    youtube_ds = YoutubeDataset()

    platform_datasets = [twitter_ds, reddit_ds, youtube_ds]
    all_texts = twitter_ds.texts + reddit_ds.texts + youtube_ds.texts
    
    # 위에서 한 전처리와 같이 처리를 하여 벡터화 시킴
    exclude_words = set(TfidfVectorizer(stop_words="english").get_stop_words()).union(jargons)

    bow_vectorizer = TfidfVectorizer(stop_words=list(exclude_words), lowercase=False, token_pattern=r'\b([A-Z]+|[a-z]+)\b')
    bow_vectorizer.fit(all_texts)

    # Transform all_texts to create a BoW matrix
    bow_matrix = bow_vectorizer.transform(all_texts).toarray()

    # Create dataloaders(데이터로더를 사용하여 배치사이즈만큼 각 플랫폼의 데이터를 가져옴)
    dataloaders = create_dataloaders(platform_datasets, batch_size=8)
    
    #테스트 결과 확인
    for platform_idx, dataloader in enumerate(dataloaders):
        # 첫 번째 배치만 가져옴
        batch = next(iter(dataloader))
        texts, targets = batch

        print(f"Platform: {platform_names[platform_idx]}")

        positive_sample_indices = compute_bow_batchwise(texts, bow_matrix, all_texts)

        for sample_idx, positive_sample_index in enumerate(positive_sample_indices):
            print(f"Sample text {sample_idx + 1}: {texts[sample_idx]}")
            print(f"Sample Index {sample_idx + 1}: {targets[sample_idx]}")
            positive_text = all_texts[positive_sample_index]
            print(f"Positive sample for text {sample_idx + 1}: {positive_text}")
            print(f"Positive sample Index for text {sample_idx + 1}: {positive_sample_index}")
            print('-' * 20)

        print('-' * 50)


if __name__ == "__main__":
    main()

Platform: Twitter
Shape of the normalized BoW matrix: (8, 122907)
Sample text 1:  Yes but ChatGPT democratises cheating Any student with an internet connection can do it for any topic without payment yet
Sample Index 1: 56919
Positive sample for text 1: i would say this technology democratises this ability some way
Positive sample Index for text 1: 169946
--------------------
Sample text 2: I used Chat GPT to take a transcribed video using  and then give me  title ideas

 The titles were brilliant Wow just wow
Sample Index 2: 9312
Positive sample for text 2: I Love the title of the video
Positive sample Index for text 2: 240960
--------------------
Sample text 3: Does Ketamine hold therapeutic value ChatGPT 
Sample Index 3: 80506
Positive sample for text 3: ChatGPT Response on its value for Industry 

Read 
Positive sample Index for text 3: 14370
--------------------
Sample text 4: AI must be regulated says CTO of ChatGPT maker OpenAI

Sample Index 4: 94259
Positive sample for text 4: 

In [None]:
for platform_idx, dataloader in enumerate(dataloaders):
        # 첫 번째 배치만 가져옴
    batch = next(iter(dataloader))
    texts, targets = batch

    print(f"Platform: {platform_names[platform_idx]}")
        
    positive_sample_indices = compute_bow_batchwise(texts, bow_matrix, all_texts)

    for sample_idx, positive_sample_index in enumerate(positive_sample_indices):
        print(f"Sample text {sample_idx + 1}: {texts[sample_idx]}")
        print(f"Sample Index {sample_idx + 1}: {targets[sample_idx]}")
        positive_text = all_texts[positive_sample_index]
        print(f"Positive sample for text {sample_idx + 1}: {positive_text}")
        print(f"Positive sample Index for text {sample_idx + 1}: {positive_sample_index}")
        print('-' * 20)
        
    print('-' * 50)


In [None]:
for platform_idx, dataloader in enumerate(dataloaders):
        for batch_idx, batch in enumerate(dataloader):
            texts, targets = batch
            
            print(f"Platform: {platform_names[platform_idx]}")
            print(f"Batch index: {batch_idx}")
            print("Sample text:", texts[0])
            print("Sample Index:", targets[0])
            print('-' * 50)
            
            positive_sample_indices = compute_bow_batchwise(texts, bow_matrix, all_texts)

            for sample_idx, positive_sample_index in enumerate(positive_sample_indices):
                positive_text = all_texts[positive_sample_index]
                print(f"Positive sample for text {sample_idx}: {positive_text}")

            print('-' * 50)