# Tutorial

> API details.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
import torch
import wandb

import pytorch_lightning as pl

# from codecarbon import EmissionsTracker
from pathlib import Path
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.loggers import WandbLogger
from retrofit.data import RetroDataset
from retrofit.model import RetroFitModelWrapper
from transformers import AutoTokenizer, EncoderDecoderModel

In [None]:
column = "whole_func_string"
encoder_name = "distilbert-base-uncased"
decoder_name = "gpt2"
encoder_tokenizer = AutoTokenizer.from_pretrained(encoder_name)
decoder_tokenizer = AutoTokenizer.from_pretrained(decoder_name)
decoder_tokenizer.pad_token = decoder_tokenizer.eos_token
retro_ds = RetroDataset(
    "code_search_net",
    "flax-sentence-embeddings/st-codesearch-distilroberta-base",
    encoder_tokenizer,
    decoder_tokenizer,
    dataset_config="python",
    column=column,
    batch_size=2,
    k=2,
    n_perc=1
)

model = EncoderDecoderModel.from_encoder_decoder_pretrained(encoder_name, decoder_name)
model.config.decoder_start_token_id = decoder_tokenizer.bos_token_id
model.config.pad_token_id = decoder_tokenizer.pad_token_id
retro_model = RetroFitModelWrapper(
    model,
    weight_decay=0.1,
    lr=5e-4,
    freeze_decoder=True
)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.2.crossattention.c_attn.weight', 'h.7.crossattention.c_attn.weight', 'h.0.crossattention.c_attn.weight', 'h.8.crossattention.masked_bias', 'h.9.ln_cros

In [None]:
def train_model(model, data_module, num_epochs, output_dir, name=None):
    """
    Train a model with a given training data loader, validation data loader,
    optimizer, scheduler, loss function, metrics, and callbacks.

    Args:
        model (pl.LightningModule): The model to train.
        data_module (pl.LightningDataModule): The data module to use for training.
        num_epochs (int): The number of epochs to train for.
        output_dir (pathlib.Path): The directory to save the model to.
        name (str): The name of the model.
    Returns:
        best_model_path (str): The path to the best model's checkpoint.
    """
    # pl.seed_everything(115, workers=True)
    wandb_logger = WandbLogger(project="Retrofit", name=name)
    # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
    checkpoint_callback = ModelCheckpoint(
        monitor="val_loss",
        dirpath=str(output_dir / "checkpoints"),
        filename="retrofit-{epoch:02d}-{val_loss:.2f}",
        save_top_k=5,
        mode="min",
    )
    trainer = pl.Trainer(
        logger=wandb_logger,
        default_root_dir=str(output_dir / "checkpoints"),
        gpus=torch.cuda.device_count(),
        max_epochs=num_epochs,
        limit_train_batches=0.2,
        limit_val_batches=0.2,
        precision=16,
        callbacks=[
            checkpoint_callback,
            EarlyStopping(monitor="val_loss"),
            TQDMProgressBar(refresh_rate=1),
        ],
    )
    # tracker = EmissionsTracker(output_dir=output_dir.parent.parent, project_name=name)

    # train the model and track emissions
    # tracker.start()
    trainer.fit(model, data_module)
    # tracker.stop()

    # save the best model to wandb
    best_model_path = checkpoint_callback.best_model_path
    if best_model_path is not None:
        wandb.save(best_model_path)

    # # save the emissions csv file
    # wandb.save(str(output_dir.parent.parent / "emissions.csv"))

    return best_model_path

In [None]:
num_epochs = 2
out_dir = Path("/workspace/retrofit/data/output/")
best_model_path = train_model(
    retro_model,
    retro_ds,
    num_epochs=num_epochs,
    output_dir=out_dir / "model",
    name="test",
)
# model = RetroFitModel.load_from_checkpoint(best_model_path)

Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Reusing dataset code_search_net (/root/.cache/huggingface/datasets/code_search_net/python/1.0.0/80a244ab541c6b2125350b764dc5c2b715f65f00de7a56107a28915fac173a27)
Reusing dataset code_search_net (/root/.cache/huggingface/datasets/code_search_net/python/1.0.0/80a244ab541c6b2125350b764dc5c2b715f65f00de7a56107a28915fac173a27)


  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/4122 [00:00<?, ?ex/s]

  0%|          | 0/231 [00:00<?, ?ex/s]

  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mnatedog[0m (use `wandb login --relogin` to force relogin)



  | Name  | Type                | Params
----------------------------------------------
0 | model | EncoderDecoderModel | 219 M 
----------------------------------------------
219 M     Trainable params
0         Non-trainable params
219 M     Total params
438.339   Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

