In [1]:
from streaming.base.format.mds.encodings import Encoding, _encodings
from streaming import MDSWriter, LocalDataset
from tqdm import tqdm
from typing import List
import torch
import json

class ListOfDict(Encoding):
    def encode(self, obj: List[dict]) -> bytes:
        # Convert the list of dictionaries to a JSON-encoded string
        json_str = json.dumps(obj)
        return json_str.encode('utf-8')

    def decode(self, data: bytes) -> List[dict]:

        # Decode the JSON-encoded string back to a list of dictionaries
        json_str = data.decode('utf-8')
        return json.loads(json_str)

# Register the custom encoding for 'list_of_dict'
_encodings['list_of_dict'] = ListOfDict

In [2]:
import json
import os

In [3]:
roles = {
    'human': 'user',
    'gpt': 'assistant',
}

In [None]:
data_en, data_ms = [], []
with open('mixtral-audio-instruction.jsonl') as fopen:
    for l in tqdm(fopen):
        l = json.loads(l)
        f = os.path.join('/home/ubuntu', l['filename'])
        f = f.replace('output-audio', 'filter-audio')
        if not os.path.exists(f):
            continue
        en, ms = [], []
        for c in l['chat']:
            en_ = c['content']
            ms_ = c['content_ms']
            en.append({
                'role': c['role'],
                'content': en_
            })
            ms.append({
                'role': c['role'],
                'content': ms_
            })
            
        en[0]['content'] = '<audio> ' + en[0]['content']
        ms[0]['content'] = '<audio> ' + ms[0]['content']
        
        data_en.append({'filename': [f], 'conversations': en[:2]})
        data_ms.append({'filename': [f], 'conversations': ms[:2]})

16048it [00:00, 29531.22it/s]

In [None]:
len(data_en), len(data_ms)

In [None]:
data_en[0]

In [None]:
columns = {
    'conversations': 'list_of_dict',
    'filename': 'list_of_dict'
}

hashes = 'sha1', 'xxh64'

In [None]:
!rm -rf mosaic-audio

In [None]:
data = data_en + data_ms

In [None]:
with MDSWriter(out='mosaic-audio', columns=columns, compression=None, hashes=hashes) as out:
    for l in tqdm(data):
        try:
            for i in range(len(l['conversations'])):
                l['conversations'][i]['content'] = l['conversations'][i]['content'].replace('\n<image>', ' <image>').replace('<image>\n', '<image>').replace('\n<audio>', ' <audio>').replace('<audio>\n', '<audio>').strip()
                l['conversations'][i]['content'] = l['conversations'][i]['content'].replace('<image>', '<image> </image>').replace('<audio>', '<audio> </audio>')
            out.write(l)
        except:
            continue

In [None]:
dataset = LocalDataset('mosaic-audio')

In [None]:
dataset[1]