In [None]:
# -----------------------更新后的思路---------------------- #

In [None]:
import os
import json
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet50
from gensim.models import KeyedVectors
from PIL import Image

# Image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Paths
original_image_path = '/data1/dxw_data/llm/redbook/1000/data2'
segmented_image_path = '/data1/dxw_data/llm/redbook/combine_final'
text_feature_path = '/data1/dxw_data/llm/redbook/time_sequence/captions_with_hotness_and_time.json'
sequence_data_path = '/data1/dxw_data/llm/redbook-refine/train_data.json'

# Load pretrained ResNet
resnet = resnet50(pretrained=False)
resnet.load_state_dict(torch.load('/data1/dxw_data/llm/resnet/resnet50-19c8e357.pth'))
resnet = nn.Sequential(*list(resnet.children())[:-1])  # Remove the classification layer
resnet.eval()

# Load pretrained Word2Vec model
word2vec_path = '/data1/dxw_data/llm/word2vec/GoogleNews-vectors-negative300.bin.gz' #! 可以修改为中文大模型large-chinese-word2vec
word2vec_model = KeyedVectors.load_word2vec_format(word2vec_path, binary=True)

# Function to extract image features
def extract_image_features(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)
    with torch.no_grad():
        features = resnet(image).squeeze().numpy()
    return features

# Function to extract text features
def extract_text_features(caption):
    words = caption.split()
    word_vectors = []
    for word in words:
        if word in word2vec_model:
            vector = word2vec_model[word]
            word_vectors.append(vector)
    if not word_vectors:
        return np.zeros(word2vec_model.vector_size)
    return np.mean(word_vectors, axis=0)

# Define the dataset class
class HotnessDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

# Define the LSTM model class
class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        lstm_out = lstm_out[:, -1, :]  # Take the output of the last time step
        out = self.fc(lstm_out)
        out = self.sigmoid(out)
        return out

# Load sequence data to get the 序号
with open(sequence_data_path, 'r') as f:
    sequence_data = json.load(f)

sequence_map = {int(item['序号']): item for item in sequence_data}

# Load JSON file with text features
with open(text_feature_path, 'r') as f:
    data = json.load(f)

# Prepare dataset
combined_features = []
labels = []
times = []

for item in data:
    seq_num = int(item['image'])  # Convert seq_num to integer
    # print(seq_num)
    text = item['caption']
    # print(text)
    print(sequence_map)
    segmented_image_path='/data1/dxw_data/llm/redbook/combine_final'
    if seq_num in sequence_map:
        image_path = os.path.join(original_image_path, f"{seq_num}.png")
        segmented_image_path = os.path.join(segmented_image_path, f"{seq_num}.png")

        # Debugging statements to check paths and existence
        print(f"Checking paths for seq_num: {seq_num}")
        print(f"Original image path: {image_path}")
        print(f"Segmented image path: {segmented_image_path}")

        if os.path.exists(image_path) and os.path.exists(segmented_image_path):
            print(f"Paths exist for seq_num: {seq_num}")

            # Extract features
            image_features = extract_image_features(image_path)
            mask_features = extract_image_features(segmented_image_path)
            text_features = extract_text_features(text)

            # Flatten image features to 1D if necessary
            image_features = image_features.flatten()
            mask_features = mask_features.flatten()

            print("image_features.shape", image_features.shape)
            print("mask_features.shape", mask_features.shape)
            print("text_features.shape", text_features.shape)

            # Combine features
            combined_feature = np.hstack((mask_features, image_features, text_features))
            combined_features.append(combined_feature)

            labels.append(sequence_map[seq_num]['本产品当前火爆'])
            times.append(sequence_map[seq_num]['天数'])
        else:
            print(f"Paths do not exist for seq_num: {seq_num}")

# Convert to numpy arrays
combined_features = np.array(combined_features)
labels = np.array(labels)
times = np.array(times)

print(f"Total combined features: {len(combined_features)}")
print(f"Total labels: {len(labels)}")
print(f"Total times: {len(times)}")




In [20]:

# Function to encode time information
def encode_time(times, max_time):
    times = np.array(times)
    sin_time = np.sin(2 * np.pi * times / max_time)
    cos_time = np.cos(2 * np.pi * times / max_time)
    return np.vstack((sin_time, cos_time)).T

encoded_times = encode_time(times, max_time=100)

# Split data into training and testing
train_indices = np.where(times <= 80)[0]
test_indices = np.where(times > 80)[0]

time_train = encoded_times[train_indices]
time_test = encoded_times[test_indices]

# Ensure combined_features is 2D before concatenating
if combined_features.ndim == 1:
    combined_features = combined_features.reshape(-1, 1)

# Concatenate time information with features
X_train = np.hstack((combined_features[train_indices], time_train))
X_test = np.hstack((combined_features[test_indices], time_test))
y_train = labels[train_indices]
y_test = labels[test_indices]

# Function to create sliding windows
def create_sliding_windows(X, y, window_size):
    features = []
    labels = []
    for i in range(len(X) - window_size):
        window = X[i:i + window_size]
        label = y[i + window_size]
        features.append(window)
        labels.append(label)
    return np.array(features), np.array(labels)

window_size = 5

# Create sliding windows for training and testing
X_train, y_train = create_sliding_windows(X_train, y_train, window_size)
X_test, y_test = create_sliding_windows(X_test, y_test, window_size)

# Convert data to PyTorch tensors
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32)

# Create datasets and dataloaders
train_dataset = HotnessDataset(X_train, y_train)
test_dataset = HotnessDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

input_dim = X_train.shape[2]
hidden_dim = 64
num_layers = 2
output_dim = 1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = LSTMModel(input_dim, hidden_dim, num_layers, output_dim).to(device)

# Loss and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training the model
num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.squeeze(), labels)
        loss.backward()
        optimizer.step()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# Evaluate the model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        predicted = (outputs.squeeze() > 0.5).float()
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    accuracy = correct / total
    print(f'Test Accuracy: {accuracy * 100:.2f}%')


Epoch [1/20], Loss: 0.3704
Epoch [2/20], Loss: 0.4480
Epoch [3/20], Loss: 0.7093
Epoch [4/20], Loss: 0.4657
Epoch [5/20], Loss: 0.5305
Epoch [6/20], Loss: 0.5723
Epoch [7/20], Loss: 0.7054
Epoch [8/20], Loss: 0.5726
Epoch [9/20], Loss: 0.6155
Epoch [10/20], Loss: 0.4105
Epoch [11/20], Loss: 0.7482
Epoch [12/20], Loss: 0.6135
Epoch [13/20], Loss: 0.5729
Epoch [14/20], Loss: 0.4557
Epoch [15/20], Loss: 0.4462
Epoch [16/20], Loss: 0.4221
Epoch [17/20], Loss: 0.6632
Epoch [18/20], Loss: 0.4899
Epoch [19/20], Loss: 0.6165
Epoch [20/20], Loss: 0.4476
Test Accuracy: 75.38%
