# transfer tdnn_xvectoer weight to new model

In [1]:
import os
%load_ext autoreload
%autoreload 2
%pylab
%matplotlib inline
    
import pandas as pd
import pickle
import numpy as np
import sys
import os

Using matplotlib backend: TkAgg
Populating the interactive namespace from numpy and matplotlib


In [2]:
sys.path.append('/host/projects/sv_experiments/sv_system/')
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="3"

In [3]:
import torch

checkpoint = torch.load("../sv_system/models/voxc2_fbank64_vad/tdnn_xvector_softmax/fbank64_200f_800f_v00/model_best.pth.tar")

In [4]:
saved_state = checkpoint['state_dict']

In [5]:
from model.tdnnModel import tdnn_xvector

config = dict(input_dim=64, loss='softmax')
saved_model = tdnn_xvector(config, n_labels=6114)

In [6]:
import torch.nn as nn
from model.tdnnModel import tdnn_xvector, st_pool_layer

class tdnn_xvector_untied(nn.Module):
    """xvector architecture
        tdnn6.affine is embeding layer no
        untying classifier for flexible embedding positon
        conv1d --> conv2d
    """
    def __init__(self, config, base_width=512, n_labels=31):
        super(tdnn_xvector_untied, self).__init__()
        inDim = config['input_dim']
        self.tdnn = nn.Sequential(
            nn.Conv1d(inDim, base_width, stride=1, dilation=1, kernel_size=5),
            nn.BatchNorm1d(base_width),
            nn.ReLU(True),
            nn.Conv1d(base_width, base_width, stride=1, dilation=3, kernel_size=3),
            nn.BatchNorm1d(base_width),
            nn.ReLU(True),
            nn.Conv1d(base_width, base_width, stride=1, dilation=4, kernel_size=3),
            nn.BatchNorm1d(base_width),
            nn.ReLU(True),
            nn.Conv1d(base_width, base_width, stride=1, dilation=1, kernel_size=1),
            nn.BatchNorm1d(base_width),
            nn.ReLU(True),
            nn.Conv1d(base_width, 1500, stride=1, dilation=1, kernel_size=1),
            nn.BatchNorm1d(1500),
            nn.ReLU(True),
            st_pool_layer(),
            nn.Linear(3000, base_width),
        )

        loss_type = config["loss"]
        if loss_type == "angular":
            last_fc = AngleLinear(base_width, n_labels)
        elif loss_type == "softmax":
            last_fc = nn.Linear(base_width, n_labels)
        else:
            print("not implemented loss")
            raise NotImplementedError

        self.tdnn6_bn = nn.BatchNorm1d(base_width)
        self.tdnn6_relu = nn.ReLU(True)
        self.tdnn7_affine = nn.Linear(base_width, base_width)
        self.tdnn7_bn = nn.BatchNorm1d(base_width)
        self.tdnn7_relu = nn.ReLU(True)
        self.tdnn8_last = last_fc


        self._initialize_weights()

    def embed(self, x):
        # x = x.squeeze(1)
        # (batch, time, freq) -> (batch, freq, time)
        # x = x.permute(0,2,1)
        x = self.tdnn(x)

        return x

    def forward(self, x):
        x = self.embed(x)
        x = self.tdnn6_bn(x)
        x = self.tdnn6_relu(x)
        x = self.tdnn7_affine(x)
        x = self.tdnn7_bn(x)
        x = self.tdnn7_relu(x)
        x = self.tdnn8_last(x)

        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.Conv1d):
                n = m.kernel_size[0] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

In [7]:
new_model = tdnn_xvector_untied(config, base_width=512, n_labels=6114)

In [8]:
new_state = new_model.state_dict()

In [9]:
for k, v in zip(new_state.keys(), saved_state.values()):
    new_state[k] = v

In [10]:
new_model.load_state_dict(new_state)

In [11]:
new_model

tdnn_xvector_untied(
  (tdnn): Sequential(
    (0): Conv1d(64, 512, kernel_size=(5,), stride=(1,))
    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): Conv1d(512, 512, kernel_size=(3,), stride=(1,), dilation=(3,))
    (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace)
    (6): Conv1d(512, 512, kernel_size=(3,), stride=(1,), dilation=(4,))
    (7): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace)
    (9): Conv1d(512, 512, kernel_size=(1,), stride=(1,))
    (10): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace)
    (12): Conv1d(512, 1500, kernel_size=(1,), stride=(1,))
    (13): BatchNorm1d(1500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): ReLU(inplace)
    (15): st_pool_layer()
    (16): Linear(in_features=3000, out_featur

In [14]:
checkpoint['state_dict'] = new_model.state_dict()

In [16]:
torch.save(checkpoint, "../sv_system/models/voxc2_fbank64_vad/tdnn_xvector_softmax/voxc2_fbank64_untied_model.pth.tar")