In [42]:
import re
import os
import math
import json
from copy import deepcopy
from typing import List, Dict, Tuple, Optional

import torch
from torch.utils.data import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
# QLoRA / quantization imports
from transformers import BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import wandb
from huggingface_hub import login
import random
from sklearn.model_selection import train_test_split

In [43]:
with open('/project/hf_api_login_key.txt', 'r') as f:
    hf_api = f.read()
login(token=hf_api)

In [44]:
# ------------------------------------------------------------------
# put these at the VERY TOP of your script, before importing HF libs
# ------------------------------------------------------------------
import os
os.environ["HF_HOME"] = "/SERVER_FAST_SSD/hf"
os.environ["TRANSFORMERS_CACHE"] = "/SERVER_FAST_SSD/hf/transformers"
os.environ["HF_DATASETS_CACHE"] = "/SERVER_FAST_SSD/hf/datasets"
os.environ["HUGGINGFACE_HUB_CACHE"] = "/SERVER_FAST_SSD/hf/hub"
os.environ["WANDB_DISABLED"] = "false"   # flip to "false" if you actually want wandb

OUTPUT_DIR = "/project/testing_finetuning"  # your requested path
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [45]:
# ---------------------------
# 0) Load data
# ---------------------------
with open('./referencegame_data_DS.json', 'r') as f:
    data = json.load(f)

In [46]:
# -----------------------------------------
# 1) Dataset + collator
# -----------------------------------------
class BanditDataset(Dataset):
    def __init__(self, rows: List[dict], tokenizer: AutoTokenizer, max_len: int = 4096):
        self.rows = rows
        self.tok = tokenizer
        if self.tok.pad_token is None:
            self.tok.pad_token = self.tok.eos_token
        self.max_len = max_len

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, i):
        r = self.rows[i]
        if r["role"] == "gen":
            prompt, action = r["prompt_speaker"], r["utterance"]
        else:
            prompt, action = r["prompt_listener"], r["guess"]

        enc_p = self.tok(prompt, add_special_tokens=False)
        enc_a = self.tok(action, add_special_tokens=False)

        input_ids = enc_p["input_ids"] + enc_a["input_ids"]
        labels = [-100] * len(enc_p["input_ids"]) + enc_a["input_ids"]
        attn = [1] * len(input_ids)

        if len(input_ids) > self.max_len:
            overflow = len(input_ids) - self.max_len
            input_ids = input_ids[overflow:]
            labels = labels[overflow:]
            attn = attn[overflow:]

        logp_behavior = r.get("logp_behavior", None)
        has_behav = float((logp_behavior is not None) and (r["reward"] == -1))

        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attn, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long),
            "reward": torch.tensor(float(r["reward"]), dtype=torch.float),
            "logp_behavior": torch.tensor(0.0 if logp_behavior is None else float(logp_behavior), dtype=torch.float),
            "has_behav": torch.tensor(has_behav, dtype=torch.float),
        }

class PadCollator:
    def __init__(self, pad_id: int):
        self.pad_id = pad_id

    def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
        max_len = max(x["input_ids"].shape[0] for x in batch)
        for x in batch:
            pad = max_len - x["input_ids"].shape[0]
            if pad > 0:
                x["input_ids"] = torch.nn.functional.pad(x["input_ids"], (0, pad), value=self.pad_id)
                x["attention_mask"] = torch.nn.functional.pad(x["attention_mask"], (0, pad), value=0)
                x["labels"] = torch.nn.functional.pad(x["labels"], (0, pad), value=-100)
        out = {k: torch.stack([x[k] for x in batch]) for k in ["input_ids", "attention_mask", "labels", "reward", "logp_behavior", "has_behav"]}
        return out


In [47]:
def ips_reinforce_loss(model, batch, ips_clip: float = 5.0) -> torch.Tensor:
    out = model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_dict=True)
    logprobs = torch.log_softmax(out.logits, dim=-1)
    labels = batch["labels"]
    mask = (labels != -100)

    # sequence log-prob
    labels_safe = torch.where(mask, labels, torch.zeros_like(labels))
    tok_lp = torch.gather(logprobs, -1, labels_safe.unsqueeze(-1)).squeeze(-1)
    seq_logp = (tok_lp * mask).sum(dim=1)

    reward     = batch["reward"].to(seq_logp.dtype)
    logp_behav = batch["logp_behavior"].to(seq_logp.dtype)
    has_behav  = (batch["has_behav"] > 0.5)

    # importance ratio
    with torch.no_grad():
        print(seq_logp)
        print(logp_behav)
        diff   = seq_logp.detach() - logp_behav
        ratio  = torch.exp(diff)
        if ips_clip is not None and ips_clip > 0:
            ratio = torch.clamp(ratio, max=ips_clip)
        # if no behavior policy, set c=1
        c = torch.where((reward < 0) & has_behav, ratio, torch.ones_like(ratio))
        print(c)

    # final loss
    per_ex = - c * reward * seq_logp
    loss = per_ex.mean()
    return torch.nan_to_num(loss, nan=0.0, neginf=0.0, posinf=0.0)


In [48]:
# def ips_reinforce_loss(model, batch) -> torch.Tensor:
#     out = model(
#         input_ids=batch["input_ids"],
#         attention_mask=batch["attention_mask"],
#         return_dict=True,
#     )
#     logits = out.logits  # [B,T,V]
#     logprobs = torch.log_softmax(logits, dim=-1)
#     labels = batch["labels"]                     # [B,T]

#     mask = (labels != -100)
#     labels_safe = torch.where(mask, labels, torch.zeros_like(labels))
#     token_lp = torch.gather(logprobs, dim=-1, index=labels_safe.unsqueeze(-1)).squeeze(-1)
#     token_lp = token_lp * mask
#     logp_curr = token_lp.sum(dim=1)              # [B]

#     reward = batch["reward"]
#     logp_behav = batch["logp_behavior"]
#     has_behav = (batch["has_behav"] > 0.5)

#     ratio = torch.exp(logp_curr - logp_behav)
#     c = torch.where((reward < 0) & has_behav, ratio, torch.ones_like(ratio))
#     c = c.detach()                               # do not backprop through c

#     per_ex = -(reward * c * logp_curr)
#     keep = ~((reward < 0) & (~has_behav))        # drop negatives w/o behavior prob
#     if keep.sum() == 0:
#         # avoid NaN if the whole batch is filtered
#         return per_ex.mean() * 0.0
#     return per_ex[keep].mean()


In [49]:
# only reinforc
def reinforce_loss(model, batch) -> torch.Tensor:
    out = model(
        input_ids=batch["input_ids"],
        attention_mask=batch["attention_mask"],
        return_dict=True,
    )
    logprobs = torch.log_softmax(out.logits, dim=-1)
    labels = batch["labels"]                      # [B, T]
    mask = (labels != -100)

    # sequence log-prob of the action (sum over action tokens only)
    labels_safe = torch.where(mask, labels, torch.zeros_like(labels))
    tok_lp = torch.gather(logprobs, -1, labels_safe.unsqueeze(-1)).squeeze(-1)
    seq_logp = (tok_lp * mask).sum(dim=1)        # [B], typically ≤ 0

    reward = batch["reward"].to(seq_logp.dtype)  # ±1

    # (optional) tiny stabilizers while debugging:
    seq_logp = torch.nan_to_num(seq_logp, nan=0.0, neginf=-1e4, posinf=0.0)
    # Or use per-token average instead of sum:
    # seq_len = mask.sum(dim=1).clamp(min=1).float()
    # seq_logp = (seq_logp / seq_len).clamp(min=-5.0, max=0.0)

    per_ex = -(reward * seq_logp)
    loss = per_ex.mean()
    return torch.nan_to_num(loss, nan=0.0, neginf=0.0, posinf=0.0)


In [50]:
# -----------------------------------------
# 6) Trainer subclass to use our loss
# -----------------------------------------
class IPSReinforceTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        # kwargs may contain num_items_in_batch etc. — we don't use them.
        loss = ips_reinforce_loss(model, inputs)
        return (loss, {"loss": loss}) if return_outputs else loss

In [51]:
"""# -----------------------------------------
# Trainer subclass using plain REINFORCE
# -----------------------------------------
class ReinforceTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        loss = reinforce_loss(model, inputs)
        return (loss, {"loss": loss}) if return_outputs else loss
"""

'# -----------------------------------------\n# Trainer subclass using plain REINFORCE\n# -----------------------------------------\nclass ReinforceTrainer(Trainer):\n    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):\n        loss = reinforce_loss(model, inputs)\n        return (loss, {"loss": loss}) if return_outputs else loss\n'

In [52]:
def load_tokenizer(model_ckpt: str):
    tok = AutoTokenizer.from_pretrained(model_ckpt, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    return tok

def _bf16_supported() -> bool:
    if not torch.cuda.is_available():
        return False
    major, minor = torch.cuda.get_device_capability()
    return major >= 8  # Ampere (8.0) or newer

def load_qlora_base(model_ckpt: str):
    compute_dtype = torch.bfloat16 if _bf16_supported() else torch.float16

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
    )
    
    if not torch.cuda.is_available():
        raise RuntimeError("QLoRA 4-bit requires CUDA; no GPU visible.")
    device_index = torch.cuda.current_device()  # usually 0 on your A100 box
    device_map = {"": 0}
    #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    

    model = AutoModelForCausalLM.from_pretrained(
        model_ckpt,
        quantization_config=bnb_config,
        #device_map="auto",
        device_map=device_map,
        #device_map=device,
        attn_implementation="eager",  # more stable while debugging; switch to "flash_attention_2" later
    )
    model.config.use_cache = False
    try:
        model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
    except TypeError:
        model.gradient_checkpointing_enable()

    # Prepare for k-bit training (fix layer norms, input grads, etc.)
    from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
    try:
        model = prepare_model_for_kbit_training(
            model, use_gradient_checkpointing=True,
            gradient_checkpointing_kwargs={"use_reentrant": False}
        )
    except TypeError:
        model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)

    peft_config = LoraConfig(
        r=16,
        lora_alpha=8,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj"],  # only these, per paper
    )
    
    model = get_peft_model(model, peft_config)
    return model

def main():
    # --- 0) Set seed
    random.seed(9)
    
    # --- 1) Build/prepare data as before ---

    # shuffle data and split into train and dev
    train_data, dev_data = train_test_split(data, test_size=0.15, random_state=42, shuffle=True)
    
    # --- 2) Tokenizer & QLoRA model ---
    model_ckpt = "meta-llama/Llama-3.1-8B-Instruct"
    tokenizer = load_tokenizer(model_ckpt)
    model = load_qlora_base(model_ckpt)

    # --- 3) Dataset/Collator ---
    dataset = BanditDataset(train_data, tokenizer, max_len=4096)
    data_collator = PadCollator(pad_id=tokenizer.pad_token_id)

    # --- 4) Trainer args for QLoRA ---
    # Use a higher LR typical for LoRA adapters; use paged optimizers for 4-bit
    # prefer bf16 on Ampere+, else fp16
    use_bf16 = _bf16_supported()
    args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        per_device_train_batch_size=32,
        gradient_accumulation_steps=1,
        learning_rate=0.0001,#2e-4,
        num_train_epochs=1,
        lr_scheduler_type="constant",
        warmup_ratio=0.0,
        weight_decay=0.1,
        logging_steps=1,
        save_steps=100,
        save_total_limit=2,
        report_to="wandb",
        run_name="whyareurunnin",
        bf16=use_bf16,
        fp16=not use_bf16,                 # <- only one of these should be True
        optim="adamw_torch", #"paged_adamw_8bit",
        remove_unused_columns=False,
        )
    
    trainer = IPSReinforceTrainer(
    model=model,
    args=args,
    train_dataset=dataset,
    processing_class=tokenizer,
    data_collator=data_collator,
        )

    print('Start training')
    trainer.train()
    
    # Save explicitly
    print('Save the model')
    trainer.save_model()               # saves PEFT adapter weights
    tokenizer.save_pretrained(OUTPUT_DIR)

    # Later for inference:
    # from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
    # from peft import PeftModel

    # base = AutoModelForCausalLM.from_pretrained(
    #     "meta-llama/Llama-3.1-8B-Instruct",
    #     quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16,
    #                                            bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4"),
    #     device_map="auto",
    # )
    # model = PeftModel.from_pretrained(base, OUTPUT_DIR)
    # tok = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")


    # push adapters
    model.push_to_hub("imge/llama_v1")
    tokenizer.push_to_hub("imge/llama_v1")

In [53]:
main()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Start training
tensor([-258.4668, -233.5099, -210.8028,  -53.1822, -277.1786,  -34.8920,
        -253.8637,  -35.4692,  -92.0188,  -48.6660, -199.8020, -585.5114,
         -55.7941, -253.8636,  -38.1408,  -50.2399,  -52.5271,  -54.4470,
        -199.8020, -272.1948,  -51.8451, -258.4668,  -55.1483,  -48.4044,
        -123.5749, -180.5758, -382.9168, -667.1732,  -50.0904, -129.7884,
        -120.1087,  -53.8062], device='cuda:0', grad_fn=<SumBackward1>)
tensor([   0.0000,    0.0000, -207.0000,  -56.7500,    0.0000,    0.0000,
        -251.0000,    0.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000, -251.0000,    0.0000,  -54.2500,    0.0000,    0.0000,
           0.0000, -266.0000,    0.0000,    0.0000,  -57.0000,    0.0000,
           0.0000, -178.0000, -378.0000,    0.0000,    0.0000,    0.0000,
           0.0000,  -56.0000], device='cuda:0')
tensor([1.0000e+00, 1.0000e+00, 2.2308e-02, 5.0000e+00, 1.0000e+00, 1.0000e+00,
        5.7059e-02, 1.0000e+00, 1.0000e+00, 1

Step,Training Loss
1,78.2194
2,-101.283
3,-25.2729
4,85.3773
5,-135.8621
6,4.4431
7,-69.0905
8,41.5
9,102.3095
10,108.452


tensor([-505.1823, -559.5725,  -53.2788, -252.4946,  -55.6971, -662.4315,
        -111.2572, -252.6646, -663.4072, -159.8318,  -51.9672,  -52.2444,
         -55.2432, -244.9593,  -49.4636,  -50.1026,  -53.5878, -258.0309,
         -53.4318, -379.8980, -588.1936, -722.5264,  -51.1584, -562.0634,
         -54.6298,  -53.1316, -252.5452, -221.4471, -181.6573, -248.1126,
         -55.7697,  -50.7195], device='cuda:0', grad_fn=<SumBackward1>)
tensor([-492.0000, -552.0000,    0.0000, -251.0000,  -56.2500, -664.0000,
           0.0000,    0.0000, -644.0000,    0.0000,  -55.0000,  -55.0000,
         -58.0000,    0.0000,  -51.0000,  -55.0000,    0.0000,    0.0000,
         -55.0000, -378.0000, -580.0000,    0.0000,    0.0000, -552.0000,
         -55.5000,    0.0000, -251.0000,    0.0000,    0.0000, -249.0000,
           0.0000,    0.0000], device='cuda:0')
tensor([1.8837e-06, 5.1440e-04, 1.0000e+00, 2.2435e-01, 1.7383e+00, 4.7994e+00,
        1.0000e+00, 1.0000e+00, 3.7286e-09, 1.0000e+00, 5.00

tensor([-294.4225, -557.9563,  -59.8030,  -59.3786,  -60.8472,  -58.5339,
        -252.3942,  -60.1437, -341.9130,  -40.2898,  -59.0742, -420.9152,
         -61.0883,  -63.6726,  -57.0076, -415.5738,  -39.6594,  -41.0366,
         -54.9229, -587.7743,  -57.1811,  -57.0502, -587.4235,  -57.8369,
        -160.4869, -300.9436, -269.2845, -504.3029, -634.9478, -230.5892,
        -273.6657, -380.0076], device='cuda:0', grad_fn=<SumBackward1>)
tensor([-294.0000, -552.0000,    0.0000,    0.0000,    0.0000,    0.0000,
           0.0000,    0.0000, -334.0000,    0.0000,  -58.7500, -416.0000,
         -55.0000,  -57.7500,  -54.0000,    0.0000,    0.0000,    0.0000,
         -52.2500,    0.0000,  -53.7500,  -54.5000,    0.0000,  -52.5000,
           0.0000,    0.0000, -266.0000, -492.0000,    0.0000,    0.0000,
        -268.0000, -378.0000], device='cuda:0')
tensor([6.5542e-01, 2.5895e-03, 1.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 1.0000e+00, 3.6595e-04, 1.0000e+00, 7.23

KeyboardInterrupt: 

In [None]:
print(torch.cuda.is_available())

print(torch.cuda.device_count())
print(torch.cuda.current_device())
print(torch.cuda.device(0))
print(torch.cuda.get_device_name(0))


Tried out the losses without IPS coefficient, with the normal loss, and these were the results.

Step 	Training Loss
1 	518.321400
2 	406.030500
3 	-916.910800
4 	-245.761400
5 	-996.767900
6 	-439.983500
7 	27.368200
8 	-303.980600
9 	98.382900
10 	1272.444300
11 	581.502000
12 	-508.267100
13 	-1107.622900
14 	-2049.277300
15 	715.464200
16 	97.236000
17 	-1123.329300
18 	734.981900
19 	-921.665000
20 	2865.062700
21 	109.311600
22 	-711.254700
23 	-748.666000
24 	-454.268800
25 	-2073.575900
26 	-1976.982200
27 	-246.352200
28 	-2102.958000
29 	-2487.609900
30 	704.824800
31 	-836.676900
32 	-1319.440900
33 	-2682.269300


there is MOST DEFENITELY an issue with my loss somewher and I need to calmly debug it ... UFFF T_T


In [None]:
# See if another process is using the GPU
!nvidia-smi


# Check driver and CUDA versions
#nvidia-smi | head -n 3
#python -c "import torch; print(torch.version.cuda); print(torch.backends.cudnn.version())"
