In [7]:
import pandas as pd
from dateutil import parser
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
import numpy as np
import torch.nn as nn
import torch.optim as optim
from imagebind import data
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
import torch.nn.functional as F

# Ensure CUDA_LAUNCH_BLOCKING is set for better debugging
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Define file paths
data_dir = "/data1/dxw_data/llm/redbook_final/script_next/"
rawdata_path = os.path.join(data_dir, "rawdata_20%.csv")
after2monthdata_path = os.path.join(data_dir, "after2monthdata_20%_with_trend.csv")
image_dir = os.path.join(data_dir, "data_img_20%")

# Check if files exist
assert os.path.exists(rawdata_path), f"File not found: {rawdata_path}"
assert os.path.exists(after2monthdata_path), f"File not found: {after2monthdata_path}"
assert os.path.exists(image_dir), f"Directory not found: {image_dir}"

# Read CSV files
rawdata = pd.read_csv(rawdata_path)
after2monthdata = pd.read_csv(after2monthdata_path)

# Convert date columns to standard format
def parse_date(date_str):
    try:
        return parser.parse(date_str)
    except ValueError:
        return None

rawdata['post_date'] = rawdata['post_date'].apply(parse_date)
rawdata = rawdata.dropna(subset=['post_date'])

# Only use 1/10th of the data
def get_subset_indices(data, fraction=0.01):
    data_size = len(data)
    indices = list(range(data_size))
    np.random.shuffle(indices)
    split = int(np.floor(fraction * data_size))
    return indices[:split]

# Random sampling of 1/10th of the data
subset_indices = get_subset_indices(rawdata)
rawdata = rawdata.iloc[subset_indices]
after2monthdata = after2monthdata[after2monthdata['post_id'].isin(rawdata['post_id'])]

print(f"Total raw data samples: {len(rawdata)}")
print(f"Total after 2 month data samples: {len(after2monthdata)}")

# Split data into training and testing sets
train_rawdata = rawdata[(rawdata['post_date'].dt.month >= 1) & (rawdata['post_date'].dt.month <= 9)]
test_rawdata = rawdata[(rawdata['post_date'].dt.month >= 10) & (rawdata['post_date'].dt.month >= 10)]

train_after2monthdata = after2monthdata[after2monthdata['post_id'].isin(train_rawdata['post_id'])]
test_after2monthdata = after2monthdata[after2monthdata['post_id'].isin(test_rawdata['post_id'])]

print(f"Training raw data samples: {len(train_rawdata)}")
print(f"Testing raw data samples: {len(test_rawdata)}")
print(f"Training after 2 month data samples: {len(train_after2monthdata)}")
print(f"Testing after 2 month data samples: {len(test_after2monthdata)}")

# Remove non-finite values in the 'proportion' column
train_after2monthdata = train_after2monthdata.replace([np.inf, -np.inf], np.nan).dropna(subset=['proportion'])
test_after2monthdata = test_after2monthdata.replace([np.inf, -np.inf], np.nan).dropna(subset=['proportion'])

print(f"Training after 2 month data samples after cleaning: {len(train_after2monthdata)}")
print(f"Testing after 2 month data samples after cleaning: {len(test_after2monthdata)}")

# Define a sliding window function
def sliding_window(data, window_size=60):
    num_windows = len(data) - window_size + 1
    return [data[i:i + window_size] for i in range(num_windows)]

# Get sliding windows for training and testing data
train_windows = sliding_window(train_rawdata)
test_windows = sliding_window(test_rawdata)

print(f"Number of training windows: {len(train_windows)}")
print(f"Number of testing windows: {len(test_windows)}")

# Custom Dataset Class
class MultimodalDataset(Dataset):
    def __init__(self, windows, after2monthdata, image_dir, transform=None, max_images=1):
        self.windows = windows
        self.after2monthdata = after2monthdata
        self.image_dir = image_dir
        self.transform = transform
        self.max_images = max_images
        self.data = self._prepare_data()
        print(f"Number of processed samples: {len(self.data)}")  # Debug information

    def _prepare_data(self):
        data = []
        for window in tqdm(self.windows, desc="Processing data"):
            window_data = []
            for _, row in window.iterrows():
                poster_id = row['poster_id']
                post_id = row['post_id']
                numerical_list = [float(row['post_comments']), float(row['post_like']), float(row['post_collect'])]
                image_files = [f for f in os.listdir(self.image_dir) if f"{poster_id}_{post_id}" in f]
                images = []
                for image_file in image_files[:self.max_images]:
                    image_path = os.path.join(self.image_dir, image_file)
                    try:
                        image = Image.open(image_path).convert('RGB')
                    except Exception as e:
                        print(f"Error opening image {image_path}: {e}")
                        continue
                    if self.transform:
                        image = self.transform(image)
                    images.append(image)
                while len(images) < self.max_images:
                    images.append(torch.zeros((3, 224, 224)))
                if images:
                    summary = row['summary']
                    window_data.append((summary, images, numerical_list, row['post_date'], row['post_id']))  # 添加 post_date 和 post_id
            if window_data:
                last_day_post_id = window.iloc[-1]['post_id']
                label_data = self.after2monthdata[self.after2monthdata['post_id'] == last_day_post_id]['proportion']
                if not label_data.empty:
                    label = label_data.values[0]
                    data.append((window_data, label))
        return data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        window_data, label = self.data[idx]
        summaries, images, numerical_lists, post_dates, post_ids = zip(*window_data)  # 添加 post_dates 和 post_ids
        images = [torch.stack(image_set).float() for image_set in images]
        numerical_lists = torch.tensor(numerical_lists, dtype=torch.float32)
        return summaries, torch.stack(images), numerical_lists, torch.tensor(label, dtype=torch.float32), post_dates, post_ids  # 返回 post_dates 和 post_ids

# Image Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Create Datasets
train_dataset = MultimodalDataset(train_windows, train_after2monthdata, image_dir, transform=transform)
test_dataset = MultimodalDataset(test_windows, test_after2monthdata, image_dir, transform=transform)

print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")

# Create Data Loaders
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

print(f"Number of batches in train_loader: {len(train_loader)}")
print(f"Number of batches in test_loader: {len(test_loader)}")


  rawdata = pd.read_csv(rawdata_path)


Total raw data samples: 426
Total after 2 month data samples: 637
Training raw data samples: 350
Testing raw data samples: 76
Training after 2 month data samples: 491
Testing after 2 month data samples: 146
Training after 2 month data samples after cleaning: 485
Testing after 2 month data samples after cleaning: 145
Number of training windows: 291
Number of testing windows: 17


Processing data: 100%|██████████| 291/291 [17:40<00:00,  3.64s/it]


Number of processed samples: 133


Processing data: 100%|██████████| 17/17 [01:01<00:00,  3.61s/it]

Number of processed samples: 8
Number of training samples: 133
Number of test samples: 8
Number of batches in train_loader: 133
Number of batches in test_loader: 8





In [None]:
# 打印输出的值


In [8]:
# Model Definition
device = "cuda:3" if torch.cuda.is_available() else "cpu"

# Load imagebind model
imagebind_model = imagebind_model.imagebind_huge(pretrained=True)
imagebind_model.eval()
imagebind_model.to(device)
print("ImageBind model loaded and set to evaluation mode.")

class CrossAttentionFusionLSTM(nn.Module):
    def __init__(self, text_embedding_dim, vision_embedding_dim, common_embedding_dim, numerical_feature_dim, num_heads):
        super(CrossAttentionFusionLSTM, self).__init__()
        self.text_linear = nn.Linear(text_embedding_dim, common_embedding_dim)
        self.vision_linear = nn.Linear(vision_embedding_dim, common_embedding_dim)
        self.numerical_mlp = nn.Sequential(
            nn.Linear(numerical_feature_dim, common_embedding_dim),
            nn.ReLU(),
            nn.Linear(common_embedding_dim, common_embedding_dim)
        )
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=common_embedding_dim, nhead=num_heads), num_layers=2)
        self.lstm = nn.LSTM(common_embedding_dim, common_embedding_dim, batch_first=True)
        self.fc = nn.Linear(common_embedding_dim, 1)

    def forward(self, text_embeddings, vision_embeddings, numerical_features):
        text_embeddings = self.text_linear(text_embeddings)
        vision_embeddings = self.vision_linear(vision_embeddings)
        numerical_embeddings = self.numerical_mlp(numerical_features)
        min_len = min(text_embeddings.size(1), vision_embeddings.size(1), numerical_embeddings.size(1))
        text_embeddings = text_embeddings[:, :min_len, :]
        vision_embeddings = vision_embeddings[:, :min_len, :]
        numerical_embeddings = numerical_embeddings[:, :min_len, :]
        multimodal_embeddings = torch.cat((text_embeddings, vision_embeddings, numerical_embeddings), dim=1)
        multimodal_embeddings = self.transformer_encoder(multimodal_embeddings)
        lstm_out, _ = self.lstm(multimodal_embeddings)
        lstm_out = lstm_out[:, -1, :]
        output = self.fc(lstm_out)
        return output

def pad_embeddings(embeddings, target_length):
    if embeddings.size(1) < target_length:
        padding = torch.zeros((embeddings.size(0), target_length - embeddings.size(1), embeddings.size(2)), device=embeddings.device)
        embeddings = torch.cat((embeddings, padding), dim=1)
    return embeddings

def get_embeddings(text_list, image_tensors):
    inputs = {
        ModalityType.TEXT: data.load_and_transform_text(text_list, device),
        ModalityType.VISION: image_tensors.to(device)
    }
    
    with torch.no_grad():
        embeddings = imagebind_model(inputs)
    
    return embeddings

def train_model(model, train_loader, criterion, optimizer, num_epochs):
    model.train()
    log_file = open("log.txt", "w")
    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
            summaries, images, numerical_lists, labels, post_dates, post_ids = batch
            images, numerical_lists, labels = images.to(device), numerical_lists.to(device), labels.to(device)
            text_embeddings_list = []
            vision_embeddings_list = []
            for day_summaries, day_images in zip(summaries[0], images[0]):
                embeddings = get_embeddings(day_summaries, day_images)
                text_embeddings_list.append(embeddings[ModalityType.TEXT])
                vision_embeddings_list.append(embeddings[ModalityType.VISION])
            
            max_len = 60
            text_embeddings_list = text_embeddings_list[:max_len]
            vision_embeddings_list = vision_embeddings_list[:max_len]
            numerical_lists = numerical_lists[:, :max_len, :]

            text_embeddings = torch.stack(text_embeddings_list, dim=0)
            vision_embeddings = torch.stack(vision_embeddings_list, dim=0)

            target_length = max(text_embeddings.size(1), vision_embeddings.size(1), numerical_lists.size(1))
            text_embeddings = pad_embeddings(text_embeddings, target_length)
            vision_embeddings = pad_embeddings(vision_embeddings, target_length)
            numerical_lists = pad_embeddings(numerical_lists, target_length)

            optimizer.zero_grad()
            outputs = model(text_embeddings, vision_embeddings, numerical_lists)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        epoch_loss = running_loss / len(train_loader)
        log_file.write(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss}\n')
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss}')
    log_file.close()

# Initialize model, criterion, and optimizer
text_embedding_dim = 1024
vision_embedding_dim = 1024
common_embedding_dim = 768
numerical_feature_dim = 3
num_heads = 8

model = CrossAttentionFusionLSTM(text_embedding_dim, vision_embedding_dim, common_embedding_dim, numerical_feature_dim, num_heads).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
num_epochs = 10
train_model(model, train_loader, criterion, optimizer, num_epochs)

# Save the model
torch.save(model.state_dict(), os.path.join(data_dir, "multimodal_regression_model.pth"))

print("Model training completed and saved!")

ImageBind model loaded and set to evaluation mode.


Training Epoch 1/10:   0%|          | 0/133 [00:00<?, ?it/s]


TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'pandas._libs.tslibs.timestamps.Timestamp'>

In [None]:
# Evaluation Code
def evaluate_model(model, test_loader):
    model.eval()
    total_loss = 0.0
    criterion = nn.MSELoss()
    all_post_dates = []
    all_post_ids = []
    all_predictions = []
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            summaries, images, numerical_lists, labels, post_dates, post_ids = batch
            images, numerical_lists, labels = images.to(device), numerical_lists.to(device), labels.to(device)
            text_embeddings_list = []
            vision_embeddings_list = []
            for day_summaries, day_images in zip(summaries[0], images[0]):
                embeddings = get_embeddings(day_summaries, day_images)
                text_embeddings_list.append(embeddings[ModalityType.TEXT])
                vision_embeddings_list.append(embeddings[ModalityType.VISION])
            
            max_len = 60
            text_embeddings_list = text_embeddings_list[:max_len]
            vision_embeddings_list = vision_embeddings_list[:max_len]
            numerical_lists = numerical_lists[:, :max_len, :]

            text_embeddings = torch.stack(text_embeddings_list, dim=0)
            vision_embeddings = torch.stack(vision_embeddings_list, dim=0)

            target_length = max(text_embeddings.size(1), vision_embeddings.size(1), numerical_lists.size(1))
            text_embeddings = pad_embeddings(text_embeddings, target_length)
            vision_embeddings = pad_embeddings(vision_embeddings, target_length)
            numerical_lists = pad_embeddings(numerical_lists, target_length)

            outputs = model(text_embeddings, vision_embeddings, numerical_lists)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            all_post_dates.extend(post_dates)
            all_post_ids.extend(post_ids)
            all_predictions.extend(outputs.cpu().numpy())

    avg_loss = total_loss / len(test_loader)
    print(f'Evaluation Loss: {avg_loss}')
    
    for pd, pi, pred in zip(all_post_dates, all_post_ids, all_predictions):
        print(f"Post Date: {pd}, Post ID: {pi}, Predicted Trend: {pred}")

    return avg_loss

# Evaluate the model
test_loss = evaluate_model(model, test_loader)

print(f"Evaluation Loss: {test_loss}")