In [None]:
!pip install llmcompressor

In [None]:
import numpy as np
from pathlib import Path
from tqdm import tqdm

from transformers import AutoTokenizer
from huggingface_hub import snapshot_download
from datasets import Dataset

from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier

In [None]:
MODEL_ID = "cmeraki/mimi_tts_hf_stage"

model = SparseAutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

In [None]:
data_path = snapshot_download('cmeraki/quantization_data_sample', repo_type='dataset')
quant_data = np.load(
    Path(data_path, 'quantization_data_sample.npy')
)

print(quant_data.shape)

num_samples, max_seq_len = 8192, 1024

def create_dataset_from_tokens(token_array, tokenizer, num_samples=2048):
    dataset_dict = {'text': [tokenizer.decode(d, skip_special_tokens=False) for d in tqdm(token_array)]}
    dataset = Dataset.from_dict(dataset_dict)
    dataset = dataset.shuffle().select(range(num_samples))
    return dataset

ds = create_dataset_from_tokens(quant_data, tokenizer, num_samples)

In [None]:
recipe = QuantizationModifier(
    targets="Linear",
    scheme="FP8",
    ignore=["lm_head"],
)quantization_config

oneshot(
  model=model,
  dataset=ds,
  recipe=recipe,
  max_seq_length=max_seq_len,
  num_calibration_samples=num_samples,
)

In [None]:
model.push_to_hub(repo_id='cmeraki/indri-tts-775m-fp8')
tokenizer.push_to_hub(repo_id='cmeraki/indri-tts-775m-fp8')