In [9]:

from models.WAVENET.wavenet_v1.wavenet_training import generate_audio, WavenetTrainer
from safetensors import torch
from models.WAVENET.wavenet_v1.audio_data import WavenetDataset
from models.WAVENET.wavenet_v1.wavenet_model import WaveNetModel, load_latest_model_from
import tensorflow as tf

# modified WAVENET file implementation from: https://github.com/Vichoko/pytorch-wavenet/tree/master

dtype = torch.FloatTensor  # data type
ltype = torch.LongTensor  # label type

use_cuda = torch.cuda.is_available()
if use_cuda:
    print('use gpu')
    dtype = torch.cuda.FloatTensor
    ltype = torch.cuda.LongTensor

model = WaveNetModel(layers=10,
                     blocks=3,
                     dilation_channels=32,
                     residual_channels=32,
                     skip_channels=1024,
                     end_channels=512,
                     output_length=16,
                     dtype=dtype,
                     bias=True)
model.cuda()
print('model: ', model)
print('receptive field: ', model.receptive_field)
print('parameter count: ', model.parameter_count())

data = WavenetDataset(dataset_file='./example.npz',
                      item_length=model.receptive_field + model.output_length - 1,
                      target_length=model.output_length,
                      file_location='../unpacked_data',
                      test_stride=500)
print('the dataset has ' + str(len(data)) + ' items')


def generate_and_log_samples(step):
    sample_length = 32000
    gen_model = load_latest_model_from('snapshots', use_cuda=False)
    print("start generating...")
    samples = generate_audio(gen_model,
                             length=sample_length,
                             temperatures=[0.5])
    tf_samples = tf.convert_to_tensor(samples, dtype=tf.float32)
    # logger.audio_summary('temperature_0.5', tf_samples, step, sr=16000)

    samples = generate_audio(gen_model,
                             length=sample_length,
                             temperatures=[1.])
    tf_samples = tf.convert_to_tensor(samples, dtype=tf.float32)
    # logger.audio_summary('temperature_1.0', tf_samples, step, sr=16000)
    print("audio clips generated")


trainer = WavenetTrainer(model=model,
                         dataset=data,
                         lr=0.001,
                         snapshot_path='snapshots',
                         snapshot_name='birdset_model',
                         snapshot_interval=1000,
                         dtype=dtype,
                         ltype=ltype)

start_data = data[250000][0]
display(data)
trainer.train(batch_size=16,
              epochs=12)

use gpu
model:  WaveNetModel(
  (filter_convs): ModuleList(
    (0-29): 30 x Conv1d(32, 32, kernel_size=(2,), stride=(1,))
  )
  (gate_convs): ModuleList(
    (0-29): 30 x Conv1d(32, 32, kernel_size=(2,), stride=(1,))
  )
  (residual_convs): ModuleList(
    (0-29): 30 x Conv1d(32, 32, kernel_size=(1,), stride=(1,))
  )
  (skip_convs): ModuleList(
    (0-29): 30 x Conv1d(32, 1024, kernel_size=(1,), stride=(1,))
  )
  (start_conv): Conv1d(256, 32, kernel_size=(1,), stride=(1,))
  (end_conv_1): Conv1d(1024, 512, kernel_size=(1,), stride=(1,))
  (end_conv_2): Conv1d(512, 256, kernel_size=(1,), stride=(1,))
)
receptive field:  3070
parameter count:  1834592
one hot input
the dataset has 255490 items
start training...


<audio_data.WavenetDataset at 0x1abdc6542d0>

epoch 0
loss at step 50: 4.003303761482239
one training step does take approximately 0.2564183950424194 seconds)
loss at step 100: 3.17235077381134
loss at step 150: 3.122062883377075
loss at step 200: 3.1990980386734007


100%|██████████| 32/32 [00:02<00:00, 10.73it/s]


validation loss: 2.68060103058815
validation accuracy: 20.2392578125%
loss at step 250: 3.056705141067505
loss at step 300: 2.9848251390457152
loss at step 350: 2.883783016204834
loss at step 400: 2.9968074655532835


100%|██████████| 32/32 [00:02<00:00, 10.94it/s]


validation loss: 2.5347808115184307
validation accuracy: 23.6083984375%
loss at step 450: 2.9816877841949463
loss at step 500: 2.7024767446517943
loss at step 550: 2.6369041442871093
loss at step 600: 2.561952242851257


100%|██████████| 32/32 [00:02<00:00, 10.84it/s]


validation loss: 2.2553861029446125
validation accuracy: 28.9306640625%
loss at step 650: 2.667593116760254
loss at step 700: 2.6250334119796754
loss at step 750: 2.5522208881378172
loss at step 800: 2.518535180091858


100%|██████████| 32/32 [00:03<00:00, 10.65it/s]


validation loss: 2.1644764691591263
validation accuracy: 29.7119140625%
loss at step 850: 2.4818655228614808
loss at step 900: 2.3894595432281496
loss at step 950: 2.4699355220794676
loss at step 1000: 2.3447110867500305


100%|██████████| 32/32 [00:03<00:00, 10.64it/s]


validation loss: 2.0521650724112988
validation accuracy: 31.06689453125%
loss at step 1050: 2.3461958932876588
loss at step 1100: 2.3993005299568178
loss at step 1150: 2.385872085094452
loss at step 1200: 2.299815812110901


100%|██████████| 32/32 [00:02<00:00, 10.88it/s]


validation loss: 1.9788086414337158
validation accuracy: 32.7880859375%
loss at step 1250: 2.3319837641716004
loss at step 1300: 2.343190920352936
loss at step 1350: 2.354561774730682
loss at step 1400: 2.338815813064575


100%|██████████| 32/32 [00:02<00:00, 10.84it/s]


validation loss: 1.9704119563102722
validation accuracy: 33.203125%
loss at step 1450: 2.306089787483215
loss at step 1500: 2.4071853876113893
loss at step 1550: 2.2570983457565306
loss at step 1600: 2.3501757526397706


100%|██████████| 32/32 [00:02<00:00, 11.03it/s]


validation loss: 1.948568519204855
validation accuracy: 32.5927734375%
loss at step 1650: 2.2863951349258422
loss at step 1700: 2.29395361661911
loss at step 1750: 2.2980816650390623
loss at step 1800: 2.225964250564575


100%|██████████| 32/32 [00:02<00:00, 10.84it/s]


validation loss: 1.911602895706892
validation accuracy: 34.46044921875%
loss at step 1850: 2.2462178683280944
loss at step 1900: 2.2050483798980713
loss at step 1950: 2.274432137012482
loss at step 2000: 2.233248646259308


100%|██████████| 32/32 [00:02<00:00, 10.93it/s]


validation loss: 1.8863034956157207
validation accuracy: 34.11865234375%
loss at step 2050: 2.240848572254181
loss at step 2100: 2.20605482339859
loss at step 2150: 2.1702775979042053
loss at step 2200: 2.330989754199982


100%|██████████| 32/32 [00:02<00:00, 10.97it/s]


validation loss: 1.895939439535141
validation accuracy: 34.5947265625%
loss at step 2250: 2.245310831069946
loss at step 2300: 2.296190061569214
loss at step 2350: 2.30474228143692
loss at step 2400: 2.2152951908111573


100%|██████████| 32/32 [00:02<00:00, 10.86it/s]


validation loss: 1.904658231884241
validation accuracy: 33.8623046875%
loss at step 2450: 2.251043429374695
loss at step 2500: 2.193372700214386
loss at step 2550: 2.196246702671051
loss at step 2600: 2.252015283107758


100%|██████████| 32/32 [00:02<00:00, 10.97it/s]


validation loss: 1.889542255550623
validation accuracy: 33.69140625%
loss at step 2650: 2.1721480894088745
loss at step 2700: 2.196977105140686
loss at step 2750: 2.245991811752319
loss at step 2800: 2.2117829728126526


100%|██████████| 32/32 [00:02<00:00, 10.95it/s]


validation loss: 1.862830314785242
validation accuracy: 35.4736328125%
loss at step 2850: 2.2048026299476624
loss at step 2900: 2.259955871105194
loss at step 2950: 2.2767953205108644
loss at step 3000: 2.213124165534973


100%|██████████| 32/32 [00:02<00:00, 10.89it/s]


validation loss: 1.8657510839402676
validation accuracy: 34.97314453125%
loss at step 3050: 2.233362600803375
loss at step 3100: 2.2708731722831725
loss at step 3150: 2.2551782941818237
loss at step 3200: 2.2770800161361695


100%|██████████| 32/32 [00:03<00:00, 10.51it/s]


validation loss: 1.8434064015746117
validation accuracy: 35.7421875%
loss at step 3250: 2.2411942076683045
loss at step 3300: 2.170863084793091
loss at step 3350: 2.137935893535614
loss at step 3400: 2.1998226928710936


100%|██████████| 32/32 [00:02<00:00, 10.93it/s]


validation loss: 1.8692993074655533
validation accuracy: 34.765625%
loss at step 3450: 2.149662711620331
loss at step 3500: 2.1991938757896423
loss at step 3550: 2.2115960836410524
loss at step 3600: 2.2119243288040162


100%|██████████| 32/32 [00:02<00:00, 10.84it/s]


validation loss: 1.8473293669521809
validation accuracy: 35.41259765625%
loss at step 3650: 2.1984219765663147
loss at step 3700: 2.2338526034355164
loss at step 3750: 2.1555791163444518
loss at step 3800: 2.218813989162445


100%|██████████| 32/32 [00:02<00:00, 10.89it/s]


validation loss: 1.8260559365153313
validation accuracy: 35.6201171875%
loss at step 3850: 2.22371280670166
loss at step 3900: 2.170864789485931
loss at step 3950: 2.1177186036109923
loss at step 4000: 2.1874665093421934


100%|██████████| 32/32 [00:02<00:00, 10.82it/s]


validation loss: 1.8406949862837791
validation accuracy: 35.5712890625%
loss at step 4050: 2.19281329870224
loss at step 4100: 2.257584426403046
loss at step 4150: 2.195015525817871
loss at step 4200: 2.208439621925354


100%|██████████| 32/32 [00:02<00:00, 10.95it/s]


validation loss: 1.8868031799793243
validation accuracy: 33.3251953125%
loss at step 4250: 2.2077652287483214
loss at step 4300: 2.115995399951935
loss at step 4350: 2.2108562755584718
loss at step 4400: 2.2333449649810793


100%|██████████| 32/32 [00:02<00:00, 10.96it/s]


validation loss: 1.848360724747181
validation accuracy: 35.1806640625%
loss at step 4450: 2.2453002524375916
loss at step 4500: 2.186497552394867
loss at step 4550: 2.1627466750144957
loss at step 4600: 2.18040607213974


100%|██████████| 32/32 [00:02<00:00, 10.96it/s]


validation loss: 1.818458680063486
validation accuracy: 35.9375%
loss at step 4650: 2.228012819290161
loss at step 4700: 2.1051745128631594
loss at step 4750: 2.2310316300392152
loss at step 4800: 2.1123380517959593


100%|██████████| 32/32 [00:02<00:00, 10.93it/s]


validation loss: 1.8222182095050812
validation accuracy: 36.1572265625%
loss at step 4850: 2.2041117405891417
loss at step 4900: 2.231197829246521
loss at step 4950: 2.126080870628357
loss at step 5000: 2.2354145216941834


100%|██████████| 32/32 [00:03<00:00, 10.45it/s]


validation loss: 1.808344617486
validation accuracy: 36.46240234375%
loss at step 5050: 2.128793613910675
loss at step 5100: 2.111220064163208
loss at step 5150: 2.175026159286499
loss at step 5200: 2.1888611459732057


100%|██████████| 32/32 [00:02<00:00, 10.90it/s]


validation loss: 1.8083001375198364
validation accuracy: 36.21826171875%
loss at step 5250: 2.131729016304016
loss at step 5300: 2.2021947860717774
loss at step 5350: 2.11373996257782
loss at step 5400: 2.2044270706176756


100%|██████████| 32/32 [00:02<00:00, 11.00it/s]


validation loss: 1.8359516151249409
validation accuracy: 35.09521484375%
loss at step 5450: 2.157803859710693
loss at step 5500: 2.094798765182495
loss at step 5550: 2.148282861709595
loss at step 5600: 2.1233441185951234


100%|██████████| 32/32 [00:02<00:00, 11.02it/s]


validation loss: 1.8194425106048584
validation accuracy: 35.94970703125%
loss at step 5650: 2.179317100048065
loss at step 5700: 2.2135221362113953
loss at step 5750: 2.1237929034233094
loss at step 5800: 2.0585733437538147


100%|██████████| 32/32 [00:02<00:00, 11.00it/s]


validation loss: 1.8182228319346905
validation accuracy: 35.80322265625%
loss at step 5850: 2.186076986789703
loss at step 5900: 2.110675039291382
loss at step 5950: 2.179688696861267
loss at step 6000: 2.1197976064682007


100%|██████████| 32/32 [00:02<00:00, 11.02it/s]


validation loss: 1.8013546168804169
validation accuracy: 36.36474609375%
loss at step 6050: 2.1196470856666565
loss at step 6100: 2.194414567947388
loss at step 6150: 2.0871270847320558
loss at step 6200: 2.1260515952110293


100%|██████████| 32/32 [00:02<00:00, 10.99it/s]


validation loss: 1.8133989796042442
validation accuracy: 35.9130859375%
loss at step 6250: 2.0966803860664367
loss at step 6300: 2.223388433456421
loss at step 6350: 2.1731936526298523
loss at step 6400: 2.1702924466133116


100%|██████████| 32/32 [00:02<00:00, 10.93it/s]


validation loss: 1.8038352131843567
validation accuracy: 36.03515625%
loss at step 6450: 2.1525226616859436
loss at step 6500: 2.161122477054596
loss at step 6550: 2.2170011806488037
loss at step 6600: 2.1465731143951414


100%|██████████| 32/32 [00:02<00:00, 10.84it/s]


validation loss: 1.820192288607359
validation accuracy: 36.474609375%
loss at step 6650: 2.1856304383277894
loss at step 6700: 2.2025769686698915
loss at step 6750: 2.163448386192322
loss at step 6800: 2.1714552521705626


100%|██████████| 32/32 [00:02<00:00, 11.02it/s]


validation loss: 1.8136172071099281
validation accuracy: 35.6689453125%
loss at step 6850: 2.1932574129104614
loss at step 6900: 2.2048912620544434
loss at step 6950: 2.099041244983673
loss at step 7000: 2.1515074825286864


100%|██████████| 32/32 [00:02<00:00, 10.81it/s]


validation loss: 1.821145884692669
validation accuracy: 35.6689453125%
loss at step 7050: 2.134585528373718
loss at step 7100: 2.1827123975753784
loss at step 7150: 2.115551266670227
loss at step 7200: 2.0582373189926146


100%|██████████| 32/32 [00:02<00:00, 10.96it/s]


validation loss: 1.8199075870215893
validation accuracy: 36.09619140625%
loss at step 7250: 2.1502810430526735
loss at step 7300: 2.1253521966934206
loss at step 7350: 2.1854027485847474
loss at step 7400: 2.087223174571991


100%|██████████| 32/32 [00:02<00:00, 10.79it/s]


validation loss: 1.8098802044987679
validation accuracy: 35.75439453125%
loss at step 7450: 2.103738327026367
loss at step 7500: 2.135127522945404
loss at step 7550: 2.1118759059906007
loss at step 7600: 2.1430595922470093


100%|██████████| 32/32 [00:02<00:00, 10.93it/s]


validation loss: 1.8014345318078995
validation accuracy: 37.02392578125%
loss at step 7650: 2.0886248922348023
loss at step 7700: 2.121959753036499
loss at step 7750: 2.1757443714141846
loss at step 7800: 2.1833918142318725


100%|██████████| 32/32 [00:03<00:00, 10.52it/s]


validation loss: 1.857930961996317
validation accuracy: 35.3759765625%
loss at step 7850: 2.101304681301117
loss at step 7900: 2.2043867897987366
loss at step 7950: 2.0913902831077578
loss at step 8000: 2.114764218330383


100%|██████████| 32/32 [00:02<00:00, 10.86it/s]


validation loss: 1.780763953924179
validation accuracy: 37.21923828125%
loss at step 8050: 2.1116508078575134
loss at step 8100: 2.184465310573578
loss at step 8150: 2.145395510196686
loss at step 8200: 2.115126848220825


100%|██████████| 32/32 [00:03<00:00, 10.40it/s]


validation loss: 1.7925753220915794
validation accuracy: 36.6455078125%
loss at step 8250: 2.090951557159424
loss at step 8300: 2.105887713432312
loss at step 8350: 2.1078706860542296
loss at step 8400: 2.1019226145744323


100%|██████████| 32/32 [00:02<00:00, 10.72it/s]


validation loss: 1.79171771556139
validation accuracy: 36.75537109375%
loss at step 8450: 2.1482628083229063
loss at step 8500: 2.105095474720001
loss at step 8550: 2.0653357434272768
loss at step 8600: 2.1338274335861205


100%|██████████| 32/32 [00:02<00:00, 10.96it/s]


validation loss: 1.7851762808859348
validation accuracy: 36.4990234375%
loss at step 8650: 2.1114230704307557
loss at step 8700: 2.1048057532310485
loss at step 8750: 2.1271958112716676
loss at step 8800: 2.195327832698822


100%|██████████| 32/32 [00:03<00:00, 10.43it/s]


validation loss: 1.7833297923207283
validation accuracy: 35.9619140625%
loss at step 8850: 2.0822678351402284
loss at step 8900: 2.1290112113952637
loss at step 8950: 2.1150684094429018
loss at step 9000: 2.0983270311355593


100%|██████████| 32/32 [00:03<00:00, 10.65it/s]


validation loss: 1.7894704267382622
validation accuracy: 36.92626953125%
loss at step 9050: 2.151949589252472
loss at step 9100: 2.0514287519454957
loss at step 9150: 2.145778338909149
loss at step 9200: 2.0387615156173706


100%|██████████| 32/32 [00:02<00:00, 10.98it/s]


validation loss: 1.7867044992744923
validation accuracy: 36.6455078125%
loss at step 9250: 2.104103357791901
loss at step 9300: 2.1316096234321593
loss at step 9350: 2.122466802597046
loss at step 9400: 2.1493461108207703


100%|██████████| 32/32 [00:02<00:00, 10.89it/s]


validation loss: 1.77762046828866
validation accuracy: 36.4990234375%
loss at step 9450: 2.1047444939613342
loss at step 9500: 2.1071203541755676
loss at step 9550: 2.186663565635681
loss at step 9600: 2.1067179012298585


100%|██████████| 32/32 [00:02<00:00, 10.85it/s]


validation loss: 1.7851886972784996
validation accuracy: 36.3037109375%
loss at step 9650: 2.1217751789093016
loss at step 9700: 2.108080413341522


KeyboardInterrupt: 

In [12]:

from models.WAVENET.wavenet_v1.audio_data import WavenetDataset

dtype = torch.FloatTensor  # data type
ltype = torch.LongTensor  # label type

use_cuda = torch.cuda.is_available()
if use_cuda:
    print('use gpu')
    dtype = torch.cuda.FloatTensor
    ltype = torch.cuda.LongTensor
device = torch.device("cuda")

model = WaveNetModel(layers=10,
                     blocks=3,
                     dilation_channels=32,
                     residual_channels=32,
                     skip_channels=1024,
                     end_channels=512,
                     output_length=16,
                     dtype=dtype,
                     bias=True)
model.load_state_dict(torch.load('birdset_modelwavenet_model.pth', weights_only=False))
model.eval()
model.cuda()

data = WavenetDataset(dataset_file='./example.npz',
                      item_length=model.receptive_field + model.output_length - 1,
                      target_length=model.output_length,
                      file_location='./unpacked_data',
                      test_stride=500)
print('the dataset has ' + str(len(data)) + ' items')



start_data = data[250000][0]
start_data = torch.max(start_data, 0)[1].to('cuda')
def prog_callback(step, total_steps):
    print(str(100 * step // total_steps) + "% generated")

generated = model.generate_fast(num_samples=64000,
                                 first_samples=start_data,
                                 progress_callback=prog_callback,
                                 progress_interval=1000,
                                 temperature=1.0,
                                 regularize=0.)

use gpu
one hot input
the dataset has 255490 items
torch.Size([1, 256, 1])
0% generated
1% generated
2% generated
4% generated
one generating step does take approximately 0.011269981861114503 seconds)
5% generated
7% generated
8% generated
10% generated
11% generated
13% generated
14% generated
16% generated
17% generated
19% generated
20% generated
22% generated
23% generated
25% generated
26% generated
28% generated
29% generated
31% generated
32% generated
34% generated
35% generated
37% generated
38% generated
40% generated
41% generated
43% generated
44% generated
46% generated
47% generated
49% generated
50% generated
52% generated
53% generated
55% generated
56% generated
58% generated
59% generated
61% generated
62% generated
64% generated
65% generated
67% generated
68% generated
70% generated
71% generated
73% generated
74% generated
76% generated
77% generated
79% generated
80% generated
81% generated
83% generated
84% generated
86% generated
87% generated
89% generated
90% 

In [13]:
import IPython.display as ipd
import soundfile as sf

print(generated.shape)
ipd.Audio(generated, rate=16000)

sf.write('output_file.wav', generated, 16000)

(64000,)
