In [1]:
import os
import time
import copy
import random
import pickle
import numpy as np
from skimage.color import gray2rgb

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import models
from torchsummary import summary

from config import models_folder, output_data_folder
from config import n_mels

from model_definitions import SpectrogramEncoderNet, MultiSiameseContrastiveClassifierNet
from data_generators import ContrastiveDataGenerator, BaseDataGenerator
from project_utils import ModelSaveAndLogHandler, load_module_from_file

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [3]:
IMG_HEIGHT = n_mels
CANDIDATE_SIZE = 5

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
torch.cuda.is_available()

cuda


True

In [5]:
# # Mobile net
# model = models.mobilenet_v2(pretrained=False)
# # Dense net
# model = models.densenet121(pretrained=False)
# summary(model, input_size=(3, IMG_HEIGHT, IMG_HEIGHT), device='cpu')

In [6]:
# # mobile net classifier
# model.classifier

In [7]:
summary(SpectrogramEncoderNet(), input_size=(3, IMG_HEIGHT, IMG_HEIGHT), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 64, 64]           9,408
       BatchNorm2d-2           [-1, 64, 64, 64]             128
              ReLU-3           [-1, 64, 64, 64]               0
         MaxPool2d-4           [-1, 64, 32, 32]               0
       BatchNorm2d-5           [-1, 64, 32, 32]             128
              ReLU-6           [-1, 64, 32, 32]               0
            Conv2d-7          [-1, 128, 32, 32]           8,192
       BatchNorm2d-8          [-1, 128, 32, 32]             256
              ReLU-9          [-1, 128, 32, 32]               0
           Conv2d-10           [-1, 32, 32, 32]          36,864
      BatchNorm2d-11           [-1, 96, 32, 32]             192
             ReLU-12           [-1, 96, 32, 32]               0
           Conv2d-13          [-1, 128, 32, 32]          12,288
      BatchNorm2d-14          [-1, 128,

          Conv2d-125            [-1, 128, 8, 8]          32,768
     BatchNorm2d-126            [-1, 128, 8, 8]             256
            ReLU-127            [-1, 128, 8, 8]               0
          Conv2d-128             [-1, 32, 8, 8]          36,864
     BatchNorm2d-129            [-1, 288, 8, 8]             576
            ReLU-130            [-1, 288, 8, 8]               0
          Conv2d-131            [-1, 128, 8, 8]          36,864
     BatchNorm2d-132            [-1, 128, 8, 8]             256
            ReLU-133            [-1, 128, 8, 8]               0
          Conv2d-134             [-1, 32, 8, 8]          36,864
     BatchNorm2d-135            [-1, 320, 8, 8]             640
            ReLU-136            [-1, 320, 8, 8]               0
          Conv2d-137            [-1, 128, 8, 8]          40,960
     BatchNorm2d-138            [-1, 128, 8, 8]             256
            ReLU-139            [-1, 128, 8, 8]               0
          Conv2d-140             [-1, 32

            ReLU-253            [-1, 128, 8, 8]               0
          Conv2d-254             [-1, 32, 8, 8]          36,864
     BatchNorm2d-255            [-1, 960, 8, 8]           1,920
            ReLU-256            [-1, 960, 8, 8]               0
          Conv2d-257            [-1, 128, 8, 8]         122,880
     BatchNorm2d-258            [-1, 128, 8, 8]             256
            ReLU-259            [-1, 128, 8, 8]               0
          Conv2d-260             [-1, 32, 8, 8]          36,864
     BatchNorm2d-261            [-1, 992, 8, 8]           1,984
            ReLU-262            [-1, 992, 8, 8]               0
          Conv2d-263            [-1, 128, 8, 8]         126,976
     BatchNorm2d-264            [-1, 128, 8, 8]             256
            ReLU-265            [-1, 128, 8, 8]               0
          Conv2d-266             [-1, 32, 8, 8]          36,864
     _DenseBlock-267           [-1, 1024, 8, 8]               0
     BatchNorm2d-268           [-1, 1024

In [8]:
summary(MultiSiameseContrastiveClassifierNet(), input_size=(CANDIDATE_SIZE+1, 3, IMG_HEIGHT, IMG_HEIGHT), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 64, 64]           9,408
       BatchNorm2d-2           [-1, 64, 64, 64]             128
              ReLU-3           [-1, 64, 64, 64]               0
         MaxPool2d-4           [-1, 64, 32, 32]               0
       BatchNorm2d-5           [-1, 64, 32, 32]             128
              ReLU-6           [-1, 64, 32, 32]               0
            Conv2d-7          [-1, 128, 32, 32]           8,192
       BatchNorm2d-8          [-1, 128, 32, 32]             256
              ReLU-9          [-1, 128, 32, 32]               0
           Conv2d-10           [-1, 32, 32, 32]          36,864
      BatchNorm2d-11           [-1, 96, 32, 32]             192
             ReLU-12           [-1, 96, 32, 32]               0
           Conv2d-13          [-1, 128, 32, 32]          12,288
      BatchNorm2d-14          [-1, 128,

          Conv2d-125            [-1, 128, 8, 8]          32,768
     BatchNorm2d-126            [-1, 128, 8, 8]             256
            ReLU-127            [-1, 128, 8, 8]               0
          Conv2d-128             [-1, 32, 8, 8]          36,864
     BatchNorm2d-129            [-1, 288, 8, 8]             576
            ReLU-130            [-1, 288, 8, 8]               0
          Conv2d-131            [-1, 128, 8, 8]          36,864
     BatchNorm2d-132            [-1, 128, 8, 8]             256
            ReLU-133            [-1, 128, 8, 8]               0
          Conv2d-134             [-1, 32, 8, 8]          36,864
     BatchNorm2d-135            [-1, 320, 8, 8]             640
            ReLU-136            [-1, 320, 8, 8]               0
          Conv2d-137            [-1, 128, 8, 8]          40,960
     BatchNorm2d-138            [-1, 128, 8, 8]             256
            ReLU-139            [-1, 128, 8, 8]               0
          Conv2d-140             [-1, 32

            ReLU-253            [-1, 128, 8, 8]               0
          Conv2d-254             [-1, 32, 8, 8]          36,864
     BatchNorm2d-255            [-1, 960, 8, 8]           1,920
            ReLU-256            [-1, 960, 8, 8]               0
          Conv2d-257            [-1, 128, 8, 8]         122,880
     BatchNorm2d-258            [-1, 128, 8, 8]             256
            ReLU-259            [-1, 128, 8, 8]               0
          Conv2d-260             [-1, 32, 8, 8]          36,864
     BatchNorm2d-261            [-1, 992, 8, 8]           1,984
            ReLU-262            [-1, 992, 8, 8]               0
          Conv2d-263            [-1, 128, 8, 8]         126,976
     BatchNorm2d-264            [-1, 128, 8, 8]             256
            ReLU-265            [-1, 128, 8, 8]               0
          Conv2d-266             [-1, 32, 8, 8]          36,864
     _DenseBlock-267           [-1, 1024, 8, 8]               0
     BatchNorm2d-268           [-1, 1024

          Conv2d-380          [-1, 128, 32, 32]           8,192
     BatchNorm2d-381          [-1, 128, 32, 32]             256
            ReLU-382          [-1, 128, 32, 32]               0
          Conv2d-383           [-1, 32, 32, 32]          36,864
     BatchNorm2d-384           [-1, 96, 32, 32]             192
            ReLU-385           [-1, 96, 32, 32]               0
          Conv2d-386          [-1, 128, 32, 32]          12,288
     BatchNorm2d-387          [-1, 128, 32, 32]             256
            ReLU-388          [-1, 128, 32, 32]               0
          Conv2d-389           [-1, 32, 32, 32]          36,864
     BatchNorm2d-390          [-1, 128, 32, 32]             256
            ReLU-391          [-1, 128, 32, 32]               0
          Conv2d-392          [-1, 128, 32, 32]          16,384
     BatchNorm2d-393          [-1, 128, 32, 32]             256
            ReLU-394          [-1, 128, 32, 32]               0
          Conv2d-395           [-1, 32, 

     BatchNorm2d-508            [-1, 320, 8, 8]             640
            ReLU-509            [-1, 320, 8, 8]               0
          Conv2d-510            [-1, 128, 8, 8]          40,960
     BatchNorm2d-511            [-1, 128, 8, 8]             256
            ReLU-512            [-1, 128, 8, 8]               0
          Conv2d-513             [-1, 32, 8, 8]          36,864
     BatchNorm2d-514            [-1, 352, 8, 8]             704
            ReLU-515            [-1, 352, 8, 8]               0
          Conv2d-516            [-1, 128, 8, 8]          45,056
     BatchNorm2d-517            [-1, 128, 8, 8]             256
            ReLU-518            [-1, 128, 8, 8]               0
          Conv2d-519             [-1, 32, 8, 8]          36,864
     BatchNorm2d-520            [-1, 384, 8, 8]             768
            ReLU-521            [-1, 384, 8, 8]               0
          Conv2d-522            [-1, 128, 8, 8]          49,152
     BatchNorm2d-523            [-1, 128

          Conv2d-636            [-1, 128, 8, 8]         126,976
     BatchNorm2d-637            [-1, 128, 8, 8]             256
            ReLU-638            [-1, 128, 8, 8]               0
          Conv2d-639             [-1, 32, 8, 8]          36,864
     _DenseBlock-640           [-1, 1024, 8, 8]               0
     BatchNorm2d-641           [-1, 1024, 8, 8]           2,048
            ReLU-642           [-1, 1024, 8, 8]               0
          Conv2d-643            [-1, 512, 8, 8]         524,288
       AvgPool2d-644            [-1, 512, 4, 4]               0
     BatchNorm2d-645            [-1, 512, 4, 4]           1,024
            ReLU-646            [-1, 512, 4, 4]               0
          Conv2d-647            [-1, 128, 4, 4]          65,536
     BatchNorm2d-648            [-1, 128, 4, 4]             256
            ReLU-649            [-1, 128, 4, 4]               0
          Conv2d-650             [-1, 32, 4, 4]          36,864
     BatchNorm2d-651            [-1, 544

In [9]:
### Training data
training_folder = os.path.join(output_data_folder, "training_dataset_full_spectrogram/vox1_dev_wav")
spectrogram_samples_files = [os.path.join(training_folder, file) for file in os.listdir(training_folder)]
candidate_size = CANDIDATE_SIZE
# batch_size = 15   # mobilenet_v2
batch_size = 6   # densenet121
num_batches = 2000 // batch_size
num_sub_samples = 200
training_data_generator = ContrastiveDataGenerator(spectrogram_samples_files, candidate_size, batch_size, num_batches, num_sub_samples, IMG_HEIGHT)

In [10]:
### Validation data
validation_set_file = os.path.join(output_data_folder, "validation_sets", "contrastive_validation_set.pickle")
with open(validation_set_file, 'rb') as f:
    validation_data = pickle.load(f)

In [11]:
def intra_class_variance_reduction(contrastive_model, contrastive_sub_samples, log_handler):   
    num_classes = 2
    samples_per_class = 10
    
    # prep for training
    criterion = nn.MSELoss(reduction='mean')
#     criterion = nn.L1Loss(reduction='mean')
    
    # train
    total_loss = 0.0
    spectrogram_sub_samples = random.sample(contrastive_sub_samples, num_classes)  # sample classes
    for spectrogram in spectrogram_sub_samples:   # treat one spectrogram/user as one class
        input_imgs = [BaseDataGenerator.get_sliding_img_slice_from_spectrogram(spectrogram) for _ in range(samples_per_class)]
        input_imgs = torch.tensor(input_imgs)
        inputs = input_imgs.to(device)  
        
        encoded_outputs = contrastive_model.encoder(inputs)
        mean = torch.mean(encoded_outputs, dim=0)   # get mean encoding vector of this class
        mean = mean.repeat(samples_per_class, 1)   # mean vector
        loss = criterion(encoded_outputs, mean)   # MSE against mean (variance)
        total_loss += loss
    return total_loss


In [12]:
def train_model(model, criterion, optimizer, scheduler, num_epochs, training_data_generator, validation_data, log_handler):
    since = time.time()
    best_acc = 0.0
    
    # intra class variance reduction 
#     run_variance_reduction_on_epoch = [*range(1, num_epochs)]
    run_variance_reduction_on_epoch = [*range(0, num_epochs)]   # continue training
#     run_variance_reduction_on_epoch = [*range(1, num_epochs, 2)]   # alternate
    variance_reduction_frequency = 2   # every n batches
    loss_variance_scale = 0.20

    t1 = time.time()
    for epoch in range(num_epochs):
        log_handler.print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        log_handler.print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            
            # intra class variance reduction
            run_variance_reduction = phase == 'train' and epoch in run_variance_reduction_on_epoch
            if run_variance_reduction: log_handler.print("-- variance reduction")
            
            # Main training (contrastive training)
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            
            # Iterate over data.
            batches_used = 0
            data_generator = training_data_generator.generate_batches() if phase == 'train' else validation_data
            for data in data_generator:
                batches_used += 1                
                
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):   # gradient only for train
                    
                    # intra class variance reduction
                    loss_var = None
                    if batches_used % variance_reduction_frequency == 0 and run_variance_reduction:
                        loss_var = intra_class_variance_reduction(model, training_data_generator.sub_samples, log_handler)
                        loss_var *= loss_variance_scale
                        loss_var.backward()
                    
                    # Main training (contrastive training)
                    input_imgs, labels = data
                    inputs = [img.to(device) for img in input_imgs]
                    labels = labels.to(device)
                    
                    outputs = model(inputs)      
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
    
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                if loss_var is not None: loss += loss_var   # Overall loss with variance reduction
                running_loss += loss.item() * inputs[0].size(0)
                running_corrects += torch.sum(preds == labels.data)
                
            if phase == 'train':
                scheduler.step()
            
            epoch_loss = running_loss / (batches_used * inputs[0].size(0))
            epoch_acc = running_corrects.double() / (batches_used * inputs[0].size(0))
            log_handler.print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc and epoch > 0:
                best_acc = epoch_acc
                log_handler.save_pytorch_model(model, "best_model_{}.pt".format(model.__class__.__name__))
                example = [torch.rand(1, 3, IMG_HEIGHT, IMG_HEIGHT), torch.rand(1, 3, IMG_HEIGHT, IMG_HEIGHT)]
                log_handler.save_pytorch_model_as_torchscript(model, "mobile_model.pt", (example,))

        # end of epoch
        log_handler.print("Time taken is {} seconds".format(int(time.time()-t1)))
        t1 = time.time()
        log_handler.print()

    time_elapsed = time.time() - since
    log_handler.print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    log_handler.print('Best val Acc: {:4f}'.format(best_acc))
    

In [13]:
# load model from contrastive training
def load_contrastive_encoder_model():
    encoder_model_folder = os.path.join(models_folder, "contrastive_encoder", "good_models", "2020-04-08_02-34-53")
    module_file = os.path.join(encoder_model_folder, "model_definitions.py")
    module_name = "MultiSiameseContrastiveClassifierNet"
    module = load_module_from_file(module_file, module_name)
    # load model
    model = module.MultiSiameseContrastiveClassifierNet()
    state_dict_file = os.path.join(encoder_model_folder, "best_model_MultiSiameseContrastiveClassifierNet.pt")
    model.load_state_dict(torch.load(state_dict_file, map_location="cpu"))
    return model

In [14]:
### Train

epochs = 70
# epochs = 50
# epochs = 30

model_ft = MultiSiameseContrastiveClassifierNet().to(device)
# model_ft = load_contrastive_encoder_model().to(device)   # continue training
criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.Adam(model_ft.parameters(), lr = 0.0001)

# Decay LR by a factor of 0.1 every 7 epochs
# learning_rate_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)
learning_rate_scheduler = lr_scheduler.CyclicLR(optimizer_ft, base_lr=0.0001, max_lr=0.001, cycle_momentum=False)   # 0.001 seems better


### Train 

# Logger
model_save_folder = os.path.join(models_folder, "contrastive_encoder")
log_handler = ModelSaveAndLogHandler(model_save_folder, enable_model_saving=True, enable_logging=True)   # init
model_def_src_file_path = os.path.join(r"D:\Desktop\projects\speaker_recognition_voxceleb1\scripts", "model_definitions.py")
log_handler.save_model_definition_file(model_def_src_file_path)   # copy model def file
print(log_handler.folder)

# Description
log_handler.print("Description: Candidates: 5, Encoding: 128, Projection: None")
# log_handler.print("Base Model: mobileNetV2")
log_handler.print("Base Model: densenet121")
log_handler.print("With Intra Class Variance Reduction")
# log_handler.print("Continued from 2020-04-08_02-34-53")

# Train
# train_model(model_ft, criterion, optimizer_ft, learning_rate_scheduler, epochs, num_batches, training_data_generator, log_handler)
train_model(model_ft, criterion, optimizer_ft, learning_rate_scheduler, epochs, training_data_generator, validation_data, log_handler)

D:\Desktop\projects\speaker_recognition_voxceleb1\output_data\models\contrastive_encoder\2020-04-08_16-43-10
Description: Candidates: 5, Encoding: 128, Projection: None
Base Model: mobileNetV2
With Intra Class Variance Reduction
Continued from 2020-04-08_02-34-53
Epoch 0/69
----------
-- variance reduction
train Loss: 1.4895 Acc: 0.3789
val Loss: 1.3208 Acc: 0.5479
Time taken is 645 seconds

Epoch 1/69
----------
-- variance reduction
train Loss: 1.3866 Acc: 0.4575
val Loss: 1.2705 Acc: 0.5799
MODEL SAVED
MODEL SAVED (MOBILE)
Time taken is 648 seconds

Epoch 2/69
----------
-- variance reduction
train Loss: 1.3563 Acc: 0.4590
val Loss: 1.2262 Acc: 0.5805
MODEL SAVED
MODEL SAVED (MOBILE)
Time taken is 645 seconds

Epoch 3/69
----------
-- variance reduction
train Loss: 1.3494 Acc: 0.4595
val Loss: 1.2486 Acc: 0.5664
Time taken is 642 seconds

Epoch 4/69
----------
-- variance reduction
train Loss: 1.3237 Acc: 0.4805
val Loss: 1.2274 Acc: 0.5855
MODEL SAVED
MODEL SAVED (MOBILE)
Time take

train Loss: 1.1537 Acc: 0.6216
val Loss: 1.1287 Acc: 0.6737
Time taken is 645 seconds

Epoch 57/69
----------
-- variance reduction
train Loss: 1.1427 Acc: 0.6416
val Loss: 1.1451 Acc: 0.6717
Time taken is 632 seconds

Epoch 58/69
----------
-- variance reduction
train Loss: 1.1574 Acc: 0.6191
val Loss: 1.1265 Acc: 0.6927
MODEL SAVED
MODEL SAVED (MOBILE)
Time taken is 650 seconds

Epoch 59/69
----------
-- variance reduction
train Loss: 1.1462 Acc: 0.6351
val Loss: 1.1298 Acc: 0.6722
Time taken is 636 seconds

Epoch 60/69
----------
-- variance reduction
train Loss: 1.1422 Acc: 0.6341
val Loss: 1.1330 Acc: 0.6667
Time taken is 642 seconds

Epoch 61/69
----------
-- variance reduction
train Loss: 1.1576 Acc: 0.6146
val Loss: 1.1310 Acc: 0.6807
Time taken is 634 seconds

Epoch 62/69
----------
-- variance reduction
train Loss: 1.1399 Acc: 0.6517
val Loss: 1.1301 Acc: 0.6877
Time taken is 631 seconds

Epoch 63/69
----------
-- variance reduction
train Loss: 1.1626 Acc: 0.6246
val Loss: 1.

In [15]:
random_acc = 1 / CANDIDATE_SIZE
random_acc

0.2

### TODO: Overall
* Contrastive classifier
    * separate train and validate methods

* (Done) Model saving / checkpointing
* **Build binary classifier**