(1, 64, 3)


In [77]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
# define encoder in pytorch
class Encoder(nn.Module):
    def __init__(self, latent_dim,nl):
        super(Encoder, self).__init__()
        self.nl=nl
        self.conv1 = nn.Conv1d(3, 16, 3, padding=1)
        self.conv2 = nn.Conv1d(16, 8, 3, padding=1)
        self.conv3 = nn.Conv1d(8, 8, 3, padding=1)
        self.conv4 = nn.Conv1d(8, 4, 3, padding=1)
        self.fc1 = nn.Linear(4*nl, latent_dim*2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        nl=self.nl
        print(x.shape)
        x = F.max_pool1d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool1d(x, 2)
        x = F.relu(self.conv3(x))
        x = F.max_pool1d(x, 2)
        x = F.relu(self.conv4(x))
        x = F.max_pool1d(x, 2)
        print(x.size())
        x = x.view(-1, 4*nl)
        x = self.fc1(x)
        return x

# define decoder in pytorch
class Decoder(nn.Module):
    def __init__(self, latent_dim,nl):
        super(Decoder, self).__init__()
        self.nl=nl
        self.fc1 = nn.Linear(latent_dim, nl*4)
        self.conv1 = nn.ConvTranspose1d(4, 8, 3, padding=1)
        self.conv2 = nn.ConvTranspose1d(8, 8, 3, padding=1)
        self.conv3 = nn.ConvTranspose1d(8, 16, 3, padding=1)
        self.conv4 = nn.ConvTranspose1d(16, 3, 3, padding=1)

    def forward(self, x):
        nl=self.nl
        x = F.relu(self.fc1(x))
        x = x.view(-1, 4, self.nl)
        x = F.relu(self.conv1(x))
        x = F.interpolate(x, scale_factor=2)
        x = F.relu(self.conv2(x))
        x = F.interpolate(x, scale_factor=2)
        x = F.relu(self.conv3(x))
        x = F.interpolate(x, scale_factor=2)
        x = F.relu(self.conv4(x))
        x = F.interpolate(x, scale_factor=2)
        return x

# define variational autoencoder in pytorch
class VAE(nn.Module):
    def __init__(self, latent_dim,nl):
        super(VAE, self).__init__()
        self.encoder = Encoder(latent_dim,nl)
        self.decoder = Decoder(latent_dim,nl)

    def forward(self, x):
        x = self.encoder(x)
        z_mean = x[:, 0:latent_dim]
        z_log_var = x[:, latent_dim:]
        z = self.reparameterization(z_mean, z_log_var)
        x = self.decoder(z)
        return x, z_mean, z_log_var

    def reparameterization(self, z_mean, z_log_var):
        std = torch.exp(0.5 * z_log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(z_mean)

latent_dim=5

encoder=Encoder(latent_dim,4)
decoder=Decoder(latent_dim,4)
vae=VAE(latent_dim,4)
x=torch.randn(10,3,64)
xp=vae(x)
# train the the variational autoencoder model
optimizer = optim.Adam(vae.parameters(), lr=1e-3)   
criterion = nn.MSELoss()
# include the KL divergence loss
def loss_function(x_hat, x, z_mean, z_log_var):
    recon_loss = criterion(x_hat, x)
    kl_div = -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())
    return recon_loss + kl_div

x_data = torch.randn(1000,3,64)
# make a custom dataloader
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)
training_dataset = CustomDataset(x_data)
train_loader = torch.utils.data.DataLoader(training_dataset, batch_size=64, shuffle=True)

for epoch in range(1):
    optimizer.zero_grad()
    x = next(iter(train_loader))
    # sample randomnly a subset from x_data
    x_hat, z_mean, z_log_var = vae(x)
    loss = criterion(x_hat, x)
    loss.backward()
    optimizer.step()
    print('epoch [{}/{}], loss:{:.4f}'.format(epoch+1, 100, loss.item()))



torch.Size([10, 16, 64])
torch.Size([10, 4, 4])
torch.Size([10, 16, 64])
torch.Size([10, 4, 4])
epoch [1/100], loss:0.9324


In [82]:
print(x.size(1))
print(len(iter(train_loader)))

3
100
