In [1]:
import os
from time import time
import torch
import torchaudio
from cpc.feature_loader import loadModel, getCheckpointData
from cpc.train import getCriterion

## Define helper functions

In [2]:
def loadCriterion(pathCheckpoint, downsampling, nSpeakers, nPhones):
    _, _, locArgs = getCheckpointData(os.path.dirname(pathCheckpoint))
    criterion = getCriterion(locArgs, downsampling, nSpeakers, nPhones)

    state_dict = torch.load(pathCheckpoint, 'cpu')
    
    # for newer versions of CPC, the name is changed
    try:
        criterion.load_state_dict(state_dict["cpcCriterion"])
    except RuntimeError:
        state_dict["cpcCriterion"]['speakerEmb.weight'] = state_dict["cpcCriterion"]['speaker_norm.emb.weight']
        del state_dict["cpcCriterion"]['speaker_norm.emb.weight']
        criterion.load_state_dict(state_dict["cpcCriterion"])

    return criterion

def getPositiveSamples(encodedData, nPredicts=12):
    batchSize, nNegativeExt, dimEncoded = encodedData.size()
    outputs = []
    
    for k in range(1, nPredicts + 1):
        # Positive samples
        if k < nPredicts:
            posSeq = encodedData[:, k:-(nPredicts-k)]
        else:
            posSeq = encodedData[:, k:]

        posSeq = posSeq.view(batchSize, 1, posSeq.size(1), dimEncoded)
        outputs.append(posSeq)
    
    return outputs

def scoring(logprobs):
    n = len(logprobs)//2
    true_preds = 0
    results = []
    for i in range(n):
        if logprobs[2*i] > logprobs[2*i + 1]:
            true_preds += 1
            results.append(True)
        else:
            results.append(False)
    if n != 0:
        print("Test accuracy: {}/{} ({:.2f}%)".format(true_preds, n, 100*true_preds/n))

## Function to compute proba

In [3]:
def compute_score_CPC(seqPath, 
                      cpcModel, 
                      cpcCriterion, 
                      speakerLabel=0,
                      nTemporal=12,
                      logits_scaling=1,
                      reduce_method='sum',
                      prob_estimator='negative_sampling',
                      n_negative_sampling=None
                     ):
    '''
    Comment on some useful args:
        logits_scaling:  put this high to avoid having 0. log proba for near temporal steps when 
                         using sigmoid, but it seems that 1 (default) gives the best results
         reduce_method:  'sum' seems to work best
        prob_estimator:  using 'sigmoid' is faster as we don't need to compute negative samples,
                         but using 'negative_sampling' seems to have better results as this is
                         the way the CPC model is trained (however this will make the scores varying)
   n_negative_sampling:  leave this to 'None' and the model will use 128(defaut) negative samples
    '''
    assert reduce_method in ['sum', 'mean']
    assert prob_estimator in ['sigmoid', 'negative_sampling']
    with torch.no_grad():
        # Read the input signals
        seq = torchaudio.load(seqPath)[0] # 1 x frames
        seq = seq[:,:].view(1, 1, -1).cuda() # 1 x 1 x frames
        
        # Read CPC features
        cpcModel.gAR.hidden = None
        cFeature, encodedData, label = cpcModel(seq, label=None)
        ## cFeature: 1 x T x D_feat
        ## encodedData: 1 x T x D_enc
        
        # Prepare CPC features for criterion
        batchSize, seqSize, _ = cFeature.size()
        windowSize = seqSize - cpcCriterion.nPredicts # T - 12
        cFeature = cFeature[:, :windowSize] # 1 x (T - 12) x D_feat
        
        # Get positive encoded samples
        if prob_estimator=='negative_sampling':
            if n_negative_sampling is not None:
                cpcCriterion.negativeSamplingExt = n_negative_sampling
            sampledData, _ = cpcCriterion.sampleClean(encodedData, windowSize) # 12 x 1 x (1 + n_negative_sampling) x (T - 12) x D_enc
        else:
            sampledData = getPositiveSamples(encodedData, cpcCriterion.nPredicts) # 12 x 1 x 1 x (T - 12) x D_enc
        
        # Speaker embeddings
        if cpcCriterion.speakerEmb is not None:
            label = torch.tensor(speakerLabel).cuda()
            l_ = label.view(batchSize, 1).expand(batchSize, windowSize) # 1 x (T - 12)
            embeddedSpeaker = cpcCriterion.speakerEmb(l_) # 1 x (T - 12) x D_spkemb
            cFeature = torch.cat([cFeature, embeddedSpeaker], dim=2) # 1 x (T - 12) x (D_feat+D_spkemb)
            
        # Compute the criterion outputs
        predictions = cpcCriterion.wPrediction(cFeature, sampledData) # 12 x 1 x 1 x (T - 12)
        
        # Compute the pseudo log-probas
        lp_score = 0.
        for outputs in predictions[:nTemporal]:
            logits = outputs[0]/logits_scaling
            if logits.size(0) == 1:
                logits = logits.sigmoid()
            else:
                logits = logits.softmax(0)
            if reduce_method == 'sum':
                lp_score += logits[0].log().sum()
            elif reduce_method == 'mean':
                lp_score += logits[0].log().mean()
        lp_score  /= nTemporal
        
    return lp_score.item()

## Load CPC model and criterion

In [4]:
# Checkpoint path
#pathCheckpoint = "/private/home/mriviere/FairInternal/CPC_torch/Librispeech100/channel_norm_attention_dropout_2levels_multihead/checkpoint_170.pt"
pathCheckpoint = "/private/home/mriviere/FairInternal/CPC_torch/Librilight_subsample/6k_top_ctc/checkpoint_30.pt"
# Load CPC model
cpcModel = loadModel([pathCheckpoint])[0].cuda()
cpcModel.gAR.keepHidden = True
cpcModel.eval()
# Load CPC criterion
cpcCriterion = loadCriterion(pathCheckpoint, cpcModel.gEncoder.DOWNSAMPLING, 7504, None).cuda()
cpcCriterion.eval()
print('CPC model and criterion loaded!')

Loading checkpoint /private/home/mriviere/FairInternal/CPC_torch/Librilight_subsample/6k_top_ctc/checkpoint_30.pt
Loading the state dict at /private/home/mriviere/FairInternal/CPC_torch/Librilight_subsample/6k_top_ctc/checkpoint_30.pt
Using 6 speaker embeddings for 7504 speakers
Activating multi-head rnn
CPC model and criterion loaded!


### Testing

In [5]:
compute_score_CPC(seqPath='/private/home/ntuanh/Projects/ZeroSpeech/data/test/sWUGGY/final/audio/synthesis_16k/inter/dev/voiceA/0_obscenely_inter_dev_final_voiceA.wav', 
                            cpcModel=cpcModel, 
                            cpcCriterion=cpcCriterion)

-232.22927856445312

### sWUGGY dev

In [6]:
pathAudio="/private/home/ntuanh/Projects/ZeroSpeech/data/test/sWUGGY/final/audio/synthesis_16k/inter/dev/voiceA/"
filelist = sorted([item for item in os.listdir(pathAudio) if item.endswith('.wav')], key = lambda x: int(x.split('_')[0]))
print(f'{len(filelist)} files found!')
filelist[:6]

10000 files found!


['0_obscenely_inter_dev_final_voiceA.wav',
 '1_opsenely_inter_dev_final_voiceA.wav',
 '2_oxidation_inter_dev_final_voiceA.wav',
 '3_accidation_inter_dev_final_voiceA.wav',
 '4_alida_inter_dev_final_voiceA.wav',
 '5_aleca_inter_dev_final_voiceA.wav']

In [7]:
%%time
stime = time()
scores_files = []
for i, file in enumerate(filelist):
    sc = compute_score_CPC(seqPath=os.path.join(pathAudio, file), 
                            cpcModel=cpcModel, 
                            cpcCriterion=cpcCriterion)
    scores_files.append(sc)
    if i % 100 == 0:
        print(f'{i+1}/{len(filelist)} files computed in {time()-stime:.2f} seconds', end = '\r')
print(f'\n...done in {time()-stime:.2f} seconds.')
scoring(scores_files)

9901/10000 files computed in 83.15 seconds
...done in 83.94 seconds.
Test accuracy: 2877/5000 (57.54%)
CPU times: user 1min 16s, sys: 1.29 s, total: 1min 17s
Wall time: 1min 23s


### text-WUGGY set

In [8]:
pathAudio="/private/home/ntuanh/Projects/ZeroSpeech/data/test/WUGGY-text/devLS_10k/audio/wavs-16k/"
filelist = sorted([item for item in os.listdir(pathAudio) if item.endswith('.wav')], key = lambda x: int(x.split('_')[0]))
print(f'{len(filelist)} files found!')
filelist[:6]

20000 files found!


['0_aback.wav',
 '1_aball.wav',
 '2_abandon.wav',
 '3_agandon.wav',
 '4_abandoning.wav',
 '5_afandoning.wav']

In [9]:
%%time
stime = time()
scores_files = []
for i, file in enumerate(filelist):
    sc = compute_score_CPC(seqPath=os.path.join(pathAudio, file), 
                            cpcModel=cpcModel, 
                            cpcCriterion=cpcCriterion)
    scores_files.append(sc)
    if i % 100 == 0:
        print(f'{i+1}/{len(filelist)} files computed in {time()-stime:.2f} seconds', end = '\r')
print(f'\n...done in {time()-stime:.2f} seconds.')
scoring(scores_files)

19901/20000 files computed in 165.93 seconds
...done in 166.74 seconds.
Test accuracy: 6002/10000 (60.02%)
CPU times: user 2min 36s, sys: 2.02 s, total: 2min 38s
Wall time: 2min 46s
