In [2]:
import json
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPClassifier
from sklearn.pipeline import make_pipeline
from imagebind import data
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
import numpy as np

# Load JSON file
with open('/data1/dxw_data/llm/redbook/captions_labeled.json', 'r') as f:
    data_json = json.load(f)

# Set device
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Instantiate and load imagebind model
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)


ImageBindModel(
  (modality_preprocessors): ModuleDict(
    (vision): RGBDTPreprocessor(
      (cls_token): tensor((1, 1, 1280), requires_grad=True)
      
      (rgbt_stem): PatchEmbedGeneric(
        (proj): Sequential(
          (0): PadIm2Video()
          (1): Conv3d(3, 1280, kernel_size=(2, 14, 14), stride=(2, 14, 14), bias=False)
        )
      )
      (pos_embedding_helper): SpatioTemporalPosEmbeddingHelper(
        (pos_embed): tensor((1, 257, 1280), requires_grad=True)
        
      )
    )
    (text): TextPreprocessor(
      (pos_embed): tensor((1, 77, 1024), requires_grad=True)
      (mask): tensor((77, 77), requires_grad=False)
      
      (token_embedding): Embedding(49408, 1024)
    )
    (audio): AudioPreprocessor(
      (cls_token): tensor((1, 1, 768), requires_grad=True)
      
      (rgbt_stem): PatchEmbedGeneric(
        (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10), bias=False)
        (norm_layer): LayerNorm((768,), eps=1e-05, elementwise_affine=

In [3]:
# Preprocess transforms for images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Function to load and transform text data
def load_and_transform_text(text_list, device):
    return data.load_and_transform_text(text_list, device)

# Function to load and transform vision data
def load_and_transform_vision_data(image_paths, device):
    images = [transform(Image.open(image_path).convert('RGB')).unsqueeze(0) for image_path in image_paths]
    images = torch.cat(images).to(device)
    return images

# Function to extract embeddings using imagebind model
def extract_embeddings(texts, images):
    inputs = {
        ModalityType.TEXT: load_and_transform_text(texts, device),
        ModalityType.VISION: load_and_transform_vision_data(images, device),
    }
    with torch.no_grad():
        embeddings = model(inputs)
    return embeddings

# Prepare dataset
image_paths = []
captions = []
labels = []

for item in data_json:
    image_path = os.path.join('/data1/dxw_data/llm/redbook/data', item['image'])
    if os.path.exists(image_path):
        image_paths.append(image_path)
        captions.append(item['caption'])
        labels.append(item['label'])

# Extract embeddings
embeddings = extract_embeddings(captions, image_paths)
image_embeddings = embeddings[ModalityType.VISION].cpu().numpy()
text_embeddings = embeddings[ModalityType.TEXT].cpu().numpy()


In [5]:
# Combine image and text features
combined_features = np.hstack((image_embeddings, text_embeddings))

# Convert labels to numpy array
labels = np.array(labels)

# Split the dataset into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(combined_features, labels, test_size=0.2, random_state=42)

# Define and train the MLP model
mlp = make_pipeline(StandardScaler(), MLPClassifier(hidden_layer_sizes=(512, 256), max_iter=500, random_state=42))
mlp.fit(X_train, y_train)

# Evaluate the model
train_accuracy = mlp.score(X_train, y_train)
test_accuracy = mlp.score(X_test, y_test)

print(f'Train Accuracy: {train_accuracy:.4f}')
print(f'Test Accuracy: {test_accuracy:.4f}')

Train Accuracy: 1.0000
Test Accuracy: 0.5745
