In [1]:
import torch
from supervoice_valle import SupervoceNARModel, SupervoceARModel, Tokenizer
from train.dataset import load_sampler
from IPython.display import Audio, display
import matplotlib.pyplot as plt
from vocos import Vocos

In [2]:
device = "cuda"
vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(device)
tokenizer = Tokenizer("./tokenizer_text.model")
sampler = load_sampler("./external_datasets/libriheavy/libriheavy_cuts_small.jsonl.gz", "./external_datasets/libriheavy-encodec/", 1)

# Load NAR
nar_model = SupervoceNARModel()
checkpoint = torch.load("./output/valle-35.pt", map_location = "cpu")
nar_model.load_state_dict(checkpoint['model'])
nar_model = nar_model.to(device)
nar_model.eval()
step = checkpoint['step']
print(checkpoint['step'])

# Load AR
ar_model = SupervoceARModel()
checkpoint = torch.load("./output/valle-ar-1.pt", map_location = "cpu")
ar_model.load_state_dict(checkpoint['model'])
ar_model = ar_model.to(device)
ar_model.eval()
step = checkpoint['step']
print(checkpoint['step'])



506000
4000


In [3]:
def inference_nar(text, audio, coarse_tokens):
    condition_text = tokenizer.encode(text).to(device)
    condition_audio = audio.to(device)
    predicted = [coarse_tokens.to(device)]
    for i in range(1, 8):
        p = nar_model(
            condition_text = [condition_text], 
            condition_audio = [condition_audio],
            audio = [torch.stack(predicted)],
            codec = [i]
        )

        p = p[0]
        p = torch.nn.functional.softmax(p, dim=-1)
        p = torch.argmax(p, dim=-1, keepdim=True)
        p = p.squeeze(-1)
        predicted.append(p)
    predicted = torch.stack(predicted)
    predicted = torch.cat([condition_audio, predicted], dim = 1)
    return predicted

def inference_ar(text, audio):
    condition_text = tokenizer.encode(text).to(device)
    audio_tokens = audio.to(device)
    while True:
        p = ar_model(
            text = [condition_text],
            audio = [audio_tokens]
        )
        p = p[0]
        p = torch.nn.functional.softmax(p, dim=-1)
        p = torch.argmax(p, dim=-1, keepdim=True)
        code = p[p.shape[0]-1]
        if (code[0] > 1023) or audio_tokens.shape[0] > 1000:
            break
        audio_tokens = torch.cat([audio_tokens, code])
    return audio_tokens

def decode(tokens):
    features = vocos.codes_to_features(tokens.to("cuda"))
    bandwidth_id = torch.tensor([2]).to("cuda")  # 6 kbps
    return vocos.decode(features, bandwidth_id=bandwidth_id)    

In [4]:
audio, text = sampler()
text = text[0]
audio = audio[0]
print(text)
display(Audio(data=decode(audio).cpu(), rate=24000))

I flew out of the house, and concealed myself in a thicket of bushes. There I remained in an agony of fear for two hours. Suddenly, a reptile of some kind seized my leg. In my fright, I struck a blow which loosened its hold, but I could not tell whether I had killed it; it was so dark, I could not see what it was; I only knew it was something cold and slimy. The pain I felt soon indicated that the bite was poisonous. I was compelled to leave my place of concealment, and I groped my way back into the house.


In [5]:
# NAR prediction
predicted = inference_nar(text, audio[:,:75*3], audio[0,75*3:])
display(Audio(data=decode(predicted).cpu(), rate=24000))

In [6]:
# AR + NAR prediction
predicted_ar = inference_ar(text, audio[0,:75*3])
predicted = inference_nar(text, audio[:,:75*3], predicted_ar[75*3:])
display(Audio(data=decode(predicted).cpu(), rate=24000))

In [7]:
print(predicted_ar.cpu().numpy())

[ 738  475  133  133  876  876  835  133  835  126 1017  133  738  876
  106  133  835  738  738  106  738  876  835  133  876  276   73  951
  951  400 1008  979   25  103  103  661  463  126   25  537  373  373
  698  311  559  994 1023  921  291  155  942  921  501  230  716  730
  730  372  372  958  683  126  730  216  699  432  126 1011  491  699
  533  939  501  642  602  602  602  677  677  499 1008  850  213   85
  372  372  627  747   20   11  402  402  393 1001  523  753  879  699
  537  373   53  160  819  176  373  983  275  192   73  372 1008  491
  224  879  604  433   25  779  999  276  433   25  926  457  136  967
  321  604 1017  835   25  463  463  103  463  276  698  921  796  912
  540   95  683  216  833  321  126 1019  661  133  751  321  523  428
  372  192  192  872 1011  604  779  463  463  463  463  463  373  887
  976  593  627  501  428  428  428   20  393  879  432  463  395  395
  537  537  395  583  573  208  871  871  871  666  666  879  604  463
  463 