In [9]:
import torch
import torch.nn as nn
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    BitsAndBytesConfig,
    Trainer
    
)
import torch.nn.functional as F
from peft import LoraConfig, PeftModel
from trl import SFTTrainer

In [None]:
# The model that you want to train from the Hugging Face hub
model_name = "Qwen/Qwen2.5-3B-Instruct"


################################################################################
# QLoRA parameters
################################################################################

# LoRA attention dimension
lora_r = 64

# Alpha parameter for LoRA scaling
lora_alpha = 16

# Dropout probability for LoRA layers
lora_dropout = 0.1

################################################################################
# bitsandbytes parameters
################################################################################

# Activate 4-bit precision base model loading
use_4bit = True

# Compute dtype for 4-bit base models
bnb_4bit_compute_dtype = "float16"

# Quantization type (fp4 or nf4)
bnb_4bit_quant_type = "nf4"

# Activate nested quantization for 4-bit base models (double quantization)
use_nested_quant = False




################################################################################
# TrainingArguments parameters
################################################################################

# Output directory where the model predictions and checkpoints will be stored
output_dir = "/DATA/rohan_kirti/niladri/fine_tune/train_dataset/results"

# Number of training epochs
num_train_epochs =1

# Enable fp16/bf16 training (set bf16 to True with an A100)
fp16 = False
bf16 = False

# Batch size per GPU for training
per_device_train_batch_size = 4

# Batch size per GPU for evaluation
per_device_eval_batch_size = 4

# Number of update steps to accumulate the gradients for
gradient_accumulation_steps = 1

# Enable gradient checkpointing
gradient_checkpointing = True

# Maximum gradient normal (gradient clipping)
max_grad_norm = 0.3

# Initial learning rate (AdamW optimizer)
learning_rate = 2e-4

# Weight decay to apply to all layers except bias/LayerNorm weights
weight_decay = 0.001

# Optimizer to use
optim = "paged_adamw_32bit"

# Learning rate schedule
lr_scheduler_type = "cosine"

# Number of training steps (overrides num_train_epochs)
max_steps = -1

# Ratio of steps for a linear warmup (from 0 to learning rate)
warmup_ratio = 0.03

# Group sequences into batches with same length
# Saves memory and speeds up training considerably
group_by_length = True

# Save checkpoint every X updates steps
save_steps = 0

# Log every X updates steps
logging_steps =2

################################################################################
# SFT parameters
################################################################################

# Maximum sequence length to use
max_seq_length = None

# Pack multiple short examples in the same input sequence to increase efficiency
packing = False

# Load the entire model on the GPU 0
device_map ='auto'

In [3]:

# 1. Load dataset

dataset = load_dataset("csv", data_files="/DATA/rohan_kirti/niladri/Expert Fine Tuning/Engagement/Persuasion_results.csv")

# Suppose dataset columns: conversation_id, turn_no, speaker, utterance, new_agent_reply, label, reason
labels = list(set(dataset["train"]["label"]))
label2id = {l: i for i, l in enumerate(labels)}
id2label = {i: l for l, i in label2id.items()}

def preprocess(example):
    # numericalize label
    example["label_id"] = label2id[example["label"]]
    return example

dataset = dataset.map(preprocess)


Generating train split: 13383 examples [00:00, 42043.10 examples/s]
Map: 100%|██████████| 13383/13383 [00:01<00:00, 12721.09 examples/s]


In [None]:
# # Load tokenizer and model with QLoRA configuration
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=use_4bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)

# Check GPU compatibility with bfloat16
if compute_dtype == torch.float16 and use_4bit:
    major, _ = torch.cuda.get_device_capability()
    if major >= 8:
        print("=" * 80)
        print("Your GPU supports bfloat16: accelerate training with bf16=True")
        print("=" * 80)


# Load LoRA configuration
peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    task_type="CAUSAL_LM",
)

In [None]:

# 2. Tokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize(example):
    # Input context
    input_enc = tokenizer(example["new_agent_reply"], truncation=True, max_length=256)

    # Reasoning supervision: "reason" text
    reasoning_enc = tokenizer(example["reason"], truncation=True, max_length=256)

    return {
        "input_ids": input_enc["input_ids"],
        "attention_mask": input_enc["attention_mask"],
        "reasoning_labels": reasoning_enc["input_ids"],
        "label_ids": example["label_id"]
    }

tokenized = dataset.map(tokenize, batched=False)



# 3. Collator

def collate_fn(batch):
    return {
        "input_ids": torch.tensor([x["input_ids"] for x in batch]),
        "attention_mask": torch.tensor([x["attention_mask"] for x in batch]),
        "reasoning_labels": torch.tensor([x["reasoning_labels"] for x in batch]),
        "label_ids": torch.tensor([x["label_ids"] for x in batch])
    }

### Loss

In [None]:

class ReasoningWithLabelModel(nn.Module):
    def __init__(self, model_name, num_labels, loss_weight=1.0):
        super().__init__()
        # Returns vocab logits by default; ask for hidden states too
        self.lm = AutoModelForCausalLM.from_pretrained(model_name)
        hidden_size = self.lm.config.hidden_size
        self.classifier = nn.Linear(hidden_size, num_labels)
        self.loss_weight = float(loss_weight)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        reasoning_labels=None,   # token ids to predict (use -100 for ignore)
        label_ids=None,          # class indices (LongTensor) OR
        label_onehot=None,       # one-hot / soft labels (FloatTensor)
    ):
        # === 1) Run the LM to get token-level logits AND hidden states ===
        lm_outputs = self.lm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,  # so we can pool a representation
            use_cache=False             # avoid past_key_values in training
        )

        # last_hidden_state: (batch, seq_len, hidden)
        last_hidden = lm_outputs.hidden_states[-1]
        # pool with last token (you may swap for CLS/prefix pooling if desired)
        pooled = last_hidden[:, -1, :]                          # (B, H)
        logits = self.classifier(pooled)                        # (B, C)

        # === 2) Classification loss: L_cls = -∑_i y_i log(ŷ_i) ===
        classification_loss = None
        if (label_onehot is not None) or (label_ids is not None):
            if label_onehot is None:
                # convert hard labels -> one-hot
                num_labels = logits.size(-1)
                label_onehot = F.one_hot(label_ids, num_classes=num_labels).float()
            # predicted probabilities
            log_probs = F.log_softmax(logits, dim=-1)           # (B, C)
            # cross-entropy with one-hot (supports soft labels too)
            # reduce over classes, then mean over batch
            classification_loss = -(label_onehot * log_probs).sum(dim=-1).mean()

        # === 3) Reasoning loss: token-level NLL for seq generation ===
        # Shift so that tokens <t predict token t, standard causal LM training
        reasoning_loss = None
        if reasoning_labels is not None:
            # model vocab logits: (B, T, V)
            lm_logits = lm_outputs.logits
            # shift
            shift_logits = lm_logits[:, :-1, :]                 # (B, T-1, V)
            shift_labels = reasoning_labels[:, 1:]              # (B, T-1)

            # Optional: mask out pads/ignored labels; by convention -100 is ignore_index
            ignore_mask = (shift_labels == -100)
            # compute log probs
            log_probs_t = F.log_softmax(shift_logits, dim=-1)   # (B, T-1, V)

            # gather log Pθ(y_t | y_<t, x, i^) at each step t
            # For ignore positions, we can safely put any index; we’ll zero them later
            safe_labels = torch.where(ignore_mask, torch.zeros_like(shift_labels), shift_labels)
            gathered = log_probs_t.gather(dim=-1, index=safe_labels.unsqueeze(-1)).squeeze(-1)  # (B, T-1)

            # zero out ignored positions, then sum over time and average over batch
            gathered = torch.where(ignore_mask, torch.zeros_like(gathered), gathered)
            # NLL is negative sum over time; normalize by number of valid tokens
            valid_counts = (~ignore_mask).sum(dim=-1).clamp_min(1)  # (B,)
            nll_per_example = -(gathered.sum(dim=-1) / valid_counts)  # (B,)
            reasoning_loss = nll_per_example.mean()

        # === 4) Total loss: L_intent = L_cls + L_nll ===
        total_loss = None
        if (classification_loss is not None) and (reasoning_loss is not None):
            total_loss = reasoning_loss + self.loss_weight * classification_loss
        elif reasoning_loss is not None:
            total_loss = reasoning_loss
        elif classification_loss is not None:
            total_loss = self.loss_weight * classification_loss

        return {
            "loss": total_loss,
            "reasoning_loss": reasoning_loss,
            "classification_loss": classification_loss,
            "logits": logits,
        }


In [None]:

# 5. Trainer

# args = TrainingArguments(
#     output_dir="./multitask_reasoning",
#     per_device_train_batch_size=2,
#     per_device_eval_batch_size=2,
#     num_train_epochs=3,
#     learning_rate=5e-5,
#     logging_steps=20,
#     save_strategy="epoch",
#     evaluation_strategy="epoch",
#     fp16=True,
#     report_to="tensorboard"
# )


training_arguments = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    fp16=fp16,
    bf16=bf16,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=group_by_length,
    lr_scheduler_type=lr_scheduler_type,
    report_to="tensorboard"
)



In [None]:

model = ReasoningWithLabelModel(model_name, num_labels=len(label2id), loss_weight=1.0)

In [None]:
trainer = Trainer(
    model=model,
    args=training_arguments,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["test"],
    data_collator=collate_fn,
)

In [None]:
print("="*10)
print("Starting training...")
trainer.train()
print("Training complete.")
print("="*10)


In [None]:
new_model = "/DATA/rohan_kirti/niladri/fine_tune/train_dataset/fine_tune model"
print("="*10)
trainer.model.save_pretrained(new_model)
print("Model saved")
print("="*10)