In [1]:
import pandas as pd
import numpy as np
import os
import csv
from tqdm import tqdm
import torch

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
from one_peace.models import from_pretrained

In [2]:
al_model = torch.load("/workspace/jaeyoung/checkpoints/onepeace/esd_pretrained_al_1009/checkpoint_best.pt")
vl_model = torch.load("/workspace/jaeyoung/checkpoints/onepeace/mmtts_vl_1013/checkpoint_last.pt")

In [19]:
labels = ['happy', 'angry', 'sad', 'neutral', 'surprised', 'anxious', 'disgusted', 'contempted']
img_list = ['/workspace/jaeyoung/datasets/mm-tts-dataset/video_image_save/M003_angry_level_1_001/0.jpg', '/workspace/jaeyoung/datasets/mm-tts-dataset/video_image_save/M003_fear_level_1_007/0.jpg', '/workspace/jaeyoung/datasets/mm-tts-dataset/video_image_save/M003_happy_level_2_001/1.jpg']
src_img = vl_model.process_image(img_list)

with torch.no_grad():
    img_logits = vl_model.extract_image_features(src_img)
    predicted_label_ids = img_logits.argmax(1).cpu().tolist()
print(predicted_label_ids)

# Step 7: Map Predicted Label IDs to Label Names and Print Results
for img_path, label_id in zip(img_list, predicted_label_ids):
    print('label_id:',label_id)
    if 0 <= label_id < len(labels):
        label = labels[label_id]
    else:
        label = "Unknown"  # Handle unexpected label IDs
    print(f'Image: {img_path}, Predicted Label: {label}')

[1076, 67, 632]
label_id: 1076
Image: /workspace/jaeyoung/datasets/mm-tts-dataset/video_image_save/M003_angry_level_1_001/0.jpg, Predicted Label: Unknown
label_id: 67
Image: /workspace/jaeyoung/datasets/mm-tts-dataset/video_image_save/M003_fear_level_1_007/0.jpg, Predicted Label: Unknown
label_id: 632
Image: /workspace/jaeyoung/datasets/mm-tts-dataset/video_image_save/M003_happy_level_2_001/1.jpg, Predicted Label: Unknown


In [20]:
# After obtaining img_logits, print its shape
print(f"Logits shape: {img_logits.size()}")  # e.g., torch.Size([3, 1080])

Logits shape: torch.Size([3, 1536])


In [3]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image

# Define Labels
labels = ['happy', 'angry', 'sad', 'neutral', 'surprised', 'anxious', 'disgusted', 'contempted']
num_classes = len(labels)

# Define Transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # Add normalization based on the pretrained model's requirements
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# Load a Pretrained Model (e.g., ResNet50)
model = models.resnet50(pretrained=True)

# Replace the Final Layer
model.fc = nn.Linear(model.fc.in_features, num_classes)
model = model.cuda()

# Load Trained Weights if Available
# model.load_state_dict(torch.load('emotion_classifier.pth'))
model.eval()

# Prediction Function
def predict_emotion(img_path):
    image = Image.open(img_path).convert('RGB')
    image = transform(image).unsqueeze(0).cuda()  # Add batch dimension
    with torch.no_grad():
        logits = model(image)
        predicted_label_id = logits.argmax(1).item()
        predicted_label = labels[predicted_label_id] if 0 <= predicted_label_id < len(labels) else "Unknown"
    return predicted_label

# List of Image Paths
img_list = [
    '/workspace/jaeyoung/datasets/mm-tts-dataset/video_image_save/M003_angry_level_1_001/0.jpg',
    '/workspace/jaeyoung/datasets/mm-tts-dataset/video_image_save/M003_fear_level_1_007/0.jpg',
    '/workspace/jaeyoung/datasets/mm-tts-dataset/video_image_save/M003_happy_level_2_001/1.jpg'
]

# Predict and Print Results
for img_path in img_list:
    label = predict_emotion(img_path)
    print(f'Image: {img_path}, Predicted Label: {label}')


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 114MB/s]


Image: /workspace/jaeyoung/datasets/mm-tts-dataset/video_image_save/M003_angry_level_1_001/0.jpg, Predicted Label: anxious
Image: /workspace/jaeyoung/datasets/mm-tts-dataset/video_image_save/M003_fear_level_1_007/0.jpg, Predicted Label: anxious
Image: /workspace/jaeyoung/datasets/mm-tts-dataset/video_image_save/M003_happy_level_2_001/1.jpg, Predicted Label: anxious


In [15]:
al_model['extra_state']

{'metrics': OrderedDict([('default',
               [(10,
                 'loss',
                 'AverageMeter',
                 {'val': tensor(0.7070),
                  'sum': tensor(1024.),
                  'count': 7658,
                  'round': 3}),
                (10,
                 'logit_scale_exp',
                 'AverageMeter',
                 {'val': tensor(14.2500, requires_grad=True),
                  'sum': tensor(4096., requires_grad=True),
                  'count': 7658,
                  'round': 3}),
                (10,
                 'nsentences',
                 'AverageMeter',
                 {'val': 12, 'sum': 122500, 'count': 7658, 'round': 3}),
                (10,
                 'sample_size',
                 'AverageMeter',
                 {'val': 1, 'sum': 7658, 'count': 7658, 'round': 3}),
                (10,
                 'a2t_accuracy',
                 'AverageMeter',
                 {'val': tensor(58.3333),
                  

In [27]:
al_model['cfg']

{'_name': None,
 'common': {'_name': None,
  'no_progress_bar': False,
  'log_interval': 10,
  'log_format': 'simple',
  'log_file': None,
  'aim_repo': None,
  'aim_run_hash': None,
  'tensorboard_logdir': None,
  'wandb_project': None,
  'azureml_logging': False,
  'seed': 3407,
  'cpu': False,
  'tpu': False,
  'bf16': True,
  'memory_efficient_bf16': True,
  'fp16': False,
  'memory_efficient_fp16': False,
  'fp16_no_flatten_grads': False,
  'fp16_init_scale': 128,
  'fp16_scale_window': 256,
  'fp16_scale_tolerance': 0.0,
  'on_cpu_convert_precision': False,
  'min_loss_scale': 0.0001,
  'threshold_loss_scale': None,
  'amp': False,
  'amp_batch_retries': 2,
  'amp_init_scale': 128,
  'amp_scale_window': None,
  'user_dir': '/workspace/jaeyoung/StoryTeller/ONE-PEACE/one_peace/user_module',
  'empty_cache_freq': 0,
  'all_gather_list_size': 16384,
  'model_parallel_size': 1,
  'quantization_config_path': None,
  'profile': False,
  'reset_logging': False,
  'suppress_crashes': Fals

In [4]:
al_model['model']

OrderedDict([('logit_scale', tensor(2.6562)),
             ('encoder_wrapper.text_adapter.cls_embedding',
              tensor([[[-0.0157, -0.0098, -0.0216,  ..., -0.0126,  0.0059,  0.0281]]])),
             ('encoder_wrapper.text_adapter.rp_bucket',
              tensor([[513, 511, 511,  ..., 511, 511, 511],
                      [512, 255, 254,  ...,   0,   0,   0],
                      [512, 256, 255,  ...,   0,   0,   0],
                      ...,
                      [512, 510, 510,  ..., 255, 254, 253],
                      [512, 510, 510,  ..., 256, 255, 254],
                      [512, 510, 510,  ..., 257, 256, 255]])),
             ('encoder_wrapper.text_adapter.embed_tokens.weight',
              tensor([[-0.0011,  0.0056, -0.0155,  ...,  0.0194,  0.0011,  0.0186],
                      [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
                      [-0.0034, -0.0172, -0.0283,  ..., -0.0082, -0.0129,  0.0104],
                      ...,
              

In [4]:
vision_keys = [x for x in vl_model['model'].keys() if x not in al_model['model'].keys()]

In [13]:
vl_model['model'].keys()

odict_keys(['logit_scale', 'encoder_wrapper.text_adapter.cls_embedding', 'encoder_wrapper.text_adapter.rp_bucket', 'encoder_wrapper.text_adapter.embed_tokens.weight', 'encoder_wrapper.text_adapter.embed_positions.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.0.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.1.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.2.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.3.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.4.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.5.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.6.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.7.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.8.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.9.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.10.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.11.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.12.weig

In [14]:
vl_model_keys = vl_model['model'].keys()
for key in vl_model_keys:
    if 'text' in key:
        print(key)

encoder_wrapper.text_adapter.cls_embedding
encoder_wrapper.text_adapter.rp_bucket
encoder_wrapper.text_adapter.embed_tokens.weight
encoder_wrapper.text_adapter.embed_positions.weight
encoder_wrapper.text_adapter.rel_pos_table_list.0.weight
encoder_wrapper.text_adapter.rel_pos_table_list.1.weight
encoder_wrapper.text_adapter.rel_pos_table_list.2.weight
encoder_wrapper.text_adapter.rel_pos_table_list.3.weight
encoder_wrapper.text_adapter.rel_pos_table_list.4.weight
encoder_wrapper.text_adapter.rel_pos_table_list.5.weight
encoder_wrapper.text_adapter.rel_pos_table_list.6.weight
encoder_wrapper.text_adapter.rel_pos_table_list.7.weight
encoder_wrapper.text_adapter.rel_pos_table_list.8.weight
encoder_wrapper.text_adapter.rel_pos_table_list.9.weight
encoder_wrapper.text_adapter.rel_pos_table_list.10.weight
encoder_wrapper.text_adapter.rel_pos_table_list.11.weight
encoder_wrapper.text_adapter.rel_pos_table_list.12.weight
encoder_wrapper.text_adapter.rel_pos_table_list.13.weight
encoder_wrapper

In [5]:
vl_model['model']['encoder_wrapper.text_adapter.rel_pos_table_list.0.weight']

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.3457,  0.3301,  0.4141,  ...,  0.3555,  0.3750,  0.3418],
        [ 0.3945,  0.2852,  0.4355,  ...,  0.2871,  0.4531,  0.3340],
        [-0.4844, -0.7734, -0.6133,  ..., -0.8164, -0.2949, -0.6367]])

In [6]:
al_model['model']['encoder_wrapper.text_adapter.rel_pos_table_list.0.weight']

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.3457,  0.3301,  0.4141,  ...,  0.3555,  0.3750,  0.3418],
        [ 0.3945,  0.2852,  0.4355,  ...,  0.2871,  0.4531,  0.3340],
        [-0.4844, -0.7734, -0.6133,  ..., -0.8164, -0.2949, -0.6367]])

In [20]:
# audio-text model 에 image-text model feature 추가

for key in vision_keys:
	if key in vl_model['model']:
		if key not in al_model['model']:
			print(f'Update {key} in audio-text model')
			al_model['model'][key] = vl_model['model'][key]
		else:
			print(f'{key} already exists')

Update encoder_wrapper.image_adapter.cls_embedding in audio-text model
Update encoder_wrapper.image_adapter.pos_embed in audio-text model
Update encoder_wrapper.image_adapter.position_idx in audio-text model
Update encoder_wrapper.image_adapter.rp_bucket in audio-text model
Update encoder_wrapper.image_adapter.embed_images.0.weight in audio-text model
Update encoder_wrapper.image_adapter.embed_images.0.bias in audio-text model
Update encoder_wrapper.image_adapter.embed_images.1.layer_norm.weight in audio-text model
Update encoder_wrapper.image_adapter.embed_images.1.layer_norm.bias in audio-text model
Update encoder_wrapper.image_adapter.embed_images.3.weight in audio-text model
Update encoder_wrapper.image_adapter.embed_images.3.bias in audio-text model
Update encoder_wrapper.image_adapter.embed_images.4.layer_norm.weight in audio-text model
Update encoder_wrapper.image_adapter.embed_images.4.layer_norm.bias in audio-text model
Update encoder_wrapper.image_adapter.embed_images.6.weigh

In [None]:
# feature_extractor.py

import torch
from one_peace.models import from_pretrained  # Adjust based on actual library

class FeatureExtractor:
    def __init__(self, device='cuda'):
        self.device = device if torch.cuda.is_available() else 'cpu'
        
        # Load Audio Retrieval Model
        self.audio_model = from_pretrained(
            "ONE-PEACE_AudioTextRetrieval",  # Replace with actual identifier
            model_type="one_peace_audio_retrieval",
            device=self.device,
            dtype="float32"
        )
        self.audio_model.eval()  # Freeze weights
        
        # Load Vision Retrieval Model
        self.vision_model = from_pretrained(
            "ONE-PEACE_VisionTextRetrieval",  # Replace with actual identifier
            model_type="one_peace_vision_retrieval",
            device=self.device,
            dtype="float32"
        )
        self.vision_model.eval()  # Freeze weights
    
    def extract_audio_features(self, audio_paths):
        """
        Extract embeddings from audio files.
        
        Args:
            audio_paths (list of str): Paths to audio files.
        
        Returns:
            torch.Tensor: Audio embeddings of shape [batch_size, audio_feature_dim]
        """
        # Implement based on OnePeace's API
        # Example:
        audio_features = self.audio_model.encode_audio(audio_paths)
        return audio_features  # Adjust shape as needed
    
    def extract_vision_features(self, image_paths):
        """
        Extract embeddings from image files.
        
        Args:
            image_paths (list of str): Paths to image files.
        
        Returns:
            torch.Tensor: Image embeddings of shape [batch_size, vision_feature_dim]
        """
        # Implement based on OnePeace's API
        # Example:
        vision_features = self.vision_model.encode_image(image_paths)
        return vision_features  # Adjust shape as needed


In [None]:
# classifier.py

import torch.nn as nn

class EmotionClassifier(nn.Module):
    def __init__(self, audio_feature_dim, vision_feature_dim, num_classes):
        super(EmotionClassifier, self).__init__()
        self.fc1 = nn.Linear(audio_feature_dim + vision_feature_dim, 512)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, num_classes)
    
    def forward(self, audio_features, vision_features):
        """
        Forward pass through the classifier.
        
        Args:
            audio_features (torch.Tensor): Audio embeddings [batch_size, audio_feature_dim]
            vision_features (torch.Tensor): Vision embeddings [batch_size, vision_feature_dim]
        
        Returns:
            torch.Tensor: Logits [batch_size, num_classes]
        """
        combined = torch.cat((audio_features, vision_features), dim=1)
        x = self.fc1(combined)
        x = self.relu(x)
        x = self.dropout(x)
        logits = self.fc2(x)
        return logits


In [23]:
torch.save(al_model, '/workspace/jaeyoung/checkpoints/one_peace_fusion/al_vl_0923.pt')

In [24]:
vl_model['model'].keys()

odict_keys(['logit_scale', 'encoder_wrapper.text_adapter.cls_embedding', 'encoder_wrapper.text_adapter.rp_bucket', 'encoder_wrapper.text_adapter.embed_tokens.weight', 'encoder_wrapper.text_adapter.embed_positions.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.0.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.1.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.2.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.3.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.4.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.5.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.6.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.7.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.8.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.9.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.10.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.11.weight', 'encoder_wrapper.text_adapter.rel_pos_table_list.12.weig

In [15]:
for key in vision_keys:
	if key in vl_model['model']:
		if 'text' in key:
			print(f'Update {key} in vision-text model')
			vl_model['model'][key] = al_model['model'][key]
		else:
			print(f'{key} already exists')

encoder_wrapper.image_adapter.cls_embedding already exists
encoder_wrapper.image_adapter.pos_embed already exists
encoder_wrapper.image_adapter.position_idx already exists
encoder_wrapper.image_adapter.rp_bucket already exists
encoder_wrapper.image_adapter.embed_images.0.weight already exists
encoder_wrapper.image_adapter.embed_images.0.bias already exists
encoder_wrapper.image_adapter.embed_images.1.layer_norm.weight already exists
encoder_wrapper.image_adapter.embed_images.1.layer_norm.bias already exists
encoder_wrapper.image_adapter.embed_images.3.weight already exists
encoder_wrapper.image_adapter.embed_images.3.bias already exists
encoder_wrapper.image_adapter.embed_images.4.layer_norm.weight already exists
encoder_wrapper.image_adapter.embed_images.4.layer_norm.bias already exists
encoder_wrapper.image_adapter.embed_images.6.weight already exists
encoder_wrapper.image_adapter.embed_images.6.bias already exists
encoder_wrapper.image_adapter.rel_pos_table_list.0.weight already exi

In [17]:
torch.save(vl_model, '/workspace/jaeyoung/checkpoints/one_peace_fusion/vl_txt_update_1014.pt')