## Step 1: Import Modules

In [1]:
import warnings
from transformers.utils import logging as hf_transformers_logging
import logging
import torch
import json
import os

# 불필요한 warning 메세지 출력하지 않기 위한 작업
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.filterwarnings("ignore", message=".*torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument.*", category=UserWarning)
hf_transformers_logging.set_verbosity_error()
logging.getLogger("huggingface_hub.file_download").setLevel(logging.ERROR)

from PIL import Image
import matplotlib.pyplot as plt
import IPython.display as ipd
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPImageProcessor
from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch
import librosa
import numpy as np
import torchvision.transforms as T # torchvision 임포트 추가
from tqdm import tqdm

from pretrained_model.CLAP import load_clap_model, get_clap_intermediate_patch_embeddings
from model.projector import Projector

## Step 2: Define Model Configurations

In [2]:
MODEL_CONFIGS = {
    "emu2": {
        "name_or_path": "BAAI/Emu2",
        "snapshot_path": "/home/jongmin/.cache/huggingface/hub/models--BAAI--Emu2/snapshots/fa835ec101e52da5e081695107e1ddd3c7c4d88a"
    },
    "emu2chat": {
        "name_or_path": "BAAI/Emu2-Chat",
        "snapshot_path": "/home/jongmin/.cache/huggingface/hub/models--BAAI--Emu2-Chat/snapshots/20ea30b04f8fee599cf97535e655c200df728501"
    }
}

## Step 3: Define Function to Load Models

In [3]:
def load_models(model_type, projector_checkpoint_path, clap_checkpoint_path):
    print(f"'{model_type}' 모델 로딩 시작...")
    
    if model_type not in MODEL_CONFIGS:
        raise ValueError(f"지원되지 않는 모델 타입입니다: {model_type}. 사용 가능: {list(MODEL_CONFIGS.keys())}")

    selected_model_config = MODEL_CONFIGS[model_type]
    model_name_or_path = selected_model_config["name_or_path"]
    snapshot_path = selected_model_config["snapshot_path"]

    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    print(f"Tokenizer ({model_name_or_path}) 로드 완료.")

    with init_empty_weights():
        emu_model = AutoModelForCausalLM.from_pretrained(
            model_name_or_path,
            torch_dtype=torch.bfloat16,
            low_cpu_mem_usage=True,
            trust_remote_code=True
        )
    print(f"Emu 모델 ({model_name_or_path}) 구조 로드 완료.")

    if torch.cuda.is_available():
        #num_gpus = torch.cuda.device_count()
        #max_memory_setting = {i: '20GiB' for i in range(num_gpus)}
        max_memory_setting = {0: '20GiB', 1: '20GiB', 2: '20GiB', 3: '20GiB'} # GPU 메모리 설정 (0,1,2,3번 GPU에 20GiB 할당)
        #max_memory_setting = {4: '20GiB', 5: '20GiB', 6: '20GiB', 7: '20GiB'} # GPU 메모리 설정 (4,5,6,7번 GPU에 20GiB 할당)
        device_map = infer_auto_device_map(emu_model, max_memory=max_memory_setting, no_split_module_classes=['Block','LlamaDecoderLayer'])
        # lm_head 디바이스 설정 (이전 코드에서 개선된 로직 사용 또는 단순화)
        main_device_idx_for_lm_head = 0
        if device_map:
            try: 
                main_device_idx_for_lm_head = device_map.get(max(device_map, key=lambda k: sum(1 for v in device_map.values() if v == device_map[k])), 0) 
                if not isinstance(main_device_idx_for_lm_head, int): main_device_idx_for_lm_head = 0 
            except: 
                 main_device_idx_for_lm_head = 0
        
        lm_head_assigned = False
        potential_keys = ["model.decoder.lm.lm_head", "decoder.lm.lm_head", "lm_head"]
        for key_part in potential_keys:
            actual_key_to_assign = None
            for existing_key in device_map.keys():
                if key_part in existing_key and existing_key.endswith("lm_head"):
                    actual_key_to_assign = existing_key
                    break
            if actual_key_to_assign:
                device_map[actual_key_to_assign] = main_device_idx_for_lm_head
                lm_head_assigned = True
                break
        if not lm_head_assigned:
             print(f"[경고] Device map에서 lm_head를 자동으로 찾지 못했습니다. 수동 조정이 필요할 수 있습니다.")

        target_device = f'cuda:{main_device_idx_for_lm_head}'
        aux_device = target_device
    else:
        device_map = {"": "cpu"}
        target_device = 'cpu'
        aux_device = 'cpu'

    from tqdm import tqdm
    with tqdm(total=1, desc=f"Emu 모델 ({model_name_or_path}) 가중치 로딩 중", unit="op") as pbar:
        emu_model_loaded = load_checkpoint_and_dispatch(
            emu_model,
            snapshot_path,
            device_map=device_map 
        ).eval()
        pbar.update(1)
    print(f'Emu 모델 ({model_name_or_path}) 로드 완료!')

    clap_model = load_clap_model(checkpoint_path=clap_checkpoint_path, device=aux_device)
    print(f'CLAP 모델 로드 완료 (device: {aux_device})')

    user_projector_model = Projector(
        input_patch_dim=512,
        num_input_patches=256,
        output_seq_len=256,
        output_embed_dim=1792,
        projector_transformer_hidden_dim=768,
        projector_num_transformer_layers=8,
        projector_num_heads=8,
        projector_dropout=0.1
    ).to(aux_device)
    #user_projector_model = Projector(input_dim=512, output_seq_len=256, output_embed_dim=1792, hidden_dim=2048, num_layers=8, num_heads=8, dropout=0.1).to(aux_device)
    user_projector_model.load_state_dict(torch.load(projector_checkpoint_path, map_location=aux_device))
    user_projector_model.eval()
    print(f'사용자 Projector 모델 로드 완료 (가중치: {projector_checkpoint_path}, device: {aux_device})')
    
    print("모든 모델 로딩 완료.")
    return tokenizer, emu_model_loaded, clap_model, user_projector_model, target_device, aux_device, emu_model_loaded.config.vision_config['image_size']

## Step 4: Preprocessing Each Modality Data

### Step 4.1: Preprocessing Audio for CLAP

In [4]:
def preprocess_audio_for_clap(audio_path, target_sr=48000, target_duration_sec=10):
    try:
        waveform, sr = librosa.load(audio_path, sr=None)
        if sr != target_sr:
            waveform = librosa.resample(waveform, orig_sr=sr, target_sr=target_sr)
        
        target_length = target_sr * target_duration_sec
        current_length = waveform.shape[0]

        if current_length < target_length:
            padding = target_length - current_length
            waveform = np.pad(waveform, (0, padding), 'constant')
        elif current_length > target_length:
            waveform = waveform[:target_length]
            
        return {'waveform': torch.tensor(waveform, dtype=torch.float32).unsqueeze(0), 'sample_rate': target_sr}
    except Exception as e:
        print(f"오디오 파일 처리 오류 {audio_path}: {e}")
        return None

### Step 4.2: Preprocessing Image for CLIP

In [5]:
# Emu의 이미지 정규화를 위한 상수 (modeling_emu.py의 prepare_image_input 참조)
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)

def preprocess_image_for_emu(image_pil: Image.Image, image_size: int):
    """PIL 이미지를 Emu의 visual encoder 입력 형식에 맞게 전처리합니다.
    modeling_emu.py의 prepare_image_input 메소드와 유사하게 동작합니다.

    Args:
        image_pil: 전처리할 PIL 이미지 객체입니다.
        image_size: Emu 모델의 vision_config에서 가져온 목표 이미지 크기 (정사각형 가정)입니다.

    Returns:
        전처리된 이미지 텐서 (배치 차원 포함)입니다.
    """
    transform = T.Compose(
        [
            T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(OPENAI_DATASET_MEAN, OPENAI_DATASET_STD),
        ]
    )
    # unsqueeze(0)을 통해 배치 차원 [1, C, H, W]를 추가합니다.
    return transform(image_pil).unsqueeze(0)

## Step 5: Define Function to Run Inference

### Step 5.1: Define Function to Return Modality Embedding

In [6]:
def get_intermediate_modality_embedding(input_type, file_path, emu_model, clap_model, user_projector_model, aux_device, emu_image_size):
    if not os.path.exists(file_path):
        print(f"입력 파일 {file_path}를 찾을 수 없습니다.")
        return None

    intermediate_embedding = None

    if input_type == 'image':
        try:
            image_pil = Image.open(file_path).convert('RGB')
            pixel_values = preprocess_image_for_emu(image_pil, emu_image_size)
            target_device = aux_device
            
            if callable(emu_model.dtype):
                target_dtype = emu_model.dtype()
            else:
                target_dtype = emu_model.dtype

            pixel_values = pixel_values.to(device=target_device, dtype=target_dtype)

            if hasattr(emu_model, 'model') and hasattr(emu_model.model, 'encode_image') and callable(getattr(emu_model.model, 'encode_image')):
                intermediate_embedding = emu_model.model.encode_image(pixel_values)
            else:
                error_msg = "Emu 모델 또는 그 내부 'model' 객체에서 'encode_image' 메소드를 찾을 수 없습니다."
                # (기존 오류 메시지 로직과 동일)
                print(f"오류: {error_msg}")
                raise RuntimeError(error_msg)

        except Exception as e:
            import traceback
            print(f"이미지 파일 {file_path} 처리 중 오류 (get_intermediate_modality_embedding): {e}")
            traceback.print_exc()
            return None

    elif input_type == 'audio':
        try:
            processed_audio_data = preprocess_audio_for_clap(file_path)
            if processed_audio_data is None:
                return None
            
            waveform = processed_audio_data['waveform'].to(aux_device) # CLAP 모델 디바이스로

            with torch.no_grad():
                audio_embeddings_clap = get_clap_intermediate_patch_embeddings(
                    clap_model_instance=clap_model,
                    audio_waveforms=waveform,
                    device=aux_device
                )
                #audio_embeddings_clap = clap_model.get_audio_embedding_from_data(waveform, use_tensor=True)
                if audio_embeddings_clap is None or audio_embeddings_clap.shape[0] == 0:
                    print("CLAP 모델에서 오디오 임베딩을 추출하지 못했습니다.")
                    return None

                projector_device = next(user_projector_model.parameters()).device
                projector_dtype = next(user_projector_model.parameters()).dtype

                intermediate_embedding = user_projector_model(
                    audio_embeddings_clap.to(device=projector_device, dtype=projector_dtype)
                )

        except Exception as e:
            import traceback
            print(f"오디오 파일 {file_path} 처리 중 오류 (get_intermediate_modality_embedding): {e}")
            traceback.print_exc()
            return None
    else:
        print(f"지원되지 않는 입력 타입: {input_type}")
        return None

    return intermediate_embedding

In [7]:
def run_inference(input_type, file_path, query, max_new_tokens, length_penalty, tokenizer, emu_model, clap_model, projector_model, emu_target_device, aux_device, emu_image_size):
    if not os.path.exists(file_path):
        print(f"입력 파일 {file_path}를 찾을 수 없습니다.")
        return

    query_text = query
    outputs = None 

    with torch.no_grad(): # 모든 연산을 no_grad 컨텍스트에서 수행
        try:
            intermediate_modality_embed = get_intermediate_modality_embedding(
                input_type,
                file_path,
                emu_model,
                clap_model,
                projector_model,
                aux_device,
                emu_image_size
            )

            if intermediate_modality_embed is None:
                        print("중간 모달리티 임베딩 추출 실패.")
                        return

            modality_features_projected = emu_model.project_up(
                intermediate_modality_embed.to(device=emu_model.project_up.weight.device, dtype=emu_model.dtype())
            )
            
            '''
            original_squeezed = intermediate_modality_embed.squeeze(0)  # shape: [256, 1792]
            projected_squeezed = modality_features_projected.squeeze(0) # shape: [256, 6656]
            
            time_index = 51
            original_vector = original_squeezed[time_index]
            projected_vector = projected_squeezed[time_index]

            print(f"[Time Index {time_index}] Projected vector L2 norm: {projected_vector.norm().item():.4f}")
            print(f"[Time Index {time_index}] Original vector L2 norm: {original_vector.norm().item():.4f}")
            
            feature_index = 179
            original_time_vector = original_squeezed[:, feature_index]
            projected_time_vector = projected_squeezed[:, feature_index]

            print(f"[Feature Index {feature_index}] Projected time-sequence vector L2 norm: {projected_time_vector.norm().item():.4f}")
            print(f"[Feature Index {feature_index}] Original time-sequence vector L2 norm: {original_time_vector.norm().item():.4f}")
            '''
            
            if not query_text:
                query_text = '[<IMG_PLH>]Describe the image in details:' if input_type == 'image' else '[<IMG_PLH>]Describe the audio in details:'
            
            inputs_for_text = emu_model.build_input_ids(
                text=[query_text],
                tokenizer=tokenizer,
                image=None 
            )
            
            input_ids = inputs_for_text["input_ids"].to(emu_target_device)
            attention_mask = inputs_for_text["attention_mask"].to(emu_target_device)

            text_embedding_layer = emu_model.model.decoder.lm.model.embed_tokens
            text_embeds = text_embedding_layer(input_ids.to(text_embedding_layer.weight.device))
            
            DEFAULT_IMAGE_TOKEN = "<image>" 
            image_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_TOKEN)
            n_query = emu_model.n_query
            batch_idx = 0
            actual_image_token_indices = (input_ids[batch_idx] == image_token_id).nonzero(as_tuple=True)[0]

            if len(actual_image_token_indices) > 0:
                single_modality_feature_sequence = modality_features_projected.squeeze(0)
                
                num_tokens_to_replace = min(len(actual_image_token_indices), n_query, single_modality_feature_sequence.shape[0])
                if len(actual_image_token_indices) != n_query or single_modality_feature_sequence.shape[0] != n_query:
                    print(f"경고 ({input_type}): <image> 토큰 수({len(actual_image_token_indices)}) 또는 피쳐 시퀀스 길이({single_modality_feature_sequence.shape[0]})가 n_query({n_query})와 다릅니다. {num_tokens_to_replace}개 주입.")
                
                for i in range(num_tokens_to_replace):
                    token_idx_in_sequence = actual_image_token_indices[i]
                    text_embeds[batch_idx, token_idx_in_sequence, :] = single_modality_feature_sequence[i, :].to(text_embeds.device)
            else:
                raise RuntimeError(f"오류 ({input_type}): <image> 토큰을 찾지 못하여 특징을 주입할 수 없습니다.")
            
            generation_params = {
                "inputs_embeds": text_embeds,
                "attention_mask": attention_mask.to(text_embeds.device), # attention_mask도 text_embeds와 같은 장치로
                "max_new_tokens": max_new_tokens,
                "length_penalty": length_penalty,
                "num_beams": 5,
                "min_length": 1,
                "do_sample": False,
                "penalty_alpha": None,
                "top_p": None,
                "top_k": None,
                "temperature": None,
                "repetition_penalty": 1.0,
            }
            outputs = emu_model.model.decoder.lm.generate(**generation_params)

        except Exception as e:
            import traceback
            print(f"{input_type} 처리 또는 생성 중 오류 발생: {e}")
            traceback.print_exc()
            outputs = None

    if outputs is None:
        print("텍스트 생성에 실패했습니다 (outputs가 None입니다).")
        return

    output_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    
    return output_text[0]

## Step 6: Execution

In [8]:
MODEL_TYPE_TO_LOAD = 'emu2chat'
PROJECTOR_CHECKPOINT_PATH_FOR_LOAD = '/home/jongmin/reference/Emu/Emu2/checkpoint/exp_projector_transformer/projector_epoch_028.pt'
CLAP_CHECKPOINT_PATH_FOR_LOAD = './music_speech_audioset_epoch_15_esc_89.98.pt'

tokenizer_instance, emu_core_instance, clap_instance, user_projector_instance, llm_device_instance, aux_device_instance, emu_image_size_instance = load_models(
    MODEL_TYPE_TO_LOAD,
    PROJECTOR_CHECKPOINT_PATH_FOR_LOAD,
    CLAP_CHECKPOINT_PATH_FOR_LOAD
)

'emu2chat' 모델 로딩 시작...
Tokenizer (BAAI/Emu2-Chat) 로드 완료.


Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

Emu 모델 (BAAI/Emu2-Chat) 구조 로드 완료.


Emu 모델 (BAAI/Emu2-Chat) 가중치 로딩 중: 100%|██████████| 1/1 [02:04<00:00, 124.61s/op]


Emu 모델 (BAAI/Emu2-Chat) 로드 완료!
Load the specified checkpoint ./music_speech_audioset_epoch_15_esc_89.98.pt from users.
Load Checkpoint...
CLAP audio model loaded from ./music_speech_audioset_epoch_15_esc_89.98.pt and moved to cuda:1.
CLAP 모델 로드 완료 (device: cuda:1)




사용자 Projector 모델 로드 완료 (가중치: /home/jongmin/reference/Emu/Emu2/checkpoint/exp_projector_transformer/projector_epoch_028.pt, device: cuda:1)
모든 모델 로딩 완료.


In [9]:
INPUT_TYPE_INFERENCE = 'audio'  # 'image' 또는 'audio'
FILE_PATH_INFERENCE = '/mnt/lynx1/datasets/places205/wavs/238/utterance_60731.wav'# '/home/jongmin/reference/Emu/Emu2/examples/aud_example.wav' # '/mnt/lynx1/datasets/places205/wavs/238/utterance_60731.wav'
#FILE_PATH_INFERENCE = '/home/jongmin/reference/Emu/Emu2/examples/img_example.jpg' # '/home/jongmin/reference/Emu/Emu2/examples/aud_example.wav'
QUERY_TEXT_INFERENCE = '[<IMG_PLH>] Explain specifically what you can see in the image:'

MAX_NEW_TOKENS_INFERENCE = 64
LENGTH_PENALTY_INFERENCE = -1.0

print(f"추론 실행 시작: 입력 타입='{INPUT_TYPE_INFERENCE}', 파일='{FILE_PATH_INFERENCE}'")
# 이전에 정의된 run_inference 함수와 로드된 모델을 사용합니다.
output = run_inference(
    input_type=INPUT_TYPE_INFERENCE,
    file_path=FILE_PATH_INFERENCE,
    query=QUERY_TEXT_INFERENCE,
    max_new_tokens=MAX_NEW_TOKENS_INFERENCE,
    length_penalty=LENGTH_PENALTY_INFERENCE,
    tokenizer=tokenizer_instance,
    emu_model=emu_core_instance,
    clap_model=clap_instance,
    projector_model=user_projector_instance,
    emu_target_device=llm_device_instance,
    aux_device=aux_device_instance,
    emu_image_size=emu_image_size_instance
)
print(output)

추론 실행 시작: 입력 타입='audio', 파일='/mnt/lynx1/datasets/places205/wavs/238/utterance_60731.wav'
A river running through a lush green forest.


In [10]:
def process_test_data_and_save_results(
    test_json_path,
    result_json_path,
    audio_base_path,
    input_type,
    query_text,
    max_new_tokens,
    length_penalty,
    tokenizer,
    emu_model,
    clap_model,
    projector_model,
    emu_target_device,
    aux_device,
    emu_image_size
):
    try:
        with open(test_json_path, 'r', encoding='utf-8') as f:
            test_samples = json.load(f).get('data')
    except FileNotFoundError:
        print(f"오류: '{test_json_path}' 파일을 찾을 수 없습니다.")
        return
    except json.JSONDecodeError:
        print(f"오류: '{test_json_path}' 파일을 파싱하는 중 오류가 발생했습니다.")
        return

    if not all(isinstance(s, dict) for s in test_samples):
        print(f"오류: '{test_json_path}' 파일의 형식이 잘못되었습니다. 각 항목은 딕셔너리여야 합니다.")
        return

    results_data = []

    print(f"'{test_json_path}'에서 샘플을 읽어 추론을 시작합니다...")

    for sample in tqdm(test_samples, desc="추론 진행"):
        wav_relative_path = sample.get('wav')
        if not wav_relative_path:
            print(f"경고: 샘플에 'wav' 키가 없습니다. 건너<0xEB><0x9B><0x84>니다: {sample}")
            output_text_result = "Error: 'wav' key missing in sample"
        elif not isinstance(wav_relative_path, str):
            print(f"경고: 'wav' 경로가 문자열이 아닙니다. 건너<0xEB><0x9B><0x84>니다: {sample}")
            output_text_result = "Error: 'wav' path is not a string"
        else:
            file_path_inference = os.path.join(audio_base_path, wav_relative_path)

            if not os.path.exists(file_path_inference):
                print(f"경고: 오디오 파일을 찾을 수 없습니다: {file_path_inference}. 건너<0xEB><0x9B><0x84>니다.")
                output_text_result = f"Error: Audio file not found at {file_path_inference}"
            else:
                # eval_main.ipynb에 정의된 run_inference 함수 호출
                # 이 함수는 추론된 텍스트를 반환하도록 수정되어야 합니다.
                try:
                    # run_inference 함수가 정의된 노트북 셀이 실행되어 있어야 합니다.
                    output_text_result = run_inference(
                        input_type=input_type,
                        file_path=file_path_inference,
                        query=query_text,
                        max_new_tokens=max_new_tokens,
                        length_penalty=length_penalty,
                        tokenizer=tokenizer,
                        emu_model=emu_model,
                        clap_model=clap_model,
                        projector_model=projector_model,
                        emu_target_device=emu_target_device,
                        aux_device=aux_device,
                        emu_image_size=emu_image_size
                    )
                    if output_text_result is None: # run_inference 내부에서 오류 발생 시 None 반환 가정
                        output_text_result = "Error: Inference failed (run_inference returned None)"
                        print(f"추론 실패 (run_inference가 None을 반환) 파일: {file_path_inference}")

                except Exception as e:
                    print(f"run_inference 실행 중 오류 발생 (파일: {file_path_inference}): {e}")
                    import traceback
                    traceback.print_exc()
                    output_text_result = f"Error: Exception during inference - {str(e)}"

        # 결과 저장 준비
        current_result = sample.copy() # 기존 샘플 정보 복사
        current_result['output_text'] = output_text_result
        results_data.append(current_result)

    try:
        with open(result_json_path, 'w', encoding='utf-8') as f:
            json.dump(results_data, f, indent=4, ensure_ascii=False)
        print(f"추론 결과가 '{result_json_path}' 파일에 성공적으로 저장되었습니다.")
    except IOError:
        print(f"오류: '{result_json_path}' 파일을 쓰는 중 오류가 발생했습니다.")
    except Exception as e:
        print(f"결과 저장 중 예기치 않은 오류 발생: {e}")

In [11]:
AUDIO_BASE_PATH = '/mnt/lynx1/datasets/places205/'
IMAGE_BASE_PATH = '/mnt/lynx1/datasets/places205/vision/torralba/deeplearning/images256/'
TEST_JSON_PATH = 'test.json'
RESULT_JSON_PATH = 'result_ver2_epoch028.json'

INPUT_TYPE_INFERENCE = 'audio'
QUERY_TEXT_INFERENCE = '[<IMG_PLH>] Describe the image in details:'
MAX_NEW_TOKENS_INFERENCE = 64
LENGTH_PENALTY_INFERENCE = -1.0

process_test_data_and_save_results(
    TEST_JSON_PATH,
    RESULT_JSON_PATH,
    AUDIO_BASE_PATH,
    INPUT_TYPE_INFERENCE,
    QUERY_TEXT_INFERENCE,
    MAX_NEW_TOKENS_INFERENCE,
    LENGTH_PENALTY_INFERENCE,
    tokenizer_instance,
    emu_core_instance,
    clap_instance,
    user_projector_instance,
    llm_device_instance,
    aux_device_instance,
    emu_image_size_instance
    )

'test.json'에서 샘플을 읽어 추론을 시작합니다...


추론 진행: 100%|██████████| 100/100 [04:42<00:00,  2.83s/it]

추론 결과가 'result_ver2_epoch028.json' 파일에 성공적으로 저장되었습니다.





In [None]:
RESULT_JSON_TO_LOAD = 'result.json'
AUDIO_BASE_PATH_FOR_DISPLAY = '/mnt/lynx1/datasets/places205/'
IMAGE_BASE_PATH_FOR_DISPLAY = '/mnt/lynx1/datasets/places205/vision/torralba/deeplearning/images256/'

def display_result_by_index(result_file_path, index_to_display, audio_base, image_base):
    try:
        with open(result_file_path, 'r', encoding='utf-8') as f:
            results = json.load(f)
    except FileNotFoundError:
        print(f"오류: '{result_file_path}' 파일을 찾을 수 없습니다.")
        return
    except json.JSONDecodeError:
        print(f"오류: '{result_file_path}' 파일을 파싱하는 중 오류가 발생했습니다.")
        return

    if not isinstance(results, list) or not results:
        print(f"'{result_file_path}' 파일이 비어있거나 잘못된 형식입니다.")
        return

    # 인덱스는 1부터 시작한다고 가정 (사용자 요청)
    actual_index = index_to_display - 1

    if not 0 <= actual_index < len(results):
        print(f"오류: 인덱스 {index_to_display}는 유효한 범위(1 ~ {len(results)})를 벗어났습니다.")
        return

    data_item = results[actual_index]

    print(f"--- 데이터 인덱스: {index_to_display} ---")

    # 이미지 로드 및 출력
    image_relative_path = data_item.get('image')
    if image_relative_path and isinstance(image_relative_path, str):
        image_full_path = os.path.join(image_base, image_relative_path)
        try:
            if os.path.exists(image_full_path):
                img = Image.open(image_full_path)
                plt.figure(figsize=(5,5))
                plt.imshow(img)
                plt.title(f"Image: {os.path.basename(image_full_path)}")
                plt.axis('off')
                plt.show()
            else:
                print(f"이미지 파일을 찾을 수 없습니다: {image_full_path}")
        except Exception as e:
            print(f"이미지 로드 중 오류 발생 ({image_full_path}): {e}")
    elif 'image' not in data_item:
        print("데이터에 'image' 키가 없습니다.")
    else:
        print(f"'image' 경로가 유효하지 않습니다: {image_relative_path}")


    # 오디오 로드 및 출력 (재생 위젯)
    audio_relative_path = data_item.get('wav')
    if audio_relative_path and isinstance(audio_relative_path, str):
        audio_full_path = os.path.join(audio_base, audio_relative_path)
        try:
            if os.path.exists(audio_full_path):
                print(f"Audio: {os.path.basename(audio_full_path)}")
                ipd.display(ipd.Audio(audio_full_path))
            else:
                print(f"오디오 파일을 찾을 수 없습니다: {audio_full_path}")
        except Exception as e:
            print(f"오디오 로드/재생 중 오류 발생 ({audio_full_path}): {e}")
    elif 'wav' not in data_item:
        print("데이터에 'wav' 키가 없습니다.")
    else:
        print(f"'wav' 경로가 유효하지 않습니다: {audio_relative_path}")

    # ASR 텍스트와 Output 텍스트 비교 출력
    asr_text = data_item.get('asr_text', "N/A (asr_text 없음)")
    output_text = data_item.get('output_text', "N/A (output_text 없음)")

    print(f"\nASR Text (원문):")
    print(f"  {asr_text}")
    print(f"\nOutput Text (추론 결과):")
    print(f"  {output_text}")
    print("--- --- ---")

In [None]:
display_result_by_index(RESULT_JSON_TO_LOAD, 3, AUDIO_BASE_PATH_FOR_DISPLAY, IMAGE_BASE_PATH_FOR_DISPLAY)

# Section 2. Comparing Two Modality Embeddings

## Step 1: Import Modules and Prepare Data

In [None]:
import json
from tqdm import tqdm
import traceback # 오류 로깅용

JSON_FILE_PATH = '/home/jongmin/reference/Emu/Emu2/train1k.json' 

DEFAULT_AUDIO_BASE_PATH = '/mnt/lynx1/datasets/places205/'
DEFAULT_IMAGE_BASE_PATH = '/mnt/lynx1/datasets/places205/vision/torralba/deeplearning/images256/'

audio_emb_list = []
image_emb_list = []

## Step 2: Computing Two Modality Embeddings 

In [None]:
if not os.path.exists(JSON_FILE_PATH):
    print(f"오류: JSON 파일 '{JSON_FILE_PATH}'을(를) 찾을 수 없습니다. 스크립트를 종료합니다.")
else:
    with open(JSON_FILE_PATH, 'r') as f:
        dataset_json_content = json.load(f)

    data_entries = dataset_json_content.get('data', [])
    current_audio_base_path = DEFAULT_AUDIO_BASE_PATH
    current_image_base_path = DEFAULT_IMAGE_BASE_PATH

    if 'emu_core_instance' in globals(): emu_core_instance.eval()
    if 'clap_instance' in globals(): clap_instance.eval()
    if 'user_projector_instance' in globals(): user_projector_instance.eval()

    for entry in tqdm(data_entries, desc="데이터셋 임베딩 추출 중"):
        wav_relative_path = entry.get('wav')
        image_relative_path = entry.get('image')

        audio_full_path = os.path.join(current_audio_base_path, wav_relative_path)
        image_full_path = os.path.join(current_image_base_path, image_relative_path)

        # --- 이미지 임베딩 추출 ---
        try:
            with torch.no_grad():
               img_intermediate_embed = get_intermediate_modality_embedding(
                    input_type='image',
                    file_path=image_full_path,
                    emu_model=emu_core_instance, 
                    clap_model=None, 
                    user_projector_model=None, 
                    aux_device=aux_device_instance,
                    emu_image_size=emu_image_size_instance 
                )
            
            if img_intermediate_embed is not None:
                #img_embedding_to_store = img_intermediate_embed.mean(dim=1).squeeze(0) # 결과: [C_visual] (예: [1792])
                img_embedding_to_store = torch.max(img_intermediate_embed, dim=1).values.squeeze(0)
                image_emb_list.append(img_embedding_to_store.detach().cpu())
                del img_intermediate_embed, img_embedding_to_store
            else:
                print(f"이미지 중간 임베딩 추출 실패: {image_full_path}")

        except Exception as e:
            print(f"이미지 처리 중 오류 발생 ({image_full_path}): {e}")
            traceback.print_exc()

        # --- 오디오 임베딩 추출 ---
        try:
            with torch.no_grad():
                audio_intermediate_embed = get_intermediate_modality_embedding(
                    input_type='audio',
                    file_path=audio_full_path,
                    emu_model=None,
                    clap_model=clap_instance,
                    user_projector_model=user_projector_instance,
                    aux_device=aux_device_instance,
                    emu_image_size=None
                )

            if audio_intermediate_embed is not None:
                #audio_embedding_to_store = audio_intermediate_embed.mean(dim=1).squeeze(0)
                audio_embedding_to_store = torch.max(audio_intermediate_embed, dim=1).values.squeeze(0)
                audio_emb_list.append(audio_embedding_to_store.detach().cpu())
                del audio_intermediate_embed, audio_embedding_to_store
            else:
                print(f"오디오 중간 임베딩 추출 실패: {audio_full_path}")
                
        except Exception as e:
            print(f"오디오 처리 중 오류 발생 ({audio_full_path}): {e}")
            traceback.print_exc()

    if hasattr(torch.cuda, 'empty_cache'):
        torch.cuda.empty_cache()

print(f"\n--- 임베딩 추출 완료 ---")
print(f"총 추출된 이미지 임베딩 수: {len(image_emb_list)}")
print(f"총 추출된 오디오 임베딩 수: {len(audio_emb_list)}")

if image_emb_list:
    print(f"첫 번째 이미지 임베딩 형태: {image_emb_list[0].shape}, dtype: {image_emb_list[0].dtype} (on CPU)")
else:
    print("이미지 임베딩이 추출되지 않았습니다.")

if audio_emb_list:
    print(f"첫 번째 오디오 임베딩 형태: {audio_emb_list[0].shape}, dtype: {audio_emb_list[0].dtype} (on CPU)")
else:
    print("오디오 임베딩이 추출되지 않았습니다.")
        


In [None]:
audio_embeddings_tensor = torch.stack(audio_emb_list, dim=0)
image_embeddings_tensor = torch.stack(image_emb_list, dim=0)

print(audio_embeddings_tensor.shape)
print(image_embeddings_tensor.shape)

In [None]:
from sklearn import datasets
from mpl_toolkits.mplot3d import Axes3D

from DOSNES.dosnes.dosnes import DOSNES

In [None]:
all_embeddings = np.concatenate([
    audio_embeddings_tensor.numpy(), 
    image_embeddings_tensor.to(torch.float32).numpy()
], axis=0)

labels = np.array([0] * len(audio_embeddings_tensor) + [1] * len(image_embeddings_tensor))

dosnes = DOSNES(metric="cosine", verbose=1, random_state=42)
embedded = dosnes.fit_transform(all_embeddings, y=labels, filename="dosnes_result_CS_Divergence.gif")

In [None]:
print(audio_emb_list)
print(image_emb_list)

In [None]:
a = image_emb_list[10]
b = audio_emb_list[10]
print(torch.dot(a, a))
print(torch.dot(b, b))

In [None]:
print(clap_instance)