In [13]:
# from dataset import get_ds

  from .autonotebook import tqdm as notebook_tqdm


In [70]:
import torch
from datasets import load_dataset
from transformers import Qwen2_5OmniProcessor, Qwen2_5OmniThinkerForConditionalGeneration

model_id = "Qwen/Qwen2.5-Omni-3B"

model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto", device_map="auto")

def build_conversation(model_id, word):
	processor = Qwen2_5OmniProcessor.from_pretrained(model_id)

	conversation = [
		{
			"role": "system",
			"content": [
				{"type": "text", "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."}
			],
		},
		{
			"role": "user",
			"content": [
				{"type": "audio", "audio": "PLACEHOLDER AUDIO"}, # we will manually fill in the audio
				{"type": "text", "text": f"When is \"{word}\" said?"},
			],
		},
	]

	text = processor.apply_chat_template(
		conversation,
		tokenize=False,
		add_generation_prompt=True,
	)
	

	return text

def get_ds(model_id, split='train_clean_100', slice=None):
	def preprocess_fn(example):
		audio = example['audio']
		words = example['words']

		prompt = build_conversation(model_id, words[0]['word'])
		audio_frames = audio['array']

		inputs = processor(
			text=prompt,
			audio=audio_frames,
			return_tensors='pt',
			padding=True,
		)

		input_ids = inputs['input_ids']
		attention_mask = inputs['attention_mask']
		input_features = inputs['input_features']
		feature_attention_mask = inputs['feature_attention_mask']

		time = feature_attention_mask.sum(dim=-1) # length of audio in centiseconds
		labels = torch.zeros(int(time // 4)) # each embedding is 4 centiseconds long
		end_idx = int(words[0]['end'] * 25) # convert to centiseconds and divide by 4
		# TODO: clamp to max size of labels

		labels[end_idx] = 1

		# audio_features = model.get_audio_features(
		# 	input_features=input_features.to(device=model.device),
		# 	feature_attention_mask=feature_attention_mask.to(device=model.device),
		# )

		return {
			'prompt': prompt,
			'audio_frames': audio_frames,
			'input_ids': input_ids[0],
			'attention_mask': attention_mask[0],
			'input_features': input_features[0],
			'feature_attention_mask': feature_attention_mask[0],
			'labels': labels,
			# 'audio_features': audio_features.to('cpu'),
		}

	processor = Qwen2_5OmniProcessor.from_pretrained(model_id)

	base_ds = load_dataset("gilkeyio/librispeech-alignments")[split].select(range(slice)) if slice else load_dataset("gilkeyio/librispeech-alignments")[split]

	ds = base_ds.map(preprocess_fn, remove_columns=base_ds.column_names)
	
	ds.set_format(type='torch')

	return ds

Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_section'}
Loading checkpoint shards: 100%|██████████| 3/3 [00:16<00:00,  5.50s/it]


In [None]:
ds = get_ds('Qwen/Qwen2.5-Omni-3B')

Map:   2%|▏         | 514/28538 [14:44<13:57:59,  1.79s/ examples]

In [67]:
i=4
print(ds[i]['feature_attention_mask'].sum())
print(ds[i]['audio_features'].shape)
print(ds[i]['labels'].shape)

tensor(861)
torch.Size([215, 2048])
torch.Size([215])
