In [2]:
import pandas as pd
from dateutil import parser
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
import numpy as np
import torch.nn as nn
import torch.optim as optim
from imagebind import data
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
import torch.nn.functional as F
import json

# Ensure CUDA_LAUNCH_BLOCKING is set for better debugging
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Define file paths
data_dir = "/home/disk1/red_disk1/test"
rawdata_path = os.path.join(data_dir, "poster_test_fashion_nlpclean.csv")
after2monthdata_path = os.path.join(data_dir, "after2monthdata_20%_with_trend.csv")
image_dir = os.path.join(data_dir, "combined_seg_img")

# Check if files exist
assert os.path.exists(rawdata_path), f"File not found: {rawdata_path}"
assert os.path.exists(after2monthdata_path), f"File not found: {after2monthdata_path}"
assert os.path.exists(image_dir), f"Directory not found: {image_dir}"

# Read CSV files
rawdata = pd.read_csv(rawdata_path)
after2monthdata = pd.read_csv(after2monthdata_path)

# Convert date columns to standard format
def parse_date(date_str):
    try:
        return parser.parse(date_str)
    except ValueError:
        return None

rawdata['post_date'] = rawdata['post_date'].apply(parse_date)
rawdata = rawdata.dropna(subset=['post_date'])

# Only use 1/10th of the data
# def get_subset_indices(data, fraction=0.01):
def get_subset_indices(data, fraction=1):
    data_size = len(data)
    indices = list(range(data_size))
    np.random.shuffle(indices)
    split = int(np.floor(fraction * data_size))
    return indices[:split]

# Random sampling of 1/10th of the data
subset_indices = get_subset_indices(rawdata)
rawdata = rawdata.iloc[subset_indices]
after2monthdata = after2monthdata[after2monthdata['post_id'].isin(rawdata['post_id'])]

print(f"Total raw data samples: {len(rawdata)}")
print(f"Total after 2 month data samples: {len(after2monthdata)}")

# Split data into training and testing sets
train_rawdata = rawdata[(rawdata['post_date'].dt.month >= 1) & (rawdata['post_date'].dt.month <= 9)]
test_rawdata = rawdata[(rawdata['post_date'].dt.month >= 10) & (rawdata['post_date'].dt.month >= 10)]

train_after2monthdata = after2monthdata[after2monthdata['post_id'].isin(train_rawdata['post_id'])]
test_after2monthdata = after2monthdata[after2monthdata['post_id'].isin(test_rawdata['post_id'])]

print(f"Training raw data samples: {len(train_rawdata)}")
print(f"Testing raw data samples: {len(test_rawdata)}")
print(f"Training after 2 month data samples: {len(train_after2monthdata)}")
print(f"Testing after 2 month data samples: {len(test_after2monthdata)}")

# Remove non-finite values in the 'proportion' column
train_after2monthdata = train_after2monthdata.replace([np.inf, -np.inf], np.nan).dropna(subset=['proportion'])
test_after2monthdata = test_after2monthdata.replace([np.inf, -np.inf], np.nan).dropna(subset=['proportion'])

print(f"Training after 2 month data samples after cleaning: {len(train_after2monthdata)}")
print(f"Testing after 2 month data samples after cleaning: {len(test_after2monthdata)}")

# Define a sliding window function
def sliding_window(data, window_size=60):
    num_windows = len(data) - window_size + 1
    return [data[i:i + window_size] for i in range(num_windows)]

# Get sliding windows for training and testing data
train_windows = sliding_window(train_rawdata)
test_windows = sliding_window(test_rawdata)

print(f"Number of training windows: {len(train_windows)}")
print(f"Number of testing windows: {len(test_windows)}")

# Custom Dataset Class
class MultimodalDataset(Dataset):
    def __init__(self, windows, after2monthdata, image_dir, transform=None, max_images=1, save_test_data=False, test_data_path='test_regression_data.json'):
        self.windows = windows
        self.after2monthdata = after2monthdata
        self.image_dir = image_dir
        self.transform = transform
        self.max_images = max_images
        self.save_test_data = save_test_data
        self.test_data_path = test_data_path
        self.data = self._prepare_data()
        print(f"Number of processed samples: {len(self.data)}")  # Debug information

    def _prepare_data(self):
        data = []
        test_data = []
        for window in tqdm(self.windows, desc="Processing data"):
            window_data = []
            for _, row in window.iterrows():
                poster_id = row['poster_id']
                post_id = row['post_id']
                post_date = row['post_date']
                numerical_list = [float(row['post_comments']), float(row['post_like']), float(row['post_collect'])]
                image_files = [f for f in os.listdir(self.image_dir) if f"{poster_id}_{post_id}" in f]
                images = []
                for image_file in image_files[:self.max_images]:
                    image_path = os.path.join(self.image_dir, image_file)
                    try:
                        image = Image.open(image_path).convert('RGB')
                    except Exception as e:
                        print(f"Error opening image {image_path}: {e}")
                        continue
                    if self.transform:
                        image = self.transform(image)
                    images.append(image)
                while len(images) < self.max_images:
                    images.append(torch.zeros((3, 224, 224)))
                if images:
                    summary = row['summary']
                    window_data.append((summary, images, numerical_list))
            if window_data:
                last_day_post_id = window.iloc[-1]['post_id']
                label_data = self.after2monthdata[self.after2monthdata['post_id'] == last_day_post_id]['proportion']
                if not label_data.empty:
                    label = float(label_data.values[0])
                    data.append((window_data, label))
                    
                    # If saving test data, collect necessary information
                    if self.save_test_data:
                        test_data.append({
                            'post_date': post_date.isoformat(),
                            'post_id': str(last_day_post_id),  # Convert to string
                            'true_label': float(label),
                            'predicted_label': None  # Will be updated during evaluation
                        })
                        
        if self.save_test_data and test_data:
            with open(self.test_data_path, 'w') as f:
                json.dump(test_data, f, ensure_ascii=False, indent=4)
                
        return data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        window_data, label = self.data[idx]
        summaries, images, numerical_lists = zip(*window_data)
        images = [torch.stack(image_set).float() for image_set in images]
        numerical_lists = torch.tensor(numerical_lists, dtype=torch.float32)
        return summaries, torch.stack(images), numerical_lists, torch.tensor(label, dtype=torch.float32)

# Image Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Create Datasets
train_dataset = MultimodalDataset(train_windows, train_after2monthdata, image_dir, transform=transform)
test_dataset = MultimodalDataset(test_windows, test_after2monthdata, image_dir, transform=transform, save_test_data=True, test_data_path='test_regression_data.json')

print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")

# Create Data Loaders
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

print(f"Number of batches in train_loader: {len(train_loader)}")
print(f"Number of batches in test_loader: {len(test_loader)}")


# Model Definition
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Load imagebind model
imagebind_model = imagebind_model.imagebind_huge(pretrained=True)
imagebind_model.eval()
imagebind_model.to(device)
print("ImageBind model loaded and set to evaluation mode.")

class CrossAttentionFusionLSTM(nn.Module):
    def __init__(self, text_embedding_dim, vision_embedding_dim, common_embedding_dim, numerical_feature_dim, num_heads):
        super(CrossAttentionFusionLSTM, self).__init__()
        self.text_linear = nn.Linear(text_embedding_dim, common_embedding_dim)
        self.vision_linear = nn.Linear(vision_embedding_dim, common_embedding_dim)
        self.numerical_mlp = nn.Sequential(
            nn.Linear(numerical_feature_dim, common_embedding_dim),
            nn.ReLU(),
            nn.Linear(common_embedding_dim, common_embedding_dim)
        )
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=common_embedding_dim, nhead=num_heads), num_layers=2)
        self.lstm = nn.LSTM(common_embedding_dim, common_embedding_dim, batch_first=True)
        self.fc = nn.Linear(common_embedding_dim, 1)

    def forward(self, text_embeddings, vision_embeddings, numerical_features):
        text_embeddings = self.text_linear(text_embeddings)
        vision_embeddings = self.vision_linear(vision_embeddings)
        numerical_embeddings = self.numerical_mlp(numerical_features)
        min_len = min(text_embeddings.size(1), vision_embeddings.size(1), numerical_embeddings.size(1))
        text_embeddings = text_embeddings[:, :min_len, :]
        vision_embeddings = vision_embeddings[:, :min_len, :]
        numerical_embeddings = numerical_embeddings[:, :min_len, :]
        multimodal_embeddings = torch.cat((text_embeddings, vision_embeddings, numerical_embeddings), dim=1)
        multimodal_embeddings = self.transformer_encoder(multimodal_embeddings)
        lstm_out, _ = self.lstm(multimodal_embeddings)
        lstm_out = lstm_out[:, -1, :]
        output = self.fc(lstm_out)
        return output

def pad_embeddings(embeddings, target_length):
    if embeddings.size(1) < target_length:
        padding = torch.zeros((embeddings.size(0), target_length - embeddings.size(1), embeddings.size(2)), device=embeddings.device)
        embeddings = torch.cat((embeddings, padding), dim=1)
    return embeddings

def get_embeddings(text_list, image_tensors):
    inputs = {
        ModalityType.TEXT: data.load_and_transform_text(text_list, device),
        ModalityType.VISION: image_tensors.to(device)
    }
    
    with torch.no_grad():
        embeddings = imagebind_model(inputs)
    
    return embeddings

def train_model(model, train_loader, criterion, optimizer, num_epochs):
    model.train()
    log_file = open("log.txt", "w")
    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
            summaries, images, numerical_lists, labels = batch
            images, numerical_lists, labels = images.to(device), numerical_lists.to(device), labels.to(device)
            text_embeddings_list = []
            vision_embeddings_list = []
            for day_summaries, day_images in zip(summaries[0], images[0]):
                embeddings = get_embeddings(day_summaries, day_images)
                text_embeddings_list.append(embeddings[ModalityType.TEXT])
                vision_embeddings_list.append(embeddings[ModalityType.VISION])
            
            max_len = 60
            text_embeddings_list = text_embeddings_list[:max_len]
            vision_embeddings_list = vision_embeddings_list[:max_len]
            numerical_lists = numerical_lists[:, :max_len, :]

            text_embeddings = torch.stack(text_embeddings_list, dim=0)
            vision_embeddings = torch.stack(vision_embeddings_list, dim=0)

            target_length = max(text_embeddings.size(1), vision_embeddings.size(1), numerical_lists.size(1))
            text_embeddings = pad_embeddings(text_embeddings, target_length)
            vision_embeddings = pad_embeddings(vision_embeddings, target_length)
            numerical_lists = pad_embeddings(numerical_lists, target_length)

            optimizer.zero_grad()
            outputs = model(text_embeddings, vision_embeddings, numerical_lists)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        epoch_loss = running_loss / len(train_loader)
        log_file.write(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss}\n')
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss}')
    log_file.close()

# Initialize model, criterion, and optimizer
text_embedding_dim = 1024
vision_embedding_dim = 1024
common_embedding_dim = 768
numerical_feature_dim = 3
num_heads = 8

model = CrossAttentionFusionLSTM(text_embedding_dim, vision_embedding_dim, common_embedding_dim, numerical_feature_dim, num_heads).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
num_epochs = 10
train_model(model, train_loader, criterion, optimizer, num_epochs)

# Save the model
torch.save(model.state_dict(), os.path.join(data_dir, "multimodal_model.pth"))

print("Model training completed and saved!")

# Evaluation Code
def evaluate_model(model, test_loader, test_data_path):
    model.eval()
    total_loss = 0.0
    criterion = nn.MSELoss()
    with torch.no_grad():
        test_data = json.load(open(test_data_path))
        for batch_idx, batch in enumerate(tqdm(test_loader, desc="Evaluating")):
            summaries, images, numerical_lists, labels = batch
            images, numerical_lists, labels = images.to(device), numerical_lists.to(device), labels.to(device)
            text_embeddings_list = []
            vision_embeddings_list = []
            for day_summaries, day_images in zip(summaries[0], images[0]):
                embeddings = get_embeddings(day_summaries, day_images)
                text_embeddings_list.append(embeddings[ModalityType.TEXT])
                vision_embeddings_list.append(embeddings[ModalityType.VISION])
            
            max_len = 60
            text_embeddings_list = text_embeddings_list[:max_len]
            vision_embeddings_list = vision_embeddings_list[:max_len]
            numerical_lists = numerical_lists[:, :max_len, :]

            text_embeddings = torch.stack(text_embeddings_list, dim=0)
            vision_embeddings = torch.stack(vision_embeddings_list, dim=0)

            target_length = max(text_embeddings.size(1), vision_embeddings.size(1), numerical_lists.size(1))
            text_embeddings = pad_embeddings(text_embeddings, target_length)
            vision_embeddings = pad_embeddings(vision_embeddings, target_length)
            numerical_lists = pad_embeddings(numerical_lists, target_length)

            outputs = model(text_embeddings, vision_embeddings, numerical_lists)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            # Update the predicted labels in the test data and print the details
            test_data[batch_idx]['predicted_label'] = float(outputs.item())
            print(f"Post Date: {test_data[batch_idx]['post_date']}, Post ID: {test_data[batch_idx]['post_id']}, Predicted Label: {outputs.item()}, Actual Label: {labels.item()}")
        
        # Save the updated test data with predicted labels
        with open(test_data_path, 'w') as f:
            json.dump(test_data, f, ensure_ascii=False, indent=4)
    
    avg_loss = total_loss / len(test_loader)
    print(f'Evaluation Loss: {avg_loss}')
    
    return avg_loss

# Evaluate the model
test_data_path = 'test_regression_data.json'
test_loss = evaluate_model(model, test_loader, test_data_path)

print(f"Evaluation Loss: {test_loss}")


Total raw data samples: 265
Total after 2 month data samples: 1055
Training raw data samples: 184
Testing raw data samples: 81
Training after 2 month data samples: 724
Testing after 2 month data samples: 331
Training after 2 month data samples after cleaning: 723
Testing after 2 month data samples after cleaning: 319
Number of training windows: 125
Number of testing windows: 22


Processing data: 100%|██████████| 125/125 [08:44<00:00,  4.20s/it]


Number of processed samples: 124


Processing data: 100%|██████████| 22/22 [01:32<00:00,  4.21s/it]


Number of processed samples: 22
Number of training samples: 124
Number of test samples: 22
Number of batches in train_loader: 124
Number of batches in test_loader: 22
ImageBind model loaded and set to evaluation mode.


  return F.mse_loss(input, target, reduction=self.reduction)
Training Epoch 1/10: 100%|██████████| 124/124 [01:31<00:00,  1.35it/s]


Epoch [1/10], Loss: 0.0861009943067297


Training Epoch 2/10: 100%|██████████| 124/124 [01:31<00:00,  1.35it/s]


Epoch [2/10], Loss: 0.0002702249478584547


Training Epoch 3/10: 100%|██████████| 124/124 [01:30<00:00,  1.36it/s]


Epoch [3/10], Loss: 0.00030019032980191466


Training Epoch 4/10: 100%|██████████| 124/124 [01:30<00:00,  1.37it/s]


Epoch [4/10], Loss: 0.0003026837674551396


Training Epoch 5/10: 100%|██████████| 124/124 [01:31<00:00,  1.36it/s]


Epoch [5/10], Loss: 0.00027475367878644137


Training Epoch 6/10: 100%|██████████| 124/124 [01:31<00:00,  1.35it/s]


Epoch [6/10], Loss: 0.0002503497493110228


Training Epoch 7/10: 100%|██████████| 124/124 [01:31<00:00,  1.36it/s]


Epoch [7/10], Loss: 0.00023776095949290174


Training Epoch 8/10: 100%|██████████| 124/124 [01:31<00:00,  1.36it/s]


Epoch [8/10], Loss: 0.0001819767307164653


Training Epoch 9/10: 100%|██████████| 124/124 [01:31<00:00,  1.36it/s]


Epoch [9/10], Loss: 0.0001500019463167221


Training Epoch 10/10: 100%|██████████| 124/124 [01:31<00:00,  1.35it/s]


Epoch [10/10], Loss: 0.00016902124942944343
Model training completed and saved!


Evaluating:   5%|▍         | 1/22 [00:00<00:06,  3.14it/s]

Post Date: 2023-12-16T00:00:00, Post ID: 657d8b8a000000000902544a, Predicted Label: 0.004779127426445484, Actual Label: 0.01250363513827324


Evaluating:   9%|▉         | 2/22 [00:00<00:08,  2.43it/s]

Post Date: 2023-10-24T00:00:00, Post ID: 6537c097000000002202e685, Predicted Label: 0.0048146927729249, Actual Label: 0.020743444561958313


Evaluating:  14%|█▎        | 3/22 [00:01<00:07,  2.66it/s]

Post Date: 2023-11-15T00:00:00, Post ID: 655384070000000032037b4d, Predicted Label: 0.0048105763271451, Actual Label: 0.006252894643694162


Evaluating:  18%|█▊        | 4/22 [00:01<00:06,  2.81it/s]

Post Date: 2023-11-01T00:00:00, Post ID: 654226eb0000000025015bf5, Predicted Label: 0.004779224283993244, Actual Label: 0.008105604909360409


Evaluating:  23%|██▎       | 5/22 [00:02<00:10,  1.69it/s]

Post Date: 2023-11-30T00:00:00, Post ID: 65686e42000000003202e79e, Predicted Label: 0.004814726300537586, Actual Label: 0.021537749096751213


Evaluating:  27%|██▋       | 6/22 [00:03<00:09,  1.61it/s]

Post Date: 2023-11-17T00:00:00, Post ID: 65573fe80000000032030efe, Predicted Label: 0.004814700223505497, Actual Label: 0.017600741237401962


Evaluating:  32%|███▏      | 7/22 [00:04<00:11,  1.32it/s]

Post Date: 2023-11-18T00:00:00, Post ID: 65584c77000000000f029380, Predicted Label: 0.004814700223505497, Actual Label: 0.002547475742176175


Evaluating:  36%|███▋      | 8/22 [00:04<00:10,  1.32it/s]

Post Date: 2023-10-06T00:00:00, Post ID: 651fe113000000001a015650, Predicted Label: 0.004814726300537586, Actual Label: 0.03119814209640026


Evaluating:  41%|████      | 9/22 [00:05<00:10,  1.20it/s]

Post Date: 2023-11-09T00:00:00, Post ID: 654ce1fc000000001b035b34, Predicted Label: 0.004814726300537586, Actual Label: 0.01505326572805643


Evaluating:  45%|████▌     | 10/22 [00:06<00:09,  1.31it/s]

Post Date: 2023-12-08T00:00:00, Post ID: 65730356000000000901a544, Predicted Label: 0.004814726300537586, Actual Label: 0.011922070756554604


Evaluating:  50%|█████     | 11/22 [00:07<00:08,  1.32it/s]

Post Date: 2023-11-06T00:00:00, Post ID: 654738dc000000002201e312, Predicted Label: 0.004814700223505497, Actual Label: 0.009958313778042793


Evaluating:  55%|█████▍    | 12/22 [00:07<00:06,  1.60it/s]

Post Date: 2023-12-06T00:00:00, Post ID: 656f3c8d0000000016005376, Predicted Label: 0.004779898561537266, Actual Label: 0.01366676390171051


Evaluating:  59%|█████▉    | 13/22 [00:08<00:06,  1.44it/s]

Post Date: 2023-11-09T00:00:00, Post ID: 654ca5590000000032030ccb, Predicted Label: 0.004814726300537586, Actual Label: 0.00602130638435483


Evaluating:  64%|██████▎   | 14/22 [00:08<00:05,  1.56it/s]

Post Date: 2023-12-14T00:00:00, Post ID: 657aeb26000000000503ae06, Predicted Label: 0.004814722575247288, Actual Label: 0.01773771457374096


Evaluating:  68%|██████▊   | 15/22 [00:09<00:04,  1.59it/s]

Post Date: 2023-10-23T00:00:00, Post ID: 65366b7b000000002201fca5, Predicted Label: 0.004818600602447987, Actual Label: 0.008961168117821217


Evaluating:  73%|███████▎  | 16/22 [00:10<00:04,  1.47it/s]

Post Date: 2023-11-28T00:00:00, Post ID: 6565c89a000000003202f234, Predicted Label: 0.0048146964982151985, Actual Label: 0.0328855961561203


Evaluating:  77%|███████▋  | 17/22 [00:10<00:03,  1.66it/s]

Post Date: 2023-10-21T00:00:00, Post ID: 65336ddb00000000250200e5, Predicted Label: 0.0048144506290555, Actual Label: 0.005808164831250906


Evaluating:  82%|████████▏ | 18/22 [00:13<00:05,  1.33s/it]

Post Date: 2023-12-18T00:00:00, Post ID: 658029df000000001502f1e0, Predicted Label: 0.004814726300537586, Actual Label: 0.008723465725779533


Evaluating:  86%|████████▋ | 19/22 [00:14<00:03,  1.04s/it]

Post Date: 2023-11-13T00:00:00, Post ID: 65520916000000001100cc0c, Predicted Label: 0.004814510233700275, Actual Label: 0.006252894643694162


Evaluating:  91%|█████████ | 20/22 [00:14<00:01,  1.20it/s]

Post Date: 2023-12-20T00:00:00, Post ID: 6582e634000000003a00c480, Predicted Label: 0.004810654558241367, Actual Label: 0.009595812298357487


Evaluating:  95%|█████████▌| 21/22 [00:14<00:00,  1.48it/s]

Post Date: 2023-12-05T00:00:00, Post ID: 656ed2fc0000000009021f57, Predicted Label: 0.00477925781160593, Actual Label: 0.01686536706984043


Evaluating: 100%|██████████| 22/22 [00:15<00:00,  1.40it/s]

Post Date: 2023-11-27T00:00:00, Post ID: 6564120e0000000033002013, Predicted Label: 0.004814726300537586, Actual Label: 0.0023158870171755552
Evaluation Loss: 0.00013176486892131254
Evaluation Loss: 0.00013176486892131254



