In [3]:
import pandas as pd
from dateutil import parser
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from imagebind import data
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
import json

# Ensure CUDA_LAUNCH_BLOCKING is set for better debugging
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Define file paths
data_dir = "/data1/dxw_data/llm/redbook_final/script_next/"
rawdata_path = os.path.join(data_dir, "rawdata_20%.csv")
after2monthdata_path = os.path.join(data_dir, "after2monthdata_20%_with_trend.csv")
image_dir = os.path.join(data_dir, "data_img_20%")

# Check if files and directories exist
assert os.path.exists(rawdata_path), f"File not found: {rawdata_path}"
assert os.path.exists(after2monthdata_path), f"File not found: {after2monthdata_path}"
assert os.path.exists(image_dir), f"Directory not found: {image_dir}"

print("All files and directories are verified to exist.")

# Read CSV files
rawdata = pd.read_csv(rawdata_path)
after2monthdata = pd.read_csv(after2monthdata_path)
print(f"Raw data samples: {len(rawdata)}")
print(f"After 2 months data samples: {len(after2monthdata)}")

# Convert date columns to standard format
def parse_date(date_str):
    try:
        return parser.parse(date_str)
    except ValueError:
        return None

rawdata['post_date'] = rawdata['post_date'].apply(parse_date)
rawdata = rawdata.dropna(subset=['post_date'])
print("Converted post_date to standard format and removed invalid dates.")

# Function to randomly sample a subset of the data
def get_subset_indices(data, fraction=0.01):
    data_size = len(data)
    indices = list(range(data_size))
    np.random.shuffle(indices)
    split = int(np.floor(fraction * data_size))
    return indices[:split]

# Random sampling of 1/100th of the data
subset_indices = get_subset_indices(rawdata)
rawdata = rawdata.iloc[subset_indices]
after2monthdata = after2monthdata[after2monthdata['post_id'].isin(rawdata['post_id'])]
print(f"Randomly sampled {len(rawdata)} raw data samples.")
print(f"Filtered after 2 months data to {len(after2monthdata)} samples.")

# Split data into training and testing sets
train_rawdata = rawdata[(rawdata['post_date'].dt.month >= 1) & (rawdata['post_date'].dt.month <= 9)]
test_rawdata = rawdata[(rawdata['post_date'].dt.month >= 10) & (rawdata['post_date'].dt.month >= 10)]
print(f"Training raw data samples: {len(train_rawdata)}")
print(f"Testing raw data samples: {len(test_rawdata)}")

train_after2monthdata = after2monthdata[after2monthdata['post_id'].isin(train_rawdata['post_id'])]
test_after2monthdata = after2monthdata[after2monthdata['post_id'].isin(test_rawdata['post_id'])]
print(f"Training after 2 month data samples: {len(train_after2monthdata)}")
print(f"Testing after 2 month data samples: {len(test_after2monthdata)}")

# Remove non-finite values in the 'trend' column
train_after2monthdata = train_after2monthdata.replace([np.inf, -np.inf], np.nan).dropna(subset=['trend'])
test_after2monthdata = test_after2monthdata.replace([np.inf, -np.inf], np.nan).dropna(subset=['trend'])
print(f"Training after 2 month data samples after cleaning: {len(train_after2monthdata)}")
print(f"Testing after 2 month data samples after cleaning: {len(test_after2monthdata)}")

# Ensure that the 'trend' column has correct integer labels
train_after2monthdata['trend'] = train_after2monthdata['trend'].astype(int)
test_after2monthdata['trend'] = test_after2monthdata['trend'].astype(int)

# Define a sliding window function
def sliding_window(data, window_size=60):
    num_windows = len(data) - window_size + 1
    return [data[i:i + window_size] for i in range(num_windows)]

# Get sliding windows for training and testing data
train_windows = sliding_window(train_rawdata)
test_windows = sliding_window(test_rawdata)
print(f"Number of training windows: {len(train_windows)}")
print(f"Number of testing windows: {len(test_windows)}")

# Custom Dataset Class
class MultimodalDataset(Dataset):
    def __init__(self, windows, after2monthdata, image_dir, transform=None, max_images=1, save_test_data=False, test_data_path='test_classification_data.json'):
        self.windows = windows
        self.after2monthdata = after2monthdata
        self.image_dir = image_dir
        self.transform = transform
        self.max_images = max_images
        self.save_test_data = save_test_data
        self.test_data_path = test_data_path
        self.data = self._prepare_data()
        print(f"Number of processed samples: {len(self.data)}")  # Debug information

    def _prepare_data(self):
        data = []
        test_data = []
        for window in tqdm(self.windows, desc="Processing data"):
            window_data = []
            for _, row in window.iterrows():
                poster_id = row['poster_id']
                post_id = row['post_id']
                post_date = row['post_date']
                image_files = [f for f in os.listdir(self.image_dir) if f"{poster_id}_{post_id}" in f]
                images = []
                for image_file in image_files[:self.max_images]:
                    image_path = os.path.join(self.image_dir, image_file)
                    try:
                        image = Image.open(image_path).convert('RGB')
                    except Exception as e:
                        print(f"Error opening image {image_path}: {e}")
                        continue
                    if self.transform:
                        image = self.transform(image)
                    images.append(image)
                while len(images) < self.max_images:
                    images.append(torch.zeros((3, 224, 224)))
                if images:
                    summary = row['summary']
                    numerical_list = [float(row['post_comments']), float(row['post_like']), float(row['post_collect'])]
                    window_data.append((summary, images, numerical_list))
            if window_data:
                last_day_post_id = window.iloc[-1]['post_id']
                label_data = self.after2monthdata[self.after2monthdata['post_id'] == last_day_post_id]['trend']
                if not label_data.empty:
                    label = int(label_data.values[0])  # Convert to standard int
                    one_hot_label = F.one_hot(torch.tensor(label + 1), num_classes=3)  # -1, 0, 1 -> 0, 1, 2
                    data.append((window_data, one_hot_label))
                    
                    # If saving test data, collect necessary information
                    if self.save_test_data:
                        test_data.append({
                            'post_date': post_date.isoformat(),
                            'post_id': str(last_day_post_id),  # Convert to string
                            'true_label': label,
                            'predicted_label': None  # Will be updated during evaluation
                        })
                        
        if self.save_test_data and test_data:
            with open(self.test_data_path, 'w') as f:
                json.dump(test_data, f, ensure_ascii=False, indent=4)
                
        return data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        window_data, label = self.data[idx]
        summaries, images, numerical_lists = zip(*window_data)
        images = [torch.stack(image_set).float() for image_set in images]
        return summaries, torch.stack(images), torch.tensor(numerical_lists).float(), label.float()

# Image Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
print("Image transformations defined.")

# Create Datasets
train_dataset = MultimodalDataset(train_windows, train_after2monthdata, image_dir, transform=transform)
test_dataset = MultimodalDataset(test_windows, test_after2monthdata, image_dir, transform=transform, save_test_data=True, test_data_path='test_classification_data.json')
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")

# Create Data Loaders
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
print(f"Number of batches in train_loader: {len(train_loader)}")
print(f"Number of batches in test_loader: {len(test_loader)}")

# Define model components
device = "cuda:4" if torch.cuda.is_available() else "cpu"

# Load ImageBind model
imagebind_model = imagebind_model.imagebind_huge(pretrained=True)
imagebind_model.eval()
imagebind_model.to(device)
print("ImageBind model loaded and set to evaluation mode.")

# Define CrossAttentionFusionLSTM model with MLP for numerical features
class CrossAttentionFusionLSTM(nn.Module):
    def __init__(self, text_embedding_dim, vision_embedding_dim, common_embedding_dim, num_heads, numerical_feature_dim):
        super(CrossAttentionFusionLSTM, self).__init__()
        self.text_linear = nn.Linear(text_embedding_dim, common_embedding_dim)
        self.vision_linear = nn.Linear(vision_embedding_dim, common_embedding_dim)
        self.numerical_mlp = nn.Sequential(
            nn.Linear(numerical_feature_dim, common_embedding_dim),
            nn.ReLU(),
            nn.Linear(common_embedding_dim, common_embedding_dim)
        )
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=common_embedding_dim, nhead=num_heads), num_layers=2)
        self.lstm = nn.LSTM(common_embedding_dim, common_embedding_dim, batch_first=True)
        self.fc = nn.Linear(common_embedding_dim, 3)

    def forward(self, text_embeddings, vision_embeddings, numerical_embeddings):
        text_embeddings = self.text_linear(text_embeddings)
        vision_embeddings = self.vision_linear(vision_embeddings)
        numerical_embeddings = self.numerical_mlp(numerical_embeddings)
        min_len = min(text_embeddings.size(1), vision_embeddings.size(1), numerical_embeddings.size(1))
        text_embeddings = text_embeddings[:, :min_len, :]
        vision_embeddings = vision_embeddings[:, :min_len, :]
        numerical_embeddings = numerical_embeddings[:, :min_len, :]
        multimodal_embeddings = torch.cat((text_embeddings, vision_embeddings, numerical_embeddings), dim=1)
        multimodal_embeddings = self.transformer_encoder(multimodal_embeddings)
        lstm_out, _ = self.lstm(multimodal_embeddings)
        lstm_out = lstm_out[:, -1, :]
        output = self.fc(lstm_out)
        return output
print("Model architecture defined.")

# Padding function
def pad_embeddings(embeddings, target_length):
    if embeddings.size(1) < target_length:
        padding = torch.zeros((embeddings.size(0), target_length - embeddings.size(1), embeddings.size(2)), device=embeddings.device)
        embeddings = torch.cat((embeddings, padding), dim=1)
    return embeddings

# Get embeddings function
def get_embeddings(text_list, image_tensors):
    inputs = {
        ModalityType.TEXT: data.load_and_transform_text(text_list, device),
        ModalityType.VISION: image_tensors.to(device)
    }
    
    with torch.no_grad():
        embeddings = imagebind_model(inputs)
    
    return embeddings

# Train model function
def train_model(model, train_loader, criterion, optimizer, num_epochs):
    model.train()
    log_file = open("log.txt", "w")
    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
            summaries, images, numerical_lists, labels = batch
            images, labels, numerical_lists = images.to(device), labels.to(device), numerical_lists.to(device)
            text_embeddings_list = []
            vision_embeddings_list = []
            for day_summaries, day_images in zip(summaries[0], images[0]):
                embeddings = get_embeddings(day_summaries, day_images)
                text_embeddings_list.append(embeddings[ModalityType.TEXT])
                vision_embeddings_list.append(embeddings[ModalityType.VISION])
            
            max_len = 60
            text_embeddings_list = text_embeddings_list[:max_len]
            vision_embeddings_list = vision_embeddings_list[:max_len]
            numerical_lists = numerical_lists[:, :max_len, :]

            text_embeddings = torch.stack(text_embeddings_list, dim=0)
            vision_embeddings = torch.stack(vision_embeddings_list, dim=0)

            target_length = max(text_embeddings.size(1), vision_embeddings.size(1), numerical_lists.size(1))
            text_embeddings = pad_embeddings(text_embeddings, target_length)
            vision_embeddings = pad_embeddings(vision_embeddings, target_length)
            numerical_lists = pad_embeddings(numerical_lists, target_length)

            optimizer.zero_grad()
            outputs = model(text_embeddings, vision_embeddings, numerical_lists)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        epoch_loss = running_loss / len(train_loader)
        log_file.write(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss}\n')
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss}')
    log_file.close()

# Initialize model, criterion, and optimizer
text_embedding_dim = 1024
vision_embedding_dim = 1024
common_embedding_dim = 768
num_heads = 8
numerical_feature_dim = 3

model = CrossAttentionFusionLSTM(text_embedding_dim, vision_embedding_dim, common_embedding_dim, num_heads, numerical_feature_dim).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
print("Model, criterion, and optimizer initialized.")

# Train the model
num_epochs = 2
train_model(model, train_loader, criterion, optimizer, num_epochs)

# Save the model
torch.save(model.state_dict(), os.path.join(data_dir, "multimodal_model.pth"))
print("Model training completed and saved!")

# Evaluation function
def evaluate_model(model, test_loader, test_data_path):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        test_data = json.load(open(test_data_path))
        for batch_idx, batch in enumerate(tqdm(test_loader, desc="Evaluating")):
            summaries, images, numerical_lists, labels = batch
            images, labels, numerical_lists = images.to(device), labels.to(device), numerical_lists.to(device)
            text_embeddings_list = []
            vision_embeddings_list = []
            for day_summaries, day_images in zip(summaries[0], images[0]):
                embeddings = get_embeddings(day_summaries, day_images)
                text_embeddings_list.append(embeddings[ModalityType.TEXT])
                vision_embeddings_list.append(embeddings[ModalityType.VISION])
            
            max_len = 60
            text_embeddings_list = text_embeddings_list[:max_len]
            vision_embeddings_list = vision_embeddings_list[:max_len]
            numerical_lists = numerical_lists[:, :max_len, :]

            text_embeddings = torch.stack(text_embeddings_list, dim=0)
            vision_embeddings = torch.stack(vision_embeddings_list, dim=0)

            target_length = max(text_embeddings.size(1), vision_embeddings.size(1), numerical_lists.size(1))
            text_embeddings = pad_embeddings(text_embeddings, target_length)
            vision_embeddings = pad_embeddings(vision_embeddings, target_length)
            numerical_lists = pad_embeddings(numerical_lists, target_length)

            outputs = model(text_embeddings, vision_embeddings, numerical_lists)
            _, predicted = torch.max(outputs.data, 1)
            _, labels_max = torch.max(labels.data, 1)
            total += labels.size(0)
            correct += (predicted == labels_max).sum().item()

            # Update the predicted labels in the test data and print the details
            test_data[batch_idx]['predicted_label'] = int(predicted.item())
            print(f"Post Date: {test_data[batch_idx]['post_date']}, Post ID: {test_data[batch_idx]['post_id']}, Predicted Label: {predicted.item()}, Actual Label: {labels_max.item()}")
        
        # Save the updated test data with predicted labels
        with open(test_data_path, 'w') as f:
            json.dump(test_data, f, ensure_ascii=False, indent=4)
    
    if total > 0:
        accuracy = correct / total
        print(f'Accuracy: {accuracy * 100:.2f}%')
    else:
        accuracy = 0
        print("No data to evaluate.")
    
    return accuracy

# Evaluate the model
test_data_path = 'test_classification_data.json'
test_accuracy = evaluate_model(model, test_loader, test_data_path)
print(f"Test Accuracy: {test_accuracy * 100:.2f}%")


All files and directories are verified to exist.


  rawdata = pd.read_csv(rawdata_path)


Raw data samples: 42698
After 2 months data samples: 43683
Converted post_date to standard format and removed invalid dates.
Randomly sampled 426 raw data samples.
Filtered after 2 months data to 649 samples.
Training raw data samples: 323
Testing raw data samples: 103
Training after 2 month data samples: 447
Testing after 2 month data samples: 202
Training after 2 month data samples after cleaning: 437
Testing after 2 month data samples after cleaning: 202
Number of training windows: 264
Number of testing windows: 44
Image transformations defined.


Processing data: 100%|██████████| 264/264 [15:17<00:00,  3.48s/it]


Number of processed samples: 106


Processing data: 100%|██████████| 44/44 [02:44<00:00,  3.73s/it]


Number of processed samples: 22
Number of training samples: 106
Number of test samples: 22
Number of batches in train_loader: 106
Number of batches in test_loader: 22
ImageBind model loaded and set to evaluation mode.
Model architecture defined.
Model, criterion, and optimizer initialized.


Training Epoch 1/2: 100%|██████████| 106/106 [01:24<00:00,  1.26it/s]


Epoch [1/2], Loss: 0.5750053165808934


Training Epoch 2/2: 100%|██████████| 106/106 [01:34<00:00,  1.12it/s]


Epoch [2/2], Loss: 0.437958725077926
Model training completed and saved!


Evaluating:   5%|▍         | 1/22 [00:00<00:11,  1.88it/s]

Post Date: 2024-11-21T00:00:00, Post ID: 655c80b000000000330093c0, Predicted Label: 2, Actual Label: 2


Evaluating:   9%|▉         | 2/22 [00:01<00:14,  1.36it/s]

Post Date: 2024-11-26T00:00:00, Post ID: 6562b917000000003300563e, Predicted Label: 2, Actual Label: 2


Evaluating:  14%|█▎        | 3/22 [00:02<00:15,  1.19it/s]

Post Date: 2024-10-31T00:00:00, Post ID: 653fec5f000000001f034cce, Predicted Label: 2, Actual Label: 2


Evaluating:  18%|█▊        | 4/22 [00:03<00:16,  1.07it/s]

Post Date: 2024-12-21T00:00:00, Post ID: 65840a5a000000000602b98c, Predicted Label: 2, Actual Label: 2


Evaluating:  23%|██▎       | 5/22 [00:03<00:11,  1.44it/s]

Post Date: 2024-10-31T00:00:00, Post ID: 6540ff7c0000000025020b89, Predicted Label: 2, Actual Label: 2


Evaluating:  27%|██▋       | 6/22 [00:04<00:11,  1.39it/s]

Post Date: 2024-10-16T00:00:00, Post ID: 652d2f0e000000001a01540a, Predicted Label: 2, Actual Label: 2


Evaluating:  32%|███▏      | 7/22 [00:04<00:08,  1.78it/s]

Post Date: 2024-10-28T00:00:00, Post ID: 6531fda1000000001f03f950, Predicted Label: 2, Actual Label: 2


Evaluating:  36%|███▋      | 8/22 [00:07<00:16,  1.17s/it]

Post Date: 2024-10-26T00:00:00, Post ID: 653a1632000000001f007a20, Predicted Label: 2, Actual Label: 2


Evaluating:  41%|████      | 9/22 [00:08<00:15,  1.22s/it]

Post Date: 2024-11-07T00:00:00, Post ID: 6549b6690000000025008fb1, Predicted Label: 2, Actual Label: 2


Evaluating:  45%|████▌     | 10/22 [00:09<00:13,  1.15s/it]

Post Date: 2024-10-28T00:00:00, Post ID: 653ce06c00000000250099fe, Predicted Label: 2, Actual Label: 2


Evaluating:  50%|█████     | 11/22 [00:10<00:10,  1.01it/s]

Post Date: 2024-11-14T00:00:00, Post ID: 6553966d000000000f02ba6c, Predicted Label: 2, Actual Label: 1


Evaluating:  55%|█████▍    | 12/22 [00:10<00:07,  1.28it/s]

Post Date: 2024-12-12T00:00:00, Post ID: 657844db000000000801fb87, Predicted Label: 2, Actual Label: 2


Evaluating:  59%|█████▉    | 13/22 [00:10<00:06,  1.47it/s]

Post Date: 2024-11-03T00:00:00, Post ID: 6544b4d6000000001f03fcb8, Predicted Label: 2, Actual Label: 2


Evaluating:  64%|██████▎   | 14/22 [00:12<00:07,  1.11it/s]

Post Date: 2024-11-08T00:00:00, Post ID: 654b85f7000000003103ed3f, Predicted Label: 2, Actual Label: 2


Evaluating:  68%|██████▊   | 15/22 [00:13<00:07,  1.08s/it]

Post Date: 2024-10-24T00:00:00, Post ID: 6537a2e8000000002202fe5f, Predicted Label: 2, Actual Label: 2


Evaluating:  73%|███████▎  | 16/22 [00:14<00:05,  1.07it/s]

Post Date: 2024-12-23T00:00:00, Post ID: 658648370000000009023b33, Predicted Label: 2, Actual Label: 2


Evaluating:  77%|███████▋  | 17/22 [00:15<00:04,  1.16it/s]

Post Date: 2024-11-06T00:00:00, Post ID: 6548b4e3000000001d015b26, Predicted Label: 2, Actual Label: 2


Evaluating:  82%|████████▏ | 18/22 [00:16<00:04,  1.13s/it]

Post Date: 2024-10-18T00:00:00, Post ID: 652f8c70000000001c016d99, Predicted Label: 2, Actual Label: 2


Evaluating:  86%|████████▋ | 19/22 [00:17<00:02,  1.08it/s]

Post Date: 2024-12-27T00:00:00, Post ID: 658c1ec2000000001000c781, Predicted Label: 2, Actual Label: 2


Evaluating:  91%|█████████ | 20/22 [00:17<00:01,  1.20it/s]

Post Date: 2024-10-17T00:00:00, Post ID: 6512dbff000000001e02221e, Predicted Label: 2, Actual Label: 2


Evaluating:  95%|█████████▌| 21/22 [00:18<00:00,  1.26it/s]

Post Date: 2024-10-27T00:00:00, Post ID: 653bb723000000002201d7fa, Predicted Label: 2, Actual Label: 2


Evaluating: 100%|██████████| 22/22 [00:18<00:00,  1.16it/s]

Post Date: 2024-12-22T00:00:00, Post ID: 65857227000000000901cdb5, Predicted Label: 2, Actual Label: 2
Accuracy: 95.45%
Test Accuracy: 95.45%



