In [None]:
class CODEFUSION(nn.Module):
    def __init__(self, num_heads, num_layers_denoiser, num_layers_decoder, dropout=0.1):
        super().__init__()

        # Load the pre-trained CodeT5 model
        self.encoder = T5ForConditionalGeneration.from_pretrained('Salesforce/codet5-small')
        encoder_config = self.encoder.config

        # Set vocab_size and in_channels from the pre-trained model's configuration
        vocab_size = encoder_config.vocab_size
        in_channels = encoder_config.d_model

        # Embedding layer L for code snippets
        self.code_embedding = nn.Embedding(vocab_size, in_channels)

        # Define model_channels based on CodeT5's configuration
        model_channels = encoder_config.d_model
        
        # Time embedding layer dimensions
        time_embed_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            nn.Linear(model_channels, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, in_channels),  # Aligning with CodeT5's hidden size
        )
        
        # Denoiser (N) using TransformerBlocks
        self.denoiser = nn.ModuleList([
            TransformerBlock(model_channels, num_heads, dropout=dropout) for _ in range(num_layers_denoiser)
        ])

        # Decoder (D) using TransformerBlocks
        self.decoder = nn.ModuleList([
            TransformerBlock(model_channels, num_heads, dropout=dropout) for _ in range(num_layers_decoder)
        ])

        # Classification Head (H) for token prediction
        self.classification_head = nn.Linear(model_channels, vocab_size)

        # Ensure weight sharing between the embedding and the classification head
        self.classification_head.weight = self.code_embedding.weight

    def forward(self, code, Es, timesteps, scaled_gaussian_noise, pretrain_mode=True):
        
        emb_code = self.code_embedding(code)
        emb_timestep = self.time_embed(timestep_embedding(timesteps, 512))
        seq_length = emb_code.size(1)
        noisy_emb_code = emb_code * scaled_gaussian_noise
        emb_inputs = noisy_emb_code + emb_timestep.unsqueeze(1).expand(-1, seq_length, -1)
        
        if pretrain_mode:
            encoded_utterances = torch.randn_like(emb_inputs)
            
        # Encoding natural language utterances
        else:
            encoder_outputs = self.encoder(input_ids=utterance)
            encoded_utterances = encoder_outputs.last_hidden_state

        # Denoising code embeddings
        denoised_embeddings = emb_inputs
        
        for layer in self.denoiser:
            denoised_embeddings = layer(denoised_embeddings, context=encoded_utterances)

        # Decoding to get final hidden representations
        decoded_embeddings = denoised_embeddings
        for layer in self.decoder:
            decoded_embeddings = layer(decoded_embeddings, context=encoded_utterances)
        
        if pretrain_mode:
            predicted_noise = noisy_emb_code - denoised_embedding
            return predicted_noise, decoded_embeddings, emb_code
        # Predicting code tokens
        logits = self.classification_head(decoded_embeddings)
        return logits

    def get_embeds(self, input_ids):
        return self.code_embedding(input_ids)

In [ ]:
import torch.nn as nn
import math

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_attention_heads, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(dim, num_attention_heads, dropout=dropout)
        self.norm1 = nn.LayerNorm(dim)
        self.cross_attn = nn.MultiheadAttention(dim, num_attention_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.feed_forward = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim),
            nn.Dropout(dropout)
        )
        self.norm3 = nn.LayerNorm(dim)

    def forward(self, x, context=None):
        # Self-attention
        x = x + self.self_attn(x, x, x)[0]
        x = self.norm1(x)

        # Cross-attention with the encoded utterance, if provided
        if context is not None:
            x = x + self.cross_attn(x, context, context)[0]
            x = self.norm2(x)

        # Feed-forward
        x = x + self.feed_forward(x)
        x = self.norm3(x)

        return x

def timestep_embedding(timesteps, dim, max_period=10000):
    half_dim = dim // 2
    emb = math.log(max_period) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    emb = emb.to(timesteps.device)
    emb = timesteps[:, None].float() * emb[None, :]
    emb = torch.cat((torch.sin(emb), torch.cos(emb)), dim=1)
    if dim % 2:
        emb = torch.cat((emb, torch.zeros_like(emb[:, :1])), dim=1)
    return emb

class Denoiser(nn.Module):
    def __init__(self, d_model, nhead, num_layers, dropout=0.1):
        super(Denoiser, self).__init__()
        self. denoiser = nn.ModuleList([
            TransformerBlock(d_model, nhead, dropout=dropout) for _ in range(num_layers)
        ])
        time_embed_dim = d_model * 4
        self.time_embed = nn.Sequential(
            nn.Linear(d_model, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, d_model),
        )

    def forward(self, emb_code, Es, timesteps, pretrain_mode=True):
        
        emb = self.time_embed(timestep_embedding(timesteps, 512))
        seq_length = emb_code.size(1)
        emb_inputs = emb_code + emb.unsqueeze(1).expand(-1, seq_length, -1)
        
        if pretrain_mode:
            encoded_utterances = torch.randn_like(emb_inputs)
            
        # Encoding natural language utterances
        else:
            encoder_outputs = self.encoder(input_ids=utterance)
            encoded_utterances = encoder_outputs.last_hidden_state

        # Denoising code embeddings
        denoised_embeddings = emb_inputs
        
        for layer in self.denoiser:
            denoised_embeddings = layer(denoised_embeddings, context=encoded_utterances)
        
        return denoised_embeddings

In [ ]:
def mean_flat(tensor):
    """
    Take the mean over all non-batch dimensions.
    """
    return tensor.mean(dim=list(range(1, len(tensor.shape))))

In [ ]:
import torch.optim as optim
import torch
import numpy as np

# Toy dataset of Python code snippets
code_snippets = [
    "for i in range(5): print(i)",
    "import numpy as np",
    "x = np.array([1, 2, 3])",
    "def add(a, b): return a + b",
    "print(add(5, 3))"
]

# Hyperparameters
embed_dim = 512
seq_len = max(len(snippet) for snippet in code_snippets)
batch_size = 5

# Simulate embeddings for code snippets and encoded utterances
original_embeddings = torch.randn(batch_size, seq_len, embed_dim)
E_s = torch.randn(batch_size, seq_len, embed_dim)  # Simulated encoded utterances (flattened for simplicity)

# Denoiser model
denoiser = Denoiser(embed_dim, num_layers=10, nhead=8)
optimizer = optim.AdamW(denoiser.parameters(), lr=5e-4)
loss_fn = torch.nn.MSELoss()

max_epochs = 100
max_timestep = 1200
noise_schedule = sqrt_noise_scheduler(max_timestep)
#uniform_sampler = UniformSampler(denoiser)

for epoch in range(max_epochs):
    optimizer.zero_grad()
    
    # Generate random timesteps for each example in the batch
    timesteps = torch.randint(0, max_timestep, (batch_size,))
    #print(timesteps.shape)
    # Index into the precomputed noise schedule using the generated timesteps
    noise_levels = noise_schedule[timesteps].view(-1, 1, 1)
    
    # Generate Gaussian noise scaled by the noise levels
    scaled_gaussian_noise = torch.randn_like(original_embeddings) * noise_levels
    
    # Add scaled Gaussian noise to the original embeddings to create noisy embeddings
    noisy_embeddings_with_scaled_noise = original_embeddings + scaled_gaussian_noise
    
    # Use the Gaussian noise as E_s_noisy for the denoiser
    E_s_noisy = scaled_gaussian_noise  # Optional: Use scaled noise as part of E_s_noisy if needed

    #print(timesteps.shape)
    #print(noisy_embeddings_with_scaled_noise.shape)
    # Forward pass with noisy embeddings
    denoised_embedding = denoiser(noisy_embeddings_with_scaled_noise, E_s, timesteps)

    predicted_noise = noisy_embeddings_with_scaled_noise - denoised_embedding
    # Calculate loss between predicted noise and the actual (scaled) noise used
    noise_prediction_loss = (predicted_noise - scaled_gaussian_noise) ** 2
    #print(noise_prediction_loss.shape)
    loss = mean_flat(noise_prediction_loss).mean()
    #print(loss)

    # Backward pass and optimization step
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch}, Loss: {loss.item()}, Noise Level: {noise_levels.mean().item()}")