In [1]:
import os
import time
import copy
import random
import pickle
import numpy as np
from skimage.color import gray2rgb

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import models
from torchsummary import summary

from config import models_folder, output_data_folder
from config import n_mels

from model_definitions import SpectrogramEncoderNet, MultiSiameseContrastiveClassifierNet
from data_generators import ContrastiveDataGenerator, BaseDataGenerator
from project_utils import ModelSaveAndLogHandler

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [3]:
IMG_HEIGHT = n_mels
CANDIDATE_SIZE = 5

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
torch.cuda.is_available()

cuda


True

In [5]:
# # Mobile net
# model = models.mobilenet_v2(pretrained=False)
# # Dense net
# model = models.densenet121(pretrained=False)
# summary(model, input_size=(3, IMG_HEIGHT, IMG_HEIGHT), device='cpu')

In [6]:
# # mobile net classifier
# model.classifier

In [7]:
summary(SpectrogramEncoderNet(), input_size=(3, IMG_HEIGHT, IMG_HEIGHT), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 64, 64]             864
       BatchNorm2d-2           [-1, 32, 64, 64]              64
             ReLU6-3           [-1, 32, 64, 64]               0
            Conv2d-4           [-1, 32, 64, 64]             288
       BatchNorm2d-5           [-1, 32, 64, 64]              64
             ReLU6-6           [-1, 32, 64, 64]               0
            Conv2d-7           [-1, 16, 64, 64]             512
       BatchNorm2d-8           [-1, 16, 64, 64]              32
  InvertedResidual-9           [-1, 16, 64, 64]               0
           Conv2d-10           [-1, 96, 64, 64]           1,536
      BatchNorm2d-11           [-1, 96, 64, 64]             192
            ReLU6-12           [-1, 96, 64, 64]               0
           Conv2d-13           [-1, 96, 32, 32]             864
      BatchNorm2d-14           [-1, 96,

     BatchNorm2d-125            [-1, 160, 4, 4]             320
InvertedResidual-126            [-1, 160, 4, 4]               0
          Conv2d-127            [-1, 960, 4, 4]         153,600
     BatchNorm2d-128            [-1, 960, 4, 4]           1,920
           ReLU6-129            [-1, 960, 4, 4]               0
          Conv2d-130            [-1, 960, 4, 4]           8,640
     BatchNorm2d-131            [-1, 960, 4, 4]           1,920
           ReLU6-132            [-1, 960, 4, 4]               0
          Conv2d-133            [-1, 160, 4, 4]         153,600
     BatchNorm2d-134            [-1, 160, 4, 4]             320
InvertedResidual-135            [-1, 160, 4, 4]               0
          Conv2d-136            [-1, 960, 4, 4]         153,600
     BatchNorm2d-137            [-1, 960, 4, 4]           1,920
           ReLU6-138            [-1, 960, 4, 4]               0
          Conv2d-139            [-1, 960, 4, 4]           8,640
     BatchNorm2d-140            [-1, 960

In [8]:
summary(MultiSiameseContrastiveClassifierNet(), input_size=(CANDIDATE_SIZE+1, 3, IMG_HEIGHT, IMG_HEIGHT), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 64, 64]             864
       BatchNorm2d-2           [-1, 32, 64, 64]              64
             ReLU6-3           [-1, 32, 64, 64]               0
            Conv2d-4           [-1, 32, 64, 64]             288
       BatchNorm2d-5           [-1, 32, 64, 64]              64
             ReLU6-6           [-1, 32, 64, 64]               0
            Conv2d-7           [-1, 16, 64, 64]             512
       BatchNorm2d-8           [-1, 16, 64, 64]              32
  InvertedResidual-9           [-1, 16, 64, 64]               0
           Conv2d-10           [-1, 96, 64, 64]           1,536
      BatchNorm2d-11           [-1, 96, 64, 64]             192
            ReLU6-12           [-1, 96, 64, 64]               0
           Conv2d-13           [-1, 96, 32, 32]             864
      BatchNorm2d-14           [-1, 96,

     BatchNorm2d-125            [-1, 160, 4, 4]             320
InvertedResidual-126            [-1, 160, 4, 4]               0
          Conv2d-127            [-1, 960, 4, 4]         153,600
     BatchNorm2d-128            [-1, 960, 4, 4]           1,920
           ReLU6-129            [-1, 960, 4, 4]               0
          Conv2d-130            [-1, 960, 4, 4]           8,640
     BatchNorm2d-131            [-1, 960, 4, 4]           1,920
           ReLU6-132            [-1, 960, 4, 4]               0
          Conv2d-133            [-1, 160, 4, 4]         153,600
     BatchNorm2d-134            [-1, 160, 4, 4]             320
InvertedResidual-135            [-1, 160, 4, 4]               0
          Conv2d-136            [-1, 960, 4, 4]         153,600
     BatchNorm2d-137            [-1, 960, 4, 4]           1,920
           ReLU6-138            [-1, 960, 4, 4]               0
          Conv2d-139            [-1, 960, 4, 4]           8,640
     BatchNorm2d-140            [-1, 960

     BatchNorm2d-252            [-1, 384, 8, 8]             768
           ReLU6-253            [-1, 384, 8, 8]               0
          Conv2d-254            [-1, 384, 8, 8]           3,456
     BatchNorm2d-255            [-1, 384, 8, 8]             768
           ReLU6-256            [-1, 384, 8, 8]               0
          Conv2d-257             [-1, 96, 8, 8]          36,864
     BatchNorm2d-258             [-1, 96, 8, 8]             192
InvertedResidual-259             [-1, 96, 8, 8]               0
          Conv2d-260            [-1, 576, 8, 8]          55,296
     BatchNorm2d-261            [-1, 576, 8, 8]           1,152
           ReLU6-262            [-1, 576, 8, 8]               0
          Conv2d-263            [-1, 576, 8, 8]           5,184
     BatchNorm2d-264            [-1, 576, 8, 8]           1,152
           ReLU6-265            [-1, 576, 8, 8]               0
          Conv2d-266             [-1, 96, 8, 8]          55,296
     BatchNorm2d-267             [-1, 96

In [9]:
### Training data
training_folder = os.path.join(output_data_folder, "training_dataset_full_spectrogram/vox1_dev_wav")
spectrogram_samples_files = [os.path.join(training_folder, file) for file in os.listdir(training_folder)]
candidate_size = CANDIDATE_SIZE
batch_size = 15   # mobilenet_v2
# batch_size = 6   # densenet121
num_batches = 2000 // batch_size
# num_batches = 60 // batch_size
num_sub_samples = 200
# num_sub_samples = 70
training_data_generator = ContrastiveDataGenerator(spectrogram_samples_files, candidate_size, batch_size, num_batches, num_sub_samples, IMG_HEIGHT)

In [10]:
### Validation data
validation_set_file = os.path.join(output_data_folder, "validation_sets", "contrastive_validation_set.pickle")
with open(validation_set_file, 'rb') as f:
    validation_data = pickle.load(f)

In [11]:
def intra_class_variance_reduction(contrastive_model, contrastive_sub_samples, log_handler):   
    num_classes = 5
    samples_per_class = 10
    
    # prep for training
    criterion = nn.MSELoss(reduction='mean')
#     criterion = nn.L1Loss(reduction='mean')
    
    # train
    total_loss = 0.0
    spectrogram_sub_samples = random.sample(contrastive_sub_samples, num_classes)  # sample classes
    for spectrogram in spectrogram_sub_samples:   # treat one spectrogram/user as one class
        input_imgs = [BaseDataGenerator.get_sliding_img_slice_from_spectrogram(spectrogram) for _ in range(samples_per_class)]
        input_imgs = torch.tensor(input_imgs)
        inputs = input_imgs.to(device)  
        
        encoded_outputs = contrastive_model.encoder(inputs)
        mean = torch.mean(encoded_outputs, dim=0)   # get mean encoding vector of this class
        mean = mean.repeat(samples_per_class, 1)   # mean vector
        loss = criterion(encoded_outputs, mean)   # MSE against mean (variance)
        total_loss += loss
    return total_loss


In [12]:
def train_model(model, criterion, optimizer, scheduler, num_epochs, training_data_generator, validation_data, log_handler):
    since = time.time()
    best_acc = 0.0
    
    # intra class variance reduction 
    run_variance_reduction_on_epoch = [*range(1, num_epochs)]
#     run_variance_reduction_on_epoch = [*range(1, num_epochs, 2)]   # alternate
    variance_reduction_frequency = 2   # every n batches
#     loss_variance_scale = 0.20
    loss_variance_scale = 1.0

    t1 = time.time()
    for epoch in range(num_epochs):
        log_handler.print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        log_handler.print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            
            # intra class variance reduction
            run_variance_reduction = phase == 'train' and epoch in run_variance_reduction_on_epoch
            if run_variance_reduction: log_handler.print("-- variance reduction")
            
            # Main training (contrastive training)
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            
            # Iterate over data.
            batches_used = 0
            data_generator = training_data_generator.generate_batches() if phase == 'train' else validation_data
            for data in data_generator:
                batches_used += 1                
                
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):   # gradient only for train
                    
                    # intra class variance reduction
                    loss_var = None
                    if batches_used % variance_reduction_frequency == 0 and run_variance_reduction:
                        loss_var = intra_class_variance_reduction(model, training_data_generator.sub_samples, log_handler)
                        loss_var *= loss_variance_scale
                        loss_var.backward()
                    
                    # Main training (contrastive training)
                    input_imgs, labels = data
                    inputs = [img.to(device) for img in input_imgs]
                    labels = labels.to(device)
                    
                    outputs = model(inputs)      
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
    
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                if loss_var is not None: loss += loss_var   # Overall loss with variance reduction
                running_loss += loss.item() * inputs[0].size(0)
                running_corrects += torch.sum(preds == labels.data)
                
            if phase == 'train':
                scheduler.step()
            
            epoch_loss = running_loss / (batches_used * inputs[0].size(0))
            epoch_acc = running_corrects.double() / (batches_used * inputs[0].size(0))
            log_handler.print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc and epoch > 0:
                best_acc = epoch_acc
                log_handler.save_pytorch_model(model, "best_model_{}.pt".format(model.__class__.__name__))
                example = [torch.rand(1, 3, IMG_HEIGHT, IMG_HEIGHT), torch.rand(1, 3, IMG_HEIGHT, IMG_HEIGHT)]
                log_handler.save_pytorch_model_as_torchscript(model, "mobile_model.pt", (example,))

        # end of epoch
        log_handler.print("Time taken is {} seconds".format(int(time.time()-t1)))
        t1 = time.time()
        log_handler.print()

    time_elapsed = time.time() - since
    log_handler.print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    log_handler.print('Best val Acc: {:4f}'.format(best_acc))
    

In [13]:
### Train

epochs = 70
# epochs = 50

model_ft = MultiSiameseContrastiveClassifierNet().to(device)
criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.Adam(model_ft.parameters(), lr = 0.0001)

# Decay LR by a factor of 0.1 every 7 epochs
# learning_rate_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)
learning_rate_scheduler = lr_scheduler.CyclicLR(optimizer_ft, base_lr=0.0001, max_lr=0.001, cycle_momentum=False)   # 0.001 seems better


### Train 

# Logger
model_save_folder = os.path.join(models_folder, "contrastive_encoder")
log_handler = ModelSaveAndLogHandler(model_save_folder, enable_model_saving=True, enable_logging=True)   # init
model_def_src_file_path = os.path.join(r"D:\Desktop\projects\speaker_recognition_voxceleb1\scripts", "model_definitions.py")
log_handler.save_model_definition_file(model_def_src_file_path)   # copy model def file
print(log_handler.folder)

# Description
log_handler.print("Description: Candidates: 5, Encoding: 128, Projection: None")
log_handler.print("Base Model: mobileNetV2")
# log_handler.print("Base Model: densenet121")
log_handler.print("With Intra Class Variance Reduction")

# Train
# train_model(model_ft, criterion, optimizer_ft, learning_rate_scheduler, epochs, num_batches, training_data_generator, log_handler)
train_model(model_ft, criterion, optimizer_ft, learning_rate_scheduler, epochs, training_data_generator, validation_data, log_handler)

D:\Desktop\projects\speaker_recognition_voxceleb1\output_data\models\contrastive_encoder\2020-04-08_02-37-55
Description: Candidates: 5, Encoding: 128, Projection: None
Base Model: mobileNetV2
With Intra Class Variance Reduction
Epoch 0/69
----------
train Loss: 1.5229 Acc: 0.2962
val Loss: 1.4878 Acc: 0.3253
Time taken is 205 seconds

Epoch 1/69
----------
-- variance reduction
train Loss: 1.5531 Acc: 0.3143
val Loss: 1.4971 Acc: 0.3368
MODEL SAVED
MODEL SAVED (MOBILE)
Time taken is 258 seconds

Epoch 2/69
----------
-- variance reduction
train Loss: 1.5168 Acc: 0.3434
val Loss: 1.4521 Acc: 0.3459
MODEL SAVED
MODEL SAVED (MOBILE)
Time taken is 260 seconds

Epoch 3/69
----------
-- variance reduction
train Loss: 1.5327 Acc: 0.3268
val Loss: 1.4436 Acc: 0.3679
MODEL SAVED
MODEL SAVED (MOBILE)
Time taken is 260 seconds

Epoch 4/69
----------
-- variance reduction
train Loss: 1.4882 Acc: 0.3539
val Loss: 1.4430 Acc: 0.3554
Time taken is 253 seconds

Epoch 5/69
----------
-- variance reduc

val Loss: 1.3577 Acc: 0.4331
Time taken is 210 seconds

Epoch 59/69
----------
-- variance reduction
train Loss: 1.3962 Acc: 0.4095
val Loss: 1.3561 Acc: 0.4371
Time taken is 211 seconds

Epoch 60/69
----------
-- variance reduction
train Loss: 1.4031 Acc: 0.3930
val Loss: 1.3511 Acc: 0.4195
Time taken is 211 seconds

Epoch 61/69
----------
-- variance reduction
train Loss: 1.3845 Acc: 0.4301
val Loss: 1.3485 Acc: 0.4286
Time taken is 211 seconds

Epoch 62/69
----------
-- variance reduction
train Loss: 1.3815 Acc: 0.3990
val Loss: 1.3575 Acc: 0.4326
Time taken is 213 seconds

Epoch 63/69
----------
-- variance reduction
train Loss: 1.3787 Acc: 0.3935
val Loss: 1.3550 Acc: 0.4311
Time taken is 214 seconds

Epoch 64/69
----------
-- variance reduction
train Loss: 1.4076 Acc: 0.3990
val Loss: 1.3676 Acc: 0.4321
Time taken is 212 seconds

Epoch 65/69
----------
-- variance reduction
train Loss: 1.3689 Acc: 0.4241
val Loss: 1.3601 Acc: 0.4276
Time taken is 214 seconds

Epoch 66/69
--------

In [14]:
random_acc = 1 / CANDIDATE_SIZE
random_acc

0.2

### TODO: Overall
* Contrastive classifier
    * separate train and validate methods

* (Done) Model saving / checkpointing
* **Build binary classifier**