In [1]:
# 得到去掉black area的text文件
import os
import pandas as pd
from tqdm import tqdm

# Define the paths
image_folder = '/data1/dxw_data/llm/redbook_final/script_next/combined_seg_img_pure_094'
csv_file = '/data1/dxw_data/llm/redbook_final/script_next/rawdata_20%.csv'
output_file = '/data1/dxw_data/llm/redbook_final/script_next/matching_records.csv'

# Read the CSV file into a DataFrame
df = pd.read_csv(csv_file)

# Initialize a list to store matching rows
matching_rows = []

# Iterate over the files in the image folder with a progress bar
for filename in tqdm(os.listdir(image_folder), desc="Processing images", unit="file"):
    if filename.endswith('.png'):
        # Extract the poster_id and post_id from the filename
        parts = filename.split('_')
        if len(parts) >= 3:
            date, poster_id, post_id = parts[0], parts[1], parts[2]

            # Find matching rows in the DataFrame
            matches = df[(df['poster_id'] == poster_id) & (df['post_id'] == post_id)]

            # Append the matching rows to the list
            if not matches.empty:
                matching_rows.append(matches)

# Concatenate all matching rows into a single DataFrame
if matching_rows:
    matching_df = pd.concat(matching_rows)

    # Save the matching records to a new CSV file
    matching_df.to_csv(output_file, index=False)

    print(f'Matching records saved to {output_file}')
else:
    print('No matching records found.')


  df = pd.read_csv(csv_file)
Processing images: 100%|██████████| 10379/10379 [00:52<00:00, 196.03file/s]


Matching records saved to /data1/dxw_data/llm/redbook_final/script_next/matching_records.csv


In [1]:
import pandas as pd
import re
import emoji
import jieba
from rake_nltk import Rake
from tqdm import tqdm
import numpy as np

# Define the paths
matching_records_file = '/data1/dxw_data/llm/redbook_final/script_next/matching_records.csv'
stopwords_file_path = 'stopwords_cn.txt'

# Load the stopwords
with open(stopwords_file_path, 'r', encoding='utf-8') as file:
    stopwords = set(file.read().splitlines())

# Load the matching records CSV
df = pd.read_csv(matching_records_file)

# Function to clean the summary text
def clean_text(text, stopwords):
    # Convert emoji to text
    text = emoji.demojize(text)

    # Remove specific text patterns
    text = re.sub(r'- 小红书,,', '', text)
    text = re.sub(r',,\d{2}-\d{2},,', '', text)
    text = re.sub(r'#', ' ', text)
    
    # Remove digits
    text = re.sub(r'\d+', '', text)
    
    # Remove special characters
    cleaned_text = ''.join(char for char in text if char.isalnum() or char.isspace())
    
    # Word segmentation
    words = jieba.cut(cleaned_text)
    
    # Remove stopwords
    filtered_words = [word for word in words if word not in stopwords]
    
    return ' '.join(filtered_words)

# Function to extract keywords using RAKE
def extract_keywords(text):
    r = Rake()
    r.extract_keywords_from_text(text)
    return r.get_ranked_phrases()

# Function to clean the extracted keywords
def clean_keywords(keywords):
    # Remove any NaN or invalid characters
    if isinstance(keywords, list):
        cleaned_keywords = [kw for kw in keywords if isinstance(kw, str) and kw.strip()]
        return cleaned_keywords if cleaned_keywords else np.nan
    return np.nan

# Apply the cleaning function with progress bar
tqdm.pandas(desc="Cleaning summaries")
df['summary_cleaned'] = df['summary'].progress_apply(lambda x: clean_text(str(x), stopwords))

# Apply the keyword extraction with progress bar
tqdm.pandas(desc="Extracting keywords")
df['rake_keywords'] = df['summary_cleaned'].progress_apply(lambda x: extract_keywords(x))

# Clean the extracted keywords
df['rake_keywords'] = df['rake_keywords'].apply(clean_keywords)  #! 去掉nan，并检测有如果为空就删去。

# Save the updated DataFrame back to the CSV file
df.to_csv(matching_records_file, index=False)

print(f'Updated records with cleaned summaries and keywords saved to {matching_records_file}')


Cleaning summaries:   0%|          | 0/12807 [00:00<?, ?it/s]Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.591 seconds.
Prefix dict has been built successfully.
Cleaning summaries: 100%|██████████| 12807/12807 [00:09<00:00, 1286.89it/s]
Extracting keywords: 100%|██████████| 12807/12807 [00:09<00:00, 1360.67it/s]


Updated records with cleaned summaries and keywords saved to /data1/dxw_data/llm/redbook_final/script_next/matching_records.csv


In [3]:
import os
import torch
import pandas as pd
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
from imagebind import data
from PIL import Image
from torchvision import transforms
from tqdm import tqdm

# Paths
csv_file = '/data1/dxw_data/llm/redbook_final/script_next/matching_records.csv'
image_folder = '/data1/dxw_data/llm/redbook_final/script_next/combined_seg_img_pure_094'
output_folder = '/data1/dxw_data/llm/redbook_final/script_next/combined_embeddings'
os.makedirs(output_folder, exist_ok=True)

# Device setup
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Load the model
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)

# Load CSV
df = pd.read_csv(csv_file)

# Image transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Function to load and transform a single image
def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    return transform(image).unsqueeze(0).to(device)

# Function to generate embeddings
def generate_embeddings(text, image_path):
    # Prepare inputs
    inputs = {}
    if text is not None:
        inputs[ModalityType.TEXT] = data.load_and_transform_text([text], device)
    if image_path is not None:
        inputs[ModalityType.VISION] = load_image(image_path)
    
    # Generate embeddings
    with torch.no_grad():
        embeddings = model(inputs)
    
    return embeddings

# Iterate over each row in the CSV and process
for idx, row in tqdm(df.iterrows(), total=df.shape[0], desc="Processing rows"):
    poster_id = row['poster_id']
    post_id = row['post_id']
    rake_keywords = row['rake_keywords']
    
    # Find corresponding image file
    image_filename = f"*_{poster_id}_{post_id}_*.png"
    image_path = next((os.path.join(image_folder, fname) for fname in os.listdir(image_folder) if fname.endswith('.png') and f"{poster_id}_{post_id}" in fname), None)
    
    if image_path:
        # Generate embeddings
        embeddings = generate_embeddings(rake_keywords, image_path)
        
        # Concatenate text and vision embeddings
        concatenated_embeddings = torch.cat((embeddings[ModalityType.TEXT], embeddings[ModalityType.VISION]), dim=1)
        
        # Save the concatenated embeddings
        output_file = os.path.join(output_folder, f"{poster_id}_{post_id}_embedding.pt")
        torch.save(concatenated_embeddings, output_file)

print(f"Embedding process complete. Embeddings saved to {output_folder}.")





KeyboardInterrupt: 

In [None]:
import os
import json
import torch
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from tqdm import tqdm
import shutil

# Paths
embedding_folder = '/data1/dxw_data/llm/redbook_final/script_next/combined_embeddings'
output_folder = '/data1/dxw_data/llm/redbook_final/script_next/combined_seg_img_pure_094_cluster_imagebind3'
os.makedirs(output_folder, exist_ok=True)

# Load concatenated embeddings
embedding_files = [os.path.join(embedding_folder, fname) for fname in os.listdir(embedding_folder) if fname.endswith('.pt')]

# Load all embeddings into a list
all_embeddings = []
for embedding_file in tqdm(embedding_files, desc="Loading embeddings"):
    embedding = torch.load(embedding_file)
    all_embeddings.append(embedding)

# Concatenate all embeddings into a single tensor
all_embeddings = torch.cat(all_embeddings, dim=0)

# Determine optimal number of clusters using Average Silhouette Method
silhouette_scores = []
k_values = [2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 30, 40, 50, 60, 70, 80, 90, 100, 120, 150, 200]  # Discrete values for k
for k in k_values:
    kmeans = KMeans(n_clusters=k, random_state=0)
    labels = kmeans.fit_predict(all_embeddings.numpy())
    score = silhouette_score(all_embeddings.numpy(), labels)
    silhouette_scores.append(score)

# Plot silhouette scores
plt.figure(figsize=(30, 6))  # Increased width to three times the original
plt.plot(k_values, silhouette_scores, marker='o')
plt.xlabel('Number of clusters (k)')
plt.ylabel('Average Silhouette Score')
plt.title('Average Silhouette Score vs. Number of Clusters')
plt.grid(True)
plt.show()

# Select the optimal number of clusters based on the silhouette scores
optimal_k = k_values[silhouette_scores.index(max(silhouette_scores))]

# Perform clustering with optimal k
kmeans = KMeans(n_clusters=optimal_k, random_state=0)
labels = kmeans.fit_predict(all_embeddings.numpy())

# Save clustered images to output folders based on the clustering results
for idx, label in tqdm(enumerate(labels), desc="Saving clustered images", total=len(labels)):
    label_folder = os.path.join(output_folder, str(label))
    os.makedirs(label_folder, exist_ok=True)
    
    # Extract corresponding image file name from embedding file name
    embedding_file = embedding_files[idx]
    image_filename = os.path.basename(embedding_file).replace('_embedding.pt', '.png')
    
    # Define source image path
    source_image_path = os.path.join(embedding_folder.replace('combined_embeddings', 'combined_seg_img_pure_094'), image_filename)
    
    # Copy image to corresponding cluster folder
    if os.path.exists(source_image_path):
        shutil.copy(source_image_path, os.path.join(label_folder, image_filename))

# Save clustering labels to JSON file
labels_json = {os.path.basename(embedding_files[idx]): int(label) for idx, label in enumerate(labels)}
with open(os.path.join(output_folder, 'labels.json'), 'w') as f:
    json.dump(labels_json, f)

print(f'Clustering complete. Output saved to {output_folder}')
