In [9]:
import pandas as pd
from dateutil import parser
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from imagebind import data
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType

# Ensure CUDA_LAUNCH_BLOCKING is set for better debugging
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Define file paths
data_dir = "/data1/dxw_data/llm/redbook_final/script_next/"
rawdata_path = os.path.join(data_dir, "rawdata_20%.csv")
after2monthdata_path = os.path.join(data_dir, "after2monthdata_20%_with_trend.csv")
image_dir = os.path.join(data_dir, "data_img_20%")

# Check if files and directories exist
assert os.path.exists(rawdata_path), f"File not found: {rawdata_path}"
assert os.path.exists(after2monthdata_path), f"File not found: {after2monthdata_path}"
assert os.path.exists(image_dir), f"Directory not found: {image_dir}"

print("All files and directories are verified to exist.")

# Read CSV files
rawdata = pd.read_csv(rawdata_path)
after2monthdata = pd.read_csv(after2monthdata_path)
print(f"Raw data samples: {len(rawdata)}")
print(f"After 2 months data samples: {len(after2monthdata)}")

# Convert date columns to standard format
def parse_date(date_str):
    try:
        return parser.parse(date_str)
    except ValueError:
        return None

rawdata['post_date'] = rawdata['post_date'].apply(parse_date)
rawdata = rawdata.dropna(subset=['post_date'])
print("Converted post_date to standard format and removed invalid dates.")

# Function to randomly sample a subset of the data
def get_subset_indices(data, fraction=0.01):
    data_size = len(data)
    indices = list(range(data_size))
    np.random.shuffle(indices)
    split = int(np.floor(fraction * data_size))
    return indices[:split]

# Random sampling of 1/100th of the data
subset_indices = get_subset_indices(rawdata)
rawdata = rawdata.iloc[subset_indices]
after2monthdata = after2monthdata[after2monthdata['post_id'].isin(rawdata['post_id'])]
print(f"Randomly sampled {len(rawdata)} raw data samples.")
print(f"Filtered after 2 months data to {len(after2monthdata)} samples.")

# Split data into training and testing sets
train_rawdata = rawdata[(rawdata['post_date'].dt.month >= 1) & (rawdata['post_date'].dt.month <= 9)]
test_rawdata = rawdata[(rawdata['post_date'].dt.month >= 10) & (rawdata['post_date'].dt.month >= 10)]
print(f"Training raw data samples: {len(train_rawdata)}")
print(f"Testing raw data samples: {len(test_rawdata)}")

train_after2monthdata = after2monthdata[after2monthdata['post_id'].isin(train_rawdata['post_id'])]
test_after2monthdata = after2monthdata[after2monthdata['post_id'].isin(test_rawdata['post_id'])]
print(f"Training after 2 month data samples: {len(train_after2monthdata)}")
print(f"Testing after 2 month data samples: {len(test_after2monthdata)}")

# Remove non-finite values in the 'trend' column
train_after2monthdata = train_after2monthdata.replace([np.inf, -np.inf], np.nan).dropna(subset=['trend'])
test_after2monthdata = test_after2monthdata.replace([np.inf, -np.inf], np.nan).dropna(subset=['trend'])
print(f"Training after 2 month data samples after cleaning: {len(train_after2monthdata)}")
print(f"Testing after 2 month data samples after cleaning: {len(test_after2monthdata)}")

# Ensure that the 'trend' column has correct integer labels
train_after2monthdata['trend'] = train_after2monthdata['trend'].astype(int)
test_after2monthdata['trend'] = test_after2monthdata['trend'].astype(int)

# Define a sliding window function
def sliding_window(data, window_size=60):
    num_windows = len(data) - window_size + 1
    return [data[i:i + window_size] for i in range(num_windows)]

# Get sliding windows for training and testing data
train_windows = sliding_window(train_rawdata)
test_windows = sliding_window(test_rawdata)
print(f"Number of training windows: {len(train_windows)}")
print(f"Number of testing windows: {len(test_windows)}")

# Custom Dataset Class
class MultimodalDataset(Dataset):
    def __init__(self, windows, after2monthdata, image_dir, transform=None, max_images=1):
        self.windows = windows
        self.after2monthdata = after2monthdata
        self.image_dir = image_dir
        self.transform = transform
        self.max_images = max_images
        self.data = self._prepare_data()
        print(f"Number of processed samples: {len(self.data)}")  # Debug information

    def _prepare_data(self):
        data = []
        for window in tqdm(self.windows, desc="Processing data"):
            window_data = []
            for _, row in window.iterrows():
                poster_id = row['poster_id']
                post_id = row['post_id']
                image_files = [f for f in os.listdir(self.image_dir) if f"{poster_id}_{post_id}" in f]
                images = []
                for image_file in image_files[:self.max_images]:
                    image_path = os.path.join(self.image_dir, image_file)
                    try:
                        image = Image.open(image_path).convert('RGB')
                    except Exception as e:
                        print(f"Error opening image {image_path}: {e}")
                        continue
                    if self.transform:
                        image = self.transform(image)
                    images.append(image)
                while len(images) < self.max_images:
                    images.append(torch.zeros((3, 224, 224)))
                if images:
                    summary = row['summary']
                    numerical_list = [float(row['post_comments']), float(row['post_like']), float(row['post_collect'])]
                    window_data.append((summary, images, numerical_list))
            if window_data:
                last_day_post_id = window.iloc[-1]['post_id']
                label_data = self.after2monthdata[self.after2monthdata['post_id'] == last_day_post_id]['trend']
                if not label_data.empty:
                    label = label_data.values[0]
                    one_hot_label = F.one_hot(torch.tensor(label + 1), num_classes=3)  # -1, 0, 1 -> 0, 1, 2
                    data.append((window_data, one_hot_label))
        return data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        window_data, label = self.data[idx]
        summaries, images, numerical_lists = zip(*window_data)
        images = [torch.stack(image_set).float() for image_set in images]
        return summaries, torch.stack(images), torch.tensor(numerical_lists).float(), label.float()

# Image Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
print("Image transformations defined.")

# Create Datasets
train_dataset = MultimodalDataset(train_windows, train_after2monthdata, image_dir, transform=transform)
test_dataset = MultimodalDataset(test_windows, test_after2monthdata, image_dir, transform=transform)
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")

# Create Data Loaders
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
print(f"Number of batches in train_loader: {len(train_loader)}")
print(f"Number of batches in test_loader: {len(test_loader)}")

All files and directories are verified to exist.
Raw data samples: 42698
After 2 months data samples: 43683
Converted post_date to standard format and removed invalid dates.
Randomly sampled 426 raw data samples.
Filtered after 2 months data to 694 samples.
Training raw data samples: 327
Testing raw data samples: 99
Training after 2 month data samples: 500
Testing after 2 month data samples: 194
Training after 2 month data samples after cleaning: 498
Testing after 2 month data samples after cleaning: 194
Number of training windows: 268
Number of testing windows: 40
Image transformations defined.


Processing data: 100%|██████████| 268/268 [15:08<00:00,  3.39s/it]


Number of processed samples: 121


Processing data: 100%|██████████| 40/40 [02:16<00:00,  3.40s/it]

Number of processed samples: 12
Number of training samples: 121
Number of test samples: 12
Number of batches in train_loader: 121
Number of batches in test_loader: 12





In [10]:
# Define model components
device = "cuda:4" if torch.cuda.is_available() else "cpu"

# Load ImageBind model
imagebind_model = imagebind_model.imagebind_huge(pretrained=True)
imagebind_model.eval()
imagebind_model.to(device)
print("ImageBind model loaded and set to evaluation mode.")

# Define CrossAttentionFusionLSTM model with MLP for numerical features
class CrossAttentionFusionLSTM(nn.Module):
    def __init__(self, text_embedding_dim, vision_embedding_dim, common_embedding_dim, num_heads, numerical_feature_dim):
        super(CrossAttentionFusionLSTM, self).__init__()
        self.text_linear = nn.Linear(text_embedding_dim, common_embedding_dim)
        self.vision_linear = nn.Linear(vision_embedding_dim, common_embedding_dim)
        self.numerical_mlp = nn.Sequential(
            nn.Linear(numerical_feature_dim, common_embedding_dim),
            nn.ReLU(),
            nn.Linear(common_embedding_dim, common_embedding_dim)
        )
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=common_embedding_dim, nhead=num_heads), num_layers=2)
        self.lstm = nn.LSTM(common_embedding_dim, common_embedding_dim, batch_first=True)
        self.fc = nn.Linear(common_embedding_dim, 3)

    def forward(self, text_embeddings, vision_embeddings, numerical_embeddings):
        text_embeddings = self.text_linear(text_embeddings)
        vision_embeddings = self.vision_linear(vision_embeddings)
        numerical_embeddings = self.numerical_mlp(numerical_embeddings)
        min_len = min(text_embeddings.size(1), vision_embeddings.size(1), numerical_embeddings.size(1))
        text_embeddings = text_embeddings[:, :min_len, :]
        vision_embeddings = vision_embeddings[:, :min_len, :]
        numerical_embeddings = numerical_embeddings[:, :min_len, :]
        multimodal_embeddings = torch.cat((text_embeddings, vision_embeddings, numerical_embeddings), dim=1)
        multimodal_embeddings = self.transformer_encoder(multimodal_embeddings)
        lstm_out, _ = self.lstm(multimodal_embeddings)
        lstm_out = lstm_out[:, -1, :]
        output = self.fc(lstm_out)
        return output
print("Model architecture defined.")

# Padding function
def pad_embeddings(embeddings, target_length):
    if embeddings.size(1) < target_length:
        padding = torch.zeros((embeddings.size(0), target_length - embeddings.size(1), embeddings.size(2)), device=embeddings.device)
        embeddings = torch.cat((embeddings, padding), dim=1)
    return embeddings

# Get embeddings function
def get_embeddings(text_list, image_tensors):
    inputs = {
        ModalityType.TEXT: data.load_and_transform_text(text_list, device),
        ModalityType.VISION: image_tensors.to(device)
    }
    
    with torch.no_grad():
        embeddings = imagebind_model(inputs)
    
    return embeddings

# Train model function
def train_model(model, train_loader, criterion, optimizer, num_epochs):
    model.train()
    log_file = open("log.txt", "w")
    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
            summaries, images, numerical_lists, labels = batch
            images, labels, numerical_lists = images.to(device), labels.to(device), numerical_lists.to(device)
            text_embeddings_list = []
            vision_embeddings_list = []
            for day_summaries, day_images in zip(summaries[0], images[0]):
                embeddings = get_embeddings(day_summaries, day_images)
                text_embeddings_list.append(embeddings[ModalityType.TEXT])
                vision_embeddings_list.append(embeddings[ModalityType.VISION])
            
            max_len = 60
            text_embeddings_list = text_embeddings_list[:max_len]
            vision_embeddings_list = vision_embeddings_list[:max_len]
            numerical_lists = numerical_lists[:, :max_len, :]

            text_embeddings = torch.stack(text_embeddings_list, dim=0)
            vision_embeddings = torch.stack(vision_embeddings_list, dim=0)

            target_length = max(text_embeddings.size(1), vision_embeddings.size(1), numerical_lists.size(1))
            text_embeddings = pad_embeddings(text_embeddings, target_length)
            vision_embeddings = pad_embeddings(vision_embeddings, target_length)
            numerical_lists = pad_embeddings(numerical_lists, target_length)

            optimizer.zero_grad()
            outputs = model(text_embeddings, vision_embeddings, numerical_lists)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        epoch_loss = running_loss / len(train_loader)
        log_file.write(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss}\n')
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss}')
    log_file.close()

# Initialize model, criterion, and optimizer
text_embedding_dim = 1024
vision_embedding_dim = 1024
common_embedding_dim = 768
num_heads = 8
numerical_feature_dim = 3

model = CrossAttentionFusionLSTM(text_embedding_dim, vision_embedding_dim, common_embedding_dim, num_heads, numerical_feature_dim).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
print("Model, criterion, and optimizer initialized.")

# Train the model
num_epochs = 10
train_model(model, train_loader, criterion, optimizer, num_epochs)

# Save the model
torch.save(model.state_dict(), os.path.join(data_dir, "multimodal_model.pth"))
print("Model training completed and saved!")

# Evaluation function
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            summaries, images, numerical_lists, labels = batch
            images, labels, numerical_lists = images.to(device), labels.to(device), numerical_lists.to(device)
            text_embeddings_list = []
            vision_embeddings_list = []
            for day_summaries, day_images in zip(summaries[0], images[0]):
                embeddings = get_embeddings(day_summaries, day_images)
                text_embeddings_list.append(embeddings[ModalityType.TEXT])
                vision_embeddings_list.append(embeddings[ModalityType.VISION])
            
            max_len = 60
            text_embeddings_list = text_embeddings_list[:max_len]
            vision_embeddings_list = vision_embeddings_list[:max_len]
            numerical_lists = numerical_lists[:, :max_len, :]

            text_embeddings = torch.stack(text_embeddings_list, dim=0)
            vision_embeddings = torch.stack(vision_embeddings_list, dim=0)

            target_length = max(text_embeddings.size(1), vision_embeddings.size(1), numerical_lists.size(1))
            text_embeddings = pad_embeddings(text_embeddings, target_length)
            vision_embeddings = pad_embeddings(vision_embeddings, target_length)
            numerical_lists = pad_embeddings(numerical_lists, target_length)

            outputs = model(text_embeddings, vision_embeddings, numerical_lists)
            _, predicted = torch.max(outputs.data, 1)
            _, labels_max = torch.max(labels.data, 1)
            total += labels.size(0)
            correct += (predicted == labels_max).sum().item()
    
    if total > 0:
        accuracy = correct / total
        print(f'Accuracy: {accuracy * 100:.2f}%')
    else:
        accuracy = 0
        print("No data to evaluate.")
    
    return accuracy

# Evaluate the model
test_accuracy = evaluate_model(model, test_loader)
print(f"Test Accuracy: {test_accuracy * 100:.2f}%")

ImageBind model loaded and set to evaluation mode.
Model architecture defined.
Model, criterion, and optimizer initialized.


Training Epoch 1/10: 100%|██████████| 121/121 [01:51<00:00,  1.08it/s]


Epoch [1/10], Loss: 0.5580742075608289


Training Epoch 2/10: 100%|██████████| 121/121 [02:11<00:00,  1.09s/it]


Epoch [2/10], Loss: 0.4985420394165457


Training Epoch 3/10: 100%|██████████| 121/121 [02:10<00:00,  1.08s/it]


Epoch [3/10], Loss: 0.48576662624793604


Training Epoch 4/10: 100%|██████████| 121/121 [02:08<00:00,  1.06s/it]


Epoch [4/10], Loss: 0.4791376998971316


Training Epoch 5/10: 100%|██████████| 121/121 [02:06<00:00,  1.04s/it]


Epoch [5/10], Loss: 0.46107262255977993


Training Epoch 6/10: 100%|██████████| 121/121 [02:04<00:00,  1.03s/it]


Epoch [6/10], Loss: 0.4958052186310784


Training Epoch 7/10: 100%|██████████| 121/121 [02:03<00:00,  1.02s/it]


Epoch [7/10], Loss: 0.4697736882843262


Training Epoch 8/10: 100%|██████████| 121/121 [02:02<00:00,  1.01s/it]


Epoch [8/10], Loss: 0.4795233825267839


Training Epoch 9/10: 100%|██████████| 121/121 [02:01<00:00,  1.00s/it]


Epoch [9/10], Loss: 0.47239513288844714


Training Epoch 10/10: 100%|██████████| 121/121 [02:01<00:00,  1.00s/it]


Epoch [10/10], Loss: 0.4655125682388455
Model training completed and saved!


Evaluating: 100%|██████████| 12/12 [00:13<00:00,  1.15s/it]

Accuracy: 100.00%
Test Accuracy: 100.00%



