##Helper Functions and Classes

In [None]:
from google.colab import drive
import os
import glob
import random
import numpy as np

drive.mount('/content/drive')
DATA_ROOT = '/content/drive/My Drive/Deep Fake Dataset'
PROCESSED_ROOT = os.path.join(DATA_ROOT, 'processed_faces')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

class DeepFakeHybridModel(nn.Module):
    def __init__(self):
        super(DeepFakeHybridModel, self).__init__()

        # 1. Base: EfficientNet-B3
        self.backbone = models.efficientnet_b3(pretrained=True)
        self.backbone.classifier = nn.Identity()


        # EfficientNet-B3 output is 1536
        self.dense_512 = nn.Sequential(
            nn.Linear(1536, 512),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        self.dense_128 = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        # Final Softmax (2 classes)
        self.classifier = nn.Linear(128, 2)

    def forward(self, x):
        x = self.backbone(x)         # 1536
        feat_512 = self.dense_512(x) # 512 (For Ensemble)
        x = self.dense_128(feat_512)
        out = self.classifier(x)     # Classification (For Training)
        return out, feat_512


class FakeQuantWrapper(nn.Module):
    def __init__(self, module: nn.Module, weight_bits:int=8, act_bits:int=8, eps:float=1e-8):
        super().__init__()
        self.module = module
        self.weight_bits = int(weight_bits)
        self.act_bits = int(act_bits)
        self.eps = float(eps)


        self.register_buffer('_w_q_cached', torch.empty(0))
        self.register_buffer('_scale_cached', torch.empty(0))
        self._cache_initialized = False

    def update_cached_weight(self, device=None):
        """Call this after optimizer.step() to refresh the cached quantized weight."""
        w = self.module.weight.data
        if device is not None:
            w = w.to(device)
        # use per-channel quant when weight is multi-dimensional
        if w.ndim >= 2:
            w_q, scale = quantize_tensor_symmetric_perchannel(w, self.weight_bits, eps=self.eps, channel_dim=0)
        else:
            w_q, scale = quantize_tensor_symmetric_perchannel(w, self.weight_bits, eps=self.eps, channel_dim=0)
        # store as buffer on module's device
        self._w_q_cached = w_q.detach().to(self.module.weight.device)
        self._scale_cached = scale.detach().to(self.module.weight.device)
        self._cache_initialized = True

    def forward(self, x):
        # ensure cache exists (lazy init)
        if not self._cache_initialized or self._w_q_cached.numel() == 0:
            w = self.module.weight.data
            if w.ndim >= 2:
                w_q, scale = quantize_tensor_symmetric_perchannel(w, self.weight_bits, eps=self.eps, channel_dim=0)
            else:
                w_q, scale = quantize_tensor_symmetric_perchannel(w, self.weight_bits, eps=self.eps, channel_dim=0)
            self._w_q_cached = w_q.detach().to(self.module.weight.device)
            self._scale_cached = scale.detach().to(self.module.weight.device)
            self._cache_initialized = True

        w_q_cached = self._w_q_cached

        # STE wrapper: (w_q_cached - w).detach() + w
        if isinstance(self.module, nn.Conv2d):
            bias = self.module.bias
            w = self.module.weight
            w_q_for_backward = (w_q_cached - w).detach() + w
            out = F.conv2d(
                x,
                w_q_for_backward,
                bias=bias,
                stride=self.module.stride,
                padding=self.module.padding,
                dilation=self.module.dilation,
                groups=self.module.groups
            )
        elif isinstance(self.module, nn.Linear):
            bias = self.module.bias
            w = self.module.weight
            w_q_for_backward = (w_q_cached - w).detach() + w
            out = F.linear(x, w_q_for_backward, bias)
        else:
            out = self.module(x)

        # Activation quantization
        if self.act_bits < 32:
            qmax = 2 ** (self.act_bits - 1) - 1
            max_abs = out.abs().amax()
            eps_t = torch.tensor(self.eps, device=out.device, dtype=out.dtype)
            scale = (max_abs / qmax).clamp(min=eps_t)
            q = torch.round(out / scale).clamp(-qmax, qmax)
            out_q = q * scale
            out = (out_q - out).detach() + out

        return out


In [None]:
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from tqdm.auto import tqdm
import os
import torch.nn.functional as F
import glob
import numpy as np
import os
from PIL import Image
from tqdm.auto import tqdm
import torch
import torchvision.transforms as T
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score, classification_report

train_tf = T.Compose([
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_dir = "/content/processed_faces"
train_dataset = ImageFolder(train_dir, transform=train_tf)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 2. Setup Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import os
import glob
from collections import OrderedDict
from tqdm.auto import tqdm
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image

# -----------------------------
# Feature extraction
# -----------------------------
def extract_features_from_disk(split_name, model, dataset_root="/content/processed_faces", device='cpu'):
    model.eval()
    dataset_path = os.path.join(dataset_root, split_name)
    video_features, video_labels = {}, {}
    all_images = glob.glob(os.path.join(dataset_path, '*/*.jpg'))
    print(f"Extracting features from {len(all_images)} frames in '{split_name}'...")
    transform = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
    ])
    with torch.no_grad():
        for img_path in tqdm(all_images):
            filename = os.path.basename(img_path)
            vid_name = filename.rsplit('_', 1)[0]
            parent_folder = os.path.basename(os.path.dirname(img_path))
            label = 0 if parent_folder == 'real' else 1
            img = Image.open(img_path).convert('RGB')
            img_t = transform(img).unsqueeze(0).to(device)
            outputs = model(img_t)
            if isinstance(outputs, (tuple, list)) and len(outputs) >= 2:
                feat_512 = outputs[1]
            else:
                feat_512 = outputs if not isinstance(outputs, (tuple, list)) else outputs[0]
            feat_np = feat_512.cpu().numpy().reshape(-1)
            if vid_name not in video_features:
                video_features[vid_name] = []
                video_labels[vid_name] = label
            video_features[vid_name].append(feat_np)
    X, y = [], []
    for vid, feats in video_features.items():
        X.append(np.mean(feats, axis=0))
        y.append(video_labels[vid])
    return np.array(X), np.array(y)


# -----------------------------
# Parameter helpers
# -----------------------------
def named_weight_params(model):
    for name, p in model.named_parameters():
        if p.requires_grad and p.ndim >= 1 and 'bias' not in name:
            yield name, p

def is_weight_param(name, param):
    return param.requires_grad and param.ndim >= 1 and 'bias' not in name

def quantizable_param_names(model, exclude_prefixes=('dense_128','classifier')):
    names = []
    for n,p in model.named_parameters(recurse=True):
        if not is_weight_param(n,p):
            continue
        if any(n.startswith(pref) for pref in exclude_prefixes):
            continue
        names.append(n)
    return names


# -----------------------------
# Hutchinson Hessian diagonal estimator
# -----------------------------
def hutchinson_diag(model, loss_fn, dataloader, device='cuda', num_vectors=8, max_batches=2):
    model.to(device); model.eval()
    param_items = list(named_weight_params(model))
    if not param_items:
        raise RuntimeError("No quantizable parameters found for Hutchinson estimator.")
    param_list = [p for _, p in param_items]
    param_names = [n for n, _ in param_items]
    diag_accum = [torch.zeros_like(p, device=device) for p in param_list]
    total_draws = 0

    for batch_idx, batch in enumerate(dataloader):
        if max_batches is not None and batch_idx >= max_batches:
            break
        model.zero_grad()
        inputs, targets = batch
        inputs = inputs.to(device); targets = targets.to(device)
        outputs = model(inputs)
        logits = outputs[0] if isinstance(outputs, (tuple, list)) else outputs
        loss = loss_fn(logits, targets)
        grads = torch.autograd.grad(loss, param_list, create_graph=True)

        for _ in range(num_vectors):
            vs = [(torch.randint_like(p, low=0, high=2, device=device).float()*2 - 1) for p in param_list]
            inner = sum((g * v).sum() for g, v in zip(grads, vs))
            Hv = torch.autograd.grad(inner, param_list, retain_graph=True)
            for i, (v, hvi) in enumerate(zip(vs, Hv)):
                diag_accum[i] += (v * hvi).detach()
            total_draws += 1

        # free intermediate grads
        for g in grads:
            del g

    if total_draws == 0:
        raise RuntimeError("No batches processed in Hutchinson estimator.")
    return {name: (d / total_draws).cpu() for name, d in zip(param_names, diag_accum)}

def get_diag_estimates(model, loss_fn, calib_loader, device='cuda', hutchinson_vectors=8, hutchinson_batches=2):
    print("Running Hutchinson Hessian diag estimator.")
    return hutchinson_diag(model, loss_fn, calib_loader, device=device,
                           num_vectors=hutchinson_vectors, max_batches=hutchinson_batches)


# -----------------------------
# Quantize helpers
# -----------------------------
def quantize_tensor_symmetric_perchannel(w: torch.Tensor, nbits:int, eps:float = 1e-8, channel_dim:int = 0):
    if nbits < 1:
        raise ValueError("nbits must be >= 1")
    qmax = 2 ** (nbits - 1) - 1
    with torch.no_grad():
        if w.ndim >= 2:
            reduce_dims = [i for i in range(w.ndim) if i != channel_dim]
            max_abs = w.abs().amax(dim=reduce_dims, keepdim=True)
            eps_t = torch.tensor(eps, device=w.device, dtype=w.dtype)
            scale = (max_abs / qmax).clamp(min=eps_t)
            q = torch.round(w / scale).clamp(-qmax, qmax)
            w_q = q * scale
        else:
            max_abs = w.abs().amax()
            eps_t = torch.tensor(eps, device=w.device, dtype=w.dtype)
            scale = (max_abs / qmax).clamp(min=eps_t)
            q = torch.round(w / scale).clamp(-qmax, qmax)
            w_q = q * scale

        if not torch.isfinite(w_q).all() or not torch.isfinite(scale).all():
            return w.detach().clone(), scale
    return w_q, scale

def estimated_delta_loss_from_quant(w: torch.Tensor, diag_h: torch.Tensor, w_q: torch.Tensor):
    delta = (w_q - w).detach().cpu()
    h = diag_h.cpu()
    if h.shape != delta.shape:
        try:
            h = h.expand_as(delta)
        except Exception:
            h = torch.ones_like(delta) * float(h.mean().item())
    return 0.5 * (h * (delta ** 2)).sum().item()

def compute_costs_filtered(model, diag_h_dict, include_names=None, base_bits=4, max_bits=8):
    costs = {}
    for name, p in named_weight_params(model):
        if include_names is not None and name not in include_names:
            continue
        w = p.data
        diag = diag_h_dict.get(name, torch.ones_like(w.cpu()) * 1e-6)
        costs[name] = {}
        for b in range(base_bits, max_bits+1):
            w_q, _ = quantize_tensor_symmetric_perchannel(w, b)
            costs[name][b] = estimated_delta_loss_from_quant(w, diag, w_q)
    return costs

def allocate_bits_greedy(costs, base_bits=4, max_bits=8, bit_budget=None):
    names = list(costs.keys())
    assignment = {name: base_bits for name in names}
    if bit_budget is None:
        for name in names:
            assignment[name] = min(range(base_bits, max_bits + 1), key=lambda b: costs[name][b])
        return assignment

    total_layers = len(names)
    if bit_budget <= max_bits:
        total_bits_allowed = int(bit_budget * total_layers)
    else:
        total_bits_allowed = int(bit_budget)
    total_bits = sum(assignment.values())
    if total_bits >= total_bits_allowed:
        return assignment

    while total_bits < total_bits_allowed:
        best_gain = -float('inf'); best_name = None
        for name in names:
            cur_b = assignment[name]
            if cur_b >= max_bits:
                continue
            gain = costs[name][cur_b] - costs[name][cur_b + 1]
            if gain > best_gain:
                best_gain = gain; best_name = name
        if best_name is None:
            break
        assignment[best_name] += 1
        total_bits += 1

    return assignment


# -----------------------------
# Wrap model with FakeQuantWrapper
# -----------------------------
def wrap_model_with_fakequant(model, bits_assignment, act_bits=8):
    name_to_parent = {}
    for p_name, p in model.named_parameters(recurse=True):
        if p_name in bits_assignment and p_name.endswith('weight'):
            parts = p_name.split('.')
            parent = model
            for part in parts[:-1]:
                parent = getattr(parent, part)
            name_to_parent[p_name] = (parent, parts[:-1])

    wrapped = set()
    for p_name, (parent_module, path) in name_to_parent.items():
        if len(path) == 0:
            continue
        grandparent = model
        for part in path[:-1]:
            grandparent = getattr(grandparent, part)
        attr = path[-1]
        orig_mod = getattr(grandparent, attr)
        if id(orig_mod) in wrapped:
            continue
        if isinstance(orig_mod, (nn.Conv2d, nn.Linear)):
            wbits = bits_assignment.get(p_name, 8)
            wrapped_mod = FakeQuantWrapper(orig_mod, weight_bits=wbits, act_bits=act_bits)
            setattr(grandparent, attr, wrapped_mod)
            wrapped.add(id(wrapped_mod))
    return model


# -----------------------------
# QAT training loop
# -----------------------------
def qat_train(
    model,
    train_loader,
    val_loader,
    device,
    epochs: int = 6,
    lr: float = 1e-4,
    max_grad_norm: float = 1.0,
    log_every: int = 50,
    recompute_bn: bool = True,
    recompute_bn_batches: int = 200,
):
    model.to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4, eps=1e-8)
    loss_fn = nn.CrossEntropyLoss()

    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        n = 0
        pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}")
        for i, (xb, yb) in pbar:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad()
            try:
                with torch.cuda.amp.autocast():
                    out = model(xb)
                    logits = out[0] if isinstance(out, (tuple, list)) else out
                    loss = loss_fn(logits, yb)
            except Exception as e:
                print(f"Forward error (epoch {epoch+1} batch {i}): {e}. Skipping batch.")
                continue

            if not torch.isfinite(loss):
                print(f"Non-finite loss at epoch {epoch+1} batch {i}; skipping batch.")
                continue

            scaler.scale(loss).backward()

            # gradient sanity check
            bad_grad = False
            for name, p in model.named_parameters():
                if p.grad is None:
                    continue
                if not torch.isfinite(p.grad).all():
                    print(f"Non-finite gradient in {name} at epoch {epoch+1} batch {i}. Skipping batch.")
                    bad_grad = True
                    break
            if bad_grad:
                opt.zero_grad()
                continue


            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            scaler.step(opt)
            scaler.update()

            running_loss += float(loss.item())
            n += 1
            if (i + 1) % log_every == 0:
                pbar.set_postfix({'avg_loss': f"{(running_loss / max(1, n)):.4f}"})

        if n > 0:
            print(f"Epoch {epoch+1} avg loss: {(running_loss / n):.4f}")
        else:
            print(f"Epoch {epoch+1} had no successful batches.")

    # recompute BN running stats
    if recompute_bn:
        def _recompute_bn_stats(model_, loader_, device_, num_batches_=recompute_bn_batches):
            model_.train()
            with torch.no_grad():
                it = iter(loader_)
                for _i in range(num_batches_):
                    try:
                        xb_, _ = next(it)
                    except StopIteration:
                        it = iter(loader_)
                        xb_, _ = next(it)
                    xb_ = xb_.to(device_)
                    _ = model_(xb_)
        try:
            print("Recomputing BN running statistics...")
            _recompute_bn_stats(model, train_loader, device, recompute_bn_batches)
        except Exception as e:
            print("Warning: recompute BN failed:", e)

    return model


# -----------------------------
# Dequantize loader
# -----------------------------
def dequantize_state_dict_from_disk(quant_path: str, device='cpu'):
    data = torch.load(quant_path, map_location='cpu')
    out_state = OrderedDict()
    for name, t in data.get('fp32_fallback', {}).items():
        out_state[name] = t.to(torch.float32).clone().to(device)
    for name, info in data.get('quant', {}).items():
        q = info['q']              # integer tensor (int8/int16)
        scale = info['scale']      # float32
        w_fp = q.to(torch.float32) * scale
        out_state[name] = w_fp.to(torch.float32).clone().to(device)
    return out_state


# -----------------------------
#Mixed-bit integer saver
# -----------------------------
def save_quantized_checkpoint_mixedbits(
    fp32_state_dict: dict,
    bits_assignment: dict,
    out_path: str,
    exclude_prefixes: tuple = ('dense_128', 'classifier'),
    eps: float = 1e-8
):

    quant_store = {}
    fp32_fallback = {}
    total_fp32_bytes = 0
    total_quant_bytes = 0

    for name, tensor in fp32_state_dict.items():
        is_bias = name.endswith("bias")
        if is_bias or any(name.startswith(pref) for pref in exclude_prefixes):
            fp32_fallback[name] = tensor.clone().to(torch.float32)
            total_fp32_bytes += tensor.numel() * 4
            continue

        bits = bits_assignment.get(name, None)
        if bits is None:
            fp32_fallback[name] = tensor.clone().to(torch.float32)
            total_fp32_bytes += tensor.numel() * 4
            continue

        t = tensor.clone().to(torch.float32)
        qmax = 2 ** (bits - 1) - 1
        if t.ndim >= 2:
            reduce_dims = [i for i in range(t.ndim) if i != 0]
            max_abs = t.abs().amax(dim=reduce_dims, keepdim=True)
        else:
            max_abs = t.abs().max()
        scale = (max_abs / qmax).clamp(min=eps)
        q = torch.round(t / scale).clamp(-qmax, qmax)

        if bits <= 8:
            q_int = q.to(torch.int8)
            dtype_str = "int8"
            byte_size = 1
        else:
            q_int = q.to(torch.int16)
            dtype_str = "int16"
            byte_size = 2

        quant_store[name] = {'q': q_int.cpu(), 'scale': scale.cpu(), 'bits': int(bits),
                             'dtype': dtype_str, 'shape': tuple(t.shape)}
        total_quant_bytes += q_int.numel() * byte_size
        total_quant_bytes += scale.numel() * 4

    meta = {'num_quantized': len(quant_store), 'num_fallback_fp32': len(fp32_fallback),
            'total_fp32_bytes_preexisting': total_fp32_bytes, 'total_quant_bytes_after': total_quant_bytes}

    out = {'quant': quant_store, 'fp32_fallback': fp32_fallback, 'meta': meta}
    torch.save(out, out_path)
    approx_ratio = (total_fp32_bytes / total_quant_bytes) if total_quant_bytes>0 else float('inf')
    print(f"Saved mixed-bit quantized checkpoint to {out_path}.")
    print("Meta:", meta)
    print(f"Compression ratio (FP32 bytes / quant bytes): {approx_ratio:.2f}x")
    return out_path


##Start Training

In [None]:
# ----- Config -----
CHECKPOINT_PATH = os.path.join(DATA_ROOT, 'trained_cnn', 'efficientnet_deepfake_v1_30epoch.pth')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
exclude_prefixes = ('dense_128','classifier')
hutch_vectors = 16
hutch_batches = 2
base_bits = 4
max_bits = 16
bit_budget = 8            # avg bits (or absolute if > max_bits)
activation_bits = 8
qat_epochs = 5
qat_lr = 2e-5
save_ensemble = False
DATASET_ROOT = "/content/processed_faces"
OUT_DIR = os.path.join(DATA_ROOT, 'trained_cnn')
OUT_CLF_DIR = os.path.join(OUT_DIR, 'ensemble_classifiers')
os.makedirs(OUT_CLF_DIR, exist_ok=True)
os.makedirs(OUT_DIR, exist_ok=True)

# ----- Build / Load model -----
model = DeepFakeHybridModel().to(device)
if os.path.exists(CHECKPOINT_PATH):
    print(f"Loading pretrained checkpoint from {CHECKPOINT_PATH}...")
    model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=device))
    print("Checkpoint loaded.")
else:
    print(f"Warning: CHECKPOINT_PATH not found at {CHECKPOINT_PATH}; starting from random init.")

# -----Hessian diag -----
diag_h = get_diag_estimates(model, nn.CrossEntropyLoss(), calib_loader,
                            device=device, hutchinson_vectors=hutch_vectors, hutchinson_batches=hutch_batches)

# ---- choose quantizable params -----
include_names = quantizable_param_names(model, exclude_prefixes=exclude_prefixes)
print("Sample of params to quantize:", include_names[:12])

# -----  compute costs & allocate bits -----
costs = compute_costs_filtered(model, diag_h, include_names=include_names, base_bits=base_bits, max_bits=max_bits)
bits_assignment_partial = allocate_bits_greedy(costs, base_bits=base_bits, max_bits=max_bits, bit_budget=bit_budget)

print("Assigned bits:")
for i,(k,v) in enumerate(bits_assignment_partial.items()):
    if i >= 25: break
    print(f"  {k} -> {v}")

# ----  Wrap model & run QAT -----
qat_model = model.train()
qat_model = wrap_model_with_fakequant(qat_model, bits_assignment_partial, act_bits=activation_bits)
qat_model.to(device)

# Run simplified QAT
qat_model = qat_train(qat_model, train_loader, calib_loader, device, epochs=qat_epochs, lr=qat_lr)

# -----  Save FP32 QAT state_dict-----
FP32_QAT_PATH = os.path.join(OUT_DIR, f'qat_model_fp32_bitsbudget_{bit_budget}.pth')
torch.save(qat_model.state_dict(), FP32_QAT_PATH)
print("Saved FP32 QAT model state_dict to:", FP32_QAT_PATH)

# -----  Train ensemble classifiers and validate on extracted features -----
print("\nExtracting features for classifier training")
X_train, y_train = extract_features_from_disk('train', qat_model, dataset_root=DATASET_ROOT, device=device)
X_val, y_val = extract_features_from_disk('val', qat_model, dataset_root=DATASET_ROOT, device=device)

print("Feature shapes:", X_train.shape, X_val.shape)

clfs = {
    'SVM': SVC(probability=True, kernel='rbf', random_state=42),
    'DT': DecisionTreeClassifier(random_state=42),
    'KNN': KNeighborsClassifier(n_neighbors=5),
    'NB': GaussianNB()
}

preds_stack = []
results = {}
print("\nTraining classifiers...")
for name, clf in clfs.items():
    clf.fit(X_train, y_train)
    preds = clf.predict(X_val)
    acc = accuracy_score(y_val, preds)
    results[name] = {
        'acc': acc,
        'report': classification_report(y_val, preds, digits=4, output_dict=True),
        'confusion': confusion_matrix(y_val, preds).tolist()
    }
    preds_stack.append(preds)
    # save classifier
    joblib.dump(clf, os.path.join(OUT_CLF_DIR, f"clf_{name}.joblib"))
    print(f"  {name} val acc: {acc*100:.2f}% (saved)")

# Ensemble majority vote (>= 2 votes)
preds_stack = np.array(preds_stack)  # shape (n_classifiers, n_samples)
votes = preds_stack.sum(axis=0)
ensemble_pred = (votes >= 2).astype(int)
ensemble_acc = accuracy_score(y_val, ensemble_pred)
print(f"\nEnsemble (majority >=2) accuracy: {ensemble_acc*100:.2f}%")

if save_ensemble:
    meta = {
        'classifiers': list(clfs.keys()),
        'results': results,
        'ensemble': {'acc': ensemble_acc}
    }
    joblib.dump(meta, os.path.join(OUT_CLF_DIR, 'clf_meta.joblib'))
    print("Saved classifier metadata to", os.path.join(OUT_CLF_DIR, 'clf_meta.joblib'))

    for name in clfs.keys():
        print(f"\nClassification report for {name}:")
        print(classification_report(y_val, joblib.load(os.path.join(OUT_CLF_DIR, f"clf_{name}.joblib")).predict(X_val), digits=4))





Loading weights from /content/drive/My Drive/Deep Fake Dataset/trained_cnn/efficientnet_deepfake_v1_30epoch.pth...
Model loaded successfully.
Running Hutchinson Hessian diag estimator (may be slow)...
Will quantize these params (sample): ['backbone.features.0.0.weight', 'backbone.features.0.1.weight', 'backbone.features.1.0.block.0.0.weight', 'backbone.features.1.0.block.0.1.weight', 'backbone.features.1.0.block.1.fc1.weight', 'backbone.features.1.0.block.1.fc2.weight', 'backbone.features.1.0.block.2.0.weight', 'backbone.features.1.0.block.2.1.weight', 'backbone.features.1.1.block.0.0.weight', 'backbone.features.1.1.block.0.1.weight']
Assigned bits (sample):
  backbone.features.0.0.weight -> 4
  backbone.features.0.1.weight -> 4
  backbone.features.1.0.block.0.0.weight -> 4
  backbone.features.1.0.block.0.1.weight -> 8
  backbone.features.1.0.block.1.fc1.weight -> 16
  backbone.features.1.0.block.1.fc2.weight -> 4
  backbone.features.1.0.block.2.0.weight -> 4
  backbone.features.1.0.bl

Epoch 1/5:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 1 training loss 5.1307


Epoch 2/5:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 2 training loss 0.7005


Epoch 3/5:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 3 training loss 0.5687


Epoch 4/5:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 4 training loss 0.5464


Epoch 5/5:   0%|          | 0/625 [00:00<?, ?it/s]

Epoch 5 training loss 0.5372
Recomputing BN running statistics...
Extracting features from 16000 frames in train set...


  0%|          | 0/16000 [00:00<?, ?it/s]

Extracting features from 4000 frames in val set...


  0%|          | 0/4000 [00:00<?, ?it/s]


--- Ensemble Classifier Performance ---
Ensemble accuracy: 92.50%
{8: 92.5}
Loading FP32 checkpoint: /content/qat_model_bits_8.pth
Saved quantized checkpoint to /content/drive/My Drive/Deep Fake Dataset/trained_cnn/quantized8_model_10epoch.pth
Meta: {'num_quantized': 78, 'num_fallback_fp32': 762, 'total_fp32_bytes_preexisting': 92129408, 'total_quant_bytes_after': 52296}
Approx compression ratio (FP32 bytes / quant bytes): 1761.69x (higher is better)


'/content/drive/My Drive/Deep Fake Dataset/trained_cnn/quantized8_model_10epoch.pth'

##Save in Mixed Precision

In [None]:
OUT_DIR = os.path.join(DATA_ROOT, 'trained_cnn')

FP32_QAT_PATH = os.path.join(OUT_DIR, f'qat_model_fp32_bitsbudget_{bit_budget}.pth')
quant_out_path = os.path.join(OUT_DIR, os.path.basename(FP32_QAT_PATH).replace('.pth','_mixedbits.pth'))

def canonicalize_state_dict(sd):
    canon = {}
    def score(k):
        if k.endswith(".module.weight"): return 3
        if k.endswith(".weight"): return 2
        if "w_q_cached" in k: return 1
        return 0
    for k, t in sd.items():
        if k.endswith(".module.weight"):
            can = k[:-len(".module.weight")] + ".weight"
        elif k.endswith("._w_q_cached"):
            can = k[:-len("._w_q_cached")] + ".weight"
        else:
            can = k
        s = score(k)
        prev = canon.get(can)
        if prev is None or s > prev[1]:
            canon[can] = (t.clone().to(torch.float32), s)
    return {k: v_s[0] for k, v_s in canon.items()}



if os.path.exists(FP32_QAT_PATH):
  raw = torch.load(FP32_QAT_PATH, map_location='cpu')
  sd = raw['state_dict'] if (isinstance(raw, dict) and 'state_dict' in raw) else raw
  canon_sd = canonicalize_state_dict(sd)
  save_quantized_checkpoint_mixedbits(canon_sd, bits_assignment_partial, quant_out_path,
                                       exclude_prefixes=('dense_128','classifier'), eps=1e-8)

    print("Saved mixed-bit quant file to:", quant_out_path)