In [2]:
#ba nemodar loss,tsne#fusion model embeding 256
# =========================
# Multi-Stage Vision/Text → Late Fusion Classifier (with EMA) + Metrics/Plots/TSNE/CM
# =========================

# -------- Libraries -------- 
import os, re, random, copy
import math
import torch, torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms, models
from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup
from torch.optim import AdamW
import pandas as pd
import numpy as np
import easyocr
from PIL import Image
import torch.nn.functional as F

# extra for metrics/plots
from sklearn.metrics import f1_score, confusion_matrix, recall_score
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

# -------- Parameters --------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 5

# paths
train_dir = "E:/book-clear-mix/BookCovers-split-5-c/train"
test_dir  = "E:/book-clear-mix/BookCovers-split-5-c/test"
ocr_cache_dir = "ocr_cache"
os.makedirs(ocr_cache_dir, exist_ok=True)
train_csv = os.path.join(ocr_cache_dir, "train.csv")
test_csv  = os.path.join(ocr_cache_dir, "test.csv")

# results root
RESULTS_DIR = "results"
os.makedirs(RESULTS_DIR, exist_ok=True)

# common
seed = 42
torch.manual_seed(seed); random.seed(seed); np.random.seed(seed)


import torch, numpy as np, random, os

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # برای جلوگیری از non-deterministic بودن cuDNN
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
# ======= Helpers: text cleaning / OCR CSV =======
def clean_ocr_text(text: str) -> str:
    text = (text or "").lower()
    text = re.sub(r"[^a-z0-9\sآ-ی]", " ", text)
    return " ".join(text.split()).strip()

def build_ocr_csv(imagefolder_dir, csv_path, use_gpu=True):
    if os.path.exists(csv_path):
        print(f"[!] CSV exists: {csv_path}")
        return
    print(f"[+] Building OCR CSV for: {imagefolder_dir}")
    reader = easyocr.Reader(['en','fa'], gpu=use_gpu)
    img_folder = datasets.ImageFolder(imagefolder_dir)
    rows = []
    for path, label in img_folder.samples:
        try:
            text_list = reader.readtext(path, detail=0)
            text = " ".join(text_list) if len(text_list) > 0 else "empty"
        except Exception as e:
            print(f"[OCR ERROR] {path}: {e}")
            text = "empty"
        rows.append({"path": path, "label": label, "text": text})
    df = pd.DataFrame(rows)
    df.to_csv(csv_path, index=False, encoding="utf-8")
    print(f"[+] OCR CSV saved: {csv_path} (rows={len(df)})")

# build OCR caches (idempotent)
build_ocr_csv(train_dir, train_csv)
build_ocr_csv(test_dir,  test_csv)

# ======= Utils: metrics, plotting, saving =======
def ensure_dir(p):
    os.makedirs(p, exist_ok=True)
    return p

def save_fig_dual(fig, path_no_ext):
    fig.savefig(path_no_ext + ".png", dpi=300, bbox_inches="tight")
    fig.savefig(path_no_ext + ".pdf", dpi=300, bbox_inches="tight")
    plt.close(fig)

from sklearn.metrics import precision_recall_fscore_support

def compute_metrics(y_true, y_pred, num_classes):
    y_true = np.asarray(y_true); y_pred = np.asarray(y_pred)
    acc = (y_true == y_pred).mean() if len(y_true)>0 else 0.0
    f1  = f1_score(y_true, y_pred, average="macro") if len(y_true)>0 else 0.0
    sens = recall_score(y_true, y_pred, average="macro", zero_division=0) if len(y_true)>0 else 0.0
    cm = confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))
    cm_sum = cm.sum()

    # specificity (macro)
    specs = []
    for i in range(num_classes):
        tp = cm[i, i]
        fp = cm[:, i].sum() - tp
        fn = cm[i, :].sum() - tp
        tn = cm_sum - (tp + fp + fn)
        spec_i = tn / (tn + fp + 1e-8)
        specs.append(spec_i)
    spec = float(np.mean(specs)) if len(specs)>0 else 0.0
    gmean = float(np.sqrt(max(sens,0.0) * max(spec,0.0)))

    # ---------- per-class metrics ----------
    per_class = []
    for i in range(num_classes):
        tp = cm[i, i]
        fp = cm[:, i].sum() - tp
        fn = cm[i, :].sum() - tp
        tn = cm_sum - (tp + fp + fn)

        acc_i  = (tp + tn) / (tp+tn+fp+fn+1e-8)
        sens_i = tp / (tp+fn+1e-8)      # recall
        spec_i = tn / (tn+fp+1e-8)
        f1_i   = 2*tp / (2*tp+fp+fn+1e-8)
        gmean_i = np.sqrt(sens_i * spec_i)

        per_class.append({
            "class": i,
            "acc": acc_i,
            "f1": f1_i,
            "sensitivity": sens_i,
            "specificity": spec_i,
            "gmean": gmean_i
        })

    return acc, f1, sens, spec, gmean, cm, per_class



def plot_confusion(cm, title="Confusion Matrix"):
    fig, ax = plt.subplots(figsize=(6,5))
    im = ax.imshow(cm, interpolation='nearest')
    ax.set_title(title)
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")
    tick_marks = np.arange(cm.shape[0])
    ax.set_xticks(tick_marks); ax.set_yticks(tick_marks)
    # annotate
    thresh = cm.max() / 2. if cm.size>0 else 0.5
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, int(cm[i, j]), ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black", fontsize=8)
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    return fig

def plot_curve(epochs, train_vals, test_vals, title, ylabel):
    fig, ax = plt.subplots(figsize=(7,4))
    ax.plot(epochs, train_vals, label="train")
    ax.plot(epochs, test_vals, label="test")
    ax.set_title(title); ax.set_xlabel("epoch"); ax.set_ylabel(ylabel)
    ax.legend()
    return fig

def run_tsne_and_plot(features, labels, title, out_no_ext, max_points=2000, perplexity=30, use_all=False):
    """
    Run t-SNE and plot results.

    Args:
        features: لیست یا آرایه ویژگی‌ها
        labels: لیست یا آرایه لیبل‌ها
        title: عنوان نمودار
        out_no_ext: مسیر فایل بدون پسوند برای ذخیره
        max_points: حداکثر تعداد نمونه‌ها برای سرعت
        perplexity: پارامتر perplexity t-SNE
        use_all: اگر True شود، تمام داده‌ها استفاده می‌شوند، در غیر اینصورت حداکثر max_points نمونه
    """
    X = np.asarray(features)
    y = np.asarray(labels)
    if len(X) == 0:
        return

    # زیرنمونه‌گیری فقط وقتی use_all=False
    if not use_all and len(X) > max_points:
        idx = np.random.RandomState(0).choice(len(X), size=max_points, replace=False)
        X = X[idx]
        y = y[idx]

    perplexity = min(perplexity, max(5, len(X)//3))
    tsne = TSNE(n_components=2, perplexity=perplexity, random_state=0, init='pca', learning_rate='auto')
    emb = tsne.fit_transform(X)

    fig, ax = plt.subplots(figsize=(6,5))
    sc = ax.scatter(emb[:,0], emb[:,1], c=y, s=10)
    ax.set_title(title)
    ax.set_xticks([]); ax.set_yticks([])
    save_fig_dual(fig, out_no_ext)


# ======= EMA =======
class EMA:
    def __init__(self, model, decay=0.98):
        self.decay = decay
        self.shadow = {n: p.clone().detach() for n,p in model.named_parameters() if p.requires_grad}
        self.backup = {}

    @torch.no_grad()
    def update(self, model):
        for n, p in model.named_parameters():
            if p.requires_grad and n in self.shadow:
                self.shadow[n].mul_(self.decay).add_(p.data, alpha=1-self.decay)

    def apply_shadow(self, model):
        self.backup = {}
        for n, p in model.named_parameters():
            if p.requires_grad and n in self.shadow:
                self.backup[n] = p.data.clone()
                p.data.copy_(self.shadow[n])

    def restore(self, model):
        for n, p in model.named_parameters():
            if p.requires_grad and n in self.backup:
                p.data.copy_(self.backup[n])
        self.backup = {}

# ======= Vision Stage =======
# Transforms
train_transform_img = transforms.Compose([  # # 为什么弄个twice
                        transforms.Resize(224),
                        #transforms.RandomHorizontalFlip(p=0.5),
                        #transforms.RandomRotation(10),
                        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])     ])
test_transform_img = transforms.Compose([
                        transforms.Resize(224),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])
                    ])                         
   

# Loaders
batch_size_img = 32
train_dataset_img = datasets.ImageFolder(train_dir, transform=train_transform_img)
test_dataset_img  = datasets.ImageFolder(test_dir, transform=test_transform_img)
train_loader_img = DataLoader(train_dataset_img, batch_size=batch_size_img, shuffle=True, num_workers=2, pin_memory=True)
test_loader_img  = DataLoader(test_dataset_img,  batch_size=batch_size_img, shuffle=False, num_workers=2, pin_memory=True)

# Vision model: backbone + head
class VisionBackbone(nn.Module):
    def __init__(self, out_dim=1024):
        super().__init__()
        m = models.densenet121(pretrained=True)
        num_ftrs = m.classifier.in_features  # 1024
        m.classifier = nn.Identity()
        self.backbone = m
        self.out_dim = num_ftrs
        self.post = nn.Identity()  # reserved (e.g., LayerNorm)
    def forward(self, x):
        feats = self.backbone(x)               # [B, 1024]
        return self.post(feats)                # [B, 1024]

class VisionClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = VisionBackbone()
        self.head = nn.Linear(self.backbone.out_dim, num_classes)
    def forward(self, x):
        feats = self.backbone(x)
        return self.head(feats)

@torch.no_grad()
def evaluate_img_full(model, loader, criterion):
    model.eval()
    total_loss = 0.0; n = 0
    y_true, y_pred = [], []
    all_feats = []
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        feats = model.backbone(imgs)
        logits = model.head(feats)
        loss = criterion(logits, labels)
        total_loss += float(loss.item()) * labels.size(0); n += labels.size(0)
        preds = logits.argmax(1)
        y_true.extend(labels.cpu().numpy().tolist())
        y_pred.extend(preds.cpu().numpy().tolist())
        all_feats.append(feats.cpu().numpy())
    avg_loss = total_loss / max(n, 1)
    feats_np = np.vstack(all_feats) if len(all_feats)>0 else np.zeros((0, model.backbone.out_dim))
    return avg_loss, y_true, y_pred, feats_np

def train_vision(num_epochs=40, lr=1e-4, weight_decay=0.0):
    stage = "vision"
    stage_dir = ensure_dir(os.path.join(RESULTS_DIR, stage))
    model = VisionClassifier(num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    ema = EMA(model, decay=0.98)
    best_acc = 0.0
    best_state = None

    # history holders
    train_losses, test_losses = [], []
    train_accs,   test_accs   = [], []
    train_f1s,    test_f1s    = [], []
    train_sens,   test_sens   = [], []
    train_spec,   test_spec   = [], []
    train_gmean,  test_gmean  = [], []
    metrics_rows = []

    for epoch in range(num_epochs):
        model.train()
        running, correct, total = 0.0, 0, 0
        for imgs, labels in train_loader_img:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            logits = model(imgs)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            ema.update(model)

            running += loss.item()*imgs.size(0)
            correct += (logits.argmax(1)==labels).sum().item()
            total += labels.size(0)

        train_loss_epoch = running/max(total,1)

        # eval with EMA weights on train+test (for metrics)
        ema.apply_shadow(model)
        tr_loss, tr_y, tr_p, tr_feats = evaluate_img_full(model, train_loader_img, criterion)
        te_loss, te_y, te_p, te_feats = evaluate_img_full(model, test_loader_img,  criterion)
        ema.restore(model)

        tr_acc, tr_f1, tr_sens, tr_spec, tr_gm, tr_cm,tr_pc  = compute_metrics(tr_y, tr_p, num_classes)
        te_acc, te_f1, te_sens, te_spec, te_gm, te_cm,te_pc  = compute_metrics(te_y, te_p, num_classes)

        # keep histories (use eval losses for consistency)
        train_losses.append(tr_loss)
        test_losses.append(te_loss)
        train_accs.append(tr_acc); test_accs.append(te_acc)
        train_f1s.append(tr_f1);   test_f1s.append(te_f1)
        train_sens.append(tr_sens);test_sens.append(te_sens)
        train_spec.append(tr_spec);test_spec.append(te_spec)
        train_gmean.append(tr_gm); test_gmean.append(te_gm)

        metrics_rows.append([epoch+1,"train",tr_loss,tr_acc,tr_f1,tr_sens,tr_spec,tr_gm])
        metrics_rows.append([epoch+1,"test", te_loss,te_acc,te_f1,te_sens,te_spec,te_gm])

        print(f"[Vision] Epoch {epoch+1}/{num_epochs} - TrainLoss {train_loss_epoch:.4f} - TrainAcc(EvalEMA) {tr_acc:.4f} - TestAcc(EvalEMA) {te_acc:.4f}")

        if te_acc>best_acc:
            best_acc = te_acc
            best_state = copy.deepcopy(ema.shadow)
            # محاسبه متریک‌های per-class روی تست
            _, _, _, _, _, _, te_pc = compute_metrics(te_y, te_p, num_classes)

            # ذخیره CSV
            rows = []
            for pc in te_pc:
                rows.append([stage, pc["class"], pc["acc"], pc["f1"],
                            pc["sensitivity"], pc["specificity"], pc["gmean"]])
            df = pd.DataFrame(rows, columns=["stage","class","acc","f1","sensitivity","specificity","gmean"])
            df.to_csv(os.path.join(stage_dir,"best_per_class.csv"), index=False)





    # export final model with best EMA
    final_model = VisionClassifier(num_classes).to(device)
    with torch.no_grad():
        for n, p in final_model.named_parameters():
            if p.requires_grad and n in best_state:
                p.data.copy_(best_state[n].to(p.data.device))
    os.makedirs("checkpoints", exist_ok=True)
    torch.save(final_model.state_dict(), "checkpoints/vision_best_ema.pt")
    print(f"[Vision] Best TestAcc(EMA): {best_acc:.4f}  | saved to checkpoints/vision_best_ema.pt")

    # ---- Save epoch CSV and curves ----
    df = pd.DataFrame(metrics_rows, columns=["epoch","split","loss","acc","f1","sensitivity","specificity","gmean"])
    df.to_csv(os.path.join(stage_dir,"metrics.csv"), index=False)

    epochs = list(range(1, len(train_losses)+1))
    fig = plot_curve(epochs, train_losses, test_losses, "Vision Loss", "loss")
    save_fig_dual(fig, os.path.join(stage_dir, "curve_loss"))
    fig = plot_curve(epochs, train_accs, test_accs, "Vision Accuracy", "accuracy")
    save_fig_dual(fig, os.path.join(stage_dir, "curve_acc"))

    # ---- Final CM & TSNE using best model ----
    # evaluate again on best model
    final_model.eval()
    tr_loss, tr_y, tr_p, tr_feats = evaluate_img_full(final_model, train_loader_img, nn.CrossEntropyLoss())
    te_loss, te_y, te_p, te_feats = evaluate_img_full(final_model, test_loader_img,  nn.CrossEntropyLoss())
    _, _, _, _, _, tr_cm,tr_pc = compute_metrics(tr_y, tr_p, num_classes)
    _, _, _, _, _, te_cm,te_pc = compute_metrics(te_y, te_p, num_classes)

    # CM save
    pd.DataFrame(tr_cm).to_csv(os.path.join(stage_dir,"cm_train.csv"), index=False)
    pd.DataFrame(te_cm).to_csv(os.path.join(stage_dir,"cm_test.csv"),  index=False)
    fig = plot_confusion(tr_cm, "Vision Confusion Matrix (Train)")
    save_fig_dual(fig, os.path.join(stage_dir,"cm_train"))
    fig = plot_confusion(te_cm, "Vision Confusion Matrix (Test)")
    save_fig_dual(fig, os.path.join(stage_dir,"cm_test"))

    # TSNE
    run_tsne_and_plot(tr_feats, tr_y, "Vision t-SNE (Train)", os.path.join(stage_dir,"tsne_train"))
    run_tsne_and_plot(te_feats, te_y, "Vision t-SNE (Test)",  os.path.join(stage_dir,"tsne_test"))

    return "checkpoints/vision_best_ema.pt"

# ======= Text Stage =======
class OCRTextDataset(Dataset):
    def __init__(self, csv_file, tokenizer, max_len=128, oversample_short=3):
        self.df = pd.read_csv(csv_file)
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.oversample_short = oversample_short
        self.data = self._oversample_short_texts()

    def _oversample_short_texts(self):
        rows = []
        for _, row in self.df.iterrows():
            text = clean_ocr_text(str(row["text"]))
            count = self.oversample_short if len(text.split()) < 3 else 1
            for _ in range(count):
                rows.append({"path": row["path"], "label": int(row["label"]), "text": text})
        return rows

    def __len__(self): return len(self.data)

    def __getitem__(self, idx):
        row = self.data[idx]
        enc = self.tokenizer(row["text"], padding='max_length', truncation=True,
                             max_length=self.max_len, return_tensors='pt')
        return {
            "input_ids": enc['input_ids'].squeeze(0),
            "attention_mask": enc['attention_mask'].squeeze(0),
            "label": row["label"]
        }

class TextBackbone(nn.Module):
    def __init__(self, model_name="bert-base-multilingual-cased"):
        super().__init__()
        self.bert = BertModel.from_pretrained('E:/bertmodel')
        self.out_dim = self.bert.config.hidden_size  # 768
        self.norm = nn.LayerNorm(self.out_dim)

    def forward(self, input_ids, attention_mask):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        cls = out.last_hidden_state[:,0,:]  # CLS
        return self.norm(cls)               # [B, 768]

class TextClassifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = TextBackbone()
        self.drop = nn.Dropout(0.5)
        self.fc = nn.Linear(self.backbone.out_dim, num_classes)

    def forward(self, input_ids, attention_mask):
        feats = self.backbone(input_ids, attention_mask)
        return self.fc(self.drop(feats))

@torch.no_grad()
def evaluate_text_full(model, loader, criterion):
    model.eval()
    total_loss = 0.0; n = 0
    y_true, y_pred = [], []
    all_feats = []
    for batch in loader:
        ids = batch["input_ids"].to(device)
        mask = batch["attention_mask"].to(device)
        labels = batch["label"]
        if not torch.is_tensor(labels):
            labels = torch.tensor(labels, dtype=torch.long)
        labels = labels.to(device)
        feats = model.backbone(ids, mask)
        logits = model.fc(model.drop(feats))
        loss = criterion(logits, labels)
        total_loss += float(loss.item()) * labels.size(0); n += labels.size(0)
        preds = logits.argmax(1)
        y_true.extend(labels.cpu().numpy().tolist())
        y_pred.extend(preds.cpu().numpy().tolist())
        all_feats.append(feats.cpu().numpy())
    avg_loss = total_loss / max(n, 1)
    feats_np = np.vstack(all_feats) if len(all_feats)>0 else np.zeros((0, model.backbone.out_dim))
    return avg_loss, y_true, y_pred, feats_np

def train_text(num_epochs=20, lr=3e-5, weight_decay=0.01, patience=5,
               batch_size=8, max_len=128, oversample_short=3):
    stage = "text"
    stage_dir = ensure_dir(os.path.join(RESULTS_DIR, stage))

    tokenizer = BertTokenizer.from_pretrained("E:/bert")
    train_ds = OCRTextDataset(train_csv, tokenizer, max_len=max_len, oversample_short=oversample_short)
    test_ds  = OCRTextDataset(test_csv,  tokenizer, max_len=max_len, oversample_short=1)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False)

    model = TextClassifier(num_classes).to(device)
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    total_steps = len(train_loader)*num_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, int(0.1*total_steps), total_steps)
    ce = nn.CrossEntropyLoss()
    ema = EMA(model, decay=0.98)

    best_acc, best_shadow = 0.0, None
    patience_counter = 0

    # histories
    train_losses, test_losses = [], []
    train_accs, test_accs = [], []
    train_f1s, test_f1s = [], []
    train_sens, test_sens = [], []
    train_spec, test_spec = [], []
    train_gm,   test_gm   = [], []
    metrics_rows = []

    for epoch in range(num_epochs):
        model.train()
        run_loss, corr, tot = 0.0, 0, 0
        for batch in train_loader:
            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            labels = batch["label"]
            if not torch.is_tensor(labels):
                labels = torch.tensor(labels, dtype=torch.long)
            labels = labels.to(device)

            optimizer.zero_grad()
            logits = model(ids, mask)
            loss = ce(logits, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()
            ema.update(model)

            run_loss += loss.item()
            corr += (logits.argmax(1)==labels).sum().item()
            tot += labels.size(0)

        # evaluate with EMA on train+test
        ema.apply_shadow(model)
        tr_loss, tr_y, tr_p, tr_feats = evaluate_text_full(model, train_loader, ce)
        te_loss, te_y, te_p, te_feats = evaluate_text_full(model, test_loader,  ce)
        ema.restore(model)

        tr_acc, tr_f1, tr_sens, tr_spec, tr_gm, tr_cm,tr_pc = compute_metrics(tr_y, tr_p, num_classes)
        te_acc, te_f1, te_sens, te_spec, te_gm, te_cm,te_pc = compute_metrics(te_y, te_p, num_classes)

        train_losses.append(tr_loss); test_losses.append(te_loss)
        train_accs.append(tr_acc); test_accs.append(te_acc)
        train_f1s.append(tr_f1); test_f1s.append(te_f1)
        train_sens.append(tr_sens); test_sens.append(te_sens)
        train_spec.append(tr_spec); test_spec.append(te_spec)
        train_gm.append(tr_gm);     test_gm.append(te_gm)

        metrics_rows.append([epoch+1,"train",tr_loss,tr_acc,tr_f1,tr_sens,tr_spec,tr_gm])
        metrics_rows.append([epoch+1,"test", te_loss,te_acc,te_f1,te_sens,te_spec,te_gm])

        print(f"[Text ] Epoch {epoch+1}/{num_epochs} - TrainLoss {run_loss/len(train_loader):.4f} - TrainAcc(EvalEMA) {tr_acc:.4f} - TestAcc(EvalEMA) {te_acc:.4f}")

        if te_acc > best_acc:
            best_acc = te_acc
            best_shadow = copy.deepcopy(ema.shadow)
            patience_counter = 0

            # محاسبه متریک‌های per-class روی تست
            _, _, _, _, _, _, te_pc = compute_metrics(te_y, te_p, num_classes)

            # ذخیره CSV
            rows = []
            for pc in te_pc:
                rows.append([stage, pc["class"], pc["acc"], pc["f1"],
                            pc["sensitivity"], pc["specificity"], pc["gmean"]])
            df = pd.DataFrame(rows, columns=["stage","class","acc","f1","sensitivity","specificity","gmean"])
            df.to_csv(os.path.join(stage_dir,"best_per_class.csv"), index=False)




        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"[Text ] Early stopping at epoch {epoch+1}")
                break

    # export EMA-averaged best
    final_model = TextClassifier(num_classes).to(device)
    with torch.no_grad():
        for n,p in final_model.named_parameters():
            if p.requires_grad and n in best_shadow:
                p.data.copy_(best_shadow[n].to(p.data.device))
    os.makedirs("checkpoints", exist_ok=True)
    torch.save(final_model.state_dict(), "checkpoints/text_best_ema.pt")
    print(f"[Text ] Best TestAcc(EMA): {best_acc:.4f}  | saved to checkpoints/text_best_ema.pt")

    # save epoch CSV + curves
    df = pd.DataFrame(metrics_rows, columns=["epoch","split","loss","acc","f1","sensitivity","specificity","gmean"])
    df.to_csv(os.path.join(stage_dir,"metrics.csv"), index=False)
    epochs = list(range(1, len(train_losses)+1))
    fig = plot_curve(epochs, train_losses, test_losses, "Text Loss", "loss")
    save_fig_dual(fig, os.path.join(stage_dir, "curve_loss"))
    fig = plot_curve(epochs, train_accs, test_accs, "Text Accuracy", "accuracy")
    save_fig_dual(fig, os.path.join(stage_dir, "curve_acc"))

    # Final CM & t-SNE with final model
    final_model.eval()
    tr_loss, tr_y, tr_p, tr_feats = evaluate_text_full(final_model, train_loader, ce)
    te_loss, te_y, te_p, te_feats = evaluate_text_full(final_model, test_loader,  ce)
    _, _, _, _, _, tr_cm,tr_pc = compute_metrics(tr_y, tr_p, num_classes)
    _, _, _, _, _, te_cm,te_pc = compute_metrics(te_y, te_p, num_classes)

    pd.DataFrame(tr_cm).to_csv(os.path.join(stage_dir,"cm_train.csv"), index=False)
    pd.DataFrame(te_cm).to_csv(os.path.join(stage_dir,"cm_test.csv"),  index=False)
    fig = plot_confusion(tr_cm, "Text Confusion Matrix (Train)")
    save_fig_dual(fig, os.path.join(stage_dir,"cm_train"))
    fig = plot_confusion(te_cm, "Text Confusion Matrix (Test)")
    save_fig_dual(fig, os.path.join(stage_dir,"cm_test"))

    run_tsne_and_plot(tr_feats, tr_y, "Text t-SNE (Train)", os.path.join(stage_dir,"tsne_train"))
    run_tsne_and_plot(te_feats, te_y, "Text t-SNE (Test)",  os.path.join(stage_dir,"tsne_test"))

    return "checkpoints/text_best_ema.pt"

# ======= Fusion Stage =======
# Dataset that returns (image, text) together using OCR CSV
class FusionDataset(Dataset):
    def __init__(self, csv_file, tokenizer, img_transform, max_len=128):
        self.df = pd.read_csv(csv_file)
        self.tokenizer = tokenizer
        self.img_transform = img_transform
        self.max_len = max_len

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        path = row["path"]; label = int(row["label"])
        # image
        img = Image.open(path).convert("RGB")
        img = self.img_transform(img)
        # text
        text = clean_ocr_text(str(row["text"]))
        enc = self.tokenizer(text, padding='max_length', truncation=True,
                             max_length=self.max_len, return_tensors='pt')
        return {
            "image": img,
            "input_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "label": label
        }


class DIDFusionHead(nn.Module):
    """
    Dynamic Information-Density Fusion
    """
    def __init__(self, v_dim, t_dim, num_classes, bottleneck=256, p_drop=0.6, temperature=2.0):
        super().__init__()
        self.v_proj = nn.Sequential(
            nn.LayerNorm(v_dim),
            nn.Linear(v_dim, bottleneck),
            nn.GELU(),
        )
        self.t_proj = nn.Sequential(
            nn.LayerNorm(t_dim),
            nn.Linear(t_dim, bottleneck),
            nn.GELU(),
        )
        self.temp = temperature

        # density estimators (scalar per modality)
        self.v_density = nn.Sequential(nn.Linear(bottleneck, 64), nn.GELU(), nn.Linear(64,1))
        self.t_density = nn.Sequential(nn.Linear(bottleneck, 64), nn.GELU(), nn.Linear(64,1))

        # fusion + classifier
        fused_in = bottleneck * 3  # [v_proj, t_proj, weighted_sum]
        self.fusion = nn.Sequential(
            nn.LayerNorm(fused_in),
            nn.Linear(fused_in, bottleneck),
            nn.BatchNorm1d(bottleneck),
            nn.GELU(),
            nn.Dropout(p_drop),
            nn.Linear(bottleneck, num_classes)
        )

    def forward(self, v, t):
        vp = self.v_proj(v)   # [B, b]
        tp = self.t_proj(t)   # [B, b]

        dv = self.v_density(vp)  # [B,1]
        dt = self.t_density(tp)  # [B,1]
        dcat = torch.cat([dv, dt], dim=1)  # [B,2]

        w = F.softmax(dcat / self.temp, dim=1)  # adaptive weights over {vision, text}
        wv = w[:,0:1]; wt = w[:,1:2]

        fused_weighted = wv*vp + wt*tp
        fused_cat = torch.cat([vp, tp, fused_weighted], dim=1)
        logits = self.fusion(fused_cat)
        return logits, w.detach(), vp.detach(), tp.detach(), fused_weighted.detach()




class FusionModel(nn.Module):
    def __init__(self, num_classes, vision_ckpt, text_ckpt, unfreeze_backbones=False, model_name="bert-base-multilingual-cased"):
        super().__init__()
        # Vision backbone
        self.vision_backbone = VisionBackbone()
        temp_vis = VisionClassifier(num_classes)
        temp_vis.load_state_dict(torch.load(vision_ckpt, map_location="cpu"))
        self.vision_backbone.load_state_dict(temp_vis.backbone.state_dict(), strict=True)

        # Text backbone
        self.text_backbone = TextBackbone(model_name=model_name)
        temp_txt = TextClassifier(num_classes)
        temp_txt.load_state_dict(torch.load(text_ckpt, map_location="cpu"))
        self.text_backbone.load_state_dict(temp_txt.backbone.state_dict(), strict=True)

        # freeze or not
        for p in self.vision_backbone.parameters(): p.requires_grad = unfreeze_backbones
        for p in self.text_backbone.parameters():   p.requires_grad = unfreeze_backbones

        # === Fusion bottleneck: smaller capacity + stronger regularization ===
        # === DID-Fusion head (replace previous self.fusion) ===
        bottleneck_dim = 256
        self.head = DIDFusionHead(self.vision_backbone.out_dim, self.text_backbone.out_dim,
                                  num_classes, bottleneck=bottleneck_dim, p_drop=0.6, temperature=2.0)


    def forward(self, image, input_ids, attention_mask):
        v = self.vision_backbone(image)                         # [B,1024]
        t = self.text_backbone(input_ids, attention_mask)       # [B,768]
        logits, weights, vp, tp, vw = self.head(v, t)
        return logits, weights, vp, tp, vw


@torch.no_grad()
def evaluate_fusion_full(model, loader, criterion):
    model.eval()
    total_loss = 0.0; n = 0
    y_true, y_pred = [], []
    v_feats_all, t_feats_all, fused_all, w_all = [], [], [], []
    for batch in loader:
        img = batch["image"].to(device)
        ids = batch["input_ids"].to(device)
        mask = batch["attention_mask"].to(device)
        labels = torch.as_tensor(batch["label"], dtype=torch.long, device=device)
        logits, weights, vp, tp, vw = model(img, ids, mask)
        loss = criterion(logits, labels)
        total_loss += float(loss.item()) * labels.size(0); n += labels.size(0)
        preds = logits.argmax(1)
        y_true.extend(labels.cpu().numpy().tolist())
        y_pred.extend(preds.cpu().numpy().tolist())
        v_feats_all.append(vp.cpu().numpy()); t_feats_all.append(tp.cpu().numpy())
        fused_all.append(vw.cpu().numpy()); w_all.append(weights.cpu().numpy())
    avg_loss = total_loss / max(n,1)
    v_np = np.vstack(v_feats_all) if len(v_feats_all)>0 else np.zeros((0, model.vision_backbone.out_dim))
    t_np = np.vstack(t_feats_all) if len(t_feats_all)>0 else np.zeros((0, model.text_backbone.out_dim))
    f_np = np.vstack(fused_all) if len(fused_all)>0 else np.zeros((0, fused_all[0].shape[1] if fused_all else 0))
    w_np = np.vstack(w_all) if len(w_all)>0 else np.zeros((0,2))
    return avg_loss, y_true, y_pred, v_np, t_np, f_np, w_np


def train_fusion(vision_ckpt, text_ckpt, num_epochs=20, lr_head=1e-3, lr_backbone=1e-5,
                 batch_size=16, max_len=128, unfreeze_backbones=False, weight_decay=0.01, patience=6):
    stage = "fusion"
    stage_dir = ensure_dir(os.path.join(RESULTS_DIR, stage))

    tokenizer = BertTokenizer.from_pretrained("E:/bert")
    train_ds = FusionDataset(train_csv, tokenizer, img_transform=train_transform_img, max_len=max_len)
    test_ds  = FusionDataset(test_csv,  tokenizer, img_transform=test_transform_img, max_len=max_len)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=0, pin_memory=True)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

    model = FusionModel(num_classes, vision_ckpt, text_ckpt, unfreeze_backbones=unfreeze_backbones).to(device)

    # param groups: heads vs (optional) backbones
    head_params = list(model.head.parameters())
    if unfreeze_backbones:
        back_params = list(model.vision_backbone.parameters()) + list(model.text_backbone.parameters())
        optimizer = optim.AdamW([
            {"params": head_params, "lr": lr_head, "weight_decay": 5e-4},    # stronger reg on head
            {"params": back_params, "lr": lr_backbone, "weight_decay": 1e-4} # milder reg for backbones
        ])
    else:
        optimizer = optim.AdamW([{"params": head_params, "lr": lr_head, "weight_decay": 5e-4}])


    total_steps = len(train_loader)*num_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, int(0.06*total_steps), total_steps)
    ce = nn.CrossEntropyLoss(label_smoothing=0.05)
    ema = EMA(model, decay=0.995)

    best_acc, best_shadow, patience_counter = 0.0, None, 0

    # histories
    train_losses, test_losses = [], []
    train_accs, test_accs = [], []
    train_f1s, test_f1s = [], []
    train_sens, test_sens = [], []
    train_spec, test_spec = [], []
    train_gm,   test_gm   = [], []
    metrics_rows = []

    for epoch in range(num_epochs):
        model.train()
        run_loss, corr, tot = 0.0, 0, 0
        for batch in train_loader:
            img = batch["image"].to(device)
            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            labels = torch.as_tensor(batch["label"], dtype=torch.long, device=device)

            optimizer.zero_grad()
            logits, weights, vp, tp, vw = model(img, ids, mask)
            loss = ce(logits, labels)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            ema.update(model)

            run_loss += loss.item()
            corr += (logits.argmax(1)==labels).sum().item()
            tot += labels.size(0)

        # eval with EMA weights on train+test
        ema.apply_shadow(model)
        tr_loss, tr_y, tr_p, tr_v, tr_t, tr_fused, tr_w  = evaluate_fusion_full(model, train_loader, ce)
        te_loss, te_y, te_p, te_v, te_t, te_fused, te_w  = evaluate_fusion_full(model, test_loader,  ce)
        ema.restore(model)

        tr_acc, tr_f1, tr_sens, tr_spec, tr_gm, tr_cm,tr_pc = compute_metrics(tr_y, tr_p, num_classes)
        te_acc, te_f1, te_sens, te_spec, te_gm, te_cm,te_pc = compute_metrics(te_y, te_p, num_classes)

        train_losses.append(tr_loss); test_losses.append(te_loss)
        train_accs.append(tr_acc); test_accs.append(te_acc)
        train_f1s.append(tr_f1); test_f1s.append(te_f1)
        train_sens.append(tr_sens); test_sens.append(te_sens)
        train_spec.append(tr_spec); test_spec.append(te_spec)
        train_gm.append(tr_gm);     test_gm.append(te_gm)

        metrics_rows.append([epoch+1,"train",tr_loss,tr_acc,tr_f1,tr_sens,tr_spec,tr_gm])
        metrics_rows.append([epoch+1,"test", te_loss,te_acc,te_f1,te_sens,te_spec,te_gm])

        print(f"[Fuse ] Epoch {epoch+1}/{num_epochs} - TrainLoss {run_loss/len(train_loader):.4f} - TrainAcc(EvalEMA) {tr_acc:.4f} - TestAcc(EvalEMA) {te_acc:.4f}")

        if te_acc>best_acc:
            best_acc = te_acc
            best_shadow = copy.deepcopy(ema.shadow)
            patience_counter = 0

            # محاسبه متریک‌های per-class روی تست
            _, _, _, _, _, _, te_pc = compute_metrics(te_y, te_p, num_classes)

            rows = []
            for pc in te_pc:
                rows.append([stage, pc["class"], pc["acc"], pc["f1"],
                            pc["sensitivity"], pc["specificity"], pc["gmean"]])
            df = pd.DataFrame(rows, columns=["stage","class","acc","f1","sensitivity","specificity","gmean"])
            df.to_csv(os.path.join(stage_dir,"best_per_class.csv"), index=False)         
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"[Fuse ] Early stopping at epoch {epoch+1}")
                break

    # export best EMA weights
    with torch.no_grad():
        for n,p in model.named_parameters():
            if p.requires_grad and n in best_shadow:
                p.data.copy_(best_shadow[n].to(p.data.device))
    os.makedirs("checkpoints", exist_ok=True)
    torch.save(model.state_dict(), "checkpoints/fusion_best_ema.pt")
    print(f"[Fuse ] Best TestAcc(EMA): {best_acc:.4f}  | saved to checkpoints/fusion_best_ema.pt")

    # save epoch CSV + curves
    df = pd.DataFrame(metrics_rows, columns=["epoch","split","loss","acc","f1","sensitivity","specificity","gmean"])
    df.to_csv(os.path.join(stage_dir,"metrics.csv"), index=False)
    epochs_ = list(range(1, len(train_losses)+1))
    fig = plot_curve(epochs_, train_losses, test_losses, "Fusion Loss", "loss")
    save_fig_dual(fig, os.path.join(stage_dir, "curve_loss"))
    fig = plot_curve(epochs_, train_accs, test_accs, "Fusion Accuracy", "accuracy")
    save_fig_dual(fig, os.path.join(stage_dir, "curve_acc"))

    # Final CM & TSNE (on fused features)
    model.eval()
    tr_loss, tr_y, tr_p, tr_v, tr_t, tr_fused, tr_w  = evaluate_fusion_full(model, train_loader, ce)
    te_loss, te_y, te_p, te_v, te_t, te_fused, te_w  = evaluate_fusion_full(model, test_loader,  ce)
    _, _, _, _, _, tr_cm,tr_pc = compute_metrics(tr_y, tr_p, num_classes)
    _, _, _, _, _, te_cm,te_pc = compute_metrics(te_y, te_p, num_classes)

    pd.DataFrame(tr_cm).to_csv(os.path.join(stage_dir,"cm_train.csv"), index=False)
    pd.DataFrame(te_cm).to_csv(os.path.join(stage_dir,"cm_test.csv"),  index=False)
    fig = plot_confusion(tr_cm, "Fusion Confusion Matrix (Train)")
    save_fig_dual(fig, os.path.join(stage_dir,"cm_train"))
    fig = plot_confusion(te_cm, "Fusion Confusion Matrix (Test)")
    save_fig_dual(fig, os.path.join(stage_dir,"cm_test"))

    run_tsne_and_plot(tr_fused, tr_y, "Fusion t-SNE (Train)", os.path.join(stage_dir,"tsne_train"), use_all=True)
    run_tsne_and_plot(te_fused, te_y, "Fusion t-SNE (Test)",  os.path.join(stage_dir,"tsne_test"), use_all=True)

    return "checkpoints/fusion_best_ema.pt"

# ======= Run all three stages =======
if __name__ == "__main__":
    # Stage 1: Vision
    vision_ckpt = train_vision(num_epochs=15, lr=1e-4, weight_decay=0.0)

    # Stage 2: Text
    text_ckpt = train_text(num_epochs=10, lr=3e-5, weight_decay=0.01, patience=20,
                           batch_size=8, max_len=128, oversample_short=3)


    # Stage 3: Fusion (set unfreeze_backbones=True if دیتاست بزرگ و GPU مناسب داری)
    fusion_ckpt = train_fusion(vision_ckpt, text_ckpt,
                               num_epochs=10, lr_head=1e-3, lr_backbone=1e-5,
                               batch_size=16, max_len=128, unfreeze_backbones=False,
                               weight_decay=0.01, patience=20)


[!] CSV exists: ocr_cache\train.csv
[!] CSV exists: ocr_cache\test.csv
[Vision] Epoch 1/15 - TrainLoss 1.3715 - TrainAcc(EvalEMA) 0.5471 - TestAcc(EvalEMA) 0.3520
[Vision] Epoch 2/15 - TrainLoss 0.6808 - TrainAcc(EvalEMA) 0.9266 - TestAcc(EvalEMA) 0.6201
[Vision] Epoch 3/15 - TrainLoss 0.2892 - TrainAcc(EvalEMA) 0.9937 - TestAcc(EvalEMA) 0.6760
[Vision] Epoch 4/15 - TrainLoss 0.0982 - TrainAcc(EvalEMA) 1.0000 - TestAcc(EvalEMA) 0.7598
[Vision] Epoch 5/15 - TrainLoss 0.0434 - TrainAcc(EvalEMA) 1.0000 - TestAcc(EvalEMA) 0.7765
[Vision] Epoch 6/15 - TrainLoss 0.0265 - TrainAcc(EvalEMA) 1.0000 - TestAcc(EvalEMA) 0.7654
[Vision] Epoch 7/15 - TrainLoss 0.0169 - TrainAcc(EvalEMA) 1.0000 - TestAcc(EvalEMA) 0.7821
[Vision] Epoch 8/15 - TrainLoss 0.0133 - TrainAcc(EvalEMA) 1.0000 - TestAcc(EvalEMA) 0.7654
[Vision] Epoch 9/15 - TrainLoss 0.0078 - TrainAcc(EvalEMA) 1.0000 - TestAcc(EvalEMA) 0.8101
[Vision] Epoch 10/15 - TrainLoss 0.0069 - TrainAcc(EvalEMA) 1.0000 - TestAcc(EvalEMA) 0.8156
[Vision]