In [2]:
# use this to generate the data for the project
! cd ..; python3 -m src.utils.clean

In [1]:
# package imports
import pandas as pd
from dotenv import load_dotenv
from pathlib import Path
from huggingface_hub import login
import os

In [None]:
import torch

print(f'\nAvailable cuda = {torch.cuda.is_available()}')
print(f'\nGPUs availables = {torch.cuda.device_count()}')
print(f'\nCurrent device = {torch.cuda.current_device()}')
print(f'\nCurrent Device location = {torch.cuda.device(0)}')
print(f'\nName of the device = {torch.cuda.get_device_name(0)}')
for device in range(torch.cuda.device_count()):
    print(f'\nDevice {device} = {torch.cuda.get_device_properties(device)}')
    print(f'\nDevice {device} = {torch.cuda.get_device_properties(device).total_memory / 1e9} GB')

In [4]:
def read_json(file):
    pdf = pd.read_json(file, orient='split')
    pdf['prompt'] = pdf['query'].apply(lambda query: f"Convert the following SQL query into a natural language question. Your response must be a single sentence in the form of a clear and concise question.\n Query:{query}")
    pdf['ground_truth'] = pdf['question'].tolist()
    pdf['completion'] = pdf['question'].tolist()
    return pdf[['prompt', 'completion', 'ground_truth']]

In [5]:
# load the data
pd_train = read_json('../out/train.json')
pd_test = read_json('../out/test.json')

In [None]:
# global variables
dotenv_path = Path('../prod.env')
load_dotenv(dotenv_path=dotenv_path)
os.environ['MODEL_DIR'] = os.path.abspath('../model')
login(os.getenv('HF_TOKEN'))

In [None]:
# custom imports - autoreload reloads your functions when you change them
%reload_ext autoreload
%autoreload 2
from data import DATASET
from training import GRPO, SFT
from llms import HF_PAYLOADS

In [9]:
import torch
torch.cuda.empty_cache()

In [None]:
grpo = GRPO(HF_PAYLOADS.QWEN_0_5B, pd_train, pd_test)

In [None]:
grpo.train()

In [None]:
sft = SFT(HF_PAYLOADS.QWEN_0_5B, pd_train, pd_test)

In [None]:
sft.train()

In [None]:
# use this to clear cache!!
from transformers import file_utils
print(file_utils.default_cache_path)