In [4]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.insert(1, '/scr/gmachi/prospector-guide/prospectors-v2/src')

from data import retrieve_encoder, get_embeddings

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost, encoder_include):
        super(VectorQuantizer, self).__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.commitment_cost = commitment_cost
        self.encoder_include = encoder_include

        # Embedding table
        self.embeddings = nn.Embedding(num_embeddings, embedding_dim)
        self.embeddings.weight.data.uniform_(-1 / num_embeddings, 1 / num_embeddings)

    def forward(self, inputs):
        if self.encoder_include:
            inputs = get_embeddings(self.encoder, inputs)
            # NOTE: may need to implement batched version of above function
        
        # Flatten input for matching with embedding table
        inputs_flat = inputs.view(-1, self.embedding_dim)

        # Compute distances between input and embedding vectors
        distances = torch.cdist(inputs_flat.unsqueeze(1), self.embeddings.weight.unsqueeze(0))

        # Find nearest embedding index for each input
        encoding_indices = torch.argmin(distances, dim=-1)
        quantized = self.embeddings(encoding_indices).view_as(inputs)

        # Compute commitment loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        # Pass-through gradient
        quantized = inputs + (quantized - inputs).detach()

        return quantized, loss, encoding_indices

class Decoder(nn.Module):
    def __init__(self, num_embeddings, embedding_dim):
        super(Decoder, self).__init__()
        self.embedding_dim = embedding_dim
        self.embeddings = nn.Embedding(num_embeddings, embedding_dim)
        self.embeddings.weight.data.uniform_(-1 / num_embeddings, 1 / num_embeddings)

    def forward(self, encoding_indices):
        return self.embeddings(encoding_indices)

class VQVAE1D(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25, encoder_include=False):
        super(VQVAE1D, self).__init__()
        self.embedding_dim = embedding_dim
        self.encoder_include = encoder_include
        if encoder_include:
            self.encoder = retrieve_encoder()

        # Vector quantizer
        self.vector_quantizer = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost, encoder_include)

        # Decoder
        self.decoder = Decoder(num_embeddings, embedding_dim)

    def forward(self, inputs):
        # Vector quantization
        quantized, vq_loss, encoding_indices = self.vector_quantizer(inputs)

        # Decode quantized representations
        decoded = self.decoder(encoding_indices)

        return decoded, vq_loss, encoding_indices

In [9]:
num_embeddings = 30 # this is K
embedding_dim = 1536 # enter what we need for ESM3-open

# Instantiate VQ-VAE
vq_vae = VQVAE1D(num_embeddings, embedding_dim)

# Example input (batch_size=4, sequence_length=10, embedding_dim=64)
input_embeddings = torch.randn(4, 10, embedding_dim)
print("Input shape:", input_embeddings.shape)


# Forward pass
decoded, vq_loss, encoding_indices = vq_vae(input_embeddings)

print("Decoded output shape:", decoded.shape)
print("VQ Loss:", vq_loss.item())
print("Encoding indices:", encoding_indices)

Input shape: torch.Size([4, 10, 1536])
Decoded output shape: torch.Size([40, 1, 1536])
VQ Loss: 1.247421383857727
Encoding indices: tensor([[27],
        [ 0],
        [29],
        [ 4],
        [ 5],
        [ 4],
        [ 1],
        [ 8],
        [ 9],
        [ 7],
        [13],
        [17],
        [13],
        [19],
        [18],
        [19],
        [10],
        [ 1],
        [15],
        [ 6],
        [22],
        [19],
        [14],
        [11],
        [10],
        [12],
        [15],
        [14],
        [ 0],
        [ 8],
        [ 4],
        [22],
        [29],
        [ 5],
        [ 2],
        [15],
        [11],
        [14],
        [25],
        [29]])


In [12]:
import torch.optim as optim

num_embeddings = 30
embedding_dim = 1536

# Instantiate VQ-VAE
vq_vae = VQVAE1D(num_embeddings, embedding_dim)
optimizer = optim.Adam(vq_vae.parameters(), lr=1e-3)

# Example input (batch_size=4, sequence_length=10, embedding_dim=64)
input_embeddings = torch.randn(4, 10, embedding_dim)

# Training loop
epochs = 100
for epoch in range(epochs):
    vq_vae.train()
    optimizer.zero_grad()

    # Forward pass
    decoded, vq_loss, _ = vq_vae(input_embeddings)

    decoded = decoded.view_as(input_embeddings)
    assert decoded.shape == input_embeddings.shape, f"Shape mismatch: {decoded.shape} vs {input_embeddings.shape}"

    # Reconstruction loss (MSE)
    recon_loss = F.mse_loss(decoded, input_embeddings)
    total_loss = recon_loss + vq_loss

    # Backward pass
    total_loss.backward()
    optimizer.step()

    # Print losses
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Recon Loss: {recon_loss.item():.4f}, VQ Loss: {vq_loss.item():.4f}, Total Loss: {total_loss.item():.4f}")

print("Training complete.")

Epoch [10/100], Recon Loss: 0.9999, VQ Loss: 1.2477, Total Loss: 2.2476
Epoch [20/100], Recon Loss: 0.9883, VQ Loss: 1.2333, Total Loss: 2.2215
Epoch [30/100], Recon Loss: 0.9769, VQ Loss: 1.2191, Total Loss: 2.1960
Epoch [40/100], Recon Loss: 0.9658, VQ Loss: 1.2053, Total Loss: 2.1711
Epoch [50/100], Recon Loss: 0.9550, VQ Loss: 1.1917, Total Loss: 2.1467
Epoch [60/100], Recon Loss: 0.9444, VQ Loss: 1.1785, Total Loss: 2.1230
Epoch [70/100], Recon Loss: 0.9341, VQ Loss: 1.1656, Total Loss: 2.0997
Epoch [80/100], Recon Loss: 0.9240, VQ Loss: 1.1530, Total Loss: 2.0770
Epoch [90/100], Recon Loss: 0.9141, VQ Loss: 1.1407, Total Loss: 2.0548
Epoch [100/100], Recon Loss: 0.9045, VQ Loss: 1.1287, Total Loss: 2.0331
Training complete.


In [11]:
import itertools
K = 5
n = 3

def create_kernel(n, K):
    mono = list(range(K))
    bi = [x for x in itertools.combinations(mono, 2)] + [(x, x) for x in mono]
    kernel = dict.fromkeys(mono + bi, 0.0)
    if n > 2:
        tri = [x for x in itertools.combinations(mono, 3)] + [(x, x, x) for x in mono]
        kernel = dict.fromkeys(mono + bi + tri, 0.0)  
    return kernel

len(create_kernel(n, K).keys())

35