In [None]:
import unsloth
from unsloth import FastLanguageModel
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
def upload_model_and_tokenizer(model_name):
    
    if model_name not in ["unsloth/gemma-2-2b-it-bnb-4bit",
                          "gsarti/gemma-2-2b-rebus-solver-fp16", 
                          "gsarti/llama-3.1-8b-rebus-solver-fp16", 
                          "gsarti/phi3-mini-rebus-solver-fp16"]:
        raise ValueError("Choose a gsarti finetuned model for rebus solver.")
        
    else:
        model, tokenizer = FastLanguageModel.from_pretrained(
                model_name = model_name,
                max_seq_length = 1248,
                load_in_4bit = True,
        )

        if model_name in ["gsarti/llama-3.1-8b-rebus-solver-fp16", 
                          "gsarti/phi3-mini-rebus-solver-fp16"]:
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            model = AutoModelForCausalLM.from_pretrained(model_name)     
    return model, tokenizer

In [None]:
model, tokenizer = upload_model_and_tokenizer("unsloth/gemma-2-2b-it-bnb-4bit")

In [None]:
template = """<bos><start_of_turn>user
Risolvi gli indizi tra parentesi per ottenere una prima lettura, e usa la chiave di lettura per ottenere la soluzione del rebus.

Rebus: {rebus}
Chiave risolutiva: {key}<end_of_turn>
<start_of_turn>model"""

In [1]:
from datasets import load_dataset
train_dataset_name = "saracandu/eureka-rebus-grpo"
dataset = load_dataset(train_dataset_name, split = "train")

In [None]:
dataset[0]

In [None]:
inputs = tokenizer(dataset['prompt'][0], return_tensors="pt")["input_ids"].to('cuda:0')
outputs = model.generate(input_ids = inputs, max_new_tokens = 1248, use_cache = True)

In [None]:
model_generations = tokenizer.batch_decode(outputs)
print(model_generations[0])

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = False, # True or "unsloth" for very long context
    random_state = 4249,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

In [2]:
import torch
from trl import GRPOConfig, GRPOTrainer
import wandb
wandb.login(key="5a69225ea1d050c9c21f67c2db85febf61fa8fb1")

2025-05-09 15:39:22.350683: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-09 15:39:22.392919: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-09 15:39:22.392957: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-09 15:39:22.394384: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-05-09 15:39:22.402436: I tensorflow/core/platform/cpu_feature_guar

True

In [3]:
training_args = GRPOConfig(
    output_dir = "GRPO",
    learning_rate = 2e-5,
    per_device_train_batch_size = 4,
    gradient_accumulation_steps = 2,
    max_prompt_length = 1248, 
    max_completion_length = 1248,
    num_generations = 4, # i.e. number of competitive completions considered during optimization
    optim = "adamw_8bit",
    num_train_epochs = 1,
    bf16 = False, 
    report_to = ["wandb"],
    remove_unused_columns = False, 
    logging_steps = 1
)

In [4]:
import re

def estrai_soluzione(input_string):
    # Extract only the solution
    match = re.search(r"Soluzione: (.+?)\n", input_string, re.DOTALL)
    if match:
        return match.group(1).strip()
    else:
        return "NotFound"

def estrai_indizi(input_string):

    # Extract the `[...] = relevantPart`
    pattern = r"\[([^\]]+)\] = ([^\n]+)"
    indizi = re.findall(pattern, input_string)
    risposte = [risposta for _, risposta in indizi]
    
    # Extract the `... = relevantLetters`
    pattern_sigle = r"[-–•]?\s*([A-Z]+(?:\s+[A-Z]+)*)\s*=\s*[^\n]+"    
    risposte_sigle = re.findall(pattern_sigle, input_string)
    
    # Combina le risposte estratte dalle parentesi e quelle singole
    return {'letters': risposte_sigle, 'words': risposte}

def estrai_primalet(input_string):
    # Extract the first-pass (prima lettura in italian)
    match = re.search(r"Prima lettura: (.+?)\n", input_string, re.DOTALL)
    if match:
        return match.group(1).strip()
    else:
        return "NotFound"


def estrai_rebus_e_chiave(testo):
    # Extract from the problem formulation: 
    #   - the rebus problem 
    #   - the key (i.e. bounds of number of letters of the solution)
    rebus_match = re.search(r"Rebus:\s*(.+?)\s*Chiave di lettura:", testo, re.DOTALL)
    chiave_match = re.search(r"Chiave di lettura:\s*(.+)", testo)

    rebus_raw = rebus_match.group(1).strip() if rebus_match else ""
    chiave = chiave_match.group(1).strip() if chiave_match else ""

    return rebus_raw, chiave

In [5]:
def exact_match_solution(prompts, completions, ground_truth, **kwargs):
    # Estrazione delle soluzioni
    predicted = [estrai_soluzione(completion) for completion in completions]
    gold = estrai_soluzione(ground_truth[0])
    print(gold)
    print(predicted)
    
    scores = []
    for guess in predicted:
        if guess == "NotFound":
            scores.append(0)
            continue
        try:
            scores.append(1.0 if guess == gold else 0.0)
        except:
            scores.append(0)
            continue
    return scores


def perc_correct_words_solution(prompts, completions, ground_truth, **kwargs):
    gold = estrai_soluzione(ground_truth[0]).lower().split()
    scores = []

    for completion in completions:
        pred = estrai_soluzione(completion)
        print(pred)
        if not pred:
            continue

        pred = pred.lower().split()
        score = 0
        for pw, gw in zip(pred, gold):
            if pw == gw:
                score += 1
            elif len(pw) == len(gw):
                score += 0.5
        scores.append(score / len(gold))

    return scores


def exact_match_primalet(prompts, completions, ground_truth, **kwargs):
    predicted = [estrai_primalet(completion) for completion in completions]
    golden = estrai_primalet(ground_truth[0]).lower().replace(" ", "")
    scores = []
    for guess in predicted:
        if guess == "NotFound":
            scores.append(0)
            continue
        try:
            scores.append(1.0 if guess.lower().replace(" ", "") == golden else 0.0)
        except:
            scores.append(0)
            continue
    return scores


def perc_correct_defres(prompts, completions, ground_truth, **kwargs):
    predicted = [estrai_indizi(completion.replace("*", "")) for completion in completions] 
    golden = estrai_indizi(ground_truth[0])
    word_scores = []
    letter_scores = []
    for pred in predicted:
        wscore = 0
        for pw, gw in zip(pred['words'], golden['words']):
            if pw == gw:
                wscore += 1
            elif len(pw) == len(gw):
                wscore += 0.5
        word_scores.append(wscore / len(golden['words']))

        lscore = 0
        for pw, gw in zip(pred['letters'], golden['letters']):
            if pw.lower().replace(" ", "") == gw.lower().replace(" ", ""):
                lscore += 1
        letter_scores.append(lscore / len(golden['letters']))
        
    return [word_scores[i] + letter_scores[i] for i in range(len(predicted))]

In [6]:
first_three_rows = dataset.select(range(5))
first_three_rows

Dataset({
    features: ['prompt', 'ground_truth'],
    num_rows: 5
})

In [7]:
first_three_rows['ground_truth'][2]

'Procediamo alla risoluzione del rebus passo per passo:\n- D O = D O\n- [Tristi, mogi] = mesti\n- [Quello del fucile non abbaia!] = cane\n- G = G\n- [Lo curvano i frutti] = ramo\n- L T = L T\n- [Strumento di terracotta] = ocarina\n\nPrima lettura: D O mesti cane G ramo L T ocarina\n\nOra componiamo la soluzione seguendo la chiave risolutiva:\n9 = Domestica\n5 = negra\n5 = molto\n6 = carina\n\nSoluzione: Domestica negra molto carina\n'

In [8]:
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer

model_id = "gsarti/gemma-2-2b-rebus-solver-fp16"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    attn_implementation="eager"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

lora_config = LoraConfig(
    task_type = "CAUSAL_LM",
    r = 16, # rank of the surrogate matrices!
    lora_alpha = 32, # scale factor controlling the impact of the modifications
    target_modules = "all-linear" # applied to all linear transformations in the model
)
model = get_peft_model(model, lora_config)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [9]:
import wandb

trainer = GRPOTrainer(
    model=model,
    reward_funcs=[exact_match_solution],
    args=training_args,
    train_dataset=first_three_rows,
)

wandb.init(project="GRPO")
trainer.train()

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.




Domestica negra molto carina
['Domestica negra molto carina', 'Domestica negra molto carina', 'Domestica negra molto carina', 'Domestica negra molto carina']


OutOfMemoryError: CUDA out of memory. Tried to allocate 220.00 MiB. GPU 0 has a total capacity of 31.73 GiB of which 184.25 MiB is free. Including non-PyTorch memory, this process has 31.55 GiB memory in use. Of the allocated memory 30.72 GiB is allocated by PyTorch, and 472.19 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# perc_correct_words_solution, exact_match_primalet, perc_correct_defres

In [None]:
completions = []

for i in range(4):
    outputs = model.generate(input_ids = inputs, max_new_tokens = 500, use_cache = True, do_sample=True)
    model_generations = tokenizer.batch_decode(outputs)
    completions.append(model_generations[0])

In [None]:
perc_correct_words_solution(dataset['prompt'][0], completions, dataset['ground_truth'][0])