處理健祐哥切好的 dataset

In [9]:
import torch
from torch.utils.data import Dataset
from pathlib import Path
import json
import llama
from llama import Tokenizer
import torchaudio
from torchvision import transforms
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
import copy
import logging
import whisper
from tqdm import tqdm

logging.basicConfig(level=logging.ERROR)

data_path = Path("/work/u8915687/big-superb/big-superb-train-data")
tokenizer = Tokenizer(model_path="/home/u8915687/lab/big-superb/Macaw-LLM2/weights/llama_7B/tokenizer.model")

In [24]:
import torch
import yaml
from torch.utils.data import Dataset
from PIL import Image
import json
import llama.utils
from llama import Tokenizer
import copy
import torchvision.transforms as transforms
import pandas as pd
import random
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
import torchaudio


try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC

class BigSuperbDataset(Dataset):
    def __init__(self, data_path, tokenizer, data_path2=None, max_length=128, used_data_split="train", audio_input_type="imagebind"):
        self.max_length = max_length
        self.tokenizer = tokenizer
        self.datas = []
        self.used_datasets = []
        self.used_data_split = used_data_split
        self.audio_input_type = audio_input_type

        for task_path in data_path.iterdir():
            if self._filter_dataset(task_path):
                continue
            
            self.used_datasets.append(task_path.stem)
            
            for data_split in task_path.iterdir():
                if data_split.stem != self.used_data_split:
                    continue
                
                json_data = json.load((data_split/"metadata.json").open())
                for d in json_data.values():
                    d["file"] = str(data_split/d["file"])
                    if d.get("file2"):
                        if "." not in d["file2"]:
                            d["file2"] = d["file2"] + ".wav"
                            
                        if (data_split/d["file2"]).exists():
                            d["file2"] = str(data_split/d["file2"])
                        elif (task_path/"missing_files"/d["file2"]).exists():
                            d["file2"] = str(task_path/"missing_files"/d["file2"])
                        else:
                            assert False, d["file2"]
                            
                    self.datas.append(d)
        # exclude
        if data_path2 is not None:
            for task_path in data_path2.iterdir():
                if self._filter_dataset(task_path):
                    continue
                
                self.used_datasets.append(task_path.stem)
                for data_split in task_path.iterdir():
                    if data_split.stem != self.used_data_split:
                        continue
                    
                    json_data = json.load((data_split/"metadata.json").open())
                    for file_name, d in json_data.items():
                        d["file"] = str(data_split/file_name)
                                
                        self.datas.append(d)


        # Audio loader
        self.clip_sampler = ConstantClipsPerVideoSampler(
            clip_duration=2, clips_per_video=3
        )
        print("Used datasets", len(self.used_datasets), self.used_datasets)
        print(len(self.datas))

    def __len__(self):
        return len(self.datas)

    def __getitem__(self, idx):
        data = self.datas[idx]
        new_data = {}

        instruction = data["instruction"].lower()
        text = data["text"].lower() if data.get("text") else None
        input1 = llama.format_prompt(instruction, text)
        input2 = input1 + data["label"]
        
        input1 = torch.tensor(
            self.tokenizer.encode(input1, bos=True, eos=False), dtype=torch.long
        )
        input2 = torch.tensor(self.tokenizer.encode(input2, bos=True, eos=True), dtype=torch.long)
        padding = self.max_length - input2.size(0)
        if padding > 0:
            input2 = torch.cat((input2, torch.zeros(padding, dtype=torch.long) - 1))
        else:
            input2 = input2[:self.max_length]
        
        labels = copy.deepcopy(input2)
        labels[:input1.size(0)] = -1
        
        input2_mask = input2.ge(0)
        label_mask = labels.ge(0)
        input2[~input2_mask] = 0
        labels[~label_mask] = 0
        
        input2_mask = input2_mask.float()
        label_mask = label_mask.float()
        
        new_data["instruction"] = data["instruction"]
        new_data["input_ids"] = input2
        new_data["labels"] = labels
        new_data["input_mask"] = input2_mask

        if self.audio_input_type == "imagebind":
            if data.get("file2"):
                audio = self._load_and_transform_audio([data["file"], data["file2"]])
            else:
                audio = self._load_and_transform_audio([data["file"]])
                
            new_data["audio"] = audio
        elif self.audio_input_type == "whisper":
            if data.get("file2"):
                audio = self._load_whisper_audio([data["file"], data["file2"]])
            else:
                audio = self._load_whisper_audio([data["file"]])
        
        return new_data
    
    def _load_whisper_audio(self, audio_paths):
        wavforms = []
        for audio_path in audio_paths:
            waveform = torch.tensor(whisper.load_audio(audio_path))
            if waveform.size(0) == 0:
                waveform = torch.zeros([16000*3])
                print(audio_path)
            
            wavforms.append(
                waveform
            )
        audio = torch.cat(wavforms, dim=0)
        audio = whisper.pad_or_trim(audio, 16000*5)
        mel = whisper.log_mel_spectrogram(audio)
        
        return mel

    
    def _load_and_transform_audio(self, 
            audio_paths,
            num_mel_bins=128,
            target_length=204,
            sample_rate=16000,
            clip_duration=2,
            clips_per_video=3,
            mean=-4.268,
            std=9.138
        ):

        waveforms = []
        for audio_path in audio_paths:
            waveform, sr = torchaudio.load(audio_path)

            if waveform.size(1) == 0:
                waveform = torch.zeros([1, 16000*3])
                sr = 16000
                # logging.warning(f"An audio is set to zero, {audio_path}")
                print(audio_path)
                
            if sample_rate != sr:
                waveform = torchaudio.functional.resample(
                    waveform, orig_freq=sr, new_freq=sample_rate
                )

            waveforms.append(waveform)
        waveform = torch.cat(waveforms, dim=1)
            
        all_clips_timepoints = self._get_clip_timepoints(
            self.clip_sampler, waveform.size(1) / sample_rate
        )
        all_clips = []
        for clip_timepoints in all_clips_timepoints:
            waveform_clip = waveform[
                :,
                int(clip_timepoints[0] * sample_rate) : int(
                    clip_timepoints[1] * sample_rate
                ),
            ]
            waveform_melspec = self._def_waveform2melspec(
                waveform_clip, sample_rate, num_mel_bins, target_length
            )
            all_clips.append(waveform_melspec)

        normalize = transforms.Normalize(mean=mean, std=std)
        all_clips = [normalize(ac) for ac in all_clips]

        all_clips = torch.stack(all_clips, dim=0)

        return all_clips
        
    def _get_clip_timepoints(self, clip_sampler, duration):
        # Read out all clips in this video
        all_clips_timepoints = []
        is_last_clip = False
        end = 0.0
        while not is_last_clip:
            start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
            all_clips_timepoints.append((start, end))
        return all_clips_timepoints

    def _def_waveform2melspec(self, waveform, sample_rate, num_mel_bins, target_length):
        # Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102
        waveform -= waveform.mean()
        fbank = torchaudio.compliance.kaldi.fbank(
            waveform,
            htk_compat=True,
            sample_frequency=sample_rate,
            use_energy=False,
            window_type="hanning",
            num_mel_bins=num_mel_bins,
            dither=0.0,
            frame_length=25,
            frame_shift=10,
        )
        # Convert to [mel_bins, num_frames] shape
        fbank = fbank.transpose(0, 1)
        # Pad to target_length
        n_frames = fbank.size(1)
        p = target_length - n_frames
        # if p is too large (say >20%), flash a warning
        # if abs(p) / n_frames > 0.2:
            # logging.warning(
            #     "Large gap between audio n_frames(%d) and "
            #     "target_length (%d). Is the audio_target_length "
            #     "setting correct?",
            #     n_frames,
            #     target_length,
            # )
        # cut and pad
        if p > 0:
            fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
        elif p < 0:
            fbank = fbank[:, 0:target_length]
        # Convert to [1, mel_bins, num_frames] shape, essentially like a 1
        # channel image
        fbank = fbank.unsqueeze(0)
        return fbank
    
    def _filter_dataset(self, task_path):
        if task_path.stem.startswith("HowFarAreYou"):
            return True
        else:
            return False

In [25]:
train_dataset = BigSuperbDataset(data_path, tokenizer, audio_input_type="whisper")

Used datasets 22 ['SpoofDetection_Asvspoof2017', 'SpeechTextMatching_LibrispeechTrainClean360', 'DialogueActClassification_DailyTalk', 'DialogueEmotionClassification_DailyTalk', 'SpokenTermDetection_Tedlium2Train', 'NoiseSNRLevelPredictionGaussian_VoxcelebMusan', 'EnhancementDetection_LibrittsTrainClean360Wham', 'SpeakerCounting_LibrittsTrainClean100', 'SpeakerVerification_Aishell1Train', 'SpoofDetection_ASVspoof2015', 'SpeakerVerification_LibrispeechTrainClean100', 'SpeakerVerification_Voxceleb1Train', 'SpokenTermDetection_LibrispeechTrainClean100', 'SpeechDetection_LibrispeechTrainClean100', 'SpeakerVerification_Tedlium2Train', 'NoiseDetectionGaussian_VoxcelebMusan', 'SpeechTextMatching_LibrispeechTrainClean100', 'SpeechDetection_Aishell1Train', 'SpeechTextMatching_Tedlium2Train', 'SpeechDetection_Tedlium2Train', 'ReverberationDetectionSmallRoom_VoxcelebRirsNoises', 'SpeechDetection_Voxceleb1Train']
108014


In [None]:
for i in tqdm(range(len(train_dataset))):
    train_dataset[i]

 25%|████████████████████████████████████████████████████████████████▍                                                                                                                                                                                                 | 26956/108014 [2:07:15<10:31:25,  2.14it/s]

# test

In [96]:
from datasets import load_from_disk, Dataset
from ImageBind.data import my_load_and_transform_audio_data

org_data_path = Path("/work/u8915687/big-superb/train_datasets")
all_datasets = ['BigSuperbPrivate/SpeechDetection_LibrispeechTrainClean100']

tokenizer = Tokenizer(model_path="/home/u8915687/lab/big-superb/Macaw-LLM2/weights/llama_7B/tokenizer.model")
def prepare_dataset(b):
    max_length = 128
    
    batch = {}
    instruction = b["instruction"].lower()
    text = b["text"].lower() if b.get("text") else None
    input1 = llama.format_prompt(instruction, text)
    input2 = input1 + b["label"]
    
    input1 = torch.tensor(
        tokenizer.encode(input1, bos=True, eos=False), dtype=torch.long
    )
    input2 = torch.tensor(tokenizer.encode(input2, bos=True, eos=True), dtype=torch.long)
    
    
    padding = max_length - input2.size(0)
    if padding > 0:
        input2 = torch.cat((input2, torch.zeros(padding, dtype=torch.long) - 1))
    else:
        input2 = input2[:max_length]
    
    labels = copy.deepcopy(input2)
    labels[:input1.size(0)] = -1
    
    input2_mask = input2.ge(0)
    label_mask = labels.ge(0)
    input2[~input2_mask] = 0
    labels[~label_mask] = 0
    
    input2_mask = input2_mask.float()
    label_mask = label_mask.float()
    
    batch["instruction"] = b["instruction"]
    batch["input_ids"] = input2
    batch["labels"] = labels
    batch["input_mask"] = input2_mask
    batch["audio"] = my_load_and_transform_audio_data(
        torch.tensor(b["audio"]["array"]).unsqueeze(0)
    )[0]
    return batch

for dataset_name in all_datasets:
    ori_dataset =  load_from_disk(org_data_path/dataset_name)    

    td = ori_dataset["train"].shuffle(42)[:50]
    td = Dataset.from_dict(td)

    # td = td.map(prepare_dataset)
    # td = td.with_format("torch")

In [73]:
train_dataset[65000-1986]

{'task_name': 'SpeechDetection_LibrispeechTrainClean100_0',
 'name': '289-121652-0038.flac',
 'instruction': 'Does the recording contain speech from human? The answer could be yes or no.',
 'input_ids': tensor([    1, 13866,   338,   385, 15278,   393, 16612,   263,  3414, 29889,
         14350,   263,  2933,   393,  7128,  2486,  1614,  2167,   278,  2009,
         29889,    13,    13,  2277, 29937,  2799,  4080, 29901,    13, 13221,
           278, 16867,  1712, 12032,   515,  5199, 29973,   278,  1234,  1033,
           367,  4874,   470,   694, 29889,    13,    13,  2277, 29937, 13291,
         29901,  3582,     2,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
        

In [120]:
def collate_fn(b):
    batch = {}
    
    audios = []
    prompts = []
    labels = []
    for data in b:
        audio = my_load_and_transform_audio_data(
                        torch.tensor(data["audio"]["array"], dtype=torch.float32
                    ).unsqueeze(0))[0]
        audios.append(audio)
        text = data.get("text").lower() if data.get("text") else None
        instruction = data["instruction"].lower()
        
        prompts.append(llama.format_prompt(instruction, text))
        labels.append(data["label"])
    
    batch["audio"] = torch.stack(audios)
    batch["prompts"] = prompts
    batch["instructions"] = [d["instruction"] for d in b]
    batch["labels"] = [d["label"] for d in b]
    return batch

data_loader = torch.utils.data.DataLoader(
        td,
        batch_size=16,
        shuffle=False,
        collate_fn=collate_fn
    )

In [122]:
next(iter(data_loader))

KeyError: 'input_ids'