In [1]:
import json
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.models import resnet50
from PIL import Image
from gensim.models import KeyedVectors
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPClassifier
from sklearn.pipeline import make_pipeline
import numpy as np

# Load JSON file
with open('/data1/dxw_data/llm/redbook/captions_labeled.json', 'r') as f:
    data = json.load(f)

# Load pretrained ResNet
resnet = resnet50(pretrained=False)
resnet.load_state_dict(torch.load('/data1/dxw_data/llm/resnet/resnet50-19c8e357.pth'))
resnet = nn.Sequential(*list(resnet.children())[:-1])  # Remove the classification layer
resnet.eval()

word2vec_path = '/data1/dxw_data/llm/word2vec/GoogleNews-vectors-negative300.bin.gz'
word2vec_model  = KeyedVectors.load_word2vec_format(word2vec_path, binary=True)




In [2]:
# Preprocess transforms for images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Function to extract image features
def extract_image_features(image_path):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0)
    with torch.no_grad():
        features = resnet(image).squeeze().numpy()
    return features

# Function to extract text features
def extract_text_features(caption):
    words = caption.split()
    word_vectors = []
    for word in words:
        if word in word2vec_model:
            vector = word2vec_model[word]
            word_vectors.append(vector)
    if not word_vectors:
        return np.zeros(word2vec_model.vector_size)
    return np.mean(word_vectors, axis=0)

# Prepare dataset
image_features = []
text_features = []
labels = []

for item in data:
    image_path = os.path.join('/data1/dxw_data/llm/redbook/data', item['image'])
    if os.path.exists(image_path):
        img_feat = extract_image_features(image_path)
        txt_feat = extract_text_features(item['caption'])
        image_features.append(img_feat)
        text_features.append(txt_feat)
        labels.append(item['label'])

# Convert to numpy arrays
image_features = np.array(image_features)
text_features = np.array(text_features)
labels = np.array(labels)

# Combine image and text features
combined_features = np.hstack((image_features, text_features))

# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(combined_features, labels, test_size=0.2, random_state=42)
print("X_train: ",X_train)
print("X_test: ",X_test)
print("y_train: ",y_train)
print("y_test: ",y_test)


X_train:  [[ 0.1407708   0.2051944   2.1315029  ... -0.03464926  0.05758793
   0.04998261]
 [ 0.03572091  0.73791885  2.3801253  ... -0.01762899  0.04897704
   0.04665799]
 [ 0.01663122  0.6184664   2.2751641  ... -0.01778995  0.06903729
   0.05579346]
 ...
 [ 0.09745294  0.8403073   1.8254905  ... -0.04706735  0.03224621
   0.046666  ]
 [ 0.11417428  0.752815    0.92227274 ... -0.00628855  0.05029674
   0.010082  ]
 [ 0.29126734  0.67746425  0.8735173  ... -0.0061261   0.056993
  -0.00525035]]
X_test:  [[ 1.3038322e-01  3.7690082e-01  2.4440727e+00 ... -1.2686593e-02
   7.4628919e-02  4.0783111e-02]
 [ 1.2907702e-01  7.8734642e-01  1.2004075e+00 ... -2.1322507e-02
   7.8972891e-02 -4.7176466e-03]
 [ 2.1855718e-02  1.0589955e+00  4.7144836e-01 ... -3.2391463e-04
   4.0346749e-02  2.8854102e-02]
 ...
 [ 9.3258291e-02  1.0867412e+00  1.0544221e+00 ... -2.8580261e-02
   6.6824973e-02  4.8487276e-02]
 [ 2.0029029e-01  2.0941079e-01  2.2312903e+00 ...  2.4784633e-03
   4.5201983e-02 -6.7858

In [3]:

# Define and train the MLP model
mlp = make_pipeline(StandardScaler(), MLPClassifier(hidden_layer_sizes=(512, 256), max_iter=500, random_state=42))
mlp.fit(X_train, y_train)

# Evaluate the model
train_accuracy = mlp.score(X_train, y_train)
test_accuracy = mlp.score(X_test, y_test)

print(f'Train Accuracy: {train_accuracy:.4f}')
print(f'Test Accuracy: {test_accuracy:.4f}')

Train Accuracy: 1.0000
Test Accuracy: 0.6809
