In [18]:
#Llama Tokenizer - Step 1
from transformers import MimiModel, AutoFeatureExtractor, AutoTokenizer
import torch
import numpy as np

class TextTokenizer:
    def __init__(self, name='Llama_tokenizer'):
        self.tokenizer = AutoTokenizer.from_pretrained(name, legacy=False)
        print("text vocab size", self.tokenizer.vocab_size)

    def encode(self, text: str):
        tokens = self.tokenizer.encode(text)
        return tokens

    def decode(self, tokens):
        return self.tokenizer.decode(tokens)
    
class MimiTokenizer:
    def __init__(self, device):    
        self.device = device
        self.model = MimiModel.from_pretrained("kyutai/mimi")
        self.model.to(device)
        self.model.eval()
        self.feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi", device=device)
        self.sampling_rate = self.feature_extractor.sampling_rate
        self.n_codebooks = 8
        self.vocab_size = 2048

    @torch.inference_mode()
    def encode(self, waveform):
        inputs = self.feature_extractor(raw_audio=waveform, 
                                        sampling_rate=self.sampling_rate, 
                                        return_tensors="pt").to(self.device)
            
        output = self.model.encode(inputs["input_values"], inputs["padding_mask"], num_quantizers=self.n_codebooks)
        tokens = output.audio_codes[0].cpu().numpy()
        return tokens

    def decode(self, tokens):
        assert len(tokens.shape) == 2
        tokens = torch.tensor(np.expand_dims(tokens, axis=0)).to(self.device)
        output = self.model.decode(tokens)
        waveform = output.audio_values.cpu()
        return waveform

2024-12-17 21:17:57.807816: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-17 21:17:57.815273: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1734450477.823825 2237071 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1734450477.826167 2237071 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-17 21:17:57.835489: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [11]:
import json
tokenizer = TextTokenizer()
with open('/home/subhash/.cache/indri/lj_speech/annotation/metadata.jsonl') as f:
    for line in f:
        data = json.loads(line)
        text = data['raw_text']
        tokens = tokenizer.encode(text)
        print(text)
        print(tokens)
        print(tokenizer.decode(tokens))
        break

text vocab size 128000
Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition
[128000, 39628, 11, 304, 279, 1193, 5647, 449, 902, 584, 527, 520, 3118, 11920, 11, 44642, 505, 1455, 422, 539, 505, 682, 279, 19071, 323, 44948, 15609, 304, 279, 68033]
<|begin_of_text|>Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition


In [28]:
# Mimi Tokeniser - Step 2
import numpy as np
mimi_tokens = np.load("/home/subhash/.cache/indri/lj_speech/tokens/mimi/LJ040-0046.npy")
print(mimi_tokens.shape)
#print(mimi_tokens)
type(mimi_tokens)

(8, 100)


numpy.ndarray

In [26]:
# Weave Tokens, codebook offset - Step 3
def weave_tokens(tokens):
    result = []
    max_length = max(len(codebook) for codebook in tokens)
    
    for i in range(max_length):
        for codebook_index, codebook in enumerate(tokens):
            if i < len(codebook):
                offset = 2048 * codebook_index + 128000
                result.append(codebook[i] + offset)                 
    return np.array(result)

weave_audio = weave_tokens(mimi_tokens.tolist())
weave_audio

array([128995, 130325, 132834, 135350, 138081, 138819, 140988, 143449,
       128572, 131272, 133038, 135276, 136552, 139416, 141865, 142522,
       129117, 131597, 133406, 134146, 136322, 139289, 141923, 143985,
       129069, 131186, 133662, 135220, 136512, 139853, 141760, 142410,
       129620, 131843, 133121, 134814, 137436, 138387, 141084, 142444,
       128080, 131002, 133038, 135147, 137077, 139665, 141897, 142752,
       129134, 131203, 132924, 134668, 136548, 140002, 140409, 143451,
       128715, 130220, 133849, 135225, 136357, 139775, 140460, 143596,
       129650, 131937, 132126, 134545, 136933, 139077, 141937, 143864,
       129532, 131639, 132479, 134835, 137883, 138246, 141974, 143201,
       128937, 131824, 133411, 134384, 136248, 138246, 141565, 143949,
       129077, 130323, 132375, 134147, 138069, 139010, 141512, 142954,
       129352, 130386, 133860, 134608, 138148, 140177, 141688, 143239,
       128749, 130955, 133781, 135938, 137985, 140133, 141270, 143942,
      

In [15]:
import torch
import numpy as np
from transformers import AutoTokenizer

class TTSTokenizer:
    def __init__(self, text_tokenizer_name='tts_tokenizer', audio_tokenizer_name='tts_tokenizer'):
        self.text_tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_name, legacy=False)
        self.audio_tokenizer = AutoTokenizer.from_pretrained(audio_tokenizer_name, legacy=False)
        print("text vocab size", self.audio_tokenizer.vocab_size)

    def encode(self, input_data, add_special_tokens=True):
        if isinstance(input_data, str):
            encoded_tokens = self.text_tokenizer.encode(
                input_data, 
                return_tensors='pt', 
                add_special_tokens=add_special_tokens
            )
            return encoded_tokens
        elif isinstance(input_data, list) and all(isinstance(item, str) for item in input_data):
            encoded_tokens = self.audio_tokenizer.encode(
                input_data, 
                return_tensors='pt', 
                add_special_tokens=add_special_tokens
            )
            return encoded_tokens
        else:
            raise TypeError("Input must be a string or a list of strings")

    def decode(self, tokens):
        if not isinstance(tokens, torch.Tensor):
            raise TypeError("Input must be a torch tensor of tokens")
        
        try:
            decoded_text = self.text_tokenizer.decode(tokens)
            return decoded_text
        except:
            try:
                decoded_tokens = self.audio_tokenizer.decode(tokens)
                return torch.tensor(decoded_tokens)
            except:
                raise ValueError("Unable to decode the provided tokens")

In [18]:
tokenizer = TTSTokenizer(text_tokenizer_name='tts_tokenizer', audio_tokenizer_name='tts_tokenizer')
audio_decoding = tokenizer.decode(tokens=torch.tensor(weave_audio))
text_decoding = tokenizer.decode(tokens=torch.tensor([128000, 39628, 11, 304, 279, 1193, 5647, 449, 902, 584, 527, 520, 3118, 11920, 11, 44642, 505, 1455, 422, 539, 505, 682, 279, 19071, 323, 44948, 15609, 304, 279, 68033]))
print(text_decoding)
print(audio_decoding)

text vocab size 128000


NameError: name 'weave_audio' is not defined

In [37]:
#Appending the tokens to a single sequence
#text_tokens, task_tokens, speaker_tokens, audio_start_tokens, audio_tokens, common_stop_token.
tokenizer = TTSTokenizer(text_tokenizer_name='tts_tokenizer', audio_tokenizer_name='tts_tokenizer')

MIMI = '[mimi]'
CONVERT = '[convert]'
CONTINUE = '[continue]'
DEFAULT_SPEAKER = '[spkr_unk]'
COMMON_STOP = '[stop]'

import torch

@torch.no_grad()
def append_tokens(text, audio_tokens, speaker=DEFAULT_SPEAKER):
    audio_tokens = torch.tensor(audio_tokens, dtype=torch.int32).clone().detach()
    text_tokens = torch.tensor(tokenizer.encode(text), dtype=torch.int32).view(-1).clone().detach()
    convert_tokens = torch.tensor(tokenizer.encode(CONVERT, add_special_tokens=False), dtype=torch.int32).view(-1).clone().detach()
    continue_tokens = torch.tensor(tokenizer.encode(CONTINUE, add_special_tokens=False), dtype=torch.int32).view(-1).clone().detach()
    speaker_tokens = torch.tensor(tokenizer.encode(speaker, add_special_tokens=False), dtype=torch.int32).view(-1).clone().detach()
    mimi_tokens = torch.tensor(tokenizer.encode(MIMI, add_special_tokens=False), dtype=torch.int32).view(-1).clone().detach()
    stop_tokens = torch.tensor(tokenizer.encode(COMMON_STOP, add_special_tokens=False), dtype=torch.int32).view(-1).clone().detach()
    
    result = torch.cat([
        text_tokens,
        convert_tokens,
        #continue_tokens,
        speaker_tokens,
        mimi_tokens,
        audio_tokens,
        stop_tokens
    ])
    
    return result

In [38]:
result = append_tokens("Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition", weave_audio)

  text_tokens = torch.tensor(tokenizer.encode(text), dtype=torch.int32).view(-1).clone().detach()
  convert_tokens = torch.tensor(tokenizer.encode(CONVERT, add_special_tokens=False), dtype=torch.int32).view(-1).clone().detach()
  continue_tokens = torch.tensor(tokenizer.encode(CONTINUE, add_special_tokens=False), dtype=torch.int32).view(-1).clone().detach()
  speaker_tokens = torch.tensor(tokenizer.encode(speaker, add_special_tokens=False), dtype=torch.int32).view(-1).clone().detach()
  mimi_tokens = torch.tensor(tokenizer.encode(MIMI, add_special_tokens=False), dtype=torch.int32).view(-1).clone().detach()
  stop_tokens = torch.tensor(tokenizer.encode(COMMON_STOP, add_special_tokens=False), dtype=torch.int32).view(-1).clone().detach()


In [39]:
result

tensor([128000,  39628,     11,    304,    279,   1193,   5647,    449,    902,
           584,    527,    520,   3118,  11920,     11,  44642,    505,   1455,
           422,    539,    505,    682,    279,  19071,    323,  44948,  15609,
           304,    279,  68033, 144642, 144645, 144641, 128995, 130325, 132834,
        135350, 138081, 138819, 140988, 143449, 128572, 131272, 133038, 135276,
        136552, 139416, 141865, 142522, 129117, 131597, 133406, 134146, 136322,
        139289, 141923, 143985, 129069, 131186, 133662, 135220, 136512, 139853,
        141760, 142410, 129620, 131843, 133121, 134814, 137436, 138387, 141084,
        142444, 128080, 131002, 133038, 135147, 137077, 139665, 141897, 142752,
        129134, 131203, 132924, 134668, 136548, 140002, 140409, 143451, 128715,
        130220, 133849, 135225, 136357, 139775, 140460, 143596, 129650, 131937,
        132126, 134545, 136933, 139077, 141937, 143864, 129532, 131639, 132479,
        134835, 137883, 138246, 141974, 

In [63]:
import os
import torch
import json
import numpy as np
CACHE_DIR = '/home/subhash/.cache/indri'
def load_tokens(dataset_dir):
    metadata_path = f"{CACHE_DIR}/{dataset_dir}/annotation/metadata.jsonl"
    tokens_dir = os.path.join(CACHE_DIR, dataset_dir, 'tokens', 'mimi')
    with open(metadata_path, 'r', encoding='utf-8') as file:
        for line_number, line in enumerate(file, 1):
            data = json.loads(line.strip())            
            file_path = os.path.join(tokens_dir, data['id'] + '.npy')
            
            audio_tokens = np.load(file_path)
            weave_audio = weave_tokens(audio_tokens.tolist())
            yield data['raw_text'], weave_audio, data['speaker_id']

In [67]:
dataset = 'lj_speech'
for raw_text, audio_tokens, speaker in load_tokens(dataset_dir=dataset):
    with open('allowed_speakers.jsonl', 'r', encoding='utf-8') as file:
        allowed_speakers = [json.loads(line.strip()) for line in file]
    entry = next((item for item in allowed_speakers if item['dataset'] == dataset and item['speaker'] == speaker), None)
    if entry:
        combined = entry['combined']
    else:
        combined = DEFAULT_SPEAKER
    result = append_tokens(raw_text, audio_tokens, speaker=combined)
    print(result)
    print(tokenizer.decode(result))
    break

tensor([128000,  39628,     11,  ..., 140557, 142493, 144644],
       dtype=torch.int32)
<|begin_of_text|>Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition[convert][spkr_unk][mimi][aco_1442][aco_3215][aco_4335][aco_6316][aco_8867][aco_11749][aco_14000][aco_15936][aco_1463][aco_3102][aco_4741][aco_6853][aco_9438][aco_11023][aco_12637][aco_15755][aco_12][aco_3327][aco_5098][aco_7331][aco_8508][aco_11952][aco_12220][aco_14360][aco_1133][aco_3558][aco_5358][aco_6739][aco_7940][aco_11184][aco_12999][aco_15792][aco_595][aco_2796][aco_5106][aco_7093][aco_8529][aco_11455][aco_13701][aco_15488][aco_1516][aco_2061][aco_5212][aco_6954][aco_8614][aco_10897][aco_13674][aco_14271]<|reserved_special_token_178|>[aco_2240][aco_4828][aco_6706][aco_8849][aco_10478][aco_12378][aco_15799][aco_1051][aco_3045][aco_4302][aco_7390][aco_8951][aco_11910][aco_12257][aco_15170][aco_1639][aco_2699][aco_4302][aco_60

  text_tokens = torch.tensor(tokenizer.encode(text), dtype=torch.int32).view(-1).clone().detach()
  convert_tokens = torch.tensor(tokenizer.encode(CONVERT, add_special_tokens=False), dtype=torch.int32).view(-1).clone().detach()
  continue_tokens = torch.tensor(tokenizer.encode(CONTINUE, add_special_tokens=False), dtype=torch.int32).view(-1).clone().detach()
  speaker_tokens = torch.tensor(tokenizer.encode(speaker, add_special_tokens=False), dtype=torch.int32).view(-1).clone().detach()
  mimi_tokens = torch.tensor(tokenizer.encode(MIMI, add_special_tokens=False), dtype=torch.int32).view(-1).clone().detach()
  stop_tokens = torch.tensor(tokenizer.encode(COMMON_STOP, add_special_tokens=False), dtype=torch.int32).view(-1).clone().detach()


In [None]:
# Sample data (replace this with your actual data)


# Function to get combined ID based on dataset and speaker
def get_combined_id(data, dataset, speaker):
    entry = next((item for item in data if item['dataset'] == dataset and item['speaker'] == speaker), None)
    return entry['combined'] if entry else None

# Example usage
dataset_input = "mls_eng_10k"
speaker_input = "2156"
combined_id = get_combined_id(data, dataset_input, speaker_input)

print(combined_id)  # Output: [spkr_mls_eng_10k_2156]

In [None]:
import pickle

def load_pickle_file(file_path):
    with open(file_path, 'rb') as file:
        data = pickle.load(file)
    return data

pickle_file_path = 'tokens/lj_speech_tokens.pkl'  
data = load_pickle_file(pickle_file_path)
data


In [2]:
from llama_model import Llama, LlamaConfig
def calculate_parameters(config: LlamaConfig) -> int:
    total_params = 0

    total_params += config.vocab_size * config.dim 

    for _ in range(config.n_layers):
        total_params += (config.dim * config.n_heads * (config.dim // config.n_heads)) * 3  # wq, wk, wv
        total_params += (config.n_heads * (config.dim // config.n_heads) * config.dim)  # wo

        hidden_dim = int(4 * config.dim)  # Assuming hidden_dim is 4 * dim
        total_params += (config.dim * hidden_dim) + (hidden_dim * config.dim)  # w1 and w2
        total_params += (config.dim * hidden_dim)  # w3

    total_params += config.dim * 2 * config.n_layers

    return total_params

In [5]:
config = LlamaConfig(
    dim=2048,
    n_layers=12,
    n_heads=16,
    vocab_size=144645,
    max_seq_len=2048
)
total_params = calculate_parameters(config)
print(f"Total number of parameters: {total_params}")

Total number of parameters: 1101588480


In [7]:
import pickle

# Load the pickle file
with open('tokens/lj_speech_tokens.pkl', 'rb') as f:
    data = pickle.load(f)

# Initialize variables to keep track of the maximum length tensor
max_length = 0
max_tensor = None

# Iterate through the tensors in the data
for tensor in data:
    length = tensor.shape[0]  # Assuming the length is the first dimension
    if length > max_length:
        max_length = length
        max_tensor = tensor

# Output the tensor with the maximum length
print("Tensor with the maximum length:", max_tensor)
print("Maximum length:", max_length)

  return torch.load(io.BytesIO(b))


Tensor with the maximum length: tensor([128000,     33,   4361,  ..., 141628, 143887, 144644],
       dtype=torch.int32)
Maximum length: 1063


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import pickle
from llama_model import Llama, LlamaConfig
import torch.nn.functional as F

class TokenDataset(Dataset):
    def __init__(self, tokens_file, max_token=144641):
        with open(tokens_file, 'rb') as f:
            all_tokens = pickle.load(f)
        
        self.tokens = [
            seq.tolist() for seq in all_tokens
        ]
        
        print(f"Total sequences: {len(self.tokens)}")
        print(f"Max sequence length after filtering: {max(len(seq) for seq in self.tokens)}")

    def __len__(self):
        return len(self.tokens)
    
    def __getitem__(self, idx):
        sequence = self.tokens[idx]
        
        # Find the index of the max_token
        if 144641 in sequence:
            split_idx = sequence.index(144641) + 1
        else:
            split_idx = len(sequence)
        
        # Split the sequence into input and output
        input_seq = sequence[:split_idx]
        output_seq = sequence[split_idx:]
        
        # Convert to tensor directly 
        input_seq = torch.tensor(input_seq, dtype=torch.long)
        output_seq = torch.tensor(output_seq, dtype=torch.long)
        
        return input_seq, output_seq

def collate_fn(batch):
    # Unpack the batch
    inputs, targets = zip(*batch)
    
    # Ensure inputs and targets have matching lengths
    max_input_len = max(len(inp) for inp in inputs)
    max_target_len = max(len(tgt) for tgt in targets)
    max_len = min(max_input_len, max_target_len)  # Limit to shorter sequence
    
    # Pad and truncate sequences
    inputs_padded = torch.stack([
        F.pad(inp[:max_len], (0, max_len - len(inp[:max_len])), value=0) 
        for inp in inputs
    ])
    
    targets_padded = torch.stack([
        F.pad(tgt[:max_len], (0, max_len - len(tgt[:max_len])), value=-100)  # -100 for ignore_index in loss
        for tgt in targets
    ])

    # Create attention mask
    attention_mask = (inputs_padded != 0).float()
    
    return inputs_padded, targets_padded, attention_mask

def get_vocab_size(tokens_file):
    with open(tokens_file, 'rb') as f:
        tokens = pickle.load(f)
    
    # Find the maximum token, ensuring it doesn't exceed max_token
    max_token_in_data = max(max(seq) for seq in tokens)
    return max_token_in_data + 1

def train_model(model, dataloader, optimizer, num_epochs=10):
    model.train()
    device = next(model.parameters()).device
  
    for epoch in range(num_epochs):
        total_loss = 0  # Initialize total_loss for each epoch
        for batch in dataloader:
            inputs, targets, attention_mask = batch
            inputs, targets, attention_mask = inputs.to(device), targets.to(device), attention_mask.to(device)
            
            print(f"Inputs shape: {inputs.shape}")
            print(f"Targets shape: {targets.shape}")
            if attention_mask is not None:
                print(f"Attention mask shape: {attention_mask.shape}")
            
            optimizer.zero_grad()
            loss = model.forward_loss(inputs, targets, attention_mask=attention_mask)
            loss.backward(retain_graph=True)
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataloader):.4f}")

def main():
    # Configuration
    tokens_file = 'tokens/lj_speech_tokens.pkl'
    max_token = 144641  # Specific max token you mentioned
    
    # Get vocabulary size based on the max token
    vocab_size = get_vocab_size(tokens_file)
    print(f"Detected Vocabulary Size: {vocab_size}")
    
    config = LlamaConfig(
        dim=512,  # Reduced model size for faster training
        n_layers=4,
        n_heads=8,
        vocab_size=vocab_size,  # Important: use exact vocab size
        max_seq_len=256,  # Increased to accommodate longer sequences
        multiple_of=256,
        use_scaled_rope=True
    )
    
    # Create dataset and dataloader
    dataset = TokenDataset(tokens_file, max_token)
    dataloader = DataLoader(
        dataset, 
        batch_size=2,
        shuffle=True, 
        collate_fn=collate_fn
    )
    
    # Initialize model
    torch.manual_seed(42)  # For reproducibility
    model = Llama(config).cuda()
    
    # Configure optimizer
    optimizer = model.configure_optimizers(
        weight_decay=0.01, 
        learning_rate=1e-4, 
        betas=(0.9, 0.95), 
        device_type='cuda'
    )
    
    # Enable anomaly detection for debugging
    torch.autograd.set_detect_anomaly(True)
    
    # Train the model
    train_model(model, dataloader, optimizer, num_epochs=10)

if __name__ == "__main__":
    main()


In [None]:
#Inferencing the loss

In [3]:
import torch
import torch.nn.functional as F
import pickle
from llama_model import Llama, LlamaConfig
from torch.utils.data import DataLoader, Dataset

class TokenDataset(Dataset):
    def __init__(self, tokens_file):
        with open(tokens_file, 'rb') as f:
            self.tokens = pickle.load(f)
    
    def __len__(self):
        return len(self.tokens)
    
    def __getitem__(self, idx):
        sequence = self.tokens[idx]
        special_token = 144641
        if special_token in sequence:
            split_idx = (sequence == special_token).nonzero(as_tuple=True)[0].item() + 1
        else:
            split_idx = len(sequence)
        input_seq = sequence[:split_idx]
        output_seq = sequence[split_idx:]
        input_seq = torch.tensor(input_seq, dtype=torch.long)
        output_seq = torch.tensor(output_seq, dtype=torch.long)
        return input_seq, output_seq

def collate_fn(batch):
    inputs, targets = zip(*batch)
    max_input_len = max(len(inp) for inp in inputs)
    max_target_len = max(len(tgt) for tgt in targets)
    max_len = min(max_input_len, max_target_len)
    inputs_padded = torch.stack([F.pad(inp[:max_len], (0, max_len - len(inp[:max_len])), value=0) for inp in inputs])
    targets_padded = torch.stack([F.pad(tgt[:max_len], (0, max_len - len(tgt[:max_len])), value=-100) for tgt in targets])
    attention_mask = (inputs_padded != 0).float()
    return inputs_padded, targets_padded, attention_mask

config = LlamaConfig(
    dim=1024,  
    n_layers=12,
    n_heads=16,
    vocab_size=144646,  
    max_seq_len=2048,  
    multiple_of=2048,
    use_scaled_rope=True
)
model = Llama(config).cuda()
state_dict = torch.load('models/llama_model_epoch_100.pth')
model.load_state_dict(state_dict)
model.eval()

tokens_file = 'tokens/lj_speech_tokens.pkl'
dataset = TokenDataset(tokens_file)
dataloader = DataLoader(dataset, batch_size=24, shuffle=False, collate_fn=collate_fn)

total_loss = 0
total_correct = 0
total_tokens = 0

with torch.no_grad():
    for batch in dataloader:
        inputs, targets, attention_mask = batch
        inputs, targets, attention_mask = inputs.cuda(), targets.cuda(), attention_mask.cuda()
        
        start_pos = 0
        seqlen = inputs.size(1)
        freqs_cis = model._prepare_rotary_embeddings(inputs, seqlen)
        mask = torch.triu(torch.full((seqlen, seqlen), float('-inf'), device=inputs.device), diagonal=1)
        
        h = model.tok_embeddings(inputs)
        for layer in model.layers:
            h = layer(h, start_pos, freqs_cis, mask)
        
        h = model.norm(h)
        outputs = model.output(h)
        
        loss = model.forward_loss(inputs, targets, attention_mask=attention_mask)
        total_loss += loss.item()
        
        _, predicted = torch.max(outputs, dim=-1)
        correct = (predicted == targets).float() * attention_mask
        total_correct += correct.sum().item()
        total_tokens += attention_mask.sum().item()

average_loss = total_loss / len(dataloader)
accuracy = total_correct / total_tokens

print(f"Loss: {average_loss:.4f}, Accuracy: {accuracy:.4f}")

  state_dict = torch.load('models/llama_model_epoch_100.pth')
  return torch.load(io.BytesIO(b))
  input_seq = torch.tensor(input_seq, dtype=torch.long)
  output_seq = torch.tensor(output_seq, dtype=torch.long)


Loss: 0.2663, Accuracy: 0.9007


In [13]:
import torch
from llama_model import Llama, LlamaConfig

config = LlamaConfig(
    dim=1024,  
    n_layers=12,
    n_heads=16,
    vocab_size=144646,  
    max_seq_len=2048,  
    multiple_of=2048,
    use_scaled_rope=True
)
model = Llama(config).cuda()
state_dict = torch.load('models/llama_model_epoch_100.pth')
model.load_state_dict(state_dict)
model.eval()

  state_dict = torch.load('models/llama_model_epoch_100.pth')


Llama(
  (tok_embeddings): Embedding(144646, 1024)
  (layers): ModuleList(
    (0-11): 12 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=1024, out_features=1024, bias=False)
        (wk): Linear(in_features=1024, out_features=1024, bias=False)
        (wv): Linear(in_features=1024, out_features=1024, bias=False)
        (wo): Linear(in_features=1024, out_features=1024, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=1024, out_features=4096, bias=False)
        (w2): Linear(in_features=4096, out_features=1024, bias=False)
        (w3): Linear(in_features=1024, out_features=4096, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=1024, out_features=144646, bias=False)
)

In [4]:
import torch.nn.functional as F
import torch

@torch.inference_mode()
def generate_output_tokens(model, input_tokens, max_new_tokens, temperature=0.8, top_k=50, stop_token=144644):
    idx = torch.tensor(input_tokens, dtype=torch.long).unsqueeze(0).cuda()
    start_pos = 0
    
    for layer in model.layers:
        layer.attention.cache_k = torch.zeros_like(layer.attention.cache_k)
        layer.attention.cache_v = torch.zeros_like(layer.attention.cache_v)
    
    for _ in range(max_new_tokens):
        idx_cond = idx if idx.size(1) <= model.config.max_seq_len else idx[:, -model.config.max_seq_len:]        
        h = model.tok_embeddings(idx_cond)        
        freqs_cis = model._prepare_rotary_embeddings(h, idx_cond.size(1))        
        seqlen = idx_cond.size(1)
        mask = torch.triu(torch.full((seqlen, seqlen), float('-inf'), device=h.device), diagonal=1)
        
        for layer in model.layers:
            h = layer(h, start_pos, freqs_cis, mask)
        
        logits = model.norm(h)
        logits = model.output(logits)[:, -1, :] / temperature
        
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = float('-inf')
        
        probs = F.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)
        
        if stop_token is not None and idx_next.item() == stop_token:
            break
        
        idx = torch.cat((idx, idx_next), dim=1)
        start_pos += 1

    generated_tokens = idx[:, len(input_tokens):]
    return generated_tokens.squeeze().cpu()

In [5]:
input_token = [128000, 135, 286, 790, 144642, 144645, 144641]
output = generate_output_tokens(model, input_token, max_new_tokens=1024)
print(len(output), output)

1024 tensor([140589, 134447, 131458,  ..., 129250, 129304, 129670])


In [33]:
def deflatten_tokens(tokens, n_codebooks, per_codebook_size):
    arr = []
    for i in range(n_codebooks):
        arr.append(tokens[i::n_codebooks])
    acoustic_tokens = np.stack(arr)
    return acoustic_tokens

In [34]:
output.shape

torch.Size([1024])

In [35]:
output

tensor([140589, 134447, 131458,  ..., 129250, 129304, 129670])

In [36]:
import numpy as np
ac = deflatten_tokens(tokens=output,n_codebooks=8,per_codebook_size=2048)

In [37]:
ac

array([[140589, 128933, 129304, ..., 129250, 128933, 128363],
       [134447, 130207, 129953, ..., 128727, 129304, 129953],
       [131458, 128727, 129947, ..., 129304, 129250, 129670],
       ...,
       [129670, 129250, 128363, ..., 128727, 129437, 129250],
       [129437, 129901, 128363, ..., 129282, 129250, 129304],
       [130207, 129947, 129250, ..., 129437, 129304, 129670]])

In [38]:
ac = ac-128000

In [39]:
ac.dtype

dtype('int64')

In [40]:
print(f"ac type: {type(ac)}")
print(f"ac shape: {ac.shape}")
print(f"ac dtype: {ac.dtype}")
print(f"ac sample values: {ac[:5]}")

ac type: <class 'numpy.ndarray'>
ac shape: (8, 128)
ac dtype: int64
ac sample values: [[12589   933  1304  1953  1054  1437   363  1947   363  2207   727  1250
   1953  1947  1054   363  1371  1947   363  1947  2207  1304  1670  1947
   1670  1947   363  1670   933  2273   868  2273  1670  1054  1282  1670
   1437  1953   363  1371  1054  1393  1250  1304  1670  1304  1670  1304
    933  1953  1947  1670  1250  1282  1250  2273  1670  1953  1670  2207
   1250   363  1304  1250  1953   363  1953  1250  1953  2207  1250   363
   1670  1670   363  1947  1670  1054  1304  2207   933   933  1304  1304
   1250   727  1437  1947  1250  1437  1437   363  1437  1947  1670  1437
   2273  2273  1437   933  1953  1953  1054  1437  1304   363  1953   363
   2273  1437  1282  1250  1230   363  1670  1304  1250  2207  1437  1953
   1947  1250   933  1282  1670  1250   933   363]
 [ 6447  2207  1953  2273  1054   727  1437  1947  1947   398  1054  1054
   1670  1953  1947  1250  1670  2207  1250  1953

In [41]:
# If ac_tensor is 1D, add a dimension to make it 2D
if ac_tensor.ndim == 1:
    ac_tensor = ac_tensor.unsqueeze(0)  # Add a batch dimension

# Ensure it's a long tensor
ac_tensor = ac_tensor.long()

# Verify the shape before decoding
print(f"Tensor shape before decode: {ac_tensor.shape}")

# Now try decoding
audio = mimi_tokenizer.decode(ac_tensor)

# Save the audio
torchaudio.save(
    'test.wav', 
    audio, 
    sample_rate=mimi_tokenizer.sampling_rate,
    compression=torchaudio.io.CodecConfig(bit_rate=128000)
)

Tensor shape before decode: torch.Size([1, 1024])


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [42]:
import torch
def flatten_tokens(arr: torch.tensor,
                   per_codebook_size: int):
    
    c, n = arr.shape
    i_values = np.arange(c) * per_codebook_size
    arr += i_values.reshape(c, 1)
    flat_arr = arr.reshape(c * n, order='F')
    return flat_arr

In [None]:
def deflatten_tokens(tokens, n_codebooks, per_codebook_size):
    arr = []
    for i in range(n_codebooks):
        arr.append(tokens[i::n_codebooks] - per_codebook_size * i)
    acoustic_tokens = np.stack(arr)
    return acoustic_tokens

In [17]:
import pickle

def load_pickle_file(file_path):
    with open(file_path, 'rb') as file:
        data = pickle.load(file)
    return data

pickle_file_path = 'tokens/lj_speech_tokens.pkl'  
data = load_pickle_file(pickle_file_path)
maxlen = 0
for item in data:
    maxlen = max(maxlen,len(item))
maxlen

1063

In [17]:
tensor([128000,    258,   1694,  71561,   6617,     13, 144642, 144645, 144641,
        128678, 130835, 132339, 134203, 137663, 139106, 142070, 142354, 129536,
        130238, 132125, 134707, 136590, 138658, 141753, 142495, 128362, 131923,
        133219, 135748, 137040, 138904, 140309, 142608, 129057, 131923, 132339,
        134833, 136254, 138522, 141680, 143399, 128389, 131442, 133814, 135748,
        136937, 139761, 142242, 143471, 129558, 130834, 133087, 135506, 137091,
        139920, 142253, 143824, 129054, 130238, 133462, 135701, 136499, 138697,
        141120, 144007, 129343, 130297, 132950, 135831, 137259, 139195, 141398,
        143899, 128321, 131400, 133395, 134382, 136325, 139199, 140550, 143850,
        129839, 131895, 132960, 136035, 137376, 139298, 142203, 144094, 128902,
        130505, 132643, 135349, 136999, 139242, 140509, 142800, 129389, 131868,
        134110, 134746, 137426, 138776, 141839, 143836, 129389, 131787, 132564,
        136076, 137925, 138423, 141994, 142439, 128483, 131513, 133818, 135280,
        136566, 139729, 141788, 143399, 129003, 130468, 132238, 134491, 137093,
        140108, 142029, 144369, 128722, 131481, 132955, 135161, 136952, 139786,
        142080, 142402, 128049, 130442, 133451, 136184, 137869, 139712, 140897,
        144011, 129407, 132080, 132740, 134385, 136581, 139019, 140432, 142363,
        128653, 131845, 134115, 135469, 137923, 138431, 141712, 143732, 128283,
        130951, 133686, 135985, 136323, 139514, 140513, 143179, 129655, 130807,
        133689, 135412, 136692, 139551, 141882, 142395, 129287, 130868, 133245,
        135905, 136533, 139764, 140614, 142874, 129120, 130955, 133112, 134286,
        137134, 138251, 141905, 142930, 130042, 130364, 132879, 135492, 136612,
        140107, 141107, 144344, 144644], dtype=torch.int32)

tensor([129536, 135589, 132156,  ..., 129692, 129698, 128678])

In [16]:
tokenizer = TTSTokenizer()
print(tokenizer.decode(output))

text vocab size 128000
[aco_1535][aco_9662][aco_4155][aco_2433][aco_3836][aco_471][aco_142][aco_1048][aco_1137][aco_994][aco_994][aco_798][aco_471][aco_1048][aco_1026][aco_1697][aco_1691][aco_1048][aco_1697][aco_1181][aco_1115][aco_677][aco_1137][aco_1137][aco_1181][aco_1048][aco_107][aco_1691][aco_1181][aco_798][aco_2017][aco_107][aco_1691][aco_107][aco_677][aco_107][aco_2017][aco_107][aco_994][aco_1048][aco_798][aco_2017][aco_1691][aco_612][aco_1691][aco_612][aco_2017][aco_798][aco_1691][aco_994][aco_1697][aco_107][aco_994][aco_994][aco_107][aco_107][aco_1026][aco_1181][aco_1181][aco_107][aco_1048][aco_994][aco_677][aco_1697][aco_1697][aco_1951][aco_2017][aco_1691][aco_677][aco_2017][aco_798][aco_1181][aco_994][aco_1697][aco_677][aco_1951][aco_471][aco_1137][aco_994][aco_1691][aco_1048][aco_1697][aco_1697][aco_107][aco_1048][aco_1115][aco_1697][aco_1026][aco_1691][aco_1181][aco_60][aco_1691][aco_994][aco_798][aco_1691][aco_1181][aco_1115][aco_1137][aco_1048][aco_994][aco_1414][aco_19