In [1]:
# use this to generate the data for the project
! cd ..; python3 -m src.utils.clean

In [2]:
# package imports
import pandas as pd
from dotenv import load_dotenv
from pathlib import Path
import os

In [3]:
def read_json(file):
    pdf = pd.read_json(file, orient='split')
    pdf['prompt'] = pdf['query'].apply(lambda query: f"Convert the following SQL query into a natural language question. Your response must be a single sentence in the form of a clear and concise question.\n Query:{query}")
    pdf['ground_truth'] = pdf['question'].tolist()
    pdf['completion'] = pdf['question'].tolist()
    return pdf[['prompt', 'completion', 'ground_truth']]

In [4]:
# load the data
pd_train = read_json('../out/train_mini.json')
pd_test = read_json('../out/test_mini.json')

In [5]:
# global variables
dotenv_path = Path('../prod.env')
load_dotenv(dotenv_path=dotenv_path)
os.environ['MODEL_DIR'] = os.path.abspath('../model')

In [None]:
# custom imports - autoreload reloads your functions when you change them
%reload_ext autoreload
%autoreload 2
from data import DATASET
from training import GRPO, SFT
from llms import HF_PAYLOADS

In [6]:
grpo = GRPO(HF_PAYLOADS.QWEN_05B, pd_train, pd_test)

In [None]:
grpo.train()

In [None]:
sft = SFT(HF_PAYLOADS.QWEN_05B, pd_train, pd_test)

In [None]:
sft.train()

In [None]:
# use this to clear cache!!
from transformers import file_utils
print(file_utils.default_cache_path)