<a href="https://colab.research.google.com/github/kjun1/CMVC/blob/main/cross.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

タスク

とりあえず論文読み直し

WaveRNN読む

voice encoderの最終層に非時系列化層の追加 or broadcast versionの意味理解

In [1]:
import torch
from torch import nn, optim

In [2]:
device = "cpu"

In [1]:
#import sys
#sys.path.append('../')
from cmvc import *

# レイヤー作成

In [3]:
class CBGLayer(nn.Module):
  """
  Conv+Bn+GLU
  """
  def __init__(self, in_channels, out_channels, kernel_size, stride,padding=0):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=in_channels,
                           out_channels=out_channels,
                           kernel_size=kernel_size,
                           stride = stride,
                           padding=padding)
    self.conv2 = nn.Conv2d(in_channels=in_channels,
                           out_channels=out_channels,
                           kernel_size=kernel_size,
                           stride = stride,
                           padding=padding)

    self.bn1 = nn.BatchNorm2d(out_channels)
    self.bn2 = nn.BatchNorm2d(out_channels)


  def forward(self, x):
    x1 = self.bn1(self.conv1(x))
    x2 = self.bn2(self.conv2(x))

    x = torch.cat((x1,x2),1)
    x = nn.functional.glu(x,1)
    
    return x

In [4]:
class CBLLayer(nn.Module):
  """
  Conv+Bn+LReLU
  """
  def __init__(self, in_channels, out_channels, kernel_size, stride,padding=0):
    super().__init__()
    self.conv = nn.Conv2d(in_channels=in_channels,
                          out_channels=out_channels,
                          kernel_size=kernel_size,
                          stride = stride,
                          padding=padding)

    self.bn = nn.BatchNorm2d(out_channels)

    self.lrelu = nn.LeakyReLU()


  def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    x = self.lrelu(x)

    return x

In [5]:
class DBGLayer(nn.Module):
  """
  Deconv + Bn + GLU
  """
  def __init__(self, in_channels, out_channels, kernel_size, stride,padding=0):
    super().__init__()
    self.deconv1 = nn.ConvTranspose2d(in_channels=in_channels,
                                      out_channels=out_channels,
                                      kernel_size=kernel_size,
                                      stride=stride,
                                      padding=padding)
    self.deconv2 = nn.ConvTranspose2d(in_channels=in_channels,
                                      out_channels=out_channels,
                                      kernel_size=kernel_size,
                                      stride=stride,
                                      padding=padding)
    
    self.bn1 = nn.BatchNorm2d(out_channels)
    self.bn2 = nn.BatchNorm2d(out_channels)



  def forward(self, x):
    x1 = self.bn1(self.deconv1(x))
    x2 = self.bn2(self.deconv2(x))

    x = torch.cat((x1,x2),1)
    x = nn.functional.glu(x,1)
    
    return x

In [6]:
class DBSLayer(nn.Module):
  """
  Deconv + Bn + SoftPlus
  """
  def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0):
    super().__init__()
    self.deconv = nn.ConvTranspose2d(in_channels=in_channels,
                                     out_channels=out_channels,
                                     kernel_size=kernel_size,
                                     stride=stride,
                                     padding=padding)
    
    self.bn = nn.BatchNorm2d(out_channels)

    self.softplus = nn.Softplus()


  def forward(self, x):
    x = self.deconv(x)
    x = self.bn(x)
    x = self.softplus(x)

    return x

In [7]:
class FlattenLayer(nn.Module):
  """
  (N, C, H, W)を(N, C*H*W)にする
  """
  def forward(self, x):
    sizes = x.size()
    return x.view(sizes[0],  -1)

class ReshapeLayer(nn.Module):
  """
  (N, C*H*W)を(N, C, H, W)にする
  """
  def forward(self, x, out_channel):
    sizes = x.size()
    h = int((sizes[1]/out_channel)**0.5)
    return x.view(sizes[0],  out_channel, h, h)

# Encoder Decoder

## UttrEncoder

In [None]:
class Hyper_UE:
  x_c = 1
  
  d1_k = (3,9)


In [3]:
class UttrEncoder(nn.Module):
  def __init__(self):

    super().__init__()
    self.uttr_enc_d1 = CBGLayer(in_channels=1,
                                out_channels=16,
                                kernel_size=(3, 9),
                                stride = (1,1),
                                padding=(1,4))
    
    self.uttr_enc_d2 = CBGLayer(in_channels=16,
                                out_channels=32,
                                kernel_size=(4, 8),
                                stride = (2,2),
                                padding=(1,3))    

    self.uttr_enc_d3 = CBGLayer(in_channels=32,
                                out_channels=32,
                                kernel_size=(4, 8),
                                stride = (2,2),
                                padding=(1,3))    
    
    self.uttr_enc_d4 = nn.Conv2d(in_channels=32,
                             out_channels=16,
                             kernel_size=(9, 5),
                             stride = (9,1),
                             padding=(0,2))

  def uttr_encoder(self, x):
    """
    音声のencoder
    """

    x = self.uttr_enc_d1(x)
    x = self.uttr_enc_d2(x)
    x = self.uttr_enc_d3(x)
    x = self.uttr_enc_d4(x)

    mean, log_var = torch.split(x, 8, dim=1) # 半分
     
    return mean, log_var

  def uttr_sample_z(self, mean, log_var):
    """
    音声の潜在変数出すやつ
    """
    epsilon = torch.randn(mean.shape).to(device)
    return mean + torch.exp(log_var) * epsilon

  def forward(self, x):
    mean, log_var = self.uttr_encoder(x)
    z = self.uttr_sample_z(mean, log_var)
    return z


## FaceEncoder

In [4]:
class FaceEncoder(nn.Module):
  def __init__(self):

    super().__init__()
    self.face_enc_d1 = nn.Sequential(nn.Conv2d(in_channels=3,
                                               out_channels=32,
                                               kernel_size=(6, 6),
                                               stride = (2,2),
                                               padding=(2,2)),
                                     nn.LeakyReLU())

    self.face_enc_d2 = CBLLayer(in_channels=32,
                                out_channels=64,
                                kernel_size=(6, 6),
                                stride = (2,2),
                                padding=(2,2))
    
    self.face_enc_d3 = CBLLayer(in_channels=64,
                                out_channels=128,
                                kernel_size=(4, 4),
                                stride = (2,2),
                                padding=(1,1))
    
    self.face_enc_d4 = CBLLayer(in_channels=128,
                                out_channels=128,
                                kernel_size=(4, 4),
                                stride = (2,2),
                                padding=(1,1)
                                )
    
    self.face_enc_d5 = CBLLayer(in_channels=128,
                                out_channels=256,
                                kernel_size=(4, 4),
                                stride = (2,2),
                                padding=(1,1))
    
    self.face_enc_d6 = FlattenLayer() #flattenまでに(n, c, 1, 1)になってる前提

    self.face_enc_d7 = nn.Sequential(nn.Linear(256,256),
                                     nn.LeakyReLU())
    
    self.face_enc_d8 = nn.Sequential(nn.Linear(256,16),
                                     nn.LeakyReLU())

  def face_encoder(self, y):
    """
    顔面のencoder
    """

    y = self.face_enc_d1(y)
    y = self.face_enc_d2(y)
    y = self.face_enc_d3(y)
    y = self.face_enc_d4(y)
    y = self.face_enc_d5(y)
    y = self.face_enc_d6(y)
    y = self.face_enc_d7(y)
    y = self.face_enc_d8(y)

    
    
    mean, log_var = torch.split(y, 8, dim=1) # 半分
     
    return mean, log_var

  def face_sample_z(self, mean, log_var):
    """
    顔面の潜在変数出すやつ
    """
    epsilon = torch.randn(mean.shape).to(device)
    return mean + torch.exp(log_var) * epsilon

  def forward(self, y):
    mean, log_var = self.face_encoder(y)
    z = self.face_sample_z(mean, log_var)
    z = z.unsqueeze(-1).unsqueeze(-1)
    return z



## VoiceEncoder

In [5]:
class VoiceEncoder(nn.Module):
  def __init__(self):

    super().__init__()
    self.voice_enc_d1 = CBGLayer(in_channels=1,
                                 out_channels=32,
                                 kernel_size=(3,9),
                                 stride=(1,1),
                                 padding=(1,4))
    
    self.voice_enc_d2 = CBGLayer(in_channels=32,
                                 out_channels=64,
                                 kernel_size=(4,8),
                                 stride=(2,2),
                                 padding=(1,3))
    
    self.voice_enc_d3 = CBGLayer(in_channels=64,
                                 out_channels=128,
                                 kernel_size=(4,8),
                                 stride=(2,2),
                                 padding=(1,3))
    
    self.voice_enc_d4 = CBGLayer(in_channels=128,
                                 out_channels=128,
                                 kernel_size=(4,8),
                                 stride=(2,2),
                                 padding=(1,3))
    
    self.voice_enc_d5 = CBGLayer(in_channels=128,
                                 out_channels=128,
                                 kernel_size=(4,5),
                                 stride=(4,1),
                                 padding=(0,2))
    
    self.voice_enc_d6 = CBGLayer(in_channels=128,
                                 out_channels=64,
                                 kernel_size=(1,5),
                                 stride=(1,1),
                                 padding=(0,2))
    
    self.voice_enc_d7 = nn.Conv2d(in_channels=64,
                                  out_channels=16,
                                  kernel_size=(1,5),
                                  stride=(1,1),
                                 padding=(0,2))
                                 

  def voice_encoder(self, x):
    """
    顔面持ってくるencoder
    """
    x = self.voice_enc_d1(x)
    x = self.voice_enc_d2(x)
    x = self.voice_enc_d3(x)
    x = self.voice_enc_d4(x)
    x = self.voice_enc_d5(x)
    x = self.voice_enc_d6(x)
    x = self.voice_enc_d7(x)
        
    """
    第4層のConv2d出力のchannelの半分でmean半分でlog_varを予測している？
    """
    mean, log_var = torch.split(x, 8, dim=1) # 半分
     
    return mean, log_var

  def voice_sample_z(self, mean, log_var):
    """
    音声の潜在変数出すやつ
    """
    epsilon = torch.randn(mean.shape).to(device)
    return mean + torch.exp(log_var) * epsilon

  def forward(self, x):
    mean, log_var = self.voice_encoder(x)
    z = self.voice_sample_z(mean, log_var)

    z = z.squeeze(-1).squeeze(-1)
    return z


## UttrDecoder

In [6]:
class UttrDecoder(nn.Module):
  def __init__(self):

    super().__init__()

    self.uttr_dec_d1 = DBGLayer(in_channels=8,
                                out_channels=16,
                                kernel_size=(9,5),
                                stride=(9,1),
                                padding=(0,2))
  
    self.uttr_dec_d2 = DBGLayer(in_channels=16,
                                out_channels=16,
                                kernel_size=(4,8),
                                stride=(2,2),
                                padding=(1,3))
  
    self.uttr_dec_d3 = DBGLayer(in_channels=16,
                                out_channels=8,
                                kernel_size=(4,8),
                                stride=(2,2),
                                padding=(1,3))
  
    self.uttr_dec_d4 = nn.ConvTranspose2d(in_channels=8,
                                          out_channels=2,
                                          kernel_size=(3,9),
                                          stride=(1,1),
                                          padding=(1,4))
  
  def uttr_decoder(self, z, c):
    #print(z.size())
    
    x,_ = torch.broadcast_tensors(z, c)
    x = self.uttr_dec_d1(x)
    #print(x.size())
    
    x,_ = torch.broadcast_tensors(x, torch.cat((c, c),1))
    x = self.uttr_dec_d2(x)
    #print(x.size())

    x,_ = torch.broadcast_tensors(x, torch.cat((c, c),1))
    x = self.uttr_dec_d3(x)
    #print(x.size())

    x,_ = torch.broadcast_tensors(x, c)
    x = self.uttr_dec_d4(x)
    #print(x.size())

    mean, log_var = torch.split(x, 1, dim=1) # 半分
     
    return mean, log_var
  

  def uttr_sample_z(self, mean, log_var):

    epsilon = torch.randn(mean.shape).to(device)
    return mean + torch.exp(log_var) * epsilon
  

  def forward(self, z, c):
    mean, log_var = self.uttr_decoder(z, c)
    z = self.uttr_sample_z(mean, log_var)
    return z


## FaceDecoder

In [7]:
class FaceDecoder(nn.Module):
  def __init__(self):

    super().__init__()
    self.face_dec_d1 = nn.Sequential(nn.Linear(8,128),
                                     nn.Softplus())
    
    self.face_dec_d2 = nn.Sequential(nn.Linear(128,2048),
                                     nn.Softplus())
    
    self.face_dec_d3 = ReshapeLayer()

    self.face_dec_d4 = DBSLayer(in_channels=128,
                                out_channels=128,
                                kernel_size=(3,3),
                                stride=(2,2),
                                padding=(2,2))

    self.face_dec_d5 = DBSLayer(in_channels=128,
                                out_channels=128,
                                kernel_size=(6,6),
                                stride=(2,2),
                                padding=(2,2))
    
    self.face_dec_d6 = DBSLayer(in_channels=128,
                                out_channels=64,
                                kernel_size=(6,6),
                                stride=(2,2),
                                padding=(2,2))
    
    self.face_dec_d7 = DBSLayer(in_channels=64,
                                out_channels=32,
                                kernel_size=(6,6),
                                stride=(2,2),
                                padding=(2,2))
    
    self.face_dec_d8 = nn.Conv2d(in_channels=32,
                                 out_channels=6,
                                 kernel_size=(5,5),
                                 stride=(1,1))
    
    

  def face_decoder(self, c):

    y = self.face_dec_d1(c)
    y = self.face_dec_d2(y)
    y = self.face_dec_d3(y, 128)
    y = self.face_dec_d4(y)
    y = self.face_dec_d5(y)
    y = self.face_dec_d6(y)
    y = self.face_dec_d7(y)
    y = self.face_dec_d8(y)

    mean, log_var = torch.split(y, 3, dim=1) # 半分
     
    return mean, log_var

  def face_sample_z(self, mean, log_var):
    """
    顔面の潜在変数出すやつ
    """
    epsilon = torch.randn(mean.shape).to(device)
    return mean + torch.exp(log_var) * epsilon

  def forward(self, y):
    mean, log_var = self.face_decoder(y)
    z = self.face_sample_z(mean, log_var)
    return z


# Net

In [4]:
class Net(nn.Module):
  def __init__(self):
    
    
    super().__init__()
    self.ue = UttrEncoder()
    self.ud = UttrDecoder()
    self.fe = FaceEncoder()
    self.fd = FaceDecoder()
    self.ve = VoiceEncoder()

  def forward(self, x, y):
    z = self.ue(x)
    c = self.fe(y)
    x_hat = self.ud(z, c)
    print(x_hat.size())
    c_hat = self.ve(x_hat)
    print(c_hat.size())
    y_hat = self.fd(c_hat)

    return y_hat
  
  def loss(self, x, y):
    """
    reconstruction + KL divergence
    """
    pass

# 確認

## net

In [2]:
net = Net()

### train

In [None]:
x = torch.ones((2,1,36,8))
y = torch.ones((2,3,32,32))

net.train()

x = x.to(device) 
y = y.to(device)
print(net.forward(x, y).size())

torch.Size([2, 1, 36, 8])
torch.Size([2, 8])
torch.Size([2, 3, 36, 36])


### eval

In [None]:
x = torch.ones((1,1,36,8))
y = torch.ones((1,3,32,32))

net.eval()
with torch.no_grad():
  x = x.to(device) 
  y = y.to(device)
  print(net.forward(x, y).size())

torch.Size([1, 1, 36, 8])
torch.Size([1, 8])
torch.Size([1, 3, 36, 36])


## encoder

In [None]:
model = FaceEncoder()

In [None]:
x = torch.ones((1,3,32,32))
"""
入力
1: バッチサイズ
2: channel
3: mfcc_size(画像における縦)
4: uttr_len(画像における横)
"""

model.eval()
with torch.no_grad():
  x = x.to(device) 
  print(model.forward(x).size())

torch.Size([1, 8])


## decoder

In [None]:
model = UttrDecoder()

In [None]:
z = torch.ones((1,8,1,100))
c = torch.ones((1,8,1,1))


model.eval()
with torch.no_grad():
  c = c.to(device) 
  print(model.forward(z, c).size())

torch.Size([1, 1, 36, 400])


## broadcastの仕組み

In [None]:
x = torch.ones((1,8,9,100))
print(x)
print(x.size())
y = torch.ones((1,8,1,1))
print(y)
print(y.size())

print(x+y)

tensor([[[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         ...,

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 