In [1]:
import os
import time
import copy
import random
import pickle
import numpy as np
from skimage.color import gray2rgb

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import models
from torchsummary import summary

from config import models_folder, output_data_folder
from config import n_mels

from model_definitions import SpectrogramEncoderNet, MultiSiameseContrastiveClassifierNet
from data_generators import ContrastiveDataGenerator
from project_utils import ModelSaveAndLogHandler

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [3]:
IMG_HEIGHT = n_mels
CANDIDATE_SIZE = 5

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
torch.cuda.is_available()

cuda


True

In [5]:
# # Mobile net
# model = models.mobilenet_v2(pretrained=False)
# # Dense net
# model = models.densenet121(pretrained=False)
# summary(model, input_size=(3, IMG_HEIGHT, IMG_HEIGHT), device='cpu')

In [6]:
# # mobile net classifier
# model.classifier

In [7]:
summary(SpectrogramEncoderNet(), input_size=(3, IMG_HEIGHT, IMG_HEIGHT), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 64, 64]           9,408
       BatchNorm2d-2           [-1, 64, 64, 64]             128
              ReLU-3           [-1, 64, 64, 64]               0
         MaxPool2d-4           [-1, 64, 32, 32]               0
       BatchNorm2d-5           [-1, 64, 32, 32]             128
              ReLU-6           [-1, 64, 32, 32]               0
            Conv2d-7          [-1, 128, 32, 32]           8,192
       BatchNorm2d-8          [-1, 128, 32, 32]             256
              ReLU-9          [-1, 128, 32, 32]               0
           Conv2d-10           [-1, 32, 32, 32]          36,864
      BatchNorm2d-11           [-1, 96, 32, 32]             192
             ReLU-12           [-1, 96, 32, 32]               0
           Conv2d-13          [-1, 128, 32, 32]          12,288
      BatchNorm2d-14          [-1, 128,

          Conv2d-125            [-1, 128, 8, 8]          32,768
     BatchNorm2d-126            [-1, 128, 8, 8]             256
            ReLU-127            [-1, 128, 8, 8]               0
          Conv2d-128             [-1, 32, 8, 8]          36,864
     BatchNorm2d-129            [-1, 288, 8, 8]             576
            ReLU-130            [-1, 288, 8, 8]               0
          Conv2d-131            [-1, 128, 8, 8]          36,864
     BatchNorm2d-132            [-1, 128, 8, 8]             256
            ReLU-133            [-1, 128, 8, 8]               0
          Conv2d-134             [-1, 32, 8, 8]          36,864
     BatchNorm2d-135            [-1, 320, 8, 8]             640
            ReLU-136            [-1, 320, 8, 8]               0
          Conv2d-137            [-1, 128, 8, 8]          40,960
     BatchNorm2d-138            [-1, 128, 8, 8]             256
            ReLU-139            [-1, 128, 8, 8]               0
          Conv2d-140             [-1, 32

            ReLU-253            [-1, 128, 8, 8]               0
          Conv2d-254             [-1, 32, 8, 8]          36,864
     BatchNorm2d-255            [-1, 960, 8, 8]           1,920
            ReLU-256            [-1, 960, 8, 8]               0
          Conv2d-257            [-1, 128, 8, 8]         122,880
     BatchNorm2d-258            [-1, 128, 8, 8]             256
            ReLU-259            [-1, 128, 8, 8]               0
          Conv2d-260             [-1, 32, 8, 8]          36,864
     BatchNorm2d-261            [-1, 992, 8, 8]           1,984
            ReLU-262            [-1, 992, 8, 8]               0
          Conv2d-263            [-1, 128, 8, 8]         126,976
     BatchNorm2d-264            [-1, 128, 8, 8]             256
            ReLU-265            [-1, 128, 8, 8]               0
          Conv2d-266             [-1, 32, 8, 8]          36,864
     _DenseBlock-267           [-1, 1024, 8, 8]               0
     BatchNorm2d-268           [-1, 1024

In [8]:
summary(MultiSiameseContrastiveClassifierNet(), input_size=(CANDIDATE_SIZE+1, 3, IMG_HEIGHT, IMG_HEIGHT), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 64, 64]           9,408
       BatchNorm2d-2           [-1, 64, 64, 64]             128
              ReLU-3           [-1, 64, 64, 64]               0
         MaxPool2d-4           [-1, 64, 32, 32]               0
       BatchNorm2d-5           [-1, 64, 32, 32]             128
              ReLU-6           [-1, 64, 32, 32]               0
            Conv2d-7          [-1, 128, 32, 32]           8,192
       BatchNorm2d-8          [-1, 128, 32, 32]             256
              ReLU-9          [-1, 128, 32, 32]               0
           Conv2d-10           [-1, 32, 32, 32]          36,864
      BatchNorm2d-11           [-1, 96, 32, 32]             192
             ReLU-12           [-1, 96, 32, 32]               0
           Conv2d-13          [-1, 128, 32, 32]          12,288
      BatchNorm2d-14          [-1, 128,

          Conv2d-125            [-1, 128, 8, 8]          32,768
     BatchNorm2d-126            [-1, 128, 8, 8]             256
            ReLU-127            [-1, 128, 8, 8]               0
          Conv2d-128             [-1, 32, 8, 8]          36,864
     BatchNorm2d-129            [-1, 288, 8, 8]             576
            ReLU-130            [-1, 288, 8, 8]               0
          Conv2d-131            [-1, 128, 8, 8]          36,864
     BatchNorm2d-132            [-1, 128, 8, 8]             256
            ReLU-133            [-1, 128, 8, 8]               0
          Conv2d-134             [-1, 32, 8, 8]          36,864
     BatchNorm2d-135            [-1, 320, 8, 8]             640
            ReLU-136            [-1, 320, 8, 8]               0
          Conv2d-137            [-1, 128, 8, 8]          40,960
     BatchNorm2d-138            [-1, 128, 8, 8]             256
            ReLU-139            [-1, 128, 8, 8]               0
          Conv2d-140             [-1, 32

            ReLU-253            [-1, 128, 8, 8]               0
          Conv2d-254             [-1, 32, 8, 8]          36,864
     BatchNorm2d-255            [-1, 960, 8, 8]           1,920
            ReLU-256            [-1, 960, 8, 8]               0
          Conv2d-257            [-1, 128, 8, 8]         122,880
     BatchNorm2d-258            [-1, 128, 8, 8]             256
            ReLU-259            [-1, 128, 8, 8]               0
          Conv2d-260             [-1, 32, 8, 8]          36,864
     BatchNorm2d-261            [-1, 992, 8, 8]           1,984
            ReLU-262            [-1, 992, 8, 8]               0
          Conv2d-263            [-1, 128, 8, 8]         126,976
     BatchNorm2d-264            [-1, 128, 8, 8]             256
            ReLU-265            [-1, 128, 8, 8]               0
          Conv2d-266             [-1, 32, 8, 8]          36,864
     _DenseBlock-267           [-1, 1024, 8, 8]               0
     BatchNorm2d-268           [-1, 1024

          Conv2d-380          [-1, 128, 32, 32]           8,192
     BatchNorm2d-381          [-1, 128, 32, 32]             256
            ReLU-382          [-1, 128, 32, 32]               0
          Conv2d-383           [-1, 32, 32, 32]          36,864
     BatchNorm2d-384           [-1, 96, 32, 32]             192
            ReLU-385           [-1, 96, 32, 32]               0
          Conv2d-386          [-1, 128, 32, 32]          12,288
     BatchNorm2d-387          [-1, 128, 32, 32]             256
            ReLU-388          [-1, 128, 32, 32]               0
          Conv2d-389           [-1, 32, 32, 32]          36,864
     BatchNorm2d-390          [-1, 128, 32, 32]             256
            ReLU-391          [-1, 128, 32, 32]               0
          Conv2d-392          [-1, 128, 32, 32]          16,384
     BatchNorm2d-393          [-1, 128, 32, 32]             256
            ReLU-394          [-1, 128, 32, 32]               0
          Conv2d-395           [-1, 32, 

     BatchNorm2d-508            [-1, 320, 8, 8]             640
            ReLU-509            [-1, 320, 8, 8]               0
          Conv2d-510            [-1, 128, 8, 8]          40,960
     BatchNorm2d-511            [-1, 128, 8, 8]             256
            ReLU-512            [-1, 128, 8, 8]               0
          Conv2d-513             [-1, 32, 8, 8]          36,864
     BatchNorm2d-514            [-1, 352, 8, 8]             704
            ReLU-515            [-1, 352, 8, 8]               0
          Conv2d-516            [-1, 128, 8, 8]          45,056
     BatchNorm2d-517            [-1, 128, 8, 8]             256
            ReLU-518            [-1, 128, 8, 8]               0
          Conv2d-519             [-1, 32, 8, 8]          36,864
     BatchNorm2d-520            [-1, 384, 8, 8]             768
            ReLU-521            [-1, 384, 8, 8]               0
          Conv2d-522            [-1, 128, 8, 8]          49,152
     BatchNorm2d-523            [-1, 128

          Conv2d-636            [-1, 128, 8, 8]         126,976
     BatchNorm2d-637            [-1, 128, 8, 8]             256
            ReLU-638            [-1, 128, 8, 8]               0
          Conv2d-639             [-1, 32, 8, 8]          36,864
     _DenseBlock-640           [-1, 1024, 8, 8]               0
     BatchNorm2d-641           [-1, 1024, 8, 8]           2,048
            ReLU-642           [-1, 1024, 8, 8]               0
          Conv2d-643            [-1, 512, 8, 8]         524,288
       AvgPool2d-644            [-1, 512, 4, 4]               0
     BatchNorm2d-645            [-1, 512, 4, 4]           1,024
            ReLU-646            [-1, 512, 4, 4]               0
          Conv2d-647            [-1, 128, 4, 4]          65,536
     BatchNorm2d-648            [-1, 128, 4, 4]             256
            ReLU-649            [-1, 128, 4, 4]               0
          Conv2d-650             [-1, 32, 4, 4]          36,864
     BatchNorm2d-651            [-1, 544

In [9]:
### Training data
training_folder = os.path.join(output_data_folder, "training_dataset_full_spectrogram/vox1_dev_wav")
spectrogram_samples_files = [os.path.join(training_folder, file) for file in os.listdir(training_folder)]
candidate_size = CANDIDATE_SIZE
# batch_size = 15   # mobilenet_v2
batch_size = 6   # densenet121
num_batches = 2000 // batch_size
num_sub_samples = 200
training_data_generator = ContrastiveDataGenerator(spectrogram_samples_files, candidate_size, batch_size, num_batches, num_sub_samples, IMG_HEIGHT)

In [10]:
### Validation data
validation_set_file = os.path.join(output_data_folder, "validation_sets", "contrastive_validation_set.pickle")
with open(validation_set_file, 'rb') as f:
    validation_data = pickle.load(f)

In [11]:
def train_model(model, criterion, optimizer, scheduler, num_epochs, training_data_generator, validation_data, log_handler):
    since = time.time()
    best_acc = 0.0

    t1 = time.time()
    for epoch in range(num_epochs):
        log_handler.print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        log_handler.print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            batches_used = 0
            data_generator = training_data_generator.generate_batches() if phase == 'train' else validation_data
            for data in data_generator:
                batches_used += 1
                input_imgs, labels = data
                inputs = [img.to(device) for img in input_imgs]
                labels = labels.to(device)               
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):   # gradient only for train
                    outputs = model(inputs)      
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs[0].size(0)
                running_corrects += torch.sum(preds == labels.data)
                
            if phase == 'train':
                scheduler.step()
            
            epoch_loss = running_loss / (batches_used * inputs[0].size(0))
            epoch_acc = running_corrects.double() / (batches_used * inputs[0].size(0))
            log_handler.print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                log_handler.save_pytorch_model(model, "best_model_{}.pt".format(model.__class__.__name__))
                example = [torch.rand(1, 3, IMG_HEIGHT, IMG_HEIGHT), torch.rand(1, 3, IMG_HEIGHT, IMG_HEIGHT)]
                log_handler.save_pytorch_model_as_torchscript(model, "mobile_model.pt", (example,))

        # end of epoch
        log_handler.print("Time taken is {} seconds".format(int(time.time()-t1)))
        t1 = time.time()
        log_handler.print()

    time_elapsed = time.time() - since
    log_handler.print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    log_handler.print('Best val Acc: {:4f}'.format(best_acc))
    

In [12]:
### Train

epochs = 70
# epochs = 50

model_ft = MultiSiameseContrastiveClassifierNet().to(device)
criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.Adam(model_ft.parameters(), lr = 0.0001)

# Decay LR by a factor of 0.1 every 7 epochs
# learning_rate_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)
learning_rate_scheduler = lr_scheduler.CyclicLR(optimizer_ft, base_lr=0.0001, max_lr=0.001, cycle_momentum=False)   # 0.001 seems better


### Train 

# Logger
model_save_folder = os.path.join(models_folder, "contrastive_encoder")
log_handler = ModelSaveAndLogHandler(model_save_folder, enable_model_saving=True, enable_logging=True)   # init
model_def_src_file_path = os.path.join(r"D:\Desktop\projects\speaker_recognition_voxceleb1\scripts", "model_definitions.py")
log_handler.save_model_definition_file(model_def_src_file_path)   # copy model def file
print(log_handler.folder)

# Description
log_handler.print("Description: Candidates: 5, Encoding: 128, Projection: None")
log_handler.print("Base Model: densenet121")

# Train
# train_model(model_ft, criterion, optimizer_ft, learning_rate_scheduler, epochs, num_batches, training_data_generator, log_handler)
train_model(model_ft, criterion, optimizer_ft, learning_rate_scheduler, epochs, training_data_generator, validation_data, log_handler)

D:\Desktop\projects\speaker_recognition_voxceleb1\output_data\models\contrastive_encoder\2020-04-06_19-42-12
Description: Candidates: 5, Encoding: 128, Projection: None
Base Model: densenet121
Epoch 0/69
----------
train Loss: 1.5184 Acc: 0.3423
val Loss: 1.3822 Acc: 0.4947
MODEL SAVED
MODEL SAVED (MOBILE)
Time taken is 503 seconds

Epoch 1/69
----------
train Loss: 1.4167 Acc: 0.4104
val Loss: 1.2773 Acc: 0.5343
MODEL SAVED
MODEL SAVED (MOBILE)
Time taken is 499 seconds

Epoch 2/69
----------
train Loss: 1.3332 Acc: 0.4575
val Loss: 1.2419 Acc: 0.5574
MODEL SAVED
MODEL SAVED (MOBILE)
Time taken is 501 seconds

Epoch 3/69
----------
train Loss: 1.3324 Acc: 0.4620
val Loss: 1.2401 Acc: 0.5679
MODEL SAVED
MODEL SAVED (MOBILE)
Time taken is 499 seconds

Epoch 4/69
----------
train Loss: 1.3088 Acc: 0.4725
val Loss: 1.2324 Acc: 0.5774
MODEL SAVED
MODEL SAVED (MOBILE)
Time taken is 504 seconds

Epoch 5/69
----------
train Loss: 1.3329 Acc: 0.4585
val Loss: 1.2507 Acc: 0.5830
MODEL SAVED
MOD

train Loss: 1.1246 Acc: 0.6612
val Loss: 1.1272 Acc: 0.6762
MODEL SAVED
MODEL SAVED (MOBILE)
Time taken is 512 seconds

Epoch 68/69
----------
train Loss: 1.1306 Acc: 0.6617
val Loss: 1.1326 Acc: 0.6887
MODEL SAVED
MODEL SAVED (MOBILE)
Time taken is 509 seconds

Epoch 69/69
----------
train Loss: 1.1200 Acc: 0.6727
val Loss: 1.1354 Acc: 0.6647
Time taken is 507 seconds

Training complete in 586m 7s
Best val Acc: 0.688722


In [13]:
random_acc = 1 / CANDIDATE_SIZE
random_acc

0.2

### TODO: Overall
* Contrastive classifier
    * separate train and validate methods

* (Done) Model saving / checkpointing
* **Build binary classifier**