In [9]:
import platform
import sys
import time
import numpy as np
import torch
from tqdm import tqdm
from pathlib import Path
from loguru import logger
from pprint import PrettyPrinter
import torch.nn as nn
#from parallel import DataParallelModel, DataParallelCriterion
from torch.utils.tensorboard import SummaryWriter
#from tools.utils import setup_seed, AverageMeter, a2t, t2a
from tools.utils import setup_seed, AverageMeter
from tools.loss import BiDirectionalRankingLoss, TripletLoss, NTXent, WeightTriplet
from models.ASE_model import ASE
#from data_handling.DataLoader import get_dataloader
import wandb
from tensorboardX import SummaryWriter

In [3]:
import yaml
from dotmap import DotMap
import easydict


def get_config(config_name='settings'):

    with open('settings/{}.yaml'.format(config_name), 'r') as f:

        config = yaml.load(f, Loader=yaml.FullLoader)
    config = DotMap(config)
    return config


args = easydict.EasyDict({
    
    "dataset": "Clotho",
    "lr": 0.0001,
    "config": "settings",
    "loss": "triplet",
    "freeze": "False",
    "batch":24, 
    "margin":0.2,
    "seed":20
})

config = get_config(args.config)
config

DotMap(mode='train', exp_name='exp', dataset='Clotho', text_encoder='sbert', joint_embed=1024, wav=DotMap(sr=32000, window_size=1024, hop_length=320, mel_bins=64), bert_encoder=DotMap(type='bert-base-uncased', freeze=True), cnn_encoder=DotMap(model='Cnn14', pretrained=True, freeze=True), data=DotMap(batch_size=24, num_workers=8), training=DotMap(margin=0.2, freeze=True, loss='ntxent', spec_augmentation=True, epochs=50, lr=0.0001, clip_grad=2, seed=20, resume=False, l2_norm=True, dropout=0.2, csv=True), path=DotMap(vocabulary='data/{}/pickles/words_list.p', word2vec='pretrained_models/w2v_all_vocabulary.model', resume_model=''), _ipython_display_=DotMap(), _repr_mimebundle_=DotMap())

In [1]:
import torch
import random
import numpy as np
import h5py
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader


class AudioCaptionDataset(Dataset):

    def __init__(self, dataset='Clotho', split='train'):
        """
        load audio clip's waveform and corresponding caption
        Args:
            dataset: 'AudioCaps', 'Clotho
            split: 'train', 'val', 'test'
        """
        
        super(AudioCaptionDataset, self).__init__()
        self.dataset = dataset
        self.split = split
        self.h5_path = f'data/{dataset}/hdf5s/{split}/{split}.h5'
        
        
        if dataset == 'AudioCaps' and split == 'train':
            self.is_train = True
            self.num_captions_per_audio = 1
            with h5py.File(self.h5_path, 'r') as hf:
                self.audio_keys = [audio_name.decode() for audio_name in hf['audio_name'][:]]
                # audio_names: [str] 
                self.captions = [caption.decode() for caption in hf['caption'][:]]
         
    
        else:
            self.is_train = False
            #self.is_train = True
            self.num_captions_per_audio = 5
            
            with h5py.File(self.h5_path, 'r') as hf:
                self.audio_keys = [audio_name.decode() for audio_name in hf['audio_name'][:]]
                # audio_names: [str] -> decode()처리해줘서 string
                
                self.captions = [caption for caption in hf['caption'][:]]
                #self.audio_lengths = [length for length in hf['audio_length'][:]]
                
                if dataset == 'Clotho':
                    self.audio_lengths = [length for length in hf['audio_length'][:]]
                # [cap_1, cap_2, ..., cap_5]

               
    def __len__(self):
        return len(self.audio_keys) * self.num_captions_per_audio

    
    def __getitem__(self, index):

        audio_idx = index // self.num_captions_per_audio
        audio_name = self.audio_keys[audio_idx]
        with h5py.File(self.h5_path, 'r') as hf:
            waveform = hf['waveform'][audio_idx]

        if self.dataset == 'AudioCaps' and self.is_train:
            caption = self.captions[audio_idx]
        else:
            captions = self.captions[audio_idx]
            cap_idx = index % self.num_captions_per_audio
            caption = captions[cap_idx].decode()

        if self.dataset == 'Clotho':
            length = self.audio_lengths[audio_idx]
            return waveform, caption, audio_idx, length, index
        else:
            return waveform, caption, audio_idx, len(waveform), index


def collate_fn(batch_data):
    """

    Args:
        batch_data:

    Returns:

    """

    max_audio_length = max([i[3] for i in batch_data])

    wav_tensor = []
    for waveform, _, _, _, _ in batch_data:
        if max_audio_length > waveform.shape[0]:
            padding = torch.zeros(max_audio_length - waveform.shape[0]).float()
            temp_audio = torch.cat([torch.from_numpy(waveform).float(), padding])
        else:
            temp_audio = torch.from_numpy(waveform[:max_audio_length]).float()
        wav_tensor.append(temp_audio.unsqueeze_(0))

    wavs_tensor = torch.cat(wav_tensor)
    captions = [i[1] for i in batch_data]
    audio_ids = torch.Tensor([i[2] for i in batch_data])
    indexs = np.array([i[4] for i in batch_data])

    return wavs_tensor, captions, audio_ids, indexs


def get_dataloader(split, config):
    dataset = AudioCaptionDataset(config.dataset, split)
    if split == 'train':
        shuffle = True
        drop_last = True
    else:
        shuffle = False
        drop_last = False

    return DataLoader(dataset=dataset,
                      batch_size=config.data.batch_size,
                      shuffle=shuffle,
                      drop_last=drop_last,
                      num_workers=config.data.num_workers,
                      collate_fn=collate_fn)



- train dataloader 확인하기

In [4]:
train_loader = get_dataloader('train_augment',config)

In [7]:
for batch_id, batch_data in tqdm(enumerate(train_loader), total=len(train_loader)):
    if batch_id ==2:
        break
        
    print(batch_id)
    print(batch_data)
    print('---')

    audios, captions, audio_ids, _ = batch_data
    print(audios.shape)
    print(audios)
    print(captions)
    print(audio_ids)
    print('---------------------------------')


  0%|                                                                                | 1/1600 [00:00<18:23,  1.45it/s]

0
(tensor([[ 0.0098,  0.0055, -0.0023,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0098,  0.0055, -0.0023,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0098,  0.0055, -0.0023,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0014, -0.0007, -0.0019,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0014, -0.0007, -0.0019,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0014, -0.0007, -0.0019,  ...,  0.0000,  0.0000,  0.0000]]), ['a muddled noise of broken channel of the tv', 'a television blares the rhythm of a static tv ', 'loud television static dips in and out of focus', 'the loud buzz of static constantly changes pitch and volume ', 'heavy static and the beginnings of a signal on a transistor radio', 'a person is turning a map over and over ', 'a person is very carefully rapping a gift for someone else ', 'a person is very carefully wrapping a gift for someone else ', 'he sighed as he turned the pages of the book stopping to scan the information ', 'papers are being turned st

  0%|                                                                                | 2/1600 [00:01<18:38,  1.43it/s]

1
(tensor([[ 0.0014, -0.0007, -0.0019,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0029,  0.0037,  0.0032,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0029,  0.0037,  0.0032,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0027,  0.0050,  0.0062,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0027,  0.0050,  0.0062,  ...,  0.0000,  0.0000,  0.0000],
        8., 8., 8., 9., 9., 9.]), array([24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40,
       41, 42, 43, 44, 45, 46, 47]))
---
torch.Size([24, 928520])
tensor([[ 0.0014, -0.0007, -0.0019,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0029,  0.0037,  0.0032,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0029,  0.0037,  0.0032,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0027,  0.0050,  0.0062,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0027,  0.0050,  0.0062,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0027,  0.0050,  0.0062,  ...,  0.0000,  0.0000,  0.0000]])
tensor([4., 5., 5., 5., 5


