In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class Sampling(nn.Module):
    def __init__(self, seed=1337):
        super(Sampling, self).__init__()
        self.seed = seed

    def forward(self, inputs):
        z_mean, z_log_var = inputs
        batch = z_mean.size(0)
        dim = z_mean.size(1)
        epsilon = torch.randn(batch, dim, generator=torch.Generator().manual_seed(self.seed))
        return z_mean + torch.exp(0.5 * z_log_var) * epsilon


In [None]:
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 21, 16)  # Adjust 21 according to the input dimensions after Conv1d layers
        self.fc_mean = nn.Linear(16, latent_dim)
        self.fc_log_var = nn.Linear(16, latent_dim)
        self.sampling = Sampling()

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        z_mean = self.fc_mean(x)
        z_log_var = self.fc_log_var(x)
        z = self.sampling((z_mean, z_log_var))
        return z_mean, z_log_var, z

In [None]:
class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 21 * 2)  # Adjust 21 according to the output size needed
        self.reshape = nn.Unflatten(1, (2, 21))
        self.deconv1 = nn.ConvTranspose1d(in_channels=2, out_channels=4, kernel_size=3, stride=2, padding=1)
        self.deconv2 = nn.ConvTranspose1d(in_channels=4, out_channels=1, kernel_size=3, stride=2, padding=1)

    def forward(self, z):
        x = F.relu(self.fc(z))
        x = self.reshape(x)
        x = F.relu(self.deconv1(x))
        x = torch.sigmoid(self.deconv2(x))
        return x

In [None]:
class VAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x):
        z_mean, z_log_var, z = self.encoder(x)
        reconstruction = self.decoder(z)
        return reconstruction, z_mean, z_log_var

    def train_step(self, data, optimizer):
        optimizer.zero_grad()
        # z_mean, z_log_var, z = self.encoder(data)
        # reconstruction = self.decoder(z)
        reconstruction, z_mean, z_log_var = self.forward(data)
        reconstruction_loss = torch.mean(
            torch.sum(
                F.binary_cross_entropy(reconstruction, data, reduction='none'),
                dim=(1, 2)
            )
        )
        kl_loss = -0.5 * torch.sum(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())
        kl_loss = torch.mean(torch.sum(kl_loss, dim=1))
        total_loss = reconstruction_loss + kl_loss
        total_loss.backward()
        optimizer.step()
        return total_loss.item(), reconstruction_loss.item(), kl_loss.item()