In [None]:
!pip install torch==2.6.0

In [None]:
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
!pip install unsloth

In [None]:
!pip freeze > unsloth_requirements.txt

In [None]:
!pip install unsloth_zoo

In [None]:
import os

os.environ["XFORMERS_MORE_DETAILS"] = "1"
os.environ["DISABLE_CUT_CE"] = "1"
os.environ["FLASH_ATTENTION_FORCE_DISABLED"] = "1"
os.environ["XFORMERS_DISABLED"] = "1"

In [None]:
from unsloth import FastModel
from google.colab import userdata
from huggingface_hub import login

login(userdata.get('HF_TOKEN'))
token = userdata.get('HF_TOKEN')

In [None]:
import gc
import torch

torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)  # opzionale: off
torch.backends.cuda.enable_math_sdp(True)

In [None]:
max_seq_length = 256

model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3-270m",
    max_seq_length = max_seq_length,
    load_in_4bit = False,
    token = token,
    attn_implementation = "eager",
)
model.config.attn_implementation = "eager"
model.use_fast_kernels=False

In [None]:
model = FastModel.get_peft_model(
    model,
    r = 8,
    target_modules="all-linear",
    lora_alpha = 8,
    lora_dropout=0.0,
    bias = "none",
    use_gradient_checkpointing = True,
    random_state = 42,
    use_rslora = False,  # rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

In [None]:
from datasets import load_dataset

ds = load_dataset("gretelai/synthetic_text_to_sql", token=token)
ds = ds.shuffle()
train_ds = ds["train"].select(range(35000))

In [None]:
user_prompt = """Given the <USER_QUERY> and the <SCHEMA>, generate the corresponding SQL command to retrieve the desired data, considering the query's syntax, semantics, and schema constraints.

<SCHEMA>
{context}
</SCHEMA>

<USER_QUERY>
{question}
</USER_QUERY>

<RESPONSE>
{response}
</RESPONSE>
"""

def format_function(example):
    ctx = example.get("sql_context", "")
    q   = example.get("sql_prompt", "")
    resp= example.get("sql", "")
    text = user_prompt.format(context=ctx, question=q, response=resp)
    # Consigliato chiudere con EOS se disponibile
    try:
        eos = tokenizer.eos_token or ""
    except:
        eos = ""
    return [text + eos]

In [None]:

def format_prompts_for_training(example):
    example["text"] = user_prompt.format(
        context=example["sql_context"],
        question=example["sql_prompt"],
        response=example["sql"]
    ) + (tokenizer.eos_token or "")
    return example


train_ds=train_ds.map(format_prompts_for_training)

In [None]:
from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_ds,
    eval_dataset = None,
    args = SFTConfig(
        gradient_checkpointing=True,
        dataset_text_field = "text",
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        num_train_epochs = 1, # Set this for 1 full training run.
        learning_rate = 2e-4,
        logging_steps = 100,
        optim="adamw_torch_fused",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 42,
        output_dir="./model/gemma-270m-Text2SQL-Fine-tuned",
        report_to="wandb",
        run_name="Gemma-Text2SQL",
        fp16=True,
      bf16=False,

    ),
    formatting_func=format_function,
)

In [None]:
import wandb
wandb.login()

In [None]:
%env WANDB_PROJECT=Gemma-Text2SQL

In [None]:
trainer_stats = trainer.train()

In [None]:
trainer.save_model()

In [None]:
model.push_to_hub_merged("leotod/gemma-270m-Text2SQL-Fine-tuned", tokenizer, save_method = "merged_16bit", token = token)

In [None]:
import gc
import torch

gc.collect()
del model
del trainer
torch.cuda.empty_cache()