In [9]:
import json
import os
import torch
import torchaudio
from torch.utils.data import Dataset

class AudioTextDataset(Dataset):
    def __init__(self, json_path):
        self.dirname = os.path.dirname(json_path)
        with open(json_path, "r") as f:
            self.data = json.load(f)["annotation"]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]
        audio_path = os.path.join(self.dirname, entry["path"])
        text = entry["text"]

        try:
            # torchaudio를 사용하여 로드 및 리샘플링 (mono 변환)
            waveform, sr = torchaudio.load(audio_path)
            waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=48000)(waveform)
            waveform = torch.mean(waveform, dim=0)  # 스테레오 → 모노 변환
        except Exception as e:
            print(f"오디오 로드 실패: {audio_path} - {str(e)}")
            return None  # 실패한 경우 None 반환

        return {"waveform": waveform, "text": text, "path": audio_path}


def collate_fn(batch):
    """DataLoader에서 None 값 제거 및 배치 변환"""
    batch = [b for b in batch if b is not None]  # None 제거
    if len(batch) == 0:
        return None  # 빈 배치 방지

    waveforms = [b["waveform"] for b in batch]
    texts = [b["text"] for b in batch]
    paths = [b["path"] for b in batch]

    return {"waveforms": torch.stack(waveforms), "texts": texts, "paths": paths}

In [10]:
import torch.distributed as dist

def setup(rank, world_size):
    """멀티 GPU 환경 설정"""
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    """프로세스 종료"""
    dist.destroy_process_group()

In [11]:
import json
import torch
import torch.nn.functional as F
from transformers import ClapModel, ClapProcessor
from torch.utils.data import DataLoader, DistributedSampler

# CLAP 모델 경로
MODEL_NAME = "laion/clap-htsat-unfused"

def clap_refine(rank, world_size, input_json_path, output_json_path, batch_size=8, similarity_threshold=0.5):
    """멀티 GPU를 활용한 CLAP 유사도 필터링"""
    setup(rank, world_size)

    device = torch.device(f"cuda:{rank}")
    processor = ClapProcessor.from_pretrained(MODEL_NAME)
    model = ClapModel.from_pretrained(MODEL_NAME).to(device)
    model.eval()

    dataset = AudioTextDataset(input_json_path)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False)
    dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, collate_fn=collate_fn, num_workers=4, pin_memory=True)

    filtered_annotations = []

    for batch in dataloader:
        if batch is None:
            continue

        waveforms = batch["waveforms"].to(device)
        texts = batch["texts"]
        paths = batch["paths"]

        with torch.no_grad():
            # 오디오 및 텍스트 임베딩 생성
            inputs = processor(audios=[w.cpu().numpy() for w in waveforms], return_tensors="pt", sampling_rate=48000)
            inputs = {key: val.to(device) for key, val in inputs.items()}

            text_inputs = processor(text=texts, return_tensors="pt")
            text_inputs = {key: val.to(device) for key, val in text_inputs.items()}

            audio_embeds = model.get_audio_features(**inputs)
            text_embeds = model.get_text_features(**text_inputs)

            similarities = F.cosine_similarity(audio_embeds, text_embeds).cpu().numpy()

        # 임계값 체크 후 저장
        for i, sim in enumerate(similarities):
            if sim >= similarity_threshold:
                filtered_annotations.append({"path": paths[i], "text": texts[i]})
                print(f"[GPU {rank}] 유지: {paths[i]} - 유사도: {sim:.4f}")
            else:
                print(f"[GPU {rank}] 제거: {paths[i]} - 유사도: {sim:.4f}")

    # 모든 GPU의 데이터를 모아서 저장
    gathered_data = [None] * world_size
    dist.all_gather_object(gathered_data, filtered_annotations)

    if rank == 0:
        all_filtered = [item for sublist in gathered_data for item in sublist]
        with open(output_json_path, "w") as f:
            json.dump({"annotation": all_filtered}, f, indent=4)
        print(f"필터링 완료! {len(all_filtered)}개 항목 저장됨")

    cleanup()

In [12]:
import torch.multiprocessing as mp

def main(input_json_path="input.json", output_json_path="output.json", batch_size=8):
    """멀티 GPU 실행"""
    world_size = torch.cuda.device_count()
    print(f"사용 가능한 GPU 수: {world_size}")

    mp.spawn(
        clap_refine,
        args=(world_size, input_json_path, output_json_path, batch_size),
        nprocs=world_size,
        join=True,
    )

main("/data/dataset/stage1_sample1.json")

사용 가능한 GPU 수: 2


Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/root/miniconda3/envs/salmonn/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/root/miniconda3/envs/salmonn/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'clap_refine' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/root/miniconda3/envs/salmonn/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/root/miniconda3/envs/salmonn/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'clap_refine' on <module '__main__' (built-in)>
W0204 19:15:07.196000 588919 site-packages/torch/multiprocessing/spawn.py:160] Terminating process 616

ProcessExitedException: process 0 terminated with exit code 1