In [1]:
import os
import importlib
import data.kdataset as kdataset
from SpeakerNet import SpeakerNet
from util import *
import gc
import torch
import yaml

from validation import *

In [2]:
with open('./configs/K_NeXt_TDNN.yaml') as file:
#with open('toy.yaml') as file:
    config = yaml.safe_load(file)

BATCH_SIZE = config['PARAMS']['BATCH_SIZE']
BASE_LR = float(config['PARAMS']['BASE_LR'])
NUM_WORKER = config['PARAMS']['NUM_WORKER']
CHANNEL_SIZE = config['PARAMS']['CHANNEL_SIZE']
EMBEDDING_SIZE = config['PARAMS']['EMBEDDING_SIZE']
MAX_FRAME = config['PARAMS']['MAX_FRAME']
SAMPLING_RATE = config['PARAMS']['SAMPLING_RATE']
MAX_EPOCH = config['PARAMS']['MAX_EPOCH']
DEVICE = config['PARAMS']['DEVICE']
BASE_PATH = config['PARAMS']['BASE_PATH']

In [None]:
from train import train

ckpt_name = 'ckpt_5.pt'
train(config, MAX_EPOCH, BATCH_SIZE, NUM_WORKER, BASE_LR, BASE_PATH, DEVICE, ckpt=True, ckpt_name=ckpt_name)
#train(config, MAX_EPOCH, BATCH_SIZE, NUM_WORKER, BASE_LR, BASE_PATH, DEVICE)

Setting Train Dataset...
Read pkl...

Number of speakers : 1471
Number of utterances : 5875802

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5875802 entries, 0 to 5875801
Data columns (total 3 columns):
 #   Column    Dtype 
---  ------    ----- 
 0   wavfiles  object
 1   labels    int64 
 2   speakers  object
dtypes: int64(1), object(2)
memory usage: 134.5+ MB
None

Setting Model...
Initialised AAMSoftmax margin 0.300 scale 40.000
⚡ feature_extractor ⚡
Mel_Spectrogram(
  (pre_emphasis): PreEmphasis()
  (mel_spectrogram): MelSpectrogram(
    (spectrogram): Spectrogram()
    (mel_scale): MelScale()
  )
)
⚡ spec_aug ⚡
SpecAugment(
  (fm): FrequencyMasking()
  (tm): TimeMasking()
)
⚡ model ⚡
NeXtTDNN(
  (stem): ModuleList(
    (0): Sequential(
      (0): Conv1d(80, 192, kernel_size=(4,), stride=(1,))
      (1): LayerNorm()
    )
  )
  (stages): ModuleList(
    (0-2): 3 x Sequential(
      (0): TSConvNeXt_light(
        (dwconv): Conv1d(192, 192, kernel_size=(65,), stride=(1,), paddi

  checkpoint = torch.load(os.path.join(config['CHECKPOINT']['ckpt_path'], ckpt_name))


SpeakerNet(
  1.63 M, 84.888% Params, 420.52 MMac, 98.346% MACs, 
  (feature_extractor): Mel_Spectrogram(
    0, 0.000% Params, 0.0 Mac, 0.000% MACs, 
    (pre_emphasis): PreEmphasis(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
    (mel_spectrogram): MelSpectrogram(
      0, 0.000% Params, 0.0 Mac, 0.000% MACs, 
      (spectrogram): Spectrogram(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
      (mel_scale): MelScale(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
    )
  )
  (spec_aug): SpecAugment(
    0, 0.000% Params, 0.0 Mac, 0.000% MACs, 
    (fm): FrequencyMasking(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
    (tm): TimeMasking(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
  )
  (model): NeXtTDNN(
    1.32 M, 68.832% Params, 395.25 MMac, 92.439% MACs, 
    (stem): ModuleList(
      (0): Sequential(
        61.63 k, 3.215% Params, 18.43 MMac, 4.310% MACs, 
        (0): Conv1d(61.63 k, 3.215% Params, 18.43 MMac, 4.310% MACs, 80, 192, kernel_size=(4,), stride=(1,))
        (1): LayerNorm(0, 0.000% P

 12%|█▏        | 1395/11476 [3:14:41<23:27:30,  8.38s/it, 1000 step loss : 6.088515281677246]

In [3]:
print('Load train dataset..')
asv_dataset = kdataset.asv_dataset(*config['TRAIN_DATASET'].values())

#train_dataset, validation_dataset = random_split(asv_dataset, [0.9, 0.1])
train_loader = torch.utils.data.DataLoader(
    asv_dataset,
    batch_size = BATCH_SIZE,
    num_workers = NUM_WORKER,
    pin_memory=True,
    drop_last=True,
    shuffle=True
)


Load train dataset..
Read pkl...

Number of speakers : 1471
Number of utterances : 5875802

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5875802 entries, 0 to 5875801
Data columns (total 3 columns):
 #   Column    Dtype 
---  ------    ----- 
 0   wavfiles  object
 1   labels    int64 
 2   speakers  object
dtypes: int64(1), object(2)
memory usage: 134.5+ MB
None


In [4]:
feature_extractor = importlib.import_module('preprocessing.mel_transform').__getattribute__("feature_extractor")
feature_extractor = feature_extractor(*config['FEATURE_EXTRACTOR'].values()).to(DEVICE)

#fe = feature_extractor(x.to(DEVICE))
#print('feature extractor :', fe.shape)

spec_aug = importlib.import_module('preprocessing.spec_aug').__getattribute__("spec_aug")
spec_aug = spec_aug(*config['SPEC_AUG'].values()).to(DEVICE)

#sa = spec_aug(fe)
#print('spec aug :', sa.shape)

model_cfg = config['MODEL']
model = importlib.import_module('models.NeXt_TDNN').__getattribute__("MainModel")
model =  model(
    depths = model_cfg['depths'], 
    dims = model_cfg['dims'],
    kernel_size = model_cfg['kernel_size'],
    block = model_cfg['block']
).to(DEVICE)

#m = model(sa.to(DEVICE))
#print('model :', m.shape)

aggregation = importlib.import_module('aggregation.vap_bn_tanh_fc_bn').__getattribute__("Aggregation")
aggregation = aggregation(*config['AGGREGATION'].values()).to(DEVICE)

#a = aggregation(m).to(DEVICE)
#print('aggregation : ', a.shape)

loss_function = importlib.import_module("loss.aamsoftmax").__getattribute__("LossFunction")
loss_function = loss_function(*config['LOSS'].values())

speaker_net = SpeakerNet(feature_extractor = feature_extractor,
                       spec_aug = spec_aug, 
                       model = model,
                       aggregation=aggregation,
                       loss_function = loss_function).to(DEVICE)

optimizer = importlib.import_module("optimizer." + 'adamw').__getattribute__("Optimizer")
optimizer = optimizer(speaker_net.parameters(), lr= BASE_LR*BATCH_SIZE, weight_decay = 0.01,)    

scheduler = importlib.import_module("scheduler." + 'steplr').__getattribute__("Scheduler")
scheduler = scheduler(optimizer, step_size = 10, gamma = 0.8)

Initialised AAMSoftmax margin 0.300 scale 40.000
⚡ feature_extractor ⚡
Mel_Spectrogram(
  (pre_emphasis): PreEmphasis()
  (mel_spectrogram): MelSpectrogram(
    (spectrogram): Spectrogram()
    (mel_scale): MelScale()
  )
)
⚡ spec_aug ⚡
SpecAugment(
  (fm): FrequencyMasking()
  (tm): TimeMasking()
)
⚡ model ⚡
NeXtTDNN(
  (stem): ModuleList(
    (0): Sequential(
      (0): Conv1d(80, 192, kernel_size=(4,), stride=(1,))
      (1): LayerNorm()
    )
  )
  (stages): ModuleList(
    (0-2): 3 x Sequential(
      (0): TSConvNeXt_light(
        (dwconv): Conv1d(192, 192, kernel_size=(65,), stride=(1,), padding=(32,), groups=192)
        (norm): LayerNorm()
        (pwconv1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (grn): GRN()
        (pwconv2): Linear(in_features=768, out_features=192, bias=True)
        (drop_path): Identity()
      )
    )
  )
  (MFA): Sequential(
    (0): Conv1d(576, 576, kernel_size=(1,), stride=(1,))
    (1): L

In [5]:
# 하지말것 cpu 이동됨
get_model_param_mmac(speaker_net, int(160*300 + 240), DEVICE)

SpeakerNet(
  1.63 M, 84.888% Params, 420.52 MMac, 98.346% MACs, 
  (feature_extractor): Mel_Spectrogram(
    0, 0.000% Params, 0.0 Mac, 0.000% MACs, 
    (pre_emphasis): PreEmphasis(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
    (mel_spectrogram): MelSpectrogram(
      0, 0.000% Params, 0.0 Mac, 0.000% MACs, 
      (spectrogram): Spectrogram(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
      (mel_scale): MelScale(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
    )
  )
  (spec_aug): SpecAugment(
    0, 0.000% Params, 0.0 Mac, 0.000% MACs, 
    (fm): FrequencyMasking(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
    (tm): TimeMasking(0, 0.000% Params, 0.0 Mac, 0.000% MACs, )
  )
  (model): NeXtTDNN(
    1.32 M, 68.832% Params, 395.25 MMac, 92.439% MACs, 
    (stem): ModuleList(
      (0): Sequential(
        61.63 k, 3.215% Params, 18.43 MMac, 4.310% MACs, 
        (0): Conv1d(61.63 k, 3.215% Params, 18.43 MMac, 4.310% MACs, 80, 192, kernel_size=(4,), stride=(1,))
        (1): LayerNorm(0, 0.000% P

('427.59', '1.92', 418.415136, 1.627416, 155.27652, 1.917144)

In [5]:
from eval_metric import compute_eer
from backend.cosine_similarity_full import cosine_similarity_full
from backend.euclidean_distance_full import euclidean_distance_full
import soundfile as sf
import tqdm

def make_enrollment(gt_model, enr_df_path, base_df):
    enr_list = []
    enr_path = os.path.join(enr_df_path, 'enr_df.pkl')
    
    label_list = base_df.labels.unique()
    
    enr_df = pd.DataFrame()
    for label in tqdm.tqdm(label_list):
        #wavquery = base_df.query("labels == {0}".format(label)).iloc[0]
        #enr_wav, _ = sf.read(wavquery['wavfiles'])
        
        #enr_x = torch.FloatTensor(enr_wav).to(device)
        #enr_emb = gt_model(enr_x.unsqueeze(0))
        #wavquery['enr_emb'] = enr_emb.detach().cpu()
        
        cohorts = []
        for i in range(50):
            wavquery = base_df.query('labels == {0}'.format(label)).iloc[0]
            if i%2 == 0:
                cohort = base_df.query('labels == {0}'.format(label)).sample().wavfiles.values[0]
                cohort_label = 1
            else:
                cohort = base_df.query('labels != {0}'.format(label)).sample().wavfiles.values[0]
                cohort_label = 0
            wavquery['cohort'] = cohort
            wavquery['cohort_label'] = cohort_label
            cohorts.append(wavquery)
        cohort_df = pd.DataFrame(cohorts)
        enr_df = pd.concat([enr_df, cohort_df], ignore_index=True)
    
    enr_df.to_pickle(enr_path)
    
    del base_df
    del enr_list
    del label_list
    del gt_model
    del cohort_df
    
    return enr_df

def validation(model, base_path, device):
    model.eval()
    
    cos_sim_list = []
    euc_dist_list = []
    valid_label = []
    
    if os.path.isfile(os.path.join(base_path, 'enr_df.pkl')):
        enr_df = pd.read_pickle(os.path.join(base_path, 'enr_df.pkl'))
    else:
        base_df = pd.read_pickle(os.path.join(base_path, 'train_df.pkl'))
        enr_df = make_enrollment(model, base_path, base_df)
        
        del base_df
    
    with torch.no_grad():
        #for idx, (x, y) in enumerate(loader):
        for _, row in enr_df.iterrows():
            enr_x, _ = sf.read(row['wavfiles'])
            enr_x = torch.FloatTensor(enr_x)
            enr_emb = model(enr_x.unsqueeze(0).to(device))
            
            spk_x, _ = sf.read(row['cohort'])
            spk_x = torch.FloatTensor(enr_x)
            spk_emb = model(spk_x.unsqueeze(0).to(device))
            
            valid_label.append(row['cohort_label'])
            #enr_emb = enr_df.query("labels == {0}".format(int(y)))['enr_emb'].values[0]
            #enr_x = enr_df.query("labels == {0}".format(row[]))
            #enr_emb = model(enr_x.to(device))
            #for yy in y:
            #    enr_emb.append(enr_df.query("labels == {0}".format(yy))['enr_emb'].values[0])
            #spk_emb = model(x.to(device))
        
            # cosine similarity
            #cos_sim = cosine_similarity_full(torch.stack(enr_emb).squeeze(1), spk_emb.detach().cpu())
            cos_sim = cosine_similarity_full(enr_emb, spk_emb)
            cos_sim_list.append(cos_sim.detach().cpu().numpy())
        
            # Euclidean
            #cos_sim = euclidean_distance_full(torch.stack(enr_emb).squeeze(1), spk_emb.detach().cpu())
            cos_sim = euclidean_distance_full(enr_emb, spk_emb)
            euc_dist_list.append(cos_sim.detach().cpu().numpy())
        
    # EER
    cos_eer, _ = compute_eer(cos_sim_list, valid_label)
    euc_eer, _ = compute_eer(euc_dist_list, valid_label)
    
    del enr_df
    del cos_sim_list
    del euc_dist_list
    del valid_label
    
    return cos_eer, euc_eer

# unique speaker를 추출해서 enrollment 파일을 만들고
# enrollment wav와 spk_emb를 비교

In [6]:
print('Model Training..')
print()

for epoch in range(MAX_EPOCH):
    losses = 0
    
    speaker_net.train()
    gc.collect()
    torch.cuda.empty_cache()
    print('=== Epoch : {0} ==='.format(epoch))
    for idx, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        
        spk_emb = speaker_net(x.to(DEVICE))
        loss, _ = loss_function(spk_emb, y.to(DEVICE))
        losses += loss.item()
        
        loss.backward()
        optimizer.step()
        
        if idx % 100 ==0:
            print('{0} step loss : {1}'.format(idx, loss))
    
    scheduler.step()
    print('-- Epoch {0} loss : {1}'.format(epoch, losses/len(train_loader)))
    
    # validation
    cos_eer, euc_eer = validation(speaker_net, BASE_PATH, DEVICE)
    print('Cosine EER : {0}, Euclidean EER : {1}'.format(cos_eer, euc_eer))
    
    ckpt_name = config['CHECKPOINT']['filename'].format(epoch)
    torch.save({'epoch' : epoch,
                'model' : speaker_net.state_dict(),
                'optimizer' : optimizer.state_dict(),
                'scheduler' : scheduler.state_dict(),
                'loss' : losses/len(train_loader),
                'cos_eer' : cos_eer,
                'euc_eer' : euc_eer,
                }, os.path.join(config['CHECKPOINT']['ckpt_path'], ckpt_name))
    print('-- Epoch {0} ckpt saved..'.format(epoch))
    print()


Model Training..

=== Epoch : 0 ===
0 step loss : 23.00664520263672
100 step loss : 18.883506774902344
200 step loss : 18.5505313873291
300 step loss : 18.00434684753418
400 step loss : 17.644161224365234
500 step loss : 17.54828643798828
600 step loss : 17.274524688720703
700 step loss : 16.927719116210938
800 step loss : 17.214963912963867
900 step loss : 16.62836265563965
1000 step loss : 16.54280662536621
1100 step loss : 16.301109313964844
1200 step loss : 16.11001968383789
1300 step loss : 15.71225643157959
1400 step loss : 15.703887939453125
1500 step loss : 15.203147888183594
1600 step loss : 15.463879585266113
1700 step loss : 14.993700981140137
1800 step loss : 14.833720207214355
1900 step loss : 14.880380630493164
2000 step loss : 14.52582836151123
2100 step loss : 15.052695274353027
2200 step loss : 14.097025871276855
2300 step loss : 14.471131324768066
2400 step loss : 13.708948135375977
2500 step loss : 13.32651424407959
2600 step loss : 14.166337013244629
2700 step loss 

TypeError: an integer is required

In [6]:
# 이어서 하기
ckpt_name = 'ckpt_1.pt'
checkpoint = torch.load(os.path.join(config['CHECKPOINT']['ckpt_path'], ckpt_name))
speaker_net.load_state_dict(checkpoint["model"], strict=False)
optimizer.load_state_dict(checkpoint["optimizer"])
scheduler.load_state_dict(checkpoint["scheduler"])
checkpoint_epoch = checkpoint["epoch"]

for epoch in range(checkpoint_epoch+1, MAX_EPOCH):
    losses = 0
    
    speaker_net.train()
    gc.collect()
    torch.cuda.empty_cache()
    print('=== Epoch : {0} ==='.format(epoch))
    for idx, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        
        spk_emb = speaker_net(x.to(DEVICE))
        loss, _ = loss_function(spk_emb, y.to(DEVICE))
        losses += loss.item()
        
        loss.backward()
        optimizer.step()
        
        if idx % 100 ==0:
            print('{0} step loss : {1}'.format(idx, loss))
    
    scheduler.step()
    print('-- Epoch {0} loss : {1}'.format(epoch, losses/len(train_loader)))
    
    # validation
    cos_eer, euc_eer = validation(speaker_net, BASE_PATH, DEVICE)
    print('Cosine EER : {0}, Euclidean EER : {1}'.format(cos_eer, euc_eer))
    
    ckpt_name = config['CHECKPOINT']['filename'].format(epoch)
    torch.save({'epoch' : epoch,
                'model' : speaker_net.state_dict(),
                'optimizer' : optimizer.state_dict(),
                'scheduler' : scheduler.state_dict(),
                'loss' : losses/len(train_loader),
                'cos_eer' : cos_eer,
                'euc_eer' : euc_eer,
                }, os.path.join(config['CHECKPOINT']['ckpt_path'], ckpt_name))
    print('-- Epoch {0} ckpt saved..'.format(epoch))
    print()

  checkpoint = torch.load(os.path.join(config['CHECKPOINT']['ckpt_path'], ckpt_name))


=== Epoch : 2 ===


KeyboardInterrupt: 

In [3]:
test_dataset = kdataset.asv_dataset(*config['TEST_DATASET'].values())

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size = 1,
    num_workers = 4,
    pin_memory=True,
    drop_last=False,
    shuffle=False
)

for idx, (x, label) in enumerate(test_loader):
    if idx==0:
        break

print(x.shape)

Read pkl...

Number of speakers : 2
Number of utterances : 278

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 278 entries, 0 to 277
Data columns (total 3 columns):
 #   Column    Non-Null Count  Dtype 
---  ------    --------------  ----- 
 0   wavfiles  278 non-null    object
 1   labels    278 non-null    int64 
 2   speakers  278 non-null    object
dtypes: int64(1), object(2)
memory usage: 6.6+ KB
None
torch.Size([1, 39200])


In [4]:
speaker_net.eval()
with torch.no_grad():
    for idx, (x, y) in enumerate(test_loader): #test_loader
        spk_emb = speaker_net(x)
        _, acc = loss_function(spk_emb, y)
        if idx % 50 == 0:
            print(acc)

NameError: name 'speaker_net' is not defined

single Wav file inference

In [35]:
# inference with wav file
import soundfile as sf
audio, sr = sf.read('B0001-0001M1113-2__000_0-00200760.wav')
audio = torch.FloatTensor(audio)
test_audio = audio.unsqueeze(dim=0)

speaker_net.eval()
out = speaker_net(test_audio)
out.shape

torch.Size([1, 192])