# Create dataset for training

In [5]:
from datasets import load_dataset
from tqdm import tqdm
from transformers import PreTrainedTokenizerFast

In [11]:
dataset = load_dataset("code_search_net", "python", trust_remote_code=True)

Using the latest cached version of the module from /Users/nitinmittapally/.cache/huggingface/modules/datasets_modules/datasets/code_search_net/8f2524e6b62f65af5f5d65c53715c654db7b08dc93e0b7bcce2ab2f286a75be1 (last modified on Tue Dec 10 17:00:41 2024) since it couldn't be found locally at code_search_net, or remotely on the Hugging Face Hub.


In [39]:
def filter_non_ascii(text):
    """
    Remove non-ASCII characters from text.
    """
    return ''.join(char for char in text if ord(char) < 128)
    
def clean_docstring(doc_string, max_words=20):
    """
    Preprocess the documentation string:
    - Truncate at the first empty line or limit to the first 20 words.
    """
    # Split the documentation into lines
    lines = doc_string.split("\n")
    processed_lines = []

    for line in lines:
        stripped_line = line.strip()
        # Stop if we encounter an empty line
        if not stripped_line:
            break
        processed_lines.append(stripped_line)
    return filter_non_ascii(". ".join(processed_lines))

def clean_code(code):
    """
    Normalize code indentation to PEP 8 standards:
    - Use 4 spaces per indentation level.
    - Dynamically adjust indentation levels based on leading spaces.
    - Skip empty lines for indentation calculations.
    """
    lines = code.split("\n")
    cleaned_lines = []
    current_indent_level = 0  # Track the current indentation level
    previous_spaces = 0  # Track the leading spaces of the last non-empty line

    for line in lines:
        stripped_line = line.lstrip()  # Remove leading whitespace
        leading_spaces = len(line) - len(stripped_line)  # Count leading spaces

        if not stripped_line:  # If the line is empty
            cleaned_lines.append("")  # Preserve it as a blank line
            continue  # Skip further processing for this line

        # Compare leading spaces with the previous meaningful line
        if leading_spaces > previous_spaces:
            current_indent_level += 1  # Increase indentation level
        elif leading_spaces < previous_spaces:
            current_indent_level = max(0, current_indent_level - 1)  # Decrease indentation level

        # Update the previous_spaces for the next comparison
        previous_spaces = leading_spaces

        # Construct the cleaned line with spaces
        cleaned_line = (" " * (current_indent_level * 4)) + stripped_line
        cleaned_lines.append(cleaned_line)

    return filter_non_ascii("\n".join(cleaned_lines))
    
def preprocess_dataset(record):
    if record['func_documentation_string'] and record['func_code_string']:
        return {
            "description": clean_docstring(record['func_documentation_string']),
            "code": clean_code(record['func_code_string'])
        }
    return None

In [74]:
from torch.utils.data import Dataset, DataLoader
class CodeDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return preprocess_dataset(self.data[index])
    

class CodeInputCollator:
    def __init__(self, tokenizer, max_len, device="mps"):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.device = device

    def __call__(self, batch):
        batch = [example for example in batch if example is not None]
        inputs = [example["description"] for example in batch]
        outputs = [example["code"] for example in batch]

        tokens = self.tokenizer(
            inputs, outputs,
            return_tensors="pt",
            truncation=True,
            padding=True,
            max_length=self.max_len
        )

        return {
            "input_ids": tokens["input_ids"][:, :-1].to(self.device),
            "attention_mask": tokens["attention_mask"][:, :-1].to(self.device),
            "labels": tokens["input_ids"][:, 1:].to(self.device)
            
        }


In [75]:
# Load the custom fast tokenizer
tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="./tokenizer/custom_tokenizer.json",  # Path to the saved tokenizer
    unk_token="<unk>",                        # Unknown token
    pad_token="<pad>",                        # Padding token
    bos_token="<s>",                          # Beginning-of-sequence token
    eos_token="</s>"                          # End-of-sequence token
)

In [121]:
type(tokenizer)

transformers.tokenization_utils_fast.PreTrainedTokenizerFast

In [89]:
inputs = [
    "Write function for hello world",
    "Write function for hello world",
    "Write function for hello world",
    "Write function for hello world",
    "Write function for hello world",
]
outputs = [
    "def hello_world():\n    print('Hello, World!')",
    "def hello_world():\n    print('Hello, World!')",
    "def hello_world():\n    print('Hello, World!')",
    "def hello_world():\n    print('Hello, World!')",
    "def hello_world():\n    print('Hello, World!')",
]
x = tokenizer(inputs, outputs)

In [95]:
tokenizer("Write function for hello world", "def hello_world():\n    print('Hello, World!')")

{'input_ids': [0, 2250, 399, 139, 6627, 3959, 2, 167, 6627, 75, 3959, 20, 21, 38, 468, 20, 19, 9348, 24, 10639, 13, 19, 21, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [77]:
train_dataset = CodeDataset(dataset["train"])
val_dataset = CodeDataset(dataset["validation"])
test_dataset = CodeDataset(dataset["test"])

In [120]:
type(dataset["train"][0])

dict

In [78]:

codeInputCollator = CodeInputCollator(tokenizer, max_len=512)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=codeInputCollator
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    collate_fn=codeInputCollator
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    collate_fn=codeInputCollator
)

In [79]:
len(train_dataset), len(val_dataset), len(test_dataset)

(412178, 23107, 22176)

In [96]:
for batch in train_dataloader:

    print(batch["input_ids"].shape)
    break

torch.Size([32, 511])


# Create model

In [102]:
from transformers import GPT2LMHeadModel, GPT2Config

config = GPT2Config(
    vocab_size=tokenizer.vocab_size,  # Match tokenizer vocab size
    n_positions=2048,                        # Maximum sequence length (same as max_len)
    n_embd=768,                             # Embedding dimension
    n_layer=12,                             # Number of transformer layers
    n_head=12,                              # Number of attention heads
    bos_token_id=tokenizer.bos_token_id,  # Beginning-of-sequence token
    eos_token_id=tokenizer.eos_token_id   # End-of-sequence token
)



In [104]:
import torch

In [107]:
# Initialize the GPT-2 model with the updated configuration
model = GPT2LMHeadModel(config)

# Resize token embeddings to match the tokenizer's vocabulary
model.resize_token_embeddings(tokenizer.vocab_size)

# Move model to the appropriate device
device = "mps"
model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(32000, 768)
    (wpe): Embedding(2048, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=32000, bias=False)
)

In [108]:
# Verify embedding size
print(f"Embedding size: {model.transformer.wte.weight.size()}")
print(f"Model is on device: {next(model.parameters()).device}")

Embedding size: torch.Size([32000, 768])
Model is on device: mps:0


In [113]:
from torch.utils.tensorboard import SummaryWriter
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR

# Define optimizer
optimizer = AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)

# Define learning rate scheduler (optional)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

In [116]:
from tqdm import tqdm
import math

# Initialize TensorBoard writer
writer = SummaryWriter(log_dir="./tensorboard_logs")

# Training settings
num_epochs = 3
gradient_accumulation_steps = 4
validation_frequency = 500  # Perform validation every 500 training steps
save_path = "./code_generator_model"

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    step_loss = 0
    step_count = 0

    print(f"Epoch {epoch + 1}/{num_epochs}")
    for step, batch in enumerate(tqdm(train_dataloader)):
        # Move batch to device
        input_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)

        # Forward pass
        outputs = model(input_ids, labels=labels)
        loss = outputs.loss / gradient_accumulation_steps
        total_loss += loss.item()
        step_loss += loss.item()

        # Backward pass
        loss.backward()

        if (step + 1) % gradient_accumulation_steps == 0 or (step + 1) == len(train_dataloader):
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad()

        # Perform validation periodically
        if (step + 1) % validation_frequency == 0:
            model.eval()  # Switch to evaluation mode
            val_loss = 0

            with torch.no_grad():
                for val_batch in val_dataloader:
                    input_ids = val_batch["input_ids"].to(device)
                    attention_mask = val_batch["attention_mask"].to(device)
                    labels = val_batch["labels"].to(device)

                    val_outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                    val_loss += val_outputs.loss.item()

            avg_val_loss = val_loss / len(val_dataloader)
            val_perplexity = math.exp(avg_val_loss) if avg_val_loss < 300 else float("inf")

            # Log validation metrics to TensorBoard
            writer.add_scalar("Loss/Validation", avg_val_loss, epoch * len(train_dataloader) + step + 1)
            writer.add_scalar("Perplexity/Validation", val_perplexity, epoch * len(train_dataloader) + step + 1)

            print(f"Step {step + 1}: Validation Loss: {avg_val_loss:.4f}, Perplexity: {val_perplexity:.4f}")

            model.train()  # Switch back to training mode

    # Compute average training loss and perplexity for the epoch
    avg_train_loss = total_loss / len(train_dataloader)
    train_perplexity = math.exp(avg_train_loss) if avg_train_loss < 300 else float("inf")

    # Log training metrics to TensorBoard
    writer.add_scalar("Loss/Train", avg_train_loss, epoch + 1)
    writer.add_scalar("Perplexity/Train", train_perplexity, epoch + 1)

    print(f"Epoch {epoch + 1} Training Loss: {avg_train_loss:.4f}, Perplexity: {train_perplexity:.4f}")

    # Save the model after each epoch
    model.save_pretrained(f"{save_path}/epoch_{epoch + 1}")
    custom_tokenizer.save_pretrained(f"{save_path}/epoch_{epoch + 1}")

# Close TensorBoard writer
writer.close()

Epoch 1/3


  0%|          | 0/12881 [00:27<?, ?it/s]


RuntimeError: MPS backend out of memory (MPS allocated: 17.37 GB, other allocations: 432.70 MB, max allowed: 18.13 GB). Tried to allocate 382.50 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).