In [None]:
from datasets import Dataset
from transformers import (
    AutoTokenizer, AutoModelForQuestionAnswering,
    TrainingArguments, Trainer, default_data_collator
)
import json
import numpy as np


In [None]:
!pip install evaluate


Collecting evaluate
  Downloading evaluate-0.4.5-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.5-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.5


In [None]:
!wget https://raw.githubusercontent.com/sajjjadayobi/PersianQA/main/dataset/pqa_train.json
!wget https://raw.githubusercontent.com/sajjjadayobi/PersianQA/main/dataset/pqa_test.json


--2025-08-15 13:32:19--  https://raw.githubusercontent.com/sajjjadayobi/PersianQA/main/dataset/pqa_train.json
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5269074 (5.0M) [text/plain]
Saving to: ‘pqa_train.json’


2025-08-15 13:32:19 (51.5 MB/s) - ‘pqa_train.json’ saved [5269074/5269074]

--2025-08-15 13:32:20--  https://raw.githubusercontent.com/sajjjadayobi/PersianQA/main/dataset/pqa_test.json
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 643506 (628K) [text/plain]
Saving to: ‘pqa_test.json’


2025-08-15 13:32:20 (11.2 MB/s)

In [None]:
from collections import OrderedDict
from pathlib import Path
import json


def c2dict(ds):
    return OrderedDict([('answers', [i['answers'] for i in ds]),
                        ('context', [i['context'] for i in ds]),
                        ('question', [i['question'] for i in ds])])


def read_qa(path):

    ds = []
    with open(Path(path), encoding="utf-8") as f:
        squad = json.load(f)
    for example in squad["data"]:
        title = example.get("title", "").strip()
        for paragraph in example["paragraphs"]:
            for qa in paragraph["qas"]:
                answer_starts = [answer["answer_start"] for answer in qa["answers"]]
                answers = [answer["text"].strip() for answer in qa["answers"]]
                ds.append({
                  "title": title,
                  "context": paragraph["context"].strip(),
                  "question": qa["question"].strip(),
                  "id": qa["id"],
                  "answers": {
                      "answer_start": answer_starts,
                      "text": answers},})
    return ds



if __name__ == "__main__":
    train_ds = read_qa('pqa_train.json')
    test_ds  = read_qa('pqa_test.json')

    # Example
print(train_ds[0])
    # >>> {'title': 'شرکت فولاد مبارکه اصفهان',
    # >>> 'context': 'شرکت فولاد مبارکۀ اصفهان، بزرگ ترین واحد صنعتی خصوصی در ایران و بزرگ ترین مجتمع تولید فولاد در کشور ایران است، که در شرق شهر مبارکه قرار دارد. فولاد مبارکه هم اکنون محرک بسیاری از صنایع بالادستی و پایین دستی است. فولاد مبارکه در ۱۱ دوره جایزۀ ملی تعالی سازمانی و ۶ دوره جایزۀ شرکت دانشی در کشور رتبۀ نخست را بدست آورده است و همچنین این شرکت در سال ۱۳۹۱ برای نخستین بار به عنوان تنها شرکت ایرانی با کسب امتیاز ۶۵۴ تندیس زرین جایزۀ ملی تعالی سازمانی را از آن خود کند. شرکت فولاد مبارکۀ اصفهان در ۲۳ دی ماه ۱۳۷۱ احداث شد و اکنون بزرگ ترین واحدهای صنعتی و بزرگترین مجتمع تولید فولاد در ایران است. این شرکت در زمینی به مساحت ۳۵ کیلومتر مربع در نزدیکی شهر مبارکه و در ۷۵ کیلومتری جنوب غربی شهر اصفهان واقع شده است. مصرف آب این کارخانه در کمترین میزان خود، ۱٫۵٪ از دبی زاینده رود برابر سالانه ۲۳ میلیون متر مکعب در سال است و خود یکی از عوامل کم آبی زاینده رود شناخته می شود.',
    # >>> 'question': 'شرکت فولاد مبارکه در کجا واقع شده است',
    # >>> 'id': 1,
    # >>> 'answers': {'answer_start': [114], 'text': ['در شرق شهر مبارکه']}}

{'title': 'شرکت فولاد مبارکه اصفهان', 'context': 'شرکت فولاد مبارکۀ اصفهان، بزرگ\u200cترین واحد صنعتی خصوصی در ایران و بزرگ\u200cترین مجتمع تولید فولاد در کشور ایران است، که در شرق شهر مبارکه قرار دارد. فولاد مبارکه هم\u200cاکنون محرک بسیاری از صنایع بالادستی و پایین\u200cدستی است. فولاد مبارکه در ۱۱ دوره جایزۀ ملی تعالی سازمانی و ۶ دوره جایزۀ شرکت دانشی در کشور رتبۀ نخست را بدست آورده\u200cاست و همچنین این شرکت در سال ۱۳۹۱ برای نخستین\u200cبار به عنوان تنها شرکت ایرانی با کسب امتیاز ۶۵۴ تندیس زرین جایزۀ ملی تعالی سازمانی را از آن خود کند. شرکت فولاد مبارکۀ اصفهان در ۲۳ دی ماه ۱۳۷۱ احداث شد و اکنون بزرگ\u200cترین واحدهای صنعتی و بزرگترین مجتمع تولید فولاد در ایران است. این شرکت در زمینی به مساحت ۳۵ کیلومتر مربع در نزدیکی شهر مبارکه و در ۷۵ کیلومتری جنوب غربی شهر اصفهان واقع شده\u200cاست. مصرف آب این کارخانه در کمترین میزان خود، ۱٫۵٪ از دبی زاینده\u200cرود برابر سالانه ۲۳ میلیون متر مکعب در سال است و خود یکی از عوامل کم\u200cآبی زاینده\u200cرود شناخته می\u200cشود.', 'question': 'شرکت فو

In [None]:
from sklearn.model_selection import train_test_split

# تقسیم به 80٪ train و 20٪ validation
train_raw, val_raw = train_test_split(train_ds, test_size=0.2, random_state=42)


In [None]:
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from tqdm import tqdm
import tensorflow as tf




model_name = "m3hrdadfi/albert-fa-base-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)


from tqdm import tqdm
import torch


class AnswerPredictor:
  def __init__(self, model, tokenizer, device='cuda', n_best=10, max_length=512, stride=256, no_answer=False):
      self.model = model.eval().to(device)
      self.tokenizer = tokenizer
      self.device = device
      self.max_length = max_length
      self.stride = stride
      self.no_answer = no_answer
      self.n_best = n_best


  def model_predictions(self, questions, contexts, batch_size=1):
      n = len(contexts)
      if n%batch_size!=0:
          raise Exception("batch_size must be divisible by sample length")

      tokens = self.tokenizer(questions, contexts, add_special_tokens=True,
                              return_token_type_ids=True, return_tensors="pt", padding=True,
                              return_offsets_mapping=True, truncation="only_second",
                              max_length=self.max_length, stride=self.stride)

      start_logits, end_logits = [], []
      for i in tqdm(range(0, n-batch_size+1, batch_size)):
          with torch.no_grad():
              out = self.model(tokens['input_ids'][i:i+batch_size].to(self.device),
                          tokens['attention_mask'][i:i+batch_size].to(self.device),
                          tokens['token_type_ids'][i:i+batch_size].to(self.device))

              start_logits.append(out.start_logits)
              end_logits.append(out.end_logits)

      return tokens, torch.stack(start_logits).view(n, -1), torch.stack(end_logits).view(n, -1)


  def __call__(self, questions, contexts, batch_size=1, answer_max_len=100):

      tokens, starts, ends = self.model_predictions(questions, contexts, batch_size=batch_size)
      start_indexes = starts.argsort(dim=-1, descending=True)[:, :self.n_best]
      end_indexes = ends.argsort(dim=-1, descending=True)[:, :self.n_best]
      preds = {}
      for i, (c, q) in enumerate(zip(contexts, questions)):
          min_null_score = (starts[i][0] + ends[i][0]).item() # 0 is CLS Token
          start_context = tokens['input_ids'][i].tolist().index(self.tokenizer.sep_token_id)
          offset = tokens['offset_mapping'][i]
          valid_answers = []
          for start_index in start_indexes[i]:
              if start_index<start_context:
                  continue
              for end_index in end_indexes[i]:
                  if (start_index >= len(offset) or end_index >= len(offset)
                      or offset[start_index] is None or offset[end_index] is None):
                      continue
                  if end_index < start_index or (end_index-start_index+1) > answer_max_len:
                      continue
                  start_char = offset[start_index][0]
                  end_char = offset[end_index][1]
                  valid_answers.append({"score": (starts[i][start_index] + ends[i][end_index]).item(),
                                        "text": c[start_char: end_char]})

          if len(valid_answers) > 0:
              best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
          else:
              best_answer = {"text": "", "score": min_null_score}
          if self.no_answer:
              preds[i] = best_answer if best_answer["score"] >= min_null_score else {"text": "", "score": min_null_score}
          else:
              preds[i] = best_answer
      return preds


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/684 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/1.88M [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/73.1M [00:00<?, ?B/s]

Some weights of AlbertForQuestionAnswering were not initialized from the model checkpoint at m3hrdadfi/albert-fa-base-v2 and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Ans

In [None]:

import os, json, re, string, math, random, numpy as np
from collections import Counter
from pathlib import Path

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler

from datasets import Dataset
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, get_linear_schedule_with_warmup

SEED = 42
def set_seed(seed=SEED):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed); torch.backends.cudnn.deterministic = True
set_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

MODEL_NAME = "m3hrdadfi/albert-fa-base-v2"
MAX_LEN      = 384
DOC_STRIDE   = 160
TRAIN_BATCH  = 6
EVAL_BATCH   = 8
GRAD_ACCUM   = 1
EPOCHS       = 3
BASE_LR      = 3e-5
HEAD_LR      = 1e-4
WEIGHT_DECAY = 0.01
WARMUP_RATIO = 0.1
MAX_GRAD_NORM= 1.0
USE_AMP      = torch.cuda.is_available()
PAD_TO_MAX   = True

ALPHA        = 0.6
MIN_NOANS_F1 = 60.0
GRID_STEPS   = 41

SAVE_DIR     = "./best_albert_fa_qa_ft"
os.makedirs(SAVE_DIR, exist_ok=True)

def read_qa(path):
    ds = []
    with open(Path(path), encoding="utf-8") as f:
        squad = json.load(f)
    for example in squad["data"]:
        title = example.get("title", "").strip()
        for paragraph in example["paragraphs"]:
            ctx = paragraph["context"].strip()
            for qa in paragraph["qas"]:
                ans_starts = [a["answer_start"] for a in qa["answers"]]
                ans_texts  = [a["text"].strip() for a in qa["answers"]]
                ds.append({
                    "title": title,
                    "context": ctx,
                    "question": qa["question"].strip(),
                    "id": qa["id"],
                    "answers": {"answer_start": ans_starts, "text": ans_texts}
                })
    return ds

def to_hf_dataset(obj):
    if isinstance(obj, Dataset): return obj
    if isinstance(obj, list):    return Dataset.from_list(obj)
    if isinstance(obj, dict):    return Dataset.from_dict(obj)
    raise TypeError(type(obj))

def normalize_text(s: str):
    if s is None: return ""
    s = str(s)
    s = s.replace("\u064a", "\u06cc").replace("\u0643", "\u06a9")
    arabic_diacritics = "".join([chr(c) for c in range(0x064B, 0x065F)])
    s = s.translate(str.maketrans("", "", arabic_diacritics))
    s = s.lower()
    s = s.translate(str.maketrans("", "", string.punctuation + "«»“”…–—‎\u200c؛،؟"))
    s = re.sub(r"\s+", " ", s).strip()
    return s

def f1_score(prediction, ground_truth):
    p = normalize_text(prediction).split()
    g = normalize_text(ground_truth).split()
    common = Counter(p) & Counter(g)
    num_same = sum(common.values())
    if not p and not g: return 1.0
    if not p or not g:  return 0.0
    prec = num_same / len(p); rec = num_same / len(g)
    return 0.0 if (prec+rec)==0 else 2*prec*rec/(prec+rec)

def exact_match_score(prediction, ground_truth):
    return 1.0 if normalize_text(prediction)==normalize_text(ground_truth) else 0.0

def metric_max_over_ground_truths(metric_fn, prediction, golds):
    golds = golds or [""]
    return max(metric_fn(prediction, g) for g in golds)

def prepare_train_features(examples):
    enc = tokenizer(
        examples["question"], examples["context"],
        truncation="only_second", max_length=MAX_LEN, stride=DOC_STRIDE,
        return_overflowing_tokens=True, return_offsets_mapping=True,
        padding="max_length" if PAD_TO_MAX else False,
    )
    sample_mapping = enc.pop("overflow_to_sample_mapping")
    offset_mapping = enc["offset_mapping"]

    start_positions, end_positions = [], []
    for i in range(len(offset_mapping)):
        sample_idx = sample_mapping[i]
        answers = examples["answers"][sample_idx]
        input_ids    = enc["input_ids"][i]
        cls_index    = input_ids.index(tokenizer.cls_token_id)
        sequence_ids = enc.sequence_ids(i)
        offsets      = offset_mapping[i]

        if len(answers["answer_start"]) == 0:
            start_positions.append(cls_index); end_positions.append(cls_index)
            continue

        start_char = answers["answer_start"][0]
        end_char   = start_char + len(answers["text"][0])

        idx = 0
        while idx < len(sequence_ids) and sequence_ids[idx] != 1: idx += 1
        context_start = idx
        while idx < len(sequence_ids) and sequence_ids[idx] == 1: idx += 1
        context_end = idx - 1

        if not (offsets[context_start][0] <= start_char and offsets[context_end][1] >= end_char):
            start_positions.append(cls_index); end_positions.append(cls_index)
        else:
            st = context_start
            while st <= context_end and offsets[st][0] <= start_char:
                if offsets[st][1] > start_char: break
                st += 1
            en = context_end
            while en >= context_start and offsets[en][1] >= end_char:
                if offsets[en][0] < end_char: break
                en -= 1
            start_positions.append(st); end_positions.append(en)

    enc.pop("offset_mapping")
    enc["start_positions"] = start_positions
    enc["end_positions"]   = end_positions
    return enc

def prepare_eval_features(examples, indices):
    enc = tokenizer(
        examples["question"], examples["context"],
        truncation="only_second", max_length=MAX_LEN, stride=DOC_STRIDE,
        return_overflowing_tokens=True, return_offsets_mapping=True,
        padding="max_length" if PAD_TO_MAX else False,
    )
    sample_mapping = enc.pop("overflow_to_sample_mapping")
    offsets_list = enc["offset_mapping"]

    example_index, new_offsets = [], []
    for i in range(len(offsets_list)):
        si = sample_mapping[i]
        example_index.append(indices[si])
        seq_ids = enc.sequence_ids(i)
        masked  = [o if seq_ids[k]==1 else None for k,o in enumerate(offsets_list[i])]
        new_offsets.append(masked)

    enc["example_index"]  = example_index
    enc["offset_mapping"] = new_offsets
    return enc

def collate_train(batch):
    return {
        "input_ids": torch.tensor([b["input_ids"] for b in batch]),
        "attention_mask": torch.tensor([b["attention_mask"] for b in batch]),
        "start_positions": torch.tensor([b["start_positions"] for b in batch]),
        "end_positions": torch.tensor([b["end_positions"] for b in batch]),
    }

def collate_eval(batch):
    return {
        "input_ids": torch.tensor([b["input_ids"] for b in batch]),
        "attention_mask": torch.tensor([b["attention_mask"] for b in batch]),
        "example_index": [b["example_index"] for b in batch],
        "offset_mapping": [b["offset_mapping"] for b in batch],
    }

def pick_best_span_for_feature(start_logits, end_logits, offset_mapping, input_ids, tokenizer,
                               max_answer_len=60, n_best_size=50):
    s = np.asarray(start_logits); e = np.asarray(end_logits)
    ids = np.asarray(input_ids)
    cls_positions = np.where(ids == tokenizer.cls_token_id)[0]
    cls_idx = int(cls_positions[0]) if len(cls_positions) else 0

    s_scores = F.log_softmax(torch.tensor(s), dim=-1).numpy()
    e_scores = F.log_softmax(torch.tensor(e), dim=-1).numpy()
    null_score = float(s_scores[cls_idx] + e_scores[cls_idx])

    context_mask = [om is not None for om in offset_mapping]
    start_indexes = np.argsort(s_scores)[-n_best_size:][::-1]
    end_indexes   = np.argsort(e_scores)[-n_best_size:][::-1]

    best_span, best_score = None, -1e30
    for si in start_indexes:
        if si >= len(context_mask) or not context_mask[si]: continue
        for ei in end_indexes:
            if ei >= len(context_mask) or not context_mask[ei]: continue
            if ei < si: continue
            if (ei - si + 1) > max_answer_len: continue
            score = float(s_scores[si] + e_scores[ei])
            if score > best_score:
                cs = offset_mapping[si][0]; ce = offset_mapping[ei][1]
                best_score, best_span = score, (cs, ce)
    return best_score, best_span, null_score

@torch.no_grad()
def collect_scores_per_example(model, dataloader, features_ds, raw_ds, tokenizer,
                               device="cuda", max_answer_len=60, n_best_size=50):
    raw_list = raw_ds.to_list() if isinstance(raw_ds, Dataset) else list(raw_ds)
    model.eval().to(device)
    per_example = {}
    for batch in tqdm(dataloader, desc="Infer (collect scores)"):
        input_ids      = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        out = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        s = out.start_logits.detach().cpu().numpy()
        e = out.end_logits.detach().cpu().numpy()
        offsets = batch["offset_mapping"]; exidx = batch["example_index"]
        for i, ex_idx in enumerate(exidx):
            best_span_score, best_span_chars, null_score = pick_best_span_for_feature(
                s[i], e[i], offsets[i], batch["input_ids"][i].cpu().numpy(),
                tokenizer, max_answer_len=max_answer_len, n_best_size=n_best_size
            )
            rec = per_example.get(ex_idx)
            if rec is None:
                per_example[ex_idx] = {"best_span_score": best_span_score, "null_score": null_score, "span": best_span_chars}
            else:
                if best_span_score > rec["best_span_score"]:
                    rec["best_span_score"] = best_span_score; rec["span"] = best_span_chars
                if null_score > rec["null_score"]:
                    rec["null_score"] = null_score
    for ex_idx, rec in per_example.items():
        if rec["span"] is None: rec["pred_text"] = ""
        else:
            cs, ce = rec["span"]; ctx = raw_list[ex_idx]["context"]
            cs = max(0, min(len(ctx), cs)); ce = max(0, min(len(ctx), ce))
            rec["pred_text"] = ctx[cs:ce]
    return per_example, raw_list

def evaluate_with_threshold_2d(per_example, raw_list, t_diff, s_min):
    total = len(per_example)
    em=f1=0.0
    has_total=no_total=has_em=has_f1=no_em=no_f1=0.0
    for ex_idx, rec in per_example.items():
        ex = raw_list[ex_idx]
        golds = ex.get("answers",{}).get("text",[])
        is_has = len(golds)>0
        diff = rec["null_score"] - rec["best_span_score"]
        pred = "" if (diff>t_diff or rec["best_span_score"]<s_min or rec["span"] is None) else rec["pred_text"]
        if not is_has: golds = [""]
        em_i = metric_max_over_ground_truths(exact_match_score, pred, golds)
        f1_i = metric_max_over_ground_truths(f1_score,      pred, golds)
        em += em_i; f1 += f1_i
        if is_has: has_total+=1; has_em+=em_i; has_f1+=f1_i
        else:      no_total +=1;  no_em +=em_i;  no_f1 +=f1_i
    return {
        "exact":100*em/total, "f1":100*f1/total, "total":total,
        "HasAns_exact":100*has_em/max(1,has_total), "HasAns_f1":100*has_f1/max(1,has_total), "HasAns_total":has_total,
        "NoAns_exact":100*no_em/max(1,no_total),   "NoAns_f1":100*no_f1/max(1,no_total),   "NoAns_total":no_total,
        "thresh_diff":float(t_diff), "s_min":float(s_min),
    }

def tune_threshold_2d_balanced(per_example, raw_list, grid=41, alpha=0.6, use_overall=False, min_noans_f1=60.0):
    diffs = [rec["null_score"] - rec["best_span_score"] for rec in per_example.values()]
    spans = [rec["best_span_score"] for rec in per_example.values()]
    if not diffs: return (0.0, -1e9), None
    d_lo, d_hi = np.percentile(diffs,1)-2.0, np.percentile(diffs,99)+2.0
    s_lo, s_hi = np.percentile(spans,1)-2.0, np.percentile(spans,99)+2.0

    best=None; best_pair=(0.0,-1e9)
    def score_of(res):
        return res["f1"] if use_overall else (alpha*res["HasAns_f1"] + (1-alpha)*res["NoAns_f1"])

    for td in np.linspace(d_lo, d_hi, grid):
        for smin in np.linspace(s_lo, s_hi, grid):
            res = evaluate_with_threshold_2d(per_example, raw_list, td, smin)
            if res["NoAns_f1"] < min_noans_f1:
                continue
            if best is None or score_of(res) > score_of(best):
                best, best_pair = res, (float(td), float(smin))

    if best is None:
        for td in np.linspace(d_lo, d_hi, grid):
            for smin in np.linspace(s_lo, s_hi, grid):
                res = evaluate_with_threshold_2d(per_example, raw_list, td, smin)
                if best is None or score_of(res) > score_of(best):
                    best, best_pair = res, (float(td), float(smin))

    best["best_thresh_diff"], best["best_smin"] = best_pair[0], best_pair[1]
    return best_pair, best

train_all = read_qa("pqa_train.json")
test_all  = read_qa("pqa_test.json")
train_raw, val_raw = train_test_split(train_all, test_size=0.2, random_state=SEED, shuffle=True)

train_ds_raw = to_hf_dataset(train_raw)
val_ds_raw   = to_hf_dataset(val_raw)
test_ds_raw  = to_hf_dataset(test_all)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
model     = AutoModelForQuestionAnswering.from_pretrained(
    MODEL_NAME,
    ignore_mismatched_sizes=True  # هد QA تازه‌س؛ امن‌تره
).to(device)

if getattr(model, "supports_gradient_checkpointing", False):
    model.gradient_checkpointing_enable()
if hasattr(model.config, "use_cache"):
    model.config.use_cache = False

train_features = train_ds_raw.map(
    prepare_train_features, batched=True, remove_columns=train_ds_raw.column_names
)
val_features = val_ds_raw.map(
    prepare_eval_features, batched=True, with_indices=True, remove_columns=val_ds_raw.column_names
)
test_features = test_ds_raw.map(
    prepare_eval_features, batched=True, with_indices=True, remove_columns=test_ds_raw.column_names
)

train_loader = DataLoader(train_features, batch_size=TRAIN_BATCH, shuffle=True,  collate_fn=collate_train, num_workers=0)
val_loader   = DataLoader(val_features,   batch_size=EVAL_BATCH,  shuffle=False, collate_fn=collate_eval,  num_workers=0)
test_loader  = DataLoader(test_features,  batch_size=EVAL_BATCH,  shuffle=False, collate_fn=collate_eval,  num_workers=0)

head_names = set(n for n,p in model.named_parameters() if "qa_outputs" in n)
base_params = [p for n,p in model.named_parameters() if n not in head_names]
head_params = [p for n,p in model.named_parameters() if n in head_names]

optimizer = AdamW([
    {"params": base_params, "lr": BASE_LR, "weight_decay": WEIGHT_DECAY},
    {"params": head_params, "lr": HEAD_LR, "weight_decay": 0.0},
])

num_update_steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM)
t_total   = EPOCHS * num_update_steps_per_epoch
num_warm  = int(WARMUP_RATIO * t_total)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warm, num_training_steps=t_total)
scaler    = GradScaler(enabled=USE_AMP)

best_state = {"score": -1e9, "epoch": 0, "diff": 0.0, "smin": -1e9, "dev_metrics": None}

step_id = 0
for epoch in range(1, EPOCHS+1):
    model.train()
    running = 0.0
    for step, batch in enumerate(train_loader, 1):
        input_ids      = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        start_positions= batch["start_positions"].to(device)
        end_positions  = batch["end_positions"].to(device)

        with autocast(enabled=USE_AMP):
            out = model(
                input_ids=input_ids, attention_mask=attention_mask,
                start_positions=start_positions, end_positions=end_positions,
                return_dict=True
            )
            loss = out.loss / GRAD_ACCUM

        scaler.scale(loss).backward()
        if (step % GRAD_ACCUM) == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
            scaler.step(optimizer); scaler.update()
            scheduler.step()
            optimizer.zero_grad(set_to_none=True)
        running += loss.item(); step_id += 1

        if step % 100 == 0:
            print(f"[Epoch {epoch} | Step {step}/{len(train_loader)}] loss={running/step:.4f}")

    print(f"\n[Epoch {epoch}] Train loss: {running/max(1,len(train_loader)):.4f}")

    model.eval()
    with torch.no_grad():
        val_per_ex, val_raw_list = collect_scores_per_example(
            model, val_loader, val_features, val_ds_raw, tokenizer,
            device=device, max_answer_len=60, n_best_size=50
        )
        (best_diff, best_smin), dev_metrics = tune_threshold_2d_balanced(
            val_per_ex, val_raw_list, grid=GRID_STEPS, alpha=ALPHA, use_overall=False, min_noans_f1=MIN_NOANS_F1
        )

    balanced_score = ALPHA*dev_metrics["HasAns_f1"] + (1-ALPHA)*dev_metrics["NoAns_f1"]
    print("Dev metrics (balanced tuning):")
    for k in ["exact","f1","HasAns_exact","HasAns_f1","NoAns_exact","NoAns_f1","total","HasAns_total","NoAns_total"]:
        print(f"  {k}: {dev_metrics[k]:.2f}")
    print(f"  best_diff={best_diff:.4f} | best_smin={best_smin:.4f} | balanced={balanced_score:.2f}")

    if balanced_score > best_state["score"] + 1e-6:
        best_state = {
            "score": float(balanced_score),
            "epoch": epoch,
            "diff": float(best_diff),
            "smin": float(best_smin),
            "dev_metrics": dev_metrics
        }
        for f in os.listdir(SAVE_DIR):
            p = os.path.join(SAVE_DIR, f)
            try:
                if os.path.isdir(p):
                    for ff in os.listdir(p): os.remove(os.path.join(p, ff))
                    os.rmdir(p)
                else:
                    os.remove(p)
            except Exception:
                pass
        ckpt_dir = os.path.join(SAVE_DIR, f"epoch_{epoch}")
        os.makedirs(ckpt_dir, exist_ok=True)
        model.save_pretrained(ckpt_dir); tokenizer.save_pretrained(ckpt_dir)
        with open(os.path.join(SAVE_DIR, "best_thresholds.json"), "w", encoding="utf-8") as f:
            json.dump({"best_diff": best_diff, "best_smin": best_smin, "epoch": epoch, "balanced": balanced_score}, f, ensure_ascii=False, indent=2)
        print(f"[Epoch {epoch}] ✅ New best saved at: {ckpt_dir}")
    else:
        print(f"[Epoch {epoch}] No improvement. Best balanced={best_state['score']:.2f} (epoch {best_state['epoch']})")

print("\n=== Training finished ===")
print("Best state:", best_state)

best_epoch = best_state["epoch"]
best_diff  = best_state["diff"]
best_smin  = best_state["smin"]
best_path  = os.path.join(SAVE_DIR, f"epoch_{best_epoch}")

best_model = AutoModelForQuestionAnswering.from_pretrained(best_path).to(device)
best_model.eval()

test_per_ex, test_raw_list = collect_scores_per_example(
    best_model, test_loader, test_features, test_ds_raw, tokenizer,
    device=device, max_answer_len=60, n_best_size=50
)
test_metrics = evaluate_with_threshold_2d(test_per_ex, test_raw_list, best_diff, best_smin)

print("\n=== TEST (balanced thresholds from best epoch) ===")
for k in ["exact","f1","HasAns_exact","HasAns_f1","NoAns_exact","NoAns_f1","total","HasAns_total","NoAns_total"]:
    print(f"{k}: {test_metrics[k]:.2f}")

def get_example_id(ex):
    for k in ["id","qas_id","guid","example_id"]:
        if k in ex: return k, ex[k]
    return None, None

preds = {}
for ex_idx, rec in test_per_ex.items():
    diff = rec["null_score"] - rec["best_span_score"]
    pred = "" if (diff > best_diff or rec["best_span_score"] < best_smin or rec["span"] is None) else rec["pred_text"]
    key_name, key_val = get_example_id(test_raw_list[ex_idx])
    ex_id = str(key_val) if key_val is not None else str(ex_idx)
    preds[ex_id] = pred

with open("predictions_test.json", "w", encoding="utf-8") as f:
    json.dump(preds, f, ensure_ascii=False, indent=2)
print("\nSaved: predictions_test.json")


Device: cuda


Some weights of AlbertForQuestionAnswering were not initialized from the model checkpoint at m3hrdadfi/albert-fa-base-v2 and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/7206 [00:00<?, ? examples/s]

Map:   0%|          | 0/1802 [00:00<?, ? examples/s]

Map:   0%|          | 0/930 [00:00<?, ? examples/s]

  scaler    = GradScaler(enabled=USE_AMP)
  with autocast(enabled=USE_AMP):


[Epoch 1 | Step 100/1214] loss=5.5507
[Epoch 1 | Step 200/1214] loss=4.8585
[Epoch 1 | Step 300/1214] loss=4.4591
[Epoch 1 | Step 400/1214] loss=4.1852
[Epoch 1 | Step 500/1214] loss=3.9033
[Epoch 1 | Step 600/1214] loss=3.6919
[Epoch 1 | Step 700/1214] loss=3.5068
[Epoch 1 | Step 800/1214] loss=3.3720
[Epoch 1 | Step 900/1214] loss=3.2567
[Epoch 1 | Step 1000/1214] loss=3.1655
[Epoch 1 | Step 1100/1214] loss=3.0705
[Epoch 1 | Step 1200/1214] loss=2.9986

[Epoch 1] Train loss: 2.9888


Infer (collect scores): 100%|██████████| 228/228 [00:51<00:00,  4.42it/s]


Dev metrics (balanced tuning):
  exact: 37.90
  f1: 55.06
  HasAns_exact: 21.84
  HasAns_f1: 46.39
  NoAns_exact: 75.14
  NoAns_f1: 75.14
  total: 1802.00
  HasAns_total: 1259.00
  NoAns_total: 543.00
  best_diff=-0.0239 | best_smin=-13.4948 | balanced=57.89
[Epoch 1] ✅ New best saved at: ./best_albert_fa_qa_ft/epoch_1
[Epoch 2 | Step 100/1214] loss=1.6086
[Epoch 2 | Step 200/1214] loss=1.5400
[Epoch 2 | Step 300/1214] loss=1.5233
[Epoch 2 | Step 400/1214] loss=1.5500
[Epoch 2 | Step 500/1214] loss=1.5608
[Epoch 2 | Step 600/1214] loss=1.5593
[Epoch 2 | Step 700/1214] loss=1.5594
[Epoch 2 | Step 800/1214] loss=1.5662
[Epoch 2 | Step 900/1214] loss=1.5753
[Epoch 2 | Step 1000/1214] loss=1.5753
[Epoch 2 | Step 1100/1214] loss=1.5669
[Epoch 2 | Step 1200/1214] loss=1.5653

[Epoch 2] Train loss: 1.5652


Infer (collect scores): 100%|██████████| 228/228 [00:58<00:00,  3.88it/s]


Dev metrics (balanced tuning):
  exact: 40.90
  f1: 57.21
  HasAns_exact: 23.03
  HasAns_f1: 46.38
  NoAns_exact: 82.32
  NoAns_f1: 82.32
  total: 1802.00
  HasAns_total: 1259.00
  NoAns_total: 543.00
  best_diff=-2.0056 | best_smin=-14.6039 | balanced=60.76
[Epoch 2] ✅ New best saved at: ./best_albert_fa_qa_ft/epoch_2
[Epoch 3 | Step 100/1214] loss=0.8102
[Epoch 3 | Step 200/1214] loss=0.8168
[Epoch 3 | Step 300/1214] loss=0.8143
[Epoch 3 | Step 400/1214] loss=0.8259
[Epoch 3 | Step 500/1214] loss=0.8262
[Epoch 3 | Step 600/1214] loss=0.8214
[Epoch 3 | Step 700/1214] loss=0.8142
[Epoch 3 | Step 800/1214] loss=0.8105
[Epoch 3 | Step 900/1214] loss=0.8084
[Epoch 3 | Step 1000/1214] loss=0.8094
[Epoch 3 | Step 1100/1214] loss=0.8115
[Epoch 3 | Step 1200/1214] loss=0.8207

[Epoch 3] Train loss: 0.8201


Infer (collect scores): 100%|██████████| 228/228 [00:50<00:00,  4.51it/s]


Dev metrics (balanced tuning):
  exact: 40.40
  f1: 58.82
  HasAns_exact: 23.19
  HasAns_f1: 49.55
  NoAns_exact: 80.29
  NoAns_f1: 80.29
  total: 1802.00
  HasAns_total: 1259.00
  NoAns_total: 543.00
  best_diff=-0.0443 | best_smin=-3.9205 | balanced=61.85
[Epoch 3] ✅ New best saved at: ./best_albert_fa_qa_ft/epoch_3

=== Training finished ===
Best state: {'score': 61.8501787221662, 'epoch': 3, 'diff': -0.044261203094269064, 'smin': -3.920539168134331, 'dev_metrics': {'exact': 40.399556048834626, 'f1': 58.817040869066304, 'total': 1802, 'HasAns_exact': 23.193010325655283, 'HasAns_f1': 49.553858336820895, 'HasAns_total': 1259.0, 'NoAns_exact': 80.29465930018416, 'NoAns_f1': 80.29465930018416, 'NoAns_total': 543.0, 'thresh_diff': -0.044261203094269064, 's_min': -3.920539168134331, 'best_thresh_diff': -0.044261203094269064, 'best_smin': -3.920539168134331}}


Infer (collect scores): 100%|██████████| 117/117 [00:24<00:00,  4.83it/s]



=== TEST (balanced thresholds from best epoch) ===
exact: 49.78
f1: 67.96
HasAns_exact: 38.86
HasAns_f1: 64.82
NoAns_exact: 75.27
NoAns_f1: 75.27
total: 930.00
HasAns_total: 651.00
NoAns_total: 279.00

Saved: predictions_test.json
