In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from torch.distributions import Normal
import random
import numpy as np
import pysnooper
import pdb

In [2]:
# s = torch.ones([32,40])
# d = torch.ones([32,40])
# torch.stack([s,d],1).shape

In [3]:
# embedding = nn.Embedding(10, 3)
# input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
# embedding(input).shape

In [4]:
class DisentangledVAE(nn.Module):
    def __init__(self,
             roll_dims,
             hidden_dims,
             embed_dims,
             feature_dims,
             z_other_dims,
             word2idx,
             n_step,
             k=1000):
        super().__init__()
        self.gru_0 = nn.GRU(
            roll_dims,
            hidden_dims,
            batch_first=True,
            bidirectional=True)
        self.feature_dims = feature_dims
        z_key_dims,z_meter_dims,z_culture_dims = feature_dims, feature_dims, feature_dims # 兼容性写法
        self.z_dims = z_key_dims + z_meter_dims + z_culture_dims + z_other_dims
        self.z_key_dims = z_key_dims
        self.z_meter_dims = z_meter_dims
        self.z_culture_dims = z_culture_dims
        self.z_other_dims = z_other_dims
        
        self.linear_mu = nn.Linear(hidden_dims * 2, self.z_dims)
        self.linear_var = nn.Linear(hidden_dims * 2, self.z_dims)
        self.grucell_1 = nn.GRUCell(self.z_dims + roll_dims,
                                        hidden_dims)
        self.grucell_2 = nn.GRUCell(hidden_dims, hidden_dims)
        self.linear_init = nn.Linear(self.z_dims, hidden_dims)
        self.linear_out = nn.Linear(hidden_dims, roll_dims)
        self.n_step = n_step
        self.sample = None
        self.roll_dims = roll_dims
        self.hidden_dims = hidden_dims
        self.eps = 1
        self.iteration = 0
        self.sample = None
        self.k = torch.FloatTensor([k])
        
        # NLP: word to embeddings
        self.embed_dims = embed_dims
        self.word2idx = word2idx
#         self.linear_dims = 8
        self.word_embeds = nn.Embedding(len(self.word2idx), self.embed_dims)
        
        self.music_embeds = nn.Linear(self.feature_dims, self.embed_dims)
        
        self.weight_init()
    
    def weight_init(self):
        for p in self.parameters():
            if type(p) == nn.GRU:
                torch.nn.init.orthogonal_(p)
            elif type(p) == nn.Linear:
                torch.nn.init.normal_()
#             else:
#                 torch.nn.init.normal_(p, 0, 1)
    
    def _sampling(self, x):
        idx = x.max(1)[1]
        x = torch.zeros_like(x)
        arange = torch.arange(x.size(0)).long()
        if torch.cuda.is_available():
            arange = arange.cuda()
        x[arange, idx] = 1
        return x
    
    def encoder(self, x):
        _, x = self.gru_0(x)
        x = x.transpose_(0, 1).contiguous()
        x = x.view(x.size(0), -1)
        mu = self.linear_mu(x)
        var = self.linear_var(x).exp_()
        distribution = Normal(mu, var) 
        return distribution, mu, var

    def decoder(self, z):
        out = torch.zeros((z.size(0), self.roll_dims))
        out[:, -1] = 1.
        x = []
        t = torch.tanh(self.linear_init(z))
        hx = t
        if torch.cuda.is_available():
            out = out.cuda()
        for i in range(self.n_step):
            out = torch.cat([out, z], 1)
            hx = self.grucell_1(out, hx)
            out = F.log_softmax(self.linear_out(hx), 1)
            x.append(out)
            if self.training:
                p = torch.rand(1).item()
                if p < self.eps:
                    out = self.sample[:, i, :]
                else:
                    out = self._sampling(out)
                self.eps = self.k / \
                    (self.k + torch.exp(self.iteration / self.k))
            else:
                out = self._sampling(out)
        return torch.stack(x, 1)
    
    def language_encoder(self, keywords):
        language_index = torch.transpose(torch.LongTensor([[self.word2idx[str(j)]  for j in i] for i in keywords]).cuda(), 0, 1)
        language_matrix = self.word_embeds(language_index)
        return language_matrix, language_index
    
    def music_encoder(self, z):
#         z = dis.mean
        z_key = z[:,:self.feature_dims].clone()
        z_meter = z[:,self.feature_dims: 2* self.feature_dims].clone()
        z_culture = z[:,2*self.feature_dims: 3*self.feature_dims].clone()
        z_matrix = torch.stack([self.music_embeds(z_key), self.music_embeds(z_meter), self.music_embeds(z_culture)], 1)
        
        return z_matrix
        
    def forward(self, x, keywords):
        if self.training:
            self.sample = x
            self.iteration += 1
        dis, mu, var = self.encoder(x)
        if self.training:
            z = dis.rsample()
        else:
#             z = dis.mean
            z = dis.rsample()
        z_matrix = self.music_encoder(dis.mean)
        recon = self.decoder(z)
        language_matrix, language_index = self.language_encoder(keywords) # key, meter, culture
        output = (recon, dis.mean, dis.stddev, language_index, z_matrix, )
#         pdb.set_trace()
        return output