##### 수정된 코드

In [23]:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import matplotlib.pyplot as plt

from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
from imagebind import data

In [24]:
class MultimodalDataset(Dataset):
    def __init__(self, audio_dir, image_dir, text_file, transform=None):
        self.audio_dir = audio_dir
        self.image_dir = image_dir
        self.text_file = text_file
        self.transform = transform
        
        self.audio_files = sorted([f for f in os.listdir(audio_dir) if f.endswith('.wav')])
        self.image_files = [f.replace('.wav', '.png') for f in self.audio_files]
        
        with open(text_file, 'r') as f:
            self.texts = [line.strip() for line in f.readlines()]
        
    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        audio_filename = self.audio_files[idx]
        image_filename = self.image_files[idx]
        text = self.texts[idx]  
        
        audio_path = os.path.join(self.audio_dir, audio_filename)
        image_path = os.path.join(self.image_dir, image_filename)
        
        audio_data = data.load_and_transform_audio_data([audio_path], device)
        image_data = data.load_and_transform_vision_data([image_path], device)
        text_data = data.load_and_transform_text([text], device)

        if image_data.ndim > 4:  # Assuming that image_data can be a 5D tensor [1, 3, 1, H, W]
            image_data = image_data.squeeze(2)
        
        return audio_data, image_data, text_data, (audio_filename, image_filename, text)


##### 각 modality에 해당하는 data 준비

In [25]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = imagebind_model.imagebind_huge(pretrained=True).to(device)
model.eval()

audio_dir = './modalities/audios'
image_dir = './modalities/frames_test' # for test
text_file = './modalities/labels.txt' # for test

dataset = MultimodalDataset(audio_dir, image_dir, text_file)
# dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

In [36]:
vt_sims = []
at_sims = []
va_sims = []

filenames_list = []

for audio_batch, image_batch, text_batch, filenames in dataset:
    with torch.no_grad():
        
        inputs = {
            ModalityType.AUDIO: audio_batch,
            ModalityType.VISION: image_batch,
            ModalityType.TEXT: text_batch
        }
        
        embeddings = model(inputs)
        
        vt_sim = embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T
        at_sim = embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T
        va_sim = embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T
        
        # GPU tensor를 CPU로 이동
        vt_sims.append(vt_sim.cpu().numpy())
        at_sims.append(at_sim.cpu().numpy())
        va_sims.append(va_sim.cpu().numpy())
        
        for _ in range(len(vt_sim)):
            filenames_list.append(filenames)
        
        
        print(f'Files: {filenames}') # modality pair들이 알맞게 들어가는지 확인 가능
        print('Vision x Text similarity: ', vt_sim)
        print('Audio x Text similarity: ', at_sim)
        print('Vision x Audio similarity: ', va_sim)
        print('----------------------------------------------------------\n')
        
vt_sims = np.concatenate(vt_sims, axis=None)
at_sims = np.concatenate(at_sims, axis=None)
va_sims = np.concatenate(va_sims, axis=None)       


Files: ('dog growling_32720.wav', 'dog growling_32720.png', 'dog growling')
Vision x Text similarity:  tensor([[30.6504]], device='cuda:0')
Audio x Text similarity:  tensor([[649.1145]], device='cuda:0')
Vision x Audio similarity:  tensor([[9.3685]], device='cuda:0')
----------------------------------------------------------

Files: ('male singing_27393.wav', 'male singing_27393.png', 'male singing')
Vision x Text similarity:  tensor([[26.5837]], device='cuda:0')
Audio x Text similarity:  tensor([[363.7698]], device='cuda:0')
Vision x Audio similarity:  tensor([[8.2802]], device='cuda:0')
----------------------------------------------------------

Files: ('people babbling_88641.wav', 'people babbling_88641.png', 'people babbling')
Vision x Text similarity:  tensor([[18.0223]], device='cuda:0')
Audio x Text similarity:  tensor([[162.3833]], device='cuda:0')
Vision x Audio similarity:  tensor([[6.6675]], device='cuda:0')
----------------------------------------------------------

Files: 

In [35]:
# print(vt_sims)

[30.650362 26.583687 18.02227  23.24356  31.663048]


In [None]:
# 각 similarity의 mean, median, max 출력
# def statistics(data):
#     mean_val = np.mean(data)
#     median_val = np.median(data)
#     max_val = np.max(data)
    
#     print(f'Mean: {mean_val}, Median: {median_val}, Max: {max_val}')
    
# print('Vision x Text'), statistics(vt_sims)
# print('Audio x Text'), statistics(at_sims)
# print('Vision x Audio'), statistics(va_sims)

In [38]:
# 3개의 similarity -> 산술평균은 부적절
# 3개의 similarity의 median 값을 모두 넘지 못하는 data는 제거(로 일단 구현)
vt_median = np.median(vt_sims)
at_median = np.median(at_sims)
va_median = np.median(va_sims)


print(f'visual-text median: {vt_median}')
print(f'audio-text median: {at_median}')
print(f'visual-audio median: {va_median}')

# for i, (vt, at, va) in enumerate(zip(vt_sims, at_sims, va_sims)):
#     if vt < vt_median and at < at_median and va < va_median: # 모두 넘지 못하는 경우
#         print(filenames_list[i])

visual-text median: 26.58368682861328
audio-text median: 457.72607421875
visual-audio median: 8.142138481140137
('people babbling_88641.wav', 'people babbling_88641.png', 'people babbling')


In [40]:
# 제거하고 싶은 audio를 전체 dataset에서 제거하기 위해, 해당 audio 파일명만 txt에 저장

with open('/mnt/storage1/vggsoundsync/below_median_test.txt', 'w') as file: # 경로 수정
    for i, (vt, at, va) in enumerate(zip(vt_sims, at_sims, va_sims)):
        if vt < vt_median and at < at_median and va < va_median: # 모두 넘지 못하는 경우
            audio_filename = filenames_list[i][0] 
            file.write(f"{audio_filename}\n")
            print(audio_filename)

people babbling_88641.wav


##### 위에서 저장한 txt(제거 대상)를 이용해 제거

In [None]:
# 위에서 알아낸 audio들만 다른 dir로 이동시키기 (완전 삭제는 X)
import shutil
import os

source_dir = '/mnt/storage1/trainvideo_10_audios' # 경로 확인
target_dir = 'below'

if not os.path.exists(target_dir):
    os.makedirs(target_dir)
 
with open('below_median_test.txt', 'r') as file: # 경로 확인
    for line in file:
        filename = line.strip()

        source_file = os.path.join(source_dir, filename)
        target_file = os.path.join(target_dir, filename)
        
        if os.path.exists(source_file):
            shutil.move(source_file, target_file)