In [6]:
# In[ ]:

import os, glob, pandas as pd, torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.modeling_outputs import SequenceClassifierOutput
from peft import LoraConfig, get_peft_model, TaskType
import torch.nn as nn
from safetensors.torch import load_file as load_safetensors
from sklearn.metrics import accuracy_score

# ── USER SETTINGS ──────────────────────────────────────────────────────────────
output_dir = "/n/netscratch/gershman_lab/Lab/amuppidi/reasoning"
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
csv_dir    = "/n/home04/amuppidi/reasoning-scheduling/data/gsm8k_results_with_difficulty"
use_lora   = True
lora_r     = 16
lora_alpha = 32
batch_size = 8
device     = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ── HELPERS ────────────────────────────────────────────────────────────────────
def load_difficulty_csvs(csv_dir, split):
    pattern = os.path.join(csv_dir, f"gsm8k_Y_{split}_*_with_difficulty.csv")
    paths   = sorted(glob.glob(pattern))
    if not paths:
        raise FileNotFoundError(pattern)
    return pd.concat([
        pd.read_csv(p, usecols=["question_text","difficulty"])
        for p in paths
    ], ignore_index=True)

class DifficultyDataset(torch.utils.data.Dataset):
    label2id = {"easy":0,"medium":1,"hard":2}
    def __init__(self, df, tokenizer, max_length=512):
        self.texts = df["question_text"].tolist()
        self.labels= [self.label2id[d] for d in df["difficulty"]]
        self.tokenizer, self.max_length = tokenizer, max_length
    def __len__(self): return len(self.texts)
    def __getitem__(self, i):
        enc = self.tokenizer(
            self.texts[i],
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt"
        )
        item = {k:v.squeeze(0) for k,v in enc.items()}
        item["labels"] = torch.tensor(self.labels[i],dtype=torch.long)
        return item

def collate_fn(batch):
    return {
        "input_ids":      torch.stack([b["input_ids"]      for b in batch]),
        "attention_mask": torch.stack([b["attention_mask"] for b in batch]),
        "labels":         torch.tensor([b["labels"]         for b in batch])
    }

class DifficultyFinetuner(nn.Module):
    def __init__(self, model_name, use_lora, r, alpha, num_labels):
        super().__init__()
        self.lm = AutoModelForCausalLM.from_pretrained(
            model_name, output_hidden_states=True, device_map="auto"
        )
        if use_lora:
            cfg = LoraConfig(
                task_type=TaskType.CAUSAL_LM,
                inference_mode=False,
                r=r, lora_alpha=alpha,
                lora_dropout=0.05,
                target_modules=["q_proj","v_proj"],
            )
            self.lm = get_peft_model(self.lm, cfg)
        H = self.lm.config.hidden_size
        self.classifier = nn.Sequential(
            nn.LayerNorm(H),
            nn.Linear(H, H//2),
            nn.ReLU(),
            nn.Linear(H//2, num_labels),
        )

    def forward(self, input_ids, attention_mask, labels=None):
        out = self.lm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )
        hs = out.hidden_states[-1]                    # (B, L, H)
        lengths = attention_mask.sum(dim=1) - 1       # last real index
        last   = hs[torch.arange(len(lengths)), lengths]
        logits = self.classifier(last)
        loss   = F.cross_entropy(logits, labels) if labels is not None else None
        return SequenceClassifierOutput(loss=loss, logits=logits)

from tqdm.notebook import tqdm  # Add this import

def run_eval(model, loader, device):
    model.eval()
    preds, labs = [], []
    with torch.no_grad():
        for b in tqdm(loader, desc="Evaluating", leave=False):  # Wrap the loader with tqdm
            b_input = {k: v.to(device) for k, v in b.items()}
            out = model(**{k: b_input[k] for k in ("input_ids", "attention_mask")})
            p = out.logits.argmax(-1).cpu()
            preds.append(p)
            labs.append(b_input["labels"].cpu())
    preds = torch.cat(preds).numpy()
    labs = torch.cat(labs).numpy()
    return accuracy_score(labs, preds)


In [4]:

# ── LOAD TOKENIZER & MODEL ────────────────────────────────────────────────────
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

model = DifficultyFinetuner(
    model_name, use_lora, lora_r, lora_alpha, num_labels=3
).to(device)

sd = load_safetensors(os.path.join(output_dir, "model.safetensors"), device="cuda:0")
model.load_state_dict(sd)

# ── PREPARE DATALOADERS ───────────────────────────────────────────────────────
train_df = load_difficulty_csvs(csv_dir, "train")
test_df  = load_difficulty_csvs(csv_dir, "test")

train_ds = DifficultyDataset(train_df, tokenizer)
test_ds  = DifficultyDataset(test_df, tokenizer)

train_loader = DataLoader(train_ds, batch_size=batch_size, collate_fn=collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, collate_fn=collate_fn)





In [7]:
# ── RUN & PRINT ───────────────────────────────────────────────────────────────
train_acc = run_eval(model, train_loader, device)
test_acc  = run_eval(model, test_loader,  device)
print(f"▶ Train Accuracy = {train_acc*100:.2f}%")
print(f"▶ Test  Accuracy = {test_acc*100:.2f}%")

Evaluating:   0%|          | 0/932 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/162 [00:00<?, ?it/s]

▶ Train Accuracy = 74.44%
▶ Test  Accuracy = 66.31%


In [8]:
# ── SETTINGS ───────────────────────────────────────────────────────────────────
output_dir = "/n/netscratch/gershman_lab/Lab/amuppidi/reasoning"
csv_base   = "/n/home04/amuppidi/reasoning-scheduling/data/gsm8k_results_with_difficulty"
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
use_lora   = True
lora_r     = 16
lora_alpha = 32
batch_size = 16
device     = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ── LOAD TOKENIZER & MODEL ────────────────────────────────────────────────────
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

model = DifficultyFinetuner(
    model_name=model_name,
    use_lora=use_lora,
    r=lora_r,
    alpha=lora_alpha,
    num_labels=3
).to(device)

# load the finetuned safetensors
from safetensors.torch import load_file as load_safetensors
sd = load_safetensors(os.path.join(output_dir, "model.safetensors"), device="cpu")
model.load_state_dict(sd)
model.to(device)


# ── PREDICTION UTIL ────────────────────────────────────────────────────────────
def get_preds(model, loader):
    model.eval()
    all_preds = []
    with torch.no_grad():
        for batch in loader:
            inp = {k: batch[k].to(device) for k in ("input_ids","attention_mask")}
            logits = model(**inp).logits
            preds = logits.argmax(-1).cpu().numpy().tolist()
            all_preds.extend(preds)
    return all_preds

# invert label map
id2label = {v: k for k, v in DifficultyDataset.label2id.items()}


# ── PROCESS EACH SPLIT & FILE ─────────────────────────────────────────────────
import glob

for split in ["train", "test"]:
    pattern = os.path.join(csv_base, f"gsm8k_Y_{split}_*_with_difficulty.csv")
    for path in sorted(glob.glob(pattern)):
        # 1) load existing CSV
        df = pd.read_csv(path)
        # 2) build dataset & loader
        ds     = DifficultyDataset(df, tokenizer)
        loader = DataLoader(ds, batch_size=batch_size, collate_fn=collate_fn)
        # 3) get raw preds (0/1/2)
        raw_preds = get_preds(model, loader)
        # 4) map to strings
        df["model_predicted_difficulty"] = [id2label[p] for p in raw_preds]
        # 5) overwrite CSV
        df.to_csv(path, index=False)
        print(f"✅ Updated {os.path.basename(path)} with {len(df)} rows")




✅ Updated gsm8k_Y_train_0_with_difficulty.csv with 100 rows
✅ Updated gsm8k_Y_train_10_with_difficulty.csv with 100 rows
✅ Updated gsm8k_Y_train_11_with_difficulty.csv with 100 rows
✅ Updated gsm8k_Y_train_12_with_difficulty.csv with 100 rows
✅ Updated gsm8k_Y_train_13_with_difficulty.csv with 100 rows
✅ Updated gsm8k_Y_train_14_with_difficulty.csv with 100 rows
✅ Updated gsm8k_Y_train_15_with_difficulty.csv with 100 rows
✅ Updated gsm8k_Y_train_16_with_difficulty.csv with 100 rows
✅ Updated gsm8k_Y_train_17_with_difficulty.csv with 100 rows
✅ Updated gsm8k_Y_train_18_with_difficulty.csv with 100 rows
✅ Updated gsm8k_Y_train_19_with_difficulty.csv with 100 rows
✅ Updated gsm8k_Y_train_1_with_difficulty.csv with 100 rows
✅ Updated gsm8k_Y_train_20_with_difficulty.csv with 100 rows
✅ Updated gsm8k_Y_train_21_with_difficulty.csv with 100 rows
✅ Updated gsm8k_Y_train_22_with_difficulty.csv with 100 rows
✅ Updated gsm8k_Y_train_23_with_difficulty.csv with 100 rows
✅ Updated gsm8k_Y_train_24