In [None]:


class VectorQuantizer(nn.Module):
    """
    Discretization bottleneck part of the VQ-VAE.

    Inputs:
    - n_e : number of embeddings
    - e_dim : dimension of embedding
    - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
    """

    def __init__(self, n_embeddings, embeddings_dim, beta=0.25):
        super(VectorQuantizer, self).__init__()
        self.n_embeddings = n_embeddings
        self.embeddings_dim = embeddings_dim
        self.beta = beta

        self.embedding = nn.Embedding(self.n_embeddings, self.embeddings_dim)
        self.embedding.weight.data.uniform_(-1.0 / self.n_embeddings, 1.0 / self.n_embeddings)

    def forward(self, z):
        """
        Inputs the output of the encoder network z and maps it to a discrete 
        one-hot vector that is the index of the closest embedding vector e_j

        z (continuous) -> z_q (discrete)

        z.shape = (batch, channel, height, width)

        quantization pipeline:

            1. get encoder input (B,C,H,W)
            2. flatten input to (B*H*W,C)

        """
        # reshape z -> (batch, height, width, depth, channel) and flatten
        z = z = z.permute(0, 2, 3, 4, 1).contiguous()
        z_flattened = z.view(-1, self.embeddings_dim)
        
        # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
        d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
            torch.sum(self.embedding.weight**2, dim=1) - 2 * \
            torch.matmul(z_flattened, self.embedding.weight.t())

        # find closest encodings
        min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
        min_encodings = torch.zeros(
            min_encoding_indices.shape[0], self.n_embeddings).to(device)
        min_encodings.scatter_(1, min_encoding_indices, 1)

        # get quantized latent vectors
        z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)

        # compute loss for embedding
        loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
            torch.mean((z_q - z.detach()) ** 2)

        # preserve gradients
        z_q = z + (z_q - z).detach()

        # perplexity
        e_mean = torch.mean(min_encodings, dim=0)
        perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
        print(z_q.shape)
        # reshape back to match original input shape
        z_q = z_q.permute(0, 4,1,2,3).contiguous()

        return loss, z_q, perplexity, min_encodings, min_encoding_indices

In [None]:

class VQVAE(nn.Module):
    def __init__(self, num_embeddings, embedding_dim):
        super(VQVAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv3d(1, 16, kernel_size=(3,4,4), stride=(1,2,2), padding=1),  # Example for 3D conv layer
            nn.ReLU(),
            nn.Conv3d(16, 1, kernel_size=(3,4,4), stride=(1,2,2), padding=1),
            nn.ReLU()
        )
        self.pre_quantization_conv = nn.Conv3d(
            1, embedding_dim, kernel_size=(3,4,4), stride=(1,2,2), padding = 1)
        
        self.VQ = VectorQuantizer(num_embeddings, embedding_dim)
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(embedding_dim, 1, kernel_size=(3,4,4), stride=(1,2,2), padding = 1
                              ),
            nn.ReLU(),
            nn.ConvTranspose3d(1, 16, kernel_size=(3,4,4), stride=(1,2,2), padding=1),
            nn.ReLU(),
            nn.ConvTranspose3d(16, 1, kernel_size=(3,4,4), stride=(1,2,2), padding=1),
            nn.Sigmoid()
        )

    def encode(self, x):
        x = self.encoder(x)
        return x

    def decode(self, x):
        return self.decoder(x)

    def forward(self, x):
        print(x.shape)
        x = self.encode(x)
        x = self.pre_quantization_conv(x)
        print(x.shape, "after prequant")
        embedding_loss, x, perplexity, _, _ = self.VQ(x)
        print(x.shape, "after quant")
        x_recon = self.decode(x)
        print(x_recon.shape)
        return embedding_loss, x, perplexity 

 


In [None]:

class VAE(nn.Module):
    def __init__(self, latent_dim = 8, channel_size = 9):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv3d(1, 16, kernel_size=(3,4,4), stride=(1,2,2), padding=1),  # Example for 3D conv layer
            nn.LeakyReLU(),
            nn.Conv3d(16, 1, kernel_size=(3,4,4), stride=(1,2,2), padding=1),
            nn.LeakyReLU(),
            nn.Flatten(),
            nn.Linear(64*64*channel_size,64*channel_size),
            nn.LeakyReLU(),
            
        )
        ### Latent space transformations

        self.mu = nn.Linear(64*channel_size,latent_dim)
        self.logvar = nn.Linear(64*channel_size,latent_dim)
        
        
        self.decoder = nn.Sequential(
            
            nn.Linear(latent_dim, 64*channel_size),
            nn.LeakyReLU(),
            nn.Linear(64*channel_size,64*64*channel_size),
            nn.Unflatten(1,(1,channel_size,64,64)),

            nn.ConvTranspose3d(1, 16, kernel_size=(3,4,4), stride=(1,2,2), padding=1),
            nn.LeakyReLU(),
            nn.ConvTranspose3d(16, 1, kernel_size=(3,4,4), stride=(1,2,2), padding=1),
            nn.Sigmoid()
        )

    def encode(self, x):
        x = self.encoder(x)
        mu = self.mu(x)
        logvar = self.logvar(x)
        
        return mu, logvar
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)  # Calculate standard deviation from log variance
        eps = torch.randn_like(std)
        return mu + eps * std
        
    def decode(self, x):
        return self.decoder(x)

    def forward(self, x):
        mu, logvar = self.encode(x)
        
        z = self.reparameterize(mu, logvar)
        z = self.decode(z)
        return z, mu, logvar

 


In [None]:

class VAE(nn.Module):
    def __init__(self, latent_dim = 64, channel_size = 9,latent_pixel_size = 16):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            
            nn.Conv3d(1, 8, kernel_size=(5,6,6), stride=(1,2,2), padding=(2,2,2)),  # Example for 3D conv layer
            nn.LeakyReLU(),
            
            
            nn.Conv3d(8, 16, kernel_size=(3,4,4), stride=(1,2,2), padding=1),
            nn.LeakyReLU(),
            
            nn.Conv3d(16,16, kernel_size = (3,3,3), stride= (1,1,1), padding = (1,1,1)),
            nn.LeakyReLU(),
            
            nn.Conv3d(16,32, kernel_size=(3,4,4), stride=(1,2,2), padding=1),
            nn.LeakyReLU(),
            
            nn.Conv3d(32,32, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(),
      
    

            nn.Flatten(),
            nn.Linear(32*latent_pixel_size*latent_pixel_size*channel_size,256),
            nn.LeakyReLU(),
            nn.Dropout(),
            
        )
        ### Latent space transformations

        self.mu = nn.Linear(256,latent_dim)
        self.logvar = nn.Linear(256,latent_dim)
        
        
        self.decoder = nn.Sequential(
            
            nn.Linear(latent_dim, 256),

            nn.LeakyReLU(),
            nn.Linear(256,latent_pixel_size*latent_pixel_size*channel_size*32),
            nn.Unflatten(1,(32,channel_size,latent_pixel_size,latent_pixel_size)),

            nn.ConvTranspose3d(32,32, kernel_size = 3, stride= 1, padding = 1),
            nn.LeakyReLU(),
            
            nn.ConvTranspose3d(32,16, kernel_size = (3,4,4), stride= (1,2,2), padding = 1),
            nn.LeakyReLU(),
            
            nn.ConvTranspose3d(16,16, kernel_size = (3,3,3), stride= (1,1,1), padding = (1,1,1)),
            nn.LeakyReLU(),
            
            nn.ConvTranspose3d(16, 8, kernel_size=(3,4,4), stride=(1,2,2), padding=1),
            nn.LeakyReLU(),
            
            
            nn.LeakyReLU(),
            nn.ConvTranspose3d(8, 1, kernel_size=(5,6,6), stride=(1,2,2), padding=2),

            nn.Sigmoid()
        )

    def encode(self, x):
        x = self.encoder(x)
        mu = self.mu(x)
        logvar = self.logvar(x)
        
        return mu, logvar
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)  # Calculate standard deviation from log variance
        eps = torch.randn_like(std)
        return mu + eps * std
        
    def decode(self, x):
        return self.decoder(x)

    def forward(self, x):
        mu, logvar = self.encode(x)
        
        z = self.reparameterize(mu, logvar)
        z = self.decode(z)
        return z, mu, logvar

 


In [None]:

class VAE(nn.Module):
    def __init__(self, latent_dim = 64, channel_size = 9):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv3d(1, 8, kernel_size=(3,4,4), stride=(1,2,2), padding=1),  # Example for 3D conv layer
            nn.LeakyReLU(),
            nn.Conv3d(8, 4, kernel_size=(3,4,4), stride=(1,2,2), padding=1),
            nn.LeakyReLU(),
            nn.Conv3d(4,1, kernel_size=(3,4,4), stride=(1,2,2), padding=1),
            nn.LeakyReLU(),
            nn.Flatten(),
            nn.Linear(32*32*channel_size,64),
            nn.LeakyReLU(),
            
        )
        ### Latent space transformations

        self.mu = nn.Linear(64,latent_dim)
        self.logvar = nn.Linear(64,latent_dim)
        
        
        self.decoder = nn.Sequential(
            
            nn.Linear(latent_dim, 64),
            nn.LeakyReLU(),
            nn.Linear(64,32*32*channel_size),
            nn.Unflatten(1,(1,channel_size,32,32)),

            nn.ConvTranspose3d(1,4, kernel_size = (3,4,4), stride= (1,2,2), padding = 1),
            nn.LeakyReLU(),
            nn.ConvTranspose3d(4, 8, kernel_size=(3,4,4), stride=(1,2,2), padding=1),
            nn.LeakyReLU(),
            nn.ConvTranspose3d(8, 1, kernel_size=(3,4,4), stride=(1,2,2), padding=1),
            nn.Sigmoid()
        )

    def encode(self, x):
        x = self.encoder(x)
        mu = self.mu(x)
        logvar = self.logvar(x)
        
        return mu, logvar
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)  # Calculate standard deviation from log variance
        eps = torch.randn_like(std)
        return mu + eps * std
        
    def decode(self, x):
        return self.decoder(x)

    def forward(self, x):
        mu, logvar = self.encode(x)
        
        z = self.reparameterize(mu, logvar)
        z = self.decode(z)
        return z, mu, logvar

 


In [None]:

class VAE(nn.Module):
    def __init__(self, latent_dim = 64, channel_size = 9,latent_pixel_size = 16):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            
            nn.Conv3d(1, 8, kernel_size=(5,6,6), stride=(1,2,2), padding=(2,2,2)),  # Example for 3D conv layer
            nn.LeakyReLU(),
            nn.Conv3d(8, 4, kernel_size=(3,4,4), stride=(1,2,2), padding=1),
            nn.Conv3d(4,4, kernel_size = (3,3,3), stride= (1,1,1), padding = (1,1,1)),
            nn.LeakyReLU(),
            # nn.BatchNorm3d(4),
            nn.Conv3d(4,1, kernel_size=(3,4,4), stride=(1,2,2), padding=1),
        
        
            nn.LeakyReLU(),
    

            nn.Flatten(),
            nn.Linear(latent_pixel_size*latent_pixel_size*channel_size,256),
            nn.LeakyReLU(),
            nn.Dropout(),
            
        )
        ### Latent space transformations

        self.mu = nn.Linear(256,latent_dim)
        self.logvar = nn.Linear(256,latent_dim)
        
        
        self.decoder = nn.Sequential(
            
            nn.Linear(latent_dim, 256),

            nn.LeakyReLU(),
            nn.Linear(256,latent_pixel_size*latent_pixel_size*channel_size),
            nn.Unflatten(1,(1,channel_size,latent_pixel_size,latent_pixel_size)),

            nn.ConvTranspose3d(1,4, kernel_size = (3,4,4), stride= (1,2,2), padding = 1),
            nn.LeakyReLU(),
            nn.ConvTranspose3d(4,4, kernel_size = (3,3,3), stride= (1,1,1), padding = (1,1,1)),
            # nn.BatchNorm3d(4),
            nn.ConvTranspose3d(4, 8, kernel_size=(3,4,4), stride=(1,2,2), padding=1),
            nn.LeakyReLU(),
            nn.ConvTranspose3d(8, 1, kernel_size=(5,6,6), stride=(1,2,2), padding=2),

            nn.Sigmoid()
        )

    def encode(self, x):
        x = self.encoder(x)
        mu = self.mu(x)
        logvar = self.logvar(x)
        
        return mu, logvar
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)  # Calculate standard deviation from log variance
        eps = torch.randn_like(std)
        return mu + eps * std
        
    def decode(self, x):
        return self.decoder(x)

    def forward(self, x):
        mu, logvar = self.encode(x)
        
        z = self.reparameterize(mu, logvar)
        z = self.decode(z)
        return z, mu, logvar

 


In [None]:

class VAE(nn.Module):
    def __init__(self, latent_dim = 64, channel_size = 9,latent_pixel_size = 16):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            
            nn.Conv3d(1, 8, kernel_size=(5,6,6), stride=(1,2,2), padding=(2,2,2)),  # Example for 3D conv layer
            nn.LeakyReLU(),
            
            nn.Conv3d(8,16, kernel_size = (3,3,3), stride= (1,1,1), padding = (1,1,1)),
            nn.LeakyReLU(),
            
            nn.Conv3d(16, 16, kernel_size=(3,4,4), stride=(1,2,2), padding=1),
            nn.LeakyReLU(),
            

            
            
            nn.Conv3d(16,16, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(),
      
    

            nn.Flatten(),
            nn.Linear(16*latent_pixel_size*latent_pixel_size*channel_size,256),
            nn.LeakyReLU(),
            nn.Dropout(),
            
        )
        ### Latent space transformations

        self.mu = nn.Linear(256,latent_dim)
        self.logvar = nn.Linear(256,latent_dim)
        
        
        self.decoder = nn.Sequential(
            
            nn.Linear(latent_dim, 256),

            nn.LeakyReLU(),
            nn.Linear(256,latent_pixel_size*latent_pixel_size*channel_size*16),
            nn.Unflatten(1,(16,channel_size,latent_pixel_size,latent_pixel_size)),

            nn.ConvTranspose3d(16 ,16, kernel_size = 3, stride= 1, padding = 1),
            nn.LeakyReLU(),
            
            
            nn.ConvTranspose3d(16, 16, kernel_size=(3,4,4), stride=(1,2,2), padding=1),
            nn.LeakyReLU(),
            
            nn.ConvTranspose3d(16,8, kernel_size = (3,3,3), stride= (1,1,1), padding = (1,1,1)),
            nn.LeakyReLU(),
            
            
            nn.LeakyReLU(),
            nn.ConvTranspose3d(8, 1, kernel_size=(5,6,6), stride=(1,2,2), padding=2),

            nn.Sigmoid()
        )

    def encode(self, x):
        x = self.encoder(x)
        mu = self.mu(x)
        logvar = self.logvar(x)
        
        return mu, logvar
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)  # Calculate standard deviation from log variance
        eps = torch.randn_like(std)
        return mu + eps * std
        
    def decode(self, x):
        return self.decoder(x)

    def forward(self, x):
        mu, logvar = self.encode(x)
        
        z = self.reparameterize(mu, logvar)
        z = self.decode(z)
        return z, mu, logvar

 


In [None]:

class VAE(nn.Module):
    def __init__(self, latent_dim = 16,latent_pixel_size = 16):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            
            nn.Conv2d(1, 16, kernel_size=5, stride=2, padding=2),  # Example for 3D conv layer
            nn.LeakyReLU(),
            
            nn.Conv2d(16,32, kernel_size = 5, stride= 1, padding = 2),
            nn.LeakyReLU(),
            
            nn.Conv2d(32, 32, kernel_size=4, stride= 2, padding=1),
            nn.LeakyReLU(),

            nn.Conv2d(32,32, kernel_size = 3, stride = 1, padding = 1),
            nn.LeakyReLU(),
            
      

            nn.Flatten(),
            nn.Linear(32*latent_pixel_size*latent_pixel_size,256),
            nn.LeakyReLU(),
            nn.Dropout(),
            
        )
        ### Latent space transformations

        self.mu = nn.Linear(256,latent_dim)
        self.logvar = nn.Linear(256,latent_dim)
        
        
        self.decoder = nn.Sequential(
            
            nn.Linear(latent_dim, 256),

            nn.LeakyReLU(),
            nn.Linear(256,32*latent_pixel_size*latent_pixel_size),
            nn.Unflatten(1,(32,latent_pixel_size,latent_pixel_size)),

 
            
            nn.ConvTranspose2d(32,32, kernel_size = 3, stride =1, padding = 1),
            nn.LeakyReLU(),
            
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride= 2, padding=1),
            nn.LeakyReLU(),
            
            nn.ConvTranspose2d(16,8, kernel_size =  5, stride= 1, padding = 2),
            nn.LeakyReLU(),
            
            
            nn.LeakyReLU(),
            nn.ConvTranspose2d(8, 1, kernel_size=5, stride=2, padding=2, output_padding = 1),

            nn.Sigmoid()
        )

    def encode(self, x):
        x = self.encoder(x)
        mu = self.mu(x)
        logvar = self.logvar(x)
        
        return mu, logvar
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)  # Calculate standard deviation from log variance
        eps = torch.randn_like(std)
        return mu + eps * std
        
    def decode(self, x):
        return self.decoder(x)

    def forward(self, x):
        mu, logvar = self.encode(x)
        
        z = self.reparameterize(mu, logvar)
        z = self.decode(z)
        return z, mu, logvar

 
