In [1]:
import warnings
import torch 
from collections import OrderedDict
from data_final import audio_data_loader
from functools import cmp_to_key
from model import AutoEncoder #, MaskedMSE
import os
import glob
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
from datetime import datetime
import numpy as np
from utils import check_grad, get_params, get_arguments, get_optimizer, save_model, load_model

In [2]:
cuda_available = torch.cuda.is_available()    
cuda_available

True

In [3]:
train_params,dataset_params = get_arguments()

In [4]:
train_loader, validation = audio_data_loader(**dataset_params)

specs array size =  (97, 1025)
9 training pieces in total
2 validation pieces in total


In [8]:
if cuda_available is False :
    warnings.warn("Cuda is not avalable, can not train model using multi-gpu.")
if cuda_available:
    # Remove train_params["device_ids"] for single GPU
    if train_params["device_ids"]:
        batch_size = dataset_params["batch_size"]
        num_gpu = len(train_params["device_ids"])
        assert batch_size % num_gpu == 0
        net = nn.DataParallel(net,device_ids=train_params['device_ids'])
    torch.backends.cudnn.benchmark = True		
    

In [9]:
criterion = nn.MSELoss()

In [10]:
epoch_trained = 0

In [11]:
sample_batch = next(iter(train_loader))

In [12]:
net = AutoEncoder()
net = net.cuda()

In [13]:
optimizer = get_optimizer(net,train_params['optimizer'],train_params['learning_rate'],train_params['momentum'])

In [14]:
net.train() 

AutoEncoder(
  (encoder): Sequential(
    (0): Layer(
      (linear): Linear(in_features=1025, out_features=512, bias=False)
      (tanh): Tanh()
    )
    (1): Layer(
      (linear): Linear(in_features=512, out_features=256, bias=False)
      (tanh): Tanh()
    )
    (2): Layer(
      (linear): Linear(in_features=256, out_features=120, bias=False)
      (tanh): Tanh()
    )
  )
  (decoder): Sequential(
    (0): Layer(
      (linear): Linear(in_features=120, out_features=256, bias=False)
      (tanh): Tanh()
    )
    (1): Layer(
      (linear): Linear(in_features=256, out_features=512, bias=False)
      (tanh): Tanh()
    )
    (2): Layer(
      (linear): Linear(in_features=512, out_features=1025, bias=False)
      (tanh): Tanh()
    )
  )
)

In [35]:
num_epochs = 1000
for epoch in range(num_epochs):
    print('Starting epoch {}'.format(epoch))
    total_loss = 0
    num_trained = 0
    for i_batch,sample_batch in enumerate(train_loader):

        optimizer.zero_grad()
        spec = sample_batch
        if cuda_available:
            spec = spec.cuda(async=True)

        target_spec = Variable(spec)
        spec = Variable(spec)     

        outputs = net(spec)

        loss = criterion(outputs,target_spec)

        loss.backward()

        #if check_grad(net.parameters(), train_params['clip_grad'], train_params['ignore_grad']):
        #    print('Not a finite gradient or too big, ignoring.')
        #    optimizer.zero_grad()

        optimizer.step()
        total_loss += loss.data
        num_trained += 1

        if num_trained % 100 == 0:
            avg_loss = total_loss/num_trained
            print("Step: {} loss is {}".format(num_trained, avg_loss))
            
    print('Total loss is {}'.format(total_loss))

Starting epoch 0
Total loss is 2.286532402038574
Starting epoch 1
Total loss is 0.39703840017318726
Starting epoch 2
Total loss is 0.14279711246490479
Starting epoch 3
Total loss is 0.09704403579235077
Starting epoch 4
Total loss is 0.07172279059886932
Starting epoch 5
Total loss is 0.06878972798585892
Starting epoch 6
Total loss is 0.07229939848184586
Starting epoch 7
Total loss is 0.06698562204837799
Starting epoch 8
Total loss is 0.06411643326282501
Starting epoch 9
Total loss is 0.0650828406214714
Starting epoch 10
Total loss is 0.06559184938669205
Starting epoch 11
Total loss is 0.0661793202161789
Starting epoch 12
Total loss is 0.06613273918628693
Starting epoch 13
Total loss is 0.06596813350915909
Starting epoch 14
Total loss is 0.06384192407131195
Starting epoch 15
Total loss is 0.06344491988420486
Starting epoch 16
Total loss is 0.06403090804815292
Starting epoch 17
Total loss is 0.06459600478410721
Starting epoch 18
Total loss is 0.06494856625795364
Starting epoch 19
Total lo

Total loss is 0.062064118683338165
Starting epoch 160
Total loss is 0.06335780769586563
Starting epoch 161
Total loss is 0.06245000287890434
Starting epoch 162
Total loss is 0.06411527097225189
Starting epoch 163
Total loss is 0.062267787754535675
Starting epoch 164
Total loss is 0.06255017966032028
Starting epoch 165
Total loss is 0.06205958127975464
Starting epoch 166
Total loss is 0.06242125853896141
Starting epoch 167
Total loss is 0.06144166737794876
Starting epoch 168
Total loss is 0.062162090092897415
Starting epoch 169
Total loss is 0.06195732578635216
Starting epoch 170
Total loss is 0.06138031929731369
Starting epoch 171
Total loss is 0.062004268169403076
Starting epoch 172
Total loss is 0.06458047032356262
Starting epoch 173
Total loss is 0.06150752678513527
Starting epoch 174
Total loss is 0.06111336871981621
Starting epoch 175
Total loss is 0.06140970438718796
Starting epoch 176
Total loss is 0.062392257153987885
Starting epoch 177
Total loss is 0.06139809638261795
Startin

Total loss is 0.03826826065778732
Starting epoch 321
Total loss is 0.037860460579395294
Starting epoch 322
Total loss is 0.03763153776526451
Starting epoch 323
Total loss is 0.03758428618311882
Starting epoch 324
Total loss is 0.03750171884894371
Starting epoch 325
Total loss is 0.03775808587670326
Starting epoch 326
Total loss is 0.03754569962620735
Starting epoch 327
Total loss is 0.037615202367305756
Starting epoch 328
Total loss is 0.03773214668035507
Starting epoch 329
Total loss is 0.037716712802648544
Starting epoch 330
Total loss is 0.03784027323126793
Starting epoch 331
Total loss is 0.03773725777864456
Starting epoch 332
Total loss is 0.03767384588718414
Starting epoch 333
Total loss is 0.03791946917772293
Starting epoch 334
Total loss is 0.03768625482916832
Starting epoch 335
Total loss is 0.03755716234445572
Starting epoch 336
Total loss is 0.03760906308889389
Starting epoch 337
Total loss is 0.03817372769117355
Starting epoch 338
Total loss is 0.03782078996300697
Starting 

Total loss is 0.037127453833818436
Starting epoch 484
Total loss is 0.037258949130773544
Starting epoch 485
Total loss is 0.03683307394385338
Starting epoch 486
Total loss is 0.03705732896924019
Starting epoch 487
Total loss is 0.03712984919548035
Starting epoch 488
Total loss is 0.03676048666238785
Starting epoch 489
Total loss is 0.037027668207883835
Starting epoch 490
Total loss is 0.03678838536143303
Starting epoch 491
Total loss is 0.03673188015818596
Starting epoch 492
Total loss is 0.036734938621520996
Starting epoch 493
Total loss is 0.036462899297475815
Starting epoch 494
Total loss is 0.03657720237970352
Starting epoch 495
Total loss is 0.03628719970583916
Starting epoch 496
Total loss is 0.035997647792100906
Starting epoch 497
Total loss is 0.03519417345523834
Starting epoch 498
Total loss is 0.03482190519571304
Starting epoch 499
Total loss is 0.03489011153578758
Starting epoch 500
Total loss is 0.035518959164619446
Starting epoch 501
Total loss is 0.035418085753917694
Star

Total loss is 0.028655674308538437
Starting epoch 646
Total loss is 0.028734512627124786
Starting epoch 647
Total loss is 0.028863845393061638
Starting epoch 648
Total loss is 0.02868509106338024
Starting epoch 649
Total loss is 0.02877984009683132
Starting epoch 650
Total loss is 0.028719540685415268
Starting epoch 651
Total loss is 0.028835274279117584
Starting epoch 652
Total loss is 0.028634825721383095
Starting epoch 653
Total loss is 0.028669381514191628
Starting epoch 654
Total loss is 0.028657864779233932
Starting epoch 655
Total loss is 0.028660451993346214
Starting epoch 656
Total loss is 0.028772728517651558
Starting epoch 657
Total loss is 0.02877657860517502
Starting epoch 658
Total loss is 0.0287478007376194
Starting epoch 659
Total loss is 0.02921685017645359
Starting epoch 660
Total loss is 0.029166093096137047
Starting epoch 661
Total loss is 0.02889075130224228
Starting epoch 662
Total loss is 0.029436172917485237
Starting epoch 663
Total loss is 0.02907239831984043
S

Total loss is 0.026345418766140938
Starting epoch 811
Total loss is 0.026385009288787842
Starting epoch 812
Total loss is 0.026443425565958023
Starting epoch 813
Total loss is 0.026397382840514183
Starting epoch 814
Total loss is 0.026329822838306427
Starting epoch 815
Total loss is 0.026290714740753174
Starting epoch 816
Total loss is 0.026357052847743034
Starting epoch 817
Total loss is 0.026337165385484695
Starting epoch 818
Total loss is 0.026631344109773636
Starting epoch 819
Total loss is 0.027662860229611397
Starting epoch 820
Total loss is 0.027490487322211266
Starting epoch 821
Total loss is 0.02683614008128643
Starting epoch 822
Total loss is 0.026293614879250526
Starting epoch 823
Total loss is 0.02628028765320778
Starting epoch 824
Total loss is 0.026211576536297798
Starting epoch 825
Total loss is 0.0265843253582716
Starting epoch 826
Total loss is 0.02614990994334221
Starting epoch 827
Total loss is 0.026255890727043152
Starting epoch 828
Total loss is 0.02647622860968113

Total loss is 0.026147879660129547
Starting epoch 975
Total loss is 0.026131443679332733
Starting epoch 976
Total loss is 0.026029853150248528
Starting epoch 977
Total loss is 0.025850774720311165
Starting epoch 978
Total loss is 0.026028703898191452
Starting epoch 979
Total loss is 0.02596956677734852
Starting epoch 980
Total loss is 0.026082806289196014
Starting epoch 981
Total loss is 0.026197103783488274
Starting epoch 982
Total loss is 0.026149187237024307
Starting epoch 983
Total loss is 0.02604951523244381
Starting epoch 984
Total loss is 0.025919217616319656
Starting epoch 985
Total loss is 0.02606874890625477
Starting epoch 986
Total loss is 0.025881659239530563
Starting epoch 987
Total loss is 0.02601666934788227
Starting epoch 988
Total loss is 0.02609666809439659
Starting epoch 989
Total loss is 0.02590044215321541
Starting epoch 990
Total loss is 0.026022862643003464
Starting epoch 991
Total loss is 0.02586446702480316
Starting epoch 992
Total loss is 0.025859946385025978


Not a finite gradient or too big, ignoring.


  befgad = torch.nn.utils.clip_grad_norm(params, clip_th)
