Dataset Shaping

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import seaborn as sn
import os

# %matplotlib inline
# plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
# plt.rcParams['image.interpolation'] = 'nearest'
# # change this setting for other datasets
# plt.rcParams['image.cmap'] = TBD

# for auto-reloading external modules

# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
# %load_ext autoreload
# %autoreload 2

# Download the dataset
For this we use the keras.datasets.mnist package. This package contains the MNIST dataset which is a dataset of 60,000 28x28 grayscale images of the 10 digits, along with a test set of 10,000 images. The dataset is split into 60,000 training images and 10,000 testing images. The images are grayscale, and the labels are integers (between 0 and 9).

In [None]:
# Load MNIST dataset, we can use Tensorflow for this
train_dataset, test_dataset = TBD

X_train_q = train_dataset[0]
y_train = train_dataset[1]

X_test_q = test_dataset[0]
y_test = test_dataset[1]

print("Shape of train dataset: {}".format(X_train_q.shape))
print("Shape of train labels: {}".format(y_train.shape))
print("Shape of test dataset: {}".format(X_test_q.shape))
print("Shape of test labels: {}".format(y_test.shape))

# Visualize the data
We can visualize the data by plotting the first 25 images from the training set and display the class name below each image.

In [None]:
# Visualize some examples from the dataset.
# We show a few examples of training images from each class.

classes = TBD
num_classes = len(classes)
samples_per_class = 6
fig = plt.figure(figsize=(18,10))
for y, cls in enumerate(classes):
    idxs = np.flatnonzero(y_train == y)
    idxs = np.random.choice(idxs, samples_per_class, replace=False)
    for i, idx in enumerate(idxs):
        plt_idx = i * num_classes + y + 1
        plt.subplot(samples_per_class, num_classes, plt_idx)
        plt.imshow(X_train_q[idx].astype('uint8'))
        plt.axis('off')
        if i == 0:
            plt.title(cls)

# Rescale the input data
The input data is in the range of [0, 255]. We need to rescale the input data to the range of [0, 1] before feeding it to the model. We can do this by dividing the input data by 255. Furthermore, the labels are represented as one-hot encoded vectors.

In [None]:
X_train = (X_train_q.astype(np.float32) / 255).astype(np.float32)[...,np.newaxis]
X_test = (X_test_q.astype(np.float32) / 255).astype(np.float32)[...,np.newaxis]

y_train = tf.one_hot(y_train.squeeze(), depth=10)
y_test = tf.one_hot(y_test.squeeze(), depth=10)

print("Shape of train dataset: {}".format(X_train.shape))
print("Shape of train labels: {}".format(y_train.shape))
print("Shape of test dataset: {}".format(X_test.shape))
print("Shape of test labels: {}".format(y_test.shape))

# Create the torch 
https://github.com/clovaai/voxceleb_trainer/blob/master/models/VGGVox.py


In [1]:
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter




In [None]:
class VGGModel(nn.Module):
    def __init__(self, nOut = 1024, encoder_type='SAP', log_input=True, **kwargs):
        super(MainModel, self).__init__();

        print('Embedding size is %d, encoder %s.'%(nOut, encoder_type))
        
        self.encoder_type = encoder_type
        self.log_input    = log_input

        self.netcnn = nn.Sequential(
            nn.Conv2d(1, 96, kernel_size=(5,7), stride=(1,2), padding=(2,2)),
            nn.BatchNorm2d(96),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(1,3), stride=(1,2)),

            nn.Conv2d(96, 256, kernel_size=(5,5), stride=(2,2), padding=(1,1)),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(3,3), stride=(2,2)),

            nn.Conv2d(256, 384, kernel_size=(3,3), padding=(1,1)),
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True),

            nn.Conv2d(384, 256, kernel_size=(3,3), padding=(1,1)),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            nn.Conv2d(256, 256, kernel_size=(3,3), padding=(1,1)),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=(3,3), stride=(2,2)),

            nn.Conv2d(256, 512, kernel_size=(4,1), padding=(0,0)),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            
        );

        if self.encoder_type == "MAX":
            self.encoder = nn.AdaptiveMaxPool2d((1,1))
            out_dim = 512
        elif self.encoder_type == "TAP":
            self.encoder = nn.AdaptiveAvgPool2d((1,1))
            out_dim = 512
        elif self.encoder_type == "SAP":
            self.sap_linear = nn.Linear(512, 512)
            self.attention = self.new_parameter(512, 1)
            out_dim = 512
        else:
            raise ValueError('Undefined encoder')

        self.fc = nn.Linear(out_dim, nOut)

        self.instancenorm   = nn.InstanceNorm1d(40)
        self.torchfb        = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=512, win_length=400, hop_length=160, f_min=0.0, f_max=8000, pad=0, n_mels=40)

    def new_parameter(self, *size):
        out = nn.Parameter(torch.FloatTensor(*size))
        nn.init.xavier_normal_(out)
        return out
        
    def forward(self, x):

        with torch.no_grad():
            with torch.cuda.amp.autocast(enabled=False):
                x = self.torchfb(x)+1e-6
                if self.log_input: x = x.log()
                x = self.instancenorm(x).unsqueeze(1)

        x = self.netcnn(x);

        if self.encoder_type == "MAX" or self.encoder_type == "TAP":
            x = self.encoder(x)
            x = x.view((x.size()[0], -1))

        elif self.encoder_type == "SAP":
            x = x.permute(0, 2, 1, 3)
            x = x.squeeze(dim=1).permute(0, 2, 1)  # batch * L * D
            h = torch.tanh(self.sap_linear(x))
            w = torch.matmul(h, self.attention).squeeze(dim=2)
            w = F.softmax(w, dim=1).view(x.size(0), x.size(1), 1)
            x = torch.sum(x * w, dim=1)

        x = self.fc(x);

        return x;

