# SFT Base Model on Science Papers

**NB**: Tools used on this assignment.

### Applications
* GitHub Copilot was used to accelerate only very basic completions and was **not** used as an engine for function development
* ChatGPT-4 was used for generating only matplotlib/seaborn code and is clearly cited where used for such purposes

### Key Packages
* transformers
    * *Note:* The GPT2 model was replaced with the excellent 2.7B param microsoft/phi-2 base model for improved performance (https://huggingface.co/microsoft/phi-2)
* accelerate
    * *Note:* This notebook is compatible with both single and multi GPU clusters (we train on 4x GPUs here)
* torch
    * *Note:* Raw PyTorch code is used for finetuning, pairing nicely with accelerate

### Compute
* VSCode used for non-training development
    * Apple Silicon M1 Pro, 16 GB CPU RAM
* RunPod used for training development (https://www.runpod.io/)
    * Python 3 Engine via RunPod, 4x A100 SXM4 (80 GB) GPUs

In [None]:
!nvidia-smi

In [None]:
!nvcc --version

In [None]:
#!python -m pip install --upgrade pip

In [None]:
# install for training, don't for local dev
#!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

In [None]:
#!pip list

In [None]:
# import torch
# torch._C._cuda_getDeviceCount()

In [None]:
# VSCode local setup
#!python3 -m venv .venv_atla
#!source .venv_atla/bin/activate

In [None]:
# %%capture
#!pip3 install -q -U transformers accelerate datasets cleantext matplotlib seaborn evaluate transformers[sentencepiece]

In [None]:
!git config --global user.email "dryanfurman@gmail.com"
!git config --global user.name "Daniel Furman"

In [None]:
from huggingface_hub import login

# from google.colab import userdata

login("")  # userdata.get('HF_TOKEN')

## Data Exploration

In [None]:
import json
import numpy as np
import cleantext
import re
from tqdm.notebook import tqdm
import time
import json

In [None]:
# first, upload your file
file_path = "scientific papers.txt"

data = []

with open(file_path, "r") as file:
    data = json.load(file)

In [None]:
# to do: explore the datset
print("**Step 0.** Size of the dataset: \n")
print(f"There are {len(data)} elements in the dataset", "\n\n")

print("**Step 1.** Insepct an element at random: \n")
rand_int = int(np.random.uniform(low=0, high=(len(data) - 1)))
print(f"This is the {rand_int} element's keys: {data[rand_int].keys()}")
# for key in data[rand_int].keys():
#    if key == "article_text":
#        print(f"This is the first element's {key}: {data[rand_int][key][:50]}\n")
#    else:
#        print(f"This is the first element's {key}: {data[rand_int][key]}\n")

# inspected element and looks good

In [None]:
print("**Step 2** Basic descriptive stats on article_text:")
print("**NB**: Remove punctuation in word counts to match cleaning step\n")
num_words = []
for itr in range(len(data)):
    num_words.append(
        len(
            cleantext.clean_words(
                " ".join(data[itr]["article_text"]),
                clean_all=False,  # Execute all cleaning operations
                extra_spaces=True,  # Remove extra white spaces
                stemming=False,  # Stem the words
                stopwords=False,  # Remove stop words
                lowercase=False,  # Convert to lowercase
                numbers=False,  # Remove all digits
                punct=True,  # Remove all punctuations
                stp_lang="english",  # Language for stop words
            )
        )
    )

print(f"Mean of number of words (no punct) {np.mean(num_words)}")
print(f"Std of number of words (no punct) {np.std(num_words)}", "\n")
np.save("num_words.npy", np.array(num_words))

In [None]:
# Distribution plot of num words per article

# Reference: ChatGPT-4 generation with slight modifications, dated Feb 10, 2024 PST
# Prompt used: (attached num_words.npy as file) "Make me an excellent distribution plot of the attached numpy array. It contains the number of words contained in scientific papers for 1000 different papers. Make it a very nice, professional plot. Use gridlines."
# Generated code:

# Load the numpy array from the uploaded file
num_words = np.load("num_words.npy")

# Display the first few elements to understand its structure
# print(num_words[:10])

import matplotlib.pyplot as plt
import seaborn as sns

# Setting the style for the plot
sns.set_style("darkgrid")

# Creating the distribution plot
plt.figure(figsize=(10, 6))
sns.histplot(num_words, kde=True, bins=50, edgecolor="black")
plt.axvline(
    x=np.mean(num_words), color="tab:orange", label="Mean", linestyle="--", alpha=0.75
)
plt.axvline(
    x=np.mean(num_words) + np.std(num_words),
    color="tab:green",
    label="Standard Dev",
    linestyle="--",
    alpha=0.75,
)
plt.axvline(
    x=np.mean(num_words) - np.std(num_words),
    color="tab:green",
    linestyle="--",
    alpha=0.75,
)

# Adding titles and labels
plt.title("Distribution of Word Counts in Scientific Papers", fontsize=16)
plt.xlabel("Number of Words", fontsize=14)
plt.ylabel("Frequency", fontsize=14)
plt.legend()

# Adding gridlines
plt.grid(True, which="both", linestyle="--", linewidth=0.5)

# Show the plot
plt.show()

* *Note*: We can see that the distribution is right-skewed, with a long tail to the right of the mean. These longer articles will take up more room in our training dataset once we chunk to self.context_window sized chunks - given more time, I'd investigate the longest articles to ensure that they are high quality as a result.

## Data Cleaning

1. Replace all mathematical formulas and the references to them with _[math formula]_ e.g.
    * _@xmath2..._ -> _[math formula]_
2. Eliminate all punctuation marks

In [None]:
# check a few papers for their formula structure
# paper at index 125 has formula references as "@xmathi" and formulas as "@xmathi _formula_ $ ]"
# paper at index 455 has formula references as "@xmathi" and formulas as "@xmathi _formula_ ] ]"
# paper at index 984 has formula references as "@xmathi" and formulas as "@xmathi _formula_ ] ]"
# paper at index 95 has formula references as "@xmathi" and formulas as both "@xmathi _formula_ ] ] and  "@xmathi _formula_ $ ]"
# paper at index 684 has formula references as "@xmathi" and no formulas
# paper at index 427 has formula references as "@xmathi" and formulas as "@xmathi _formula_ $ ]"

# after checking a handful of papers, it seems we can extract formula references with "@xmathi" for
# each i index in test.split("@xmath") and formulas with "$ ]" and "] ]" delimiters
# we'd want to check more papers given more time for any edge cases or other delimiters

rand_int = int(np.random.uniform(low=0, high=(len(data) - 1)))
test = " ".join(data[rand_int]["article_text"])
# test.split("@xmath")

In [None]:
def clean_text(test: str) -> str:
    # to do: implement this function

    # Step 1:
    # replace math formulas with [math formula] tag
    new_data_element = " ".join(test["article_text"])
    # if formulas are present, remove them and replace with tag
    if "@xmath" in new_data_element:
        math_splits = new_data_element.split("@xmath")
        good_elements = []
        # grab zeroth element after split, before the first math formula
        # for loop to replace formula with tag
        good_elements.append(math_splits[0])
        for math_element in math_splits[1:]:
            if "] ]" in math_element:
                math_element = "xmath" + math_element
                formula_content = re.search(r"xmath(.*?)] ]", math_element).group(1)
                content_after_formula = " ".join(
                    math_element.split(formula_content)[1:]
                ).replace("] ]", "")
                good_elements.append("[math formula] " + content_after_formula)
            elif "$ ]" in math_element:
                math_element = "xmath" + math_element
                formula_content = re.search(r"xmath(.*?)\$ ]", math_element).group(1)
                content_after_formula = " ".join(
                    math_element.split(formula_content)[1:]
                ).replace("$ ]", "")
                good_elements.append("[math formula] " + content_after_formula)
            else:
                content_after_formula = math_element.lstrip("0123456789.- ")
                good_elements.append("[math formula] " + content_after_formula)
        new_data_element = " ".join(good_elements)

    # Step 2:
    # remove punct
    new_data_element = cleantext.clean(
        new_data_element,
        clean_all=False,  # Execute all cleaning operations
        extra_spaces=False,  # Remove extra white spaces
        stemming=False,  # Stem the words
        stopwords=False,  # Remove stop words
        lowercase=False,  # Convert to lowercase
        numbers=False,  # Remove all digits
        punct=True,  # Remove all punctuations
        stp_lang="english",  # Language for stop words
    )

    # Step 3:
    # remove extra spaces
    new_data_element = cleantext.clean(
        new_data_element,
        clean_all=False,  # Execute all cleaning operations
        extra_spaces=True,  # Remove extra white spaces
        stemming=False,  # Stem the words
        stopwords=False,  # Remove stop words
        lowercase=False,  # Convert to lowercase
        numbers=False,  # Remove all digits
        punct=False,  # Remove all punctuations
        stp_lang="english",  # Language for stop words
    )

    # add brackets bag to math formula tags after removing punct above
    new_data_element = new_data_element.replace("math formula", "[math formula]")
    return new_data_element

In [None]:
# let's now clean the dataset and check some test indices along the way
# these indices correspond to the ones checked above

check_indices = [125, 455, 984, 95, 684, 427]
for itr in range(len(data)):
    if itr in check_indices:
        # print(f"Original text at index {itr}: {' '.join(data[itr]['article_text'])}")
        pass
    data[itr]["article_text"] = clean_text(data[itr])
    if itr in check_indices:
        # print(f"Cleaned text at index {itr}: {data[itr]['article_text']}", "\n")
        pass

# checked and looks good

In [None]:
# the cleaning pipeline looks good, let's proceed to training

## Training Class

In [None]:
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    get_scheduler,
)
from accelerate import Accelerator
from datasets import DatasetDict


class FineTuner:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.context_length = 256
        config = AutoConfig.from_pretrained(
            "microsoft/phi-2",
            vocab_size=len(self.tokenizer),
            n_ctx=self.context_length,
            bos_token_id=self.tokenizer.bos_token_id,
            eos_token_id=self.tokenizer.eos_token_id,
        )
        self.model = AutoModelForCausalLM.from_config(config)

    def train(self, dataset, batch_size=8, num_train_epochs=5, learning_rate=5e-4):
        """
        Train the model on the provided dataset without using Hugging Face Trainer.

        Args:
            dataset (Dataset): Huggingface dataset object.
            batch_size (int): Training batch size.
            num_train_epochs (int): Number of training epochs.
            learning_rate (float): Learning rate for optimizer.
        """

        # Sources used:
        # * https://huggingface.co/learn/nlp-course/chapter7/6
        # * https://huggingface.co/docs/accelerate/en/basic_tutorials/notebook

        # Training loop
        accelerator = Accelerator(mixed_precision="bf16")
        model_name = "phi-2-scientific-papers-base-v0.1"
        gradient_accumulation_steps = 1
        save_chkpt_steps = 300
        eval_steps = 100
        log_steps = 25

        dataset = dataset.shuffle(seed=43)
        ds_train = dataset.select(range(750))
        ds_valid = dataset.select(range(750, 1000))

        # assert there is no leakage between train and val slices
        ds_train_pandas = ds_train.to_pandas()
        ds_valid_pandas = ds_valid.to_pandas()
        assert (
            ds_train_pandas["article_text"].isin(ds_valid_pandas["article_text"]).sum()
            == 0
        )

        # create one dataset dict with train/valid splits
        raw_datasets = DatasetDict(
            {
                "train": ds_train,
                "valid": ds_valid,
            }
        )

        # creates chunks out of the article_text with self.context_length number of tokens each
        # these are the examples we will pass for language modeling
        def tokenize(element):
            outputs = self.tokenizer(
                element["article_text"],
                truncation=True,
                max_length=self.context_length,
                return_overflowing_tokens=True,
                return_length=True,
            )
            input_batch = []
            for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
                if length == self.context_length:
                    input_batch.append(input_ids)
            return {"input_ids": input_batch}

        tokenized_dataset = raw_datasets.map(
            tokenize, batched=True, remove_columns=raw_datasets["train"].column_names
        )

        model_size = sum(t.numel() for t in self.model.parameters())
        print(f"Model size: {model_size/1000**2:.1f}M parameters")

        def loss_fcn(inputs, logits):
            # Shift so that tokens < n predict n
            shift_labels = inputs[..., 1:].contiguous()
            shift_logits = logits[..., :-1, :].contiguous()
            # Calculate per-token loss
            loss_fct = CrossEntropyLoss(reduce=False)
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )
            # Resize and average loss per sample
            loss_per_sample = loss.view(
                shift_logits.size(0), shift_logits.size(1)
            ).mean(axis=1)
            final_loss = loss_per_sample.mean()
            return final_loss

        tokenized_dataset.set_format("torch")
        train_dataloader = DataLoader(
            tokenized_dataset["train"], batch_size=batch_size, shuffle=True
        )
        eval_dataloader = DataLoader(tokenized_dataset["valid"], batch_size=batch_size)

        weight_decay = 0.1

        def get_grouped_params(model, no_decay=["bias", "LayerNorm.weight"]):
            params_with_wd, params_without_wd = [], []
            for n, p in model.named_parameters():
                if any(nd in n for nd in no_decay):
                    params_without_wd.append(p)
                else:
                    params_with_wd.append(p)
            return [
                {"params": params_with_wd, "weight_decay": weight_decay},
                {"params": params_without_wd, "weight_decay": 0.0},
            ]

        def evaluate():
            self.model.eval()
            losses = []
            for eval_step, batch in enumerate(eval_dataloader):
                with torch.no_grad():
                    outputs = self.model(batch["input_ids"], labels=batch["input_ids"])

                losses.append(accelerator.gather(outputs.loss))
            loss = torch.mean(torch.cat(losses))
            try:
                perplexity = torch.exp(loss)
            except OverflowError:
                perplexity = float("inf")
            return loss.item(), perplexity.item()

        optimizer = AdamW(get_grouped_params(self.model), lr=5e-4)
        self.model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
            self.model, optimizer, train_dataloader, eval_dataloader
        )

        num_update_steps_per_epoch = len(train_dataloader)
        num_training_steps = num_train_epochs * num_update_steps_per_epoch
        print(f"num_training_steps: {num_training_steps}")
        lr_scheduler = get_scheduler(
            name="cosine",
            optimizer=optimizer,
            num_warmup_steps=500,
            num_training_steps=num_training_steps,
        )

        self.model.train()
        global_step = 0
        train_logs = []
        val_logs = []
        for epoch in tqdm(range(num_train_epochs)):
            if accelerator.is_main_process:
                print(f"Started epoch {epoch + 1} of {num_train_epochs}")
            for epoch_step, batch in tqdm(
                enumerate(train_dataloader, start=1),
                total=num_training_steps // num_train_epochs,
            ):
                logits = self.model(batch["input_ids"]).logits
                loss = loss_fcn(batch["input_ids"], logits)
                loss = loss / gradient_accumulation_steps
                accelerator.backward(loss)
                if global_step % gradient_accumulation_steps == 0:
                    accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()
                    global_step += 1

                # train logging
                if global_step % log_steps == 0:
                    train_log = {
                        "steps": global_step,
                        "loss/train": loss.item() * gradient_accumulation_steps,
                        "last_lr": lr_scheduler.get_last_lr()[0],
                    }
                    accelerator.print(train_log)
                    train_logs.append(train_log)

                # save chkpt logging at save_chkpt_steps and last step
                if (
                    (global_step % (save_chkpt_steps * gradient_accumulation_steps))
                    == 0
                ) or (global_step == num_training_steps):
                    eval_loss, perplexity = evaluate()
                    val_log = {
                        "steps": global_step,
                        "loss/eval": eval_loss,
                        "perplexity/eval": perplexity,
                    }
                    accelerator.print(val_log)
                    val_logs.append(val_log)
                    self.model.train()
                    accelerator.wait_for_everyone()
                    unwrapped_model = accelerator.unwrap_model(self.model)
                    unwrapped_model.save_pretrained(
                        model_name, save_function=accelerator.save
                    )
                    time.sleep(5)
                    try:
                        if accelerator.is_main_process:
                            self.tokenizer.save_pretrained(model_name)
                            # push to hub
                            model_id_load = f"dfurman/{model_name}"
                            # tokenizer
                            tokenizer_push = AutoTokenizer.from_pretrained(model_name)
                            tokenizer_push.push_to_hub(
                                model_id_load, use_auth_token=True
                            )
                            # model
                            model_push = AutoModelForCausalLM.from_pretrained(
                                model_name,
                            )
                            model_push.push_to_hub(
                                model_id_load,
                                use_auth_token=True,
                                safe_serialization=True,
                                commit_message=f"Training in progress step {global_step} of {num_training_steps}",
                                blocking=False,
                            )
                    except:
                        print("ERROR: Chkpt saving failed for this step")

                # eval logging
                elif (global_step % (eval_steps * gradient_accumulation_steps)) == 0:
                    eval_loss, perplexity = evaluate()
                    val_log = {
                        "steps": global_step,
                        "loss/eval": eval_loss,
                        "perplexity/eval": perplexity,
                    }
                    accelerator.print(val_log)
                    val_logs.append(val_log)
                    self.model.train()
                    accelerator.wait_for_everyone()

        # save train_logs & val_logs
        with open("train_logs.json", "w") as fout:
            json.dump(train_logs, fout)
        with open("val_logs.json", "w") as fout:
            json.dump(val_logs, fout)

In [None]:
from datasets import Dataset

# Don't reduce the size of the data, memory is not an issue and A100 GPUs go brrrr (they are super fast)
# Extract only 'article_text' from each dictionary
article_texts = [d["article_text"] for d in data]

# Create a dictionary with 'article_text' as the key
data_dict = {"article_text": article_texts}

# Create the Hugging Face Dataset
dataset = Dataset.from_dict(data_dict)

In [None]:
dataset

In [None]:
# dataset[0]

In [None]:
def training_function():
    fine_tuner = FineTuner()
    fine_tuner.train(dataset)

In [None]:
from accelerate import notebook_launcher

notebook_launcher(training_function, num_processes=4)

In [None]:
# prints during training look good!
# only error... checkpoint saving failed on last upload
# we captured 2700/3000 global steps and are only skipping the last 300 steps in the final epoch
# proceeding due to time constraints

In [None]:
print("done")

## GPU Usage During Training Run

* We can see that all 4 GPUs are utilized efficiently (~99% avg VRAM consumption)



![](../assets/mid_training_GPU_usage.png)

# Logging
Implement logging in the above FineTuner class and visualise the logs



### Three kinds of logs were captured at varying rates during training

* Train logs (completed steps, loss/train, last_lr)
    * 120 such logs were captured in this run
* Validation log (completed steps, loss/eval, perplexity/eval)
    * 30 such logs were captured in this run
* Checkpoint caching (saves model to remote repo at each checkpoint step)
    * 10 such chkpts were captured in this run
    * Logged to https://huggingface.co/dfurman/phi-2-scientific-papers-base-v0.1/commits/main

Plan is to visualize 1) loss/train & loss/eval on the same plot, 2) perpelxity/eval on its own plot, and 3) lr progression on its own plot

In [None]:
import json
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# loss plot

# Reference: ChatGPT-4 generation with slight modifications, dated Feb 11, 2024 PST
# Prompt used: (attached logs files) "Make me an excellent logging plot from the attached json file. It contains the logs from training a causal language model. Plot the loss/train and eval loss/loss on the same plot each against steps. Use gridlines and make it a very professional plot."
# Generated code:

# Setting the style for the plot
sns.set_style("darkgrid")

# Load the data from the training and validation logs
train_logs_path = "logs/train_logs.json"
val_logs_path = "logs/val_logs.json"

with open(train_logs_path, "r") as file:
    train_logs = json.load(file)

with open(val_logs_path, "r") as file:
    val_logs = json.load(file)

# Extracting the required data
steps_train, loss_train = zip(
    *[(log["steps"], log["loss/train"]) for log in train_logs]
)
steps_val, loss_eval = zip(*[(log["steps"], log["loss/eval"]) for log in val_logs])

# Create the plot
plt.figure(figsize=(10, 6))

# Plot training loss
plt.plot(steps_train, loss_train, "--.", label="loss/train")

# Plot evaluation loss
plt.plot(steps_val, loss_eval, "--.", label="loss/eval")

# Adding titles and labels
plt.title("Training and Evaluation Loss Logs", fontsize=16)
plt.xlabel("global_step", fontsize=14)
plt.ylabel("loss", fontsize=14)
plt.legend()

# Adding gridlines for better readability
plt.grid(True)

# Display the plot
plt.show()

* We observe a pretty standard looking loss plot. We can see potential overfitting on the train set past ~1500 global_steps, with the validation loss drifting further off the train loss in the right half of the plot. We'd want to explore the impacts of this by running vibe check prompts for each chkpt saved during training - see the "Eval" section for more.

In [None]:
# eval perplexity
# modified from loss plot code above

# Setting the style for the plot
sns.set_style("darkgrid")

# Load the data from the training and validation logs
val_logs_path = "logs/val_logs.json"

with open(val_logs_path, "r") as file:
    val_logs = json.load(file)

# Extracting the required data
steps_val, perplexity_eval = zip(
    *[(log["steps"], log["perplexity/eval"]) for log in val_logs]
)

# Create the plot
plt.figure(figsize=(10, 6))

# Plot evaluation loss
plt.plot(steps_val, perplexity_eval, "--.", label="perplexity/eval", color="tab:green")

# Adding titles and labels
plt.title("Evaluation Perplexity Logs", fontsize=16)
plt.xlabel("global_step", fontsize=14)
plt.ylabel("perplexity/eval", fontsize=14)
plt.legend()

# Adding gridlines for better readability
plt.grid(True)

# Display the plot
plt.show()

* Perplexity plot also looks good, and we can equally see here how, on the right half of the graph, we are making little progress in regards to validation set performance (what we care about). 

In [None]:
# learning rate
# modified from loss plot code above

# Setting the style for the plot
sns.set_style("darkgrid")

# Load the data from the training and validation logs
train_logs_path = "logs/train_logs.json"

with open(train_logs_path, "r") as file:
    train_logs = json.load(file)

# Extracting the required data
steps_train, lr_train = zip(*[(log["steps"], log["last_lr"]) for log in train_logs])

# Create the plot
plt.figure(figsize=(10, 6))

# Plot evaluation loss
plt.plot(steps_train, lr_train, "--", label="learning rate", color="tab:red")

# Adding titles and labels
plt.title("Learning Rate Logs", fontsize=16)
plt.xlabel("global_step", fontsize=14)
plt.ylabel("learning rate", fontsize=14)
plt.legend()

# Adding gridlines for better readability
plt.grid(True)

# Display the plot
plt.show()

* The learning rate plot looks good, and it matches the desired cosine scheduling set for the run. I have had good success with cosine scheduling in the past - and it is what I use for my baselines as a result. We'd want to test other lr values and schedules given more time.

# Evaluations



"**Building solid evals should be the starting point** for any LLM-based system or product (as well as conventional machine learning systems)." - Eugene Yan (https://eugeneyan.com/writing/llm-patterns/#evals-to-measure-performance)

**Run "vibe checks"** (one such example below in "Sample usage")
* at each logged chkpt, run 1-2 vibe check prompts to explore evolution of performance across the training
    * it is very possible we overfit on the train set, as per the eval loss and eval perplexity, so the above chkpt vibe checks may reveal that the best model is actually at the ~1500th global step
* for the best model identified, run ~20 vibe check prompts that are representative to explore best model performance

**Run any existing evals related to the task**
* look in Eleuther.AI's lm-eval package for existing evals that reflect our use case, run these and compare with other LLMs performance

**Run a high-quality custom eval on held out examples**

**NB** Ths is the most important step!

* create ~1k test set examples that represent a diverse and high-quality snpashot of your production / test-time use case
* conduct manual and automated testing on the above

# Sample usage

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
# load model and tok
model = AutoModelForCausalLM.from_pretrained(
    "dfurman/phi-2-scientific-papers-base-v0.1",
    device_map="auto",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(
    "dfurman/phi-2-scientific-papers-base-v0.1", trust_remote_code=True
)

model

In [None]:
# run a vibe check prompt

input_sample = "We suggest that [math formula] proves"
inputs = tokenizer(input_sample, return_tensors="pt", return_attention_mask=False)

outputs = model.generate(**inputs, max_new_tokens=10, temperature=0.1, do_sample=True)
text = tokenizer.batch_decode(outputs)[0]
print(text)

# Next steps

We have effectively created a base model here, by language modeling on solely unstructured article text from scientific papers. This is analogous to continued pretraining for the scientific paper domain. In other words, we are essentially creating a scientific paper completer at this stage of the assistant training process. Next, we want to create an assitant model capable of Q&A (https://karpathy.ai/stateofgpt.pdf, slide 3).

Here are the next steps I would follow:

**Task-relevant SFT and DPO**
1. Create a test set of ~300-1000 Q&A examples, hold these out as your eval
2. Curate a training set of Q&A examples, ideally 100k examples, at least 10k
3. Transfer learn base model weights and further train base with instruction tuning yielding a SFT model (train on assistant completions only) (https://huggingface.co/docs/trl/sft_trainer#train-on-completions-only)
4. Crete preference dataset for DPO
5. Transfer learn SFT model weights and perform DPO alignment/training

**NB** Consider leveraging a strong LLM for generating synthetic data, either as augmentation or primary driver of data curation

**To obtain better performance**
* More compute evenly split between growing model size and dataset size
* More compute on expanding context window sized chunking (self.context_length)

**Better logging**
* Move logging from JSON to weights & biases, add additional metrics such as GPU usage stats, elapsed time, etc