In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
import numpy as np
import speech_task_generator as tg
import os
import math
import argparse
import random

In [2]:
DATA_DIR = "../vec/"
TRAIN_NUM = 30
FEATURE_DIM = 512
RELATION_DIM = 8
CLASS_NUM = 5
SAMPLE_NUM_PER_CLASS = 5
BATCH_NUM_PER_CLASS = 15
EPISODE = 1000000
TEST_EPISODE = 1000
LEARNING_RATE = 0.0001

In [19]:
class RelationNetwork(nn.Module):
    """Relation Network"""
    def __init__(self, input_size, hidden_size):
        super(RelationNetwork, self).__init__()
        self.input_size = input_size # vector feature dim
        self.fc1 = nn.Linear(input_size*2, hidden_size)
        self.fc2 = nn.Linear(hidden_size, 1)

    def forward(self, sample, query, num_class):
        """
        sample: (sample_per_class x num_class) x vec_dim
        query: (batch_per_class x num_class) x vec_dim
        """
        sample = sample.view(num_class, -1, self.input_size)
        sample_num_per_class = sample.size(1)
        sample = torch.sum(sample, 1).squeeze(1) # sum within each class -> num_class x vec_dim

        query = query.view(num_class, -1, self.input_size)
        batch_num_per_class = query.size(1)
        # align
        sample_ext = sample.unsqueeze(0).repeat(batch_num_per_class*num_class, 1, 1) # (batch_per_class x num_class) x num_class x vec_dim
        query = query.view(-1, self.input_size)
        query_ext = query.unsqueeze(0).repeat(num_class, 1, 1) # num_classes x (batch_per_class x num_class) x vec_dim
        query_ext = torch.transpose(query_ext, 0, 1) # (batch_per_class x num_class) x num_class x vec_dim
        # concat
        relation_pairs = torch.cat((sample_ext, query_ext), 2).view(-1, self.input_size*2)
        # calculate relations
        out = F.relu(self.fc1(relation_pairs))
        out = torch.sigmoid(self.fc2(out))
        out = out.view(-1, num_class)
        return out

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm') != -1:
        m.weight.data.fill_(1)
        m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        n = m.weight.size(1)
        m.weight.data.normal_(0, 0.01)
        m.bias.data = torch.ones(m.bias.data.size())

In [4]:
metatrain_speech_files, metatest_speech_files = tg.voxceleb_speech_folder(data_folder=DATA_DIR, train_num=TRAIN_NUM)

In [21]:
relation_network = RelationNetwork(FEATURE_DIM,RELATION_DIM)
relation_network.apply(weights_init)

RelationNetwork(
  (fc1): Linear(in_features=1024, out_features=8, bias=True)
  (fc2): Linear(in_features=8, out_features=1, bias=True)
)

In [22]:
relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE)
relation_network_scheduler = StepLR(relation_network_optim,step_size=100000,gamma=0.5)

In [7]:
task = tg.VoxFewshotTask(metatrain_speech_files,CLASS_NUM,SAMPLE_NUM_PER_CLASS,BATCH_NUM_PER_CLASS)
sample_dataloader = tg.get_data_loader(task,num_per_class=SAMPLE_NUM_PER_CLASS,split="train",shuffle=False)
batch_dataloader = tg.get_data_loader(task,num_per_class=BATCH_NUM_PER_CLASS,split="test",shuffle=True)

{'../vec/id10305.txt': 0, '../vec/id10309.txt': 1, '../vec/id10302.txt': 2, '../vec/id10287.txt': 3, '../vec/id10293.txt': 4}
../vec/id10305.txt 0
[0, 0, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
../vec/id10309.txt 1
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
../vec/id10302.txt 2
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
../vec/id10287.txt 3
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
../vec/id10293.txt 4
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [9]:
samples,sample_labels = sample_dataloader.__iter__().next()
batches,batch_labels = batch_dataloader.__iter__().next()

In [10]:
samples.size()

torch.Size([25, 512])

In [13]:
samples

tensor([[ 1.7426, -1.9934,  1.3741,  ...,  0.5516, -2.0889,  2.4397],
        [ 0.9145,  0.1659,  0.9504,  ..., -1.8690, -2.7939,  1.6301],
        [-2.8732,  2.7469,  3.8594,  ...,  3.5852,  2.7004,  6.5310],
        ...,
        [ 1.5173,  1.6825,  0.3520,  ...,  1.3454, -1.7398,  3.7136],
        [-4.3484,  3.3825,  3.7333,  ...,  2.6806, -4.7237,  0.6005],
        [-1.4059,  2.5253,  1.5449,  ...,  6.7129, -1.8187,  0.5236]])

In [11]:
batches.size()

torch.Size([75, 512])

In [12]:
sample_labels.size()

torch.Size([25])

In [23]:
relations = relation_network(sample=samples, query=batches, num_class=CLASS_NUM)

In [25]:
relations.size()

torch.Size([75, 5])

In [26]:
mse = nn.MSELoss()
one_hot_labels = torch.zeros(BATCH_NUM_PER_CLASS*CLASS_NUM, CLASS_NUM).scatter_(1, batch_labels.view(-1,1), 1)
loss = mse(relations,one_hot_labels)