# F. Quantization for ResNet using the LAV-DF Dataset

## Helper Functions and Classes

In [None]:
from google.colab import drive
import os
import glob
import random
import numpy as np

drive.mount('/content/drive')
DATA_ROOT = '/content/drive/My Drive/Deep Fake Dataset'
PROCESSED_ROOT = os.path.join(DATA_ROOT, 'processed_faces')

Mounted at /content/drive


In [None]:
!cp -r "/content/drive/My Drive/Deep Fake Dataset/all.zip" "/content/"

In [None]:
!unzip /content/all.zip -d /content/all_unzipped

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import os
import glob
import json
from PIL import Image

# 1. Model Definition (ResNet-18 Version)
class DeepFakeHybridModel(nn.Module):
    def __init__(self):
        super(DeepFakeHybridModel, self).__init__()

        # Base: ResNet-18
        self.backbone = models.resnet18(pretrained=True)
        num_features = self.backbone.fc.in_features  # 512

        # Remove original classifier
        self.backbone.fc = nn.Identity()

        # Custom Head (Maintains 512-dim latent requirement)
        self.dense_512 = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5)
        )

        # Final Softmax (2 classes)
        self.classifier = nn.Linear(512, 2)

    def forward(self, x):
        x = self.backbone(x)         # 512
        feat_512 = self.dense_512(x) # 512 (For Ensemble)
        out = self.classifier(feat_512)     # Classification (For Training)
        return out, feat_512

class FakeQuantWrapper(nn.Module):
    def __init__(self, module: nn.Module, weight_bits:int=8, act_bits:int=8, eps:float=1e-8):
        super().__init__()
        self.module = module
        self.weight_bits = int(weight_bits)
        self.act_bits = int(act_bits)
        self.eps = float(eps)
        self.register_buffer('_w_q_cached', torch.empty(0))
        self.register_buffer('_scale_cached', torch.empty(0))
        self._cache_initialized = False

    def update_cached_weight(self, device=None):
        w = self.module.weight.data
        if device is not None: w = w.to(device)
        w_q, scale = quantize_tensor_symmetric_perchannel(w, self.weight_bits, eps=self.eps, channel_dim=0)
        self._w_q_cached = w_q.detach().to(self.module.weight.device)
        self._scale_cached = scale.detach().to(self.module.weight.device)
        self._cache_initialized = True

    def forward(self, x):
        if not self._cache_initialized or self._w_q_cached.numel() == 0:
            self.update_cached_weight()

        w_q_cached = self._w_q_cached
        if isinstance(self.module, nn.Conv2d):
            w_q_for_backward = (w_q_cached - self.module.weight).detach() + self.module.weight
            out = F.conv2d(x, w_q_for_backward, self.module.bias, self.module.stride,
                           self.module.padding, self.module.dilation, self.module.groups)
        elif isinstance(self.module, nn.Linear):
            w_q_for_backward = (w_q_cached - self.module.weight).detach() + self.module.weight
            out = F.linear(x, w_q_for_backward, self.module.bias)
        else:
            out = self.module(x)

        if self.act_bits < 32:
            qmax = 2 ** (self.act_bits - 1) - 1
            max_abs = out.abs().amax()
            eps_t = torch.tensor(self.eps, device=out.device, dtype=out.dtype)
            scale = (max_abs / qmax).clamp(min=eps_t)
            q = torch.round(out / scale).clamp(-qmax, qmax)
            out_q = q * scale
            out = (out_q - out).detach() + out
        return out

# 2. Data Loading (JSON Split Based)
split_json_path = os.path.join(DATA_ROOT, 'video_splits.json')

if os.path.exists(split_json_path):
    with open(split_json_path, 'r') as f:
        splits = json.load(f)

    train_videos = splits['train_videos']
    val_videos = splits['val_videos']

    print(f"Data Split Loaded: {len(train_videos)} Train, {len(val_videos)} Val")

else:
    raise FileNotFoundError(f"Could not find {split_json_path}")

# Dataset for Local Frames
class DeepFakeFrameDataset(Dataset):
    def __init__(self, video_list, image_dir, transform=None):
        self.transform = transform
        self.samples = []

        # 1.
        all_frames = glob.glob(os.path.join(image_dir, '**', '*.jpg'), recursive=True)
        frame_lookup = {}
        for f in all_frames:
            # key = '014471' from '014471_frame0.jpg' or '.../014471/frame.jpg'
            fname = os.path.basename(f)
            vid_id = fname.rsplit('_', 1)[0] if '_' in fname else fname
            if vid_id not in frame_lookup: frame_lookup[vid_id] = []
            frame_lookup[vid_id].append(f)

        # 2. Match JSON list to local frames
        for vid_path, label in video_list:
            # Extract ID from full path (e.g., .../train/034843.mp4 -> 034843)
            vid_id = os.path.splitext(os.path.basename(vid_path))[0]

            if vid_id in frame_lookup:
                for f_path in frame_lookup[vid_id]:
                    self.samples.append((f_path, label))


    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert('RGB')
        if self.transform: img = self.transform(img)
        return img, label

# Setup Loaders
train_tf = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

TRAIN_DIR_LOCAL = '/content/all_unzipped'
train_dataset = DeepFakeFrameDataset(train_videos, TRAIN_DIR_LOCAL, transform=train_tf)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Loading fixed splits from /content/drive/My Drive/Deep Fake Dataset/video_splits.json...
Data Split Loaded: 1984 Train, 1984 Val
Indexing frames in /content/all_unzipped...
Matching videos to frames...
Dataset ready: 18112 frames.


In [None]:
import glob
import numpy as np
import os
from PIL import Image
from tqdm.auto import tqdm
import torch
import torchvision.transforms as T
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score, classification_report

calib_loader = train_loader
loss_fn = nn.CrossEntropyLoss()

def extract_features_from_disk(split_name, model):

    model.eval()

    # Select the correct video list based on the requested split
    if split_name == 'train':
        target_list = train_videos
    elif split_name == 'val':
        target_list = val_videos
    else:
        print(f"Unknown split {split_name}")
        return np.array([]), np.array([])

    local_img_dir = '/content/all_unzipped'
    print(f"Extracting features for {split_name} ({len(target_list)} videos)...")

    # Pre-index frames (Optimization)
    all_frames = glob.glob(os.path.join(local_img_dir, '**', '*.jpg'), recursive=True)
    frame_lookup = {}
    for f in all_frames:
        vid_id = os.path.basename(f).rsplit('_', 1)[0]
        if vid_id not in frame_lookup: frame_lookup[vid_id] = []
        frame_lookup[vid_id].append(f)

    transform = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    final_X, final_y = [], []

    with torch.no_grad():
        for vid_path, label in tqdm(target_list):
            vid_id = os.path.splitext(os.path.basename(vid_path))[0]

            if vid_id in frame_lookup:
                batch_feats = []
                for img_path in frame_lookup[vid_id]:
                    try:
                        img = Image.open(img_path).convert('RGB')
                        img_t = transform(img).unsqueeze(0).to(device)
                        _, feat_512 = model(img_t) # Second output is feature
                        batch_feats.append(feat_512.cpu().numpy().flatten())
                    except: continue

                if batch_feats:
                    final_X.append(np.mean(batch_feats, axis=0))
                    final_y.append(label)

    return np.array(final_X), np.array(final_y)

# -----------------------------
# Parameter helpers
# -----------------------------
def named_weight_params(model):

    for name, p in model.named_parameters():
        if p.requires_grad and p.ndim >= 1 and 'bias' not in name:
            yield name, p

def is_weight_param(name, param):
    return param.requires_grad and param.ndim >= 1 and 'bias' not in name

def quantizable_param_names(model, exclude_prefixes=('dense_128','classifier')):
    names = []
    for n,p in model.named_parameters(recurse=True):
        if not is_weight_param(n,p):
            continue
        if any(n.startswith(pref) for pref in exclude_prefixes):
            continue
        names.append(n)
    return names


# -----------------------------
# Hutchinson Hessian diagonal estimator
# -----------------------------
def hutchinson_diag(model, loss_fn, dataloader, device='cuda', num_vectors=8, max_batches=2):
    model.to(device); model.eval()
    param_items = list(named_weight_params(model))
    if not param_items:
        raise RuntimeError("No quantizable parameters found for Hutchinson estimator.")
    param_list = [p for _, p in param_items]
    param_names = [n for n, _ in param_items]
    diag_accum = [torch.zeros_like(p, device=device) for p in param_list]
    total_draws = 0

    for batch_idx, batch in enumerate(dataloader):
        if max_batches is not None and batch_idx >= max_batches:
            break
        model.zero_grad()
        inputs, targets = batch
        inputs = inputs.to(device); targets = targets.to(device)
        outputs = model(inputs)
        logits = outputs[0] if isinstance(outputs, (tuple, list)) else outputs
        loss = loss_fn(logits, targets)
        grads = torch.autograd.grad(loss, param_list, create_graph=True)

        for _ in range(num_vectors):
            vs = [(torch.randint_like(p, low=0, high=2, device=device).float()*2 - 1) for p in param_list]
            inner = sum((g * v).sum() for g, v in zip(grads, vs))
            Hv = torch.autograd.grad(inner, param_list, retain_graph=True)
            for i, (v, hvi) in enumerate(zip(vs, Hv)):
                diag_accum[i] += (v * hvi).detach()
            total_draws += 1

        # free intermediate grads
        for g in grads:
            del g

    if total_draws == 0:
        raise RuntimeError("No batches processed in Hutchinson estimator.")
    return {name: (d / total_draws).cpu() for name, d in zip(param_names, diag_accum)}

def get_diag_estimates(model, loss_fn, calib_loader, device='cuda', hutchinson_vectors=8, hutchinson_batches=2):
    print("Running Hutchinson Hessian diag estimator.")
    return hutchinson_diag(model, loss_fn, calib_loader, device=device,
                           num_vectors=hutchinson_vectors, max_batches=hutchinson_batches)


# -----------------------------
# Quantize helpers
# -----------------------------
def quantize_tensor_symmetric_perchannel(w: torch.Tensor, nbits:int, eps:float = 1e-8, channel_dim:int = 0):
    if nbits < 1:
        raise ValueError("nbits must be >= 1")
    qmax = 2 ** (nbits - 1) - 1
    with torch.no_grad():
        if w.ndim >= 2:
            reduce_dims = [i for i in range(w.ndim) if i != channel_dim]
            max_abs = w.abs().amax(dim=reduce_dims, keepdim=True)
            eps_t = torch.tensor(eps, device=w.device, dtype=w.dtype)
            scale = (max_abs / qmax).clamp(min=eps_t)
            q = torch.round(w / scale).clamp(-qmax, qmax)
            w_q = q * scale
        else:
            max_abs = w.abs().amax()
            eps_t = torch.tensor(eps, device=w.device, dtype=w.dtype)
            scale = (max_abs / qmax).clamp(min=eps_t)
            q = torch.round(w / scale).clamp(-qmax, qmax)
            w_q = q * scale

        if not torch.isfinite(w_q).all() or not torch.isfinite(scale).all():
            return w.detach().clone(), scale
    return w_q, scale

def estimated_delta_loss_from_quant(w: torch.Tensor, diag_h: torch.Tensor, w_q: torch.Tensor):
    delta = (w_q - w).detach().cpu()
    h = diag_h.cpu()
    if h.shape != delta.shape:
        try:
            h = h.expand_as(delta)
        except Exception:
            h = torch.ones_like(delta) * float(h.mean().item())
    return 0.5 * (h * (delta ** 2)).sum().item()

def compute_costs_filtered(model, diag_h_dict, include_names=None, base_bits=4, max_bits=8):
    costs = {}
    for name, p in named_weight_params(model):
        if include_names is not None and name not in include_names:
            continue
        w = p.data
        diag = diag_h_dict.get(name, torch.ones_like(w.cpu()) * 1e-6)
        costs[name] = {}
        for b in range(base_bits, max_bits+1):
            w_q, _ = quantize_tensor_symmetric_perchannel(w, b)
            costs[name][b] = estimated_delta_loss_from_quant(w, diag, w_q)
    return costs

def allocate_bits_greedy(costs, base_bits=4, max_bits=8, bit_budget=None):
    names = list(costs.keys())
    assignment = {name: base_bits for name in names}
    if bit_budget is None:
        for name in names:
            assignment[name] = min(range(base_bits, max_bits + 1), key=lambda b: costs[name][b])
        return assignment

    total_layers = len(names)
    if bit_budget <= max_bits:
        total_bits_allowed = int(bit_budget * total_layers)
    else:
        total_bits_allowed = int(bit_budget)
    total_bits = sum(assignment.values())
    if total_bits >= total_bits_allowed:
        return assignment

    while total_bits < total_bits_allowed:
        best_gain = -float('inf'); best_name = None
        for name in names:
            cur_b = assignment[name]
            if cur_b >= max_bits:
                continue
            gain = costs[name][cur_b] - costs[name][cur_b + 1]
            if gain > best_gain:
                best_gain = gain; best_name = name
        if best_name is None:
            break
        assignment[best_name] += 1
        total_bits += 1

    return assignment


# -----------------------------
# Wrap model with FakeQuantWrapper
# -----------------------------
def wrap_model_with_fakequant(model, bits_assignment, act_bits=8):
    name_to_parent = {}
    for p_name, p in model.named_parameters(recurse=True):
        if p_name in bits_assignment and p_name.endswith('weight'):
            parts = p_name.split('.')
            parent = model
            for part in parts[:-1]:
                parent = getattr(parent, part)
            name_to_parent[p_name] = (parent, parts[:-1])

    wrapped = set()
    for p_name, (parent_module, path) in name_to_parent.items():
        if len(path) == 0:
            continue
        grandparent = model
        for part in path[:-1]:
            grandparent = getattr(grandparent, part)
        attr = path[-1]
        orig_mod = getattr(grandparent, attr)
        if id(orig_mod) in wrapped:
            continue
        if isinstance(orig_mod, (nn.Conv2d, nn.Linear)):
            wbits = bits_assignment.get(p_name, 8)
            wrapped_mod = FakeQuantWrapper(orig_mod, weight_bits=wbits, act_bits=act_bits)
            setattr(grandparent, attr, wrapped_mod)
            wrapped.add(id(wrapped_mod))
    return model


# -----------------------------
# QAT training loop
# -----------------------------
def qat_train(
    model,
    train_loader,
    val_loader,
    device,
    epochs: int = 6,
    lr: float = 1e-4,
    max_grad_norm: float = 1.0,
    log_every: int = 50,
    recompute_bn: bool = True,
    recompute_bn_batches: int = 200,
):
    model.to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4, eps=1e-8)
    loss_fn = nn.CrossEntropyLoss()

    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        n = 0
        pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}")
        for i, (xb, yb) in pbar:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad()
            try:
                with torch.cuda.amp.autocast():
                    out = model(xb)
                    logits = out[0] if isinstance(out, (tuple, list)) else out
                    loss = loss_fn(logits, yb)
            except Exception as e:
                print(f"Forward error (epoch {epoch+1} batch {i}): {e}. Skipping batch.")
                continue

            if not torch.isfinite(loss):
                print(f"Non-finite loss at epoch {epoch+1} batch {i}; skipping batch.")
                continue

            scaler.scale(loss).backward()

            # gradient sanity check
            bad_grad = False
            for name, p in model.named_parameters():
                if p.grad is None:
                    continue
                if not torch.isfinite(p.grad).all():
                    print(f"Non-finite gradient in {name} at epoch {epoch+1} batch {i}. Skipping batch.")
                    bad_grad = True
                    break
            if bad_grad:
                opt.zero_grad()
                continue


            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            scaler.step(opt)
            scaler.update()

            running_loss += float(loss.item())
            n += 1
            if (i + 1) % log_every == 0:
                pbar.set_postfix({'avg_loss': f"{(running_loss / max(1, n)):.4f}"})

        if n > 0:
            print(f"Epoch {epoch+1} avg loss: {(running_loss / n):.4f}")
        else:
            print(f"Epoch {epoch+1} had no successful batches.")

    # recompute BN running stats
    if recompute_bn:
        def _recompute_bn_stats(model_, loader_, device_, num_batches_=recompute_bn_batches):
            model_.train()
            with torch.no_grad():
                it = iter(loader_)
                for _i in range(num_batches_):
                    try:
                        xb_, _ = next(it)
                    except StopIteration:
                        it = iter(loader_)
                        xb_, _ = next(it)
                    xb_ = xb_.to(device_)
                    _ = model_(xb_)
        try:
            print("Recomputing BN running statistics...")
            _recompute_bn_stats(model, train_loader, device, recompute_bn_batches)
        except Exception as e:
            print("Warning: recompute BN failed:", e)

    return model


# -----------------------------
# Dequantize loader
# -----------------------------
def dequantize_state_dict_from_disk(quant_path: str, device='cpu'):
    data = torch.load(quant_path, map_location='cpu')
    out_state = OrderedDict()
    for name, t in data.get('fp32_fallback', {}).items():
        out_state[name] = t.to(torch.float32).clone().to(device)
    for name, info in data.get('quant', {}).items():
        q = info['q']              # integer tensor (int8/int16)
        scale = info['scale']      # float32
        w_fp = q.to(torch.float32) * scale
        out_state[name] = w_fp.to(torch.float32).clone().to(device)
    return out_state


# -----------------------------
#Mixed-bit integer saver
# -----------------------------
def save_quantized_checkpoint_mixedbits(
    fp32_state_dict: dict,
    bits_assignment: dict,
    out_path: str,
    exclude_prefixes: tuple = ('dense_128', 'classifier'),
    eps: float = 1e-8
):

    quant_store = {}
    fp32_fallback = {}
    total_fp32_bytes = 0
    total_quant_bytes = 0

    for name, tensor in fp32_state_dict.items():
        is_bias = name.endswith("bias")
        if is_bias or any(name.startswith(pref) for pref in exclude_prefixes):
            fp32_fallback[name] = tensor.clone().to(torch.float32)
            total_fp32_bytes += tensor.numel() * 4
            continue

        bits = bits_assignment.get(name, None)
        if bits is None:
            fp32_fallback[name] = tensor.clone().to(torch.float32)
            total_fp32_bytes += tensor.numel() * 4
            continue

        t = tensor.clone().to(torch.float32)
        qmax = 2 ** (bits - 1) - 1
        if t.ndim >= 2:
            reduce_dims = [i for i in range(t.ndim) if i != 0]
            max_abs = t.abs().amax(dim=reduce_dims, keepdim=True)
        else:
            max_abs = t.abs().max()
        scale = (max_abs / qmax).clamp(min=eps)
        q = torch.round(t / scale).clamp(-qmax, qmax)

        if bits <= 8:
            q_int = q.to(torch.int8)
            dtype_str = "int8"
            byte_size = 1
        else:
            q_int = q.to(torch.int16)
            dtype_str = "int16"
            byte_size = 2

        quant_store[name] = {'q': q_int.cpu(), 'scale': scale.cpu(), 'bits': int(bits),
                             'dtype': dtype_str, 'shape': tuple(t.shape)}
        total_quant_bytes += q_int.numel() * byte_size
        total_quant_bytes += scale.numel() * 4

    meta = {'num_quantized': len(quant_store), 'num_fallback_fp32': len(fp32_fallback),
            'total_fp32_bytes_preexisting': total_fp32_bytes, 'total_quant_bytes_after': total_quant_bytes}

    out = {'quant': quant_store, 'fp32_fallback': fp32_fallback, 'meta': meta}
    torch.save(out, out_path)
    approx_ratio = (total_fp32_bytes / total_quant_bytes) if total_quant_bytes>0 else float('inf')
    print(f"Saved mixed-bit quantized checkpoint to {out_path} (bits assignment used).")
    print("Meta:", meta)
    print(f"Compression ratio (FP32 bytes / quant bytes): {approx_ratio:.2f}x")
    return out_path

##Start Training

In [None]:


CHECKPOINT_PATH = os.path.join(DATA_ROOT, 'trained_cnn', 'resnet18_lav_df.pth')

exclude_prefixes = ('dense_128', 'classifier')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hutchinson / quant settings
hutch_vectors = 16
hutch_batches = 2
base_bits = 6                     # minimum bits considered
max_bits = 8                      # maximum bits to consider (keep 8)
activation_bits = 8               # activation fake-quant bits
qat_epochs = 5                    # fine-tune epochs (2-8)
qat_lr = 2e-5                     # learning rate for QAT

bit_budgets = [8]

val_accs = {}
loss_fn = nn.CrossEntropyLoss()

for bit_budget in bit_budgets:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = DeepFakeHybridModel().to(device)

    # Load pretrained checkpoint if present
    if os.path.exists(CHECKPOINT_PATH):
        print(f"Loading weights from {CHECKPOINT_PATH}...")
        model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=device))
        print("Model loaded successfully.")
    else:
        print(f"Warning: CHECKPOINT_PATH not found at {CHECKPOINT_PATH}; starting from random init.")

    # 1) Hessian diag (Hutchinson only) -- note new signature
    diag_h = get_diag_estimates(model, loss_fn, calib_loader,
                                device=device,
                                hutchinson_vectors=hutch_vectors,
                                hutchinson_batches=hutch_batches)

    # 2) build list of param names to quantize
    include_names = quantizable_param_names(model, exclude_prefixes=exclude_prefixes)
    print("Will quantize these params:", include_names[:10])

    # 3) compute costs only for included params
    costs = compute_costs_filtered(model, diag_h, include_names=include_names,
                                   base_bits=base_bits, max_bits=max_bits)

    # 4) allocate bits greedily for these params
    bits_assignment_partial = allocate_bits_greedy(costs, base_bits=base_bits, max_bits=max_bits, bit_budget=bit_budget)

    print("Assigned bits:")
    for i, (k, v) in enumerate(bits_assignment_partial.items()):
        if i >= 25: break
        print(f"  {k} -> {v}")

    # 5) Wrap model (only modules whose weight param names are in bits_assignment_partial)
    qat_model = model.train()
    qat_model = wrap_model_with_fakequant(qat_model, bits_assignment_partial, act_bits=activation_bits)
    qat_model.to(device)

    # 6) Run QAT fine-tuning (simplified qat_train returns model only)
    qat_model = qat_train(qat_model, train_loader, calib_loader, device, epochs=qat_epochs, lr=qat_lr)

    # 7) Save FP32 QAT state_dict
    out_path = f"qat_model_bits_{bit_budget}.pth"
    torch.save(qat_model.state_dict(), out_path)
    print(f"Saved QAT FP32 state_dict to {out_path}")

    # 8) Evaluate: extract features and train ensemble classifiers
    print("Extracting features for classifier evaluation...")
    X_train, y_train = extract_features_from_disk('train', qat_model, dataset_root="/content/processed_faces", device=device)
    X_val, y_val = extract_features_from_disk('val', qat_model, dataset_root="/content/processed_faces", device=device)

    from sklearn.svm import SVC
    from sklearn.tree import DecisionTreeClassifier
    from sklearn.neighbors import KNeighborsClassifier
    from sklearn.naive_bayes import GaussianNB
    from sklearn.metrics import accuracy_score
    import numpy as np

    clfs = {
        'SVM': SVC(probability=True, kernel='rbf', random_state=42),
        'DT': DecisionTreeClassifier(random_state=42),
        'KNN': KNeighborsClassifier(n_neighbors=5),
        'NB': GaussianNB()
    }

    preds_stack = []
    print("\nTraining classifiers...")
    for name, clf in clfs.items():
        clf.fit(X_train, y_train)
        preds = clf.predict(X_val)
        preds_stack.append(preds)
        acc = accuracy_score(y_val, preds)
        print(f"  {name} val acc: {acc*100:.2f}%")

    preds_stack = np.array(preds_stack)
    votes = preds_stack.sum(axis=0)
    ensemble_pred = (votes >= 2).astype(int)
    ensemble_acc = accuracy_score(y_val, ensemble_pred)
    print(f"\nEnsemble (majority >=2) accuracy: {ensemble_acc*100:.2f}%")

    val_accs[bit_budget] = ensemble_acc
    print("val_accs so far:", val_accs)




Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 244MB/s]


Loading weights from /content/drive/My Drive/Deep Fake Dataset/trained_cnn/resnet18_lav_df.pth...
Model loaded successfully.
Running Hutchinson Hessian diag estimator...
Will quantize these params (sample): ['backbone.conv1.weight', 'backbone.bn1.weight', 'backbone.layer1.0.conv1.weight', 'backbone.layer1.0.bn1.weight', 'backbone.layer1.0.conv2.weight', 'backbone.layer1.0.bn2.weight', 'backbone.layer1.1.conv1.weight', 'backbone.layer1.1.bn1.weight', 'backbone.layer1.1.conv2.weight', 'backbone.layer1.1.bn2.weight']
Assigned bits (sample):
  backbone.conv1.weight -> 8
  backbone.bn1.weight -> 8
  backbone.layer1.0.conv1.weight -> 8
  backbone.layer1.0.bn1.weight -> 8
  backbone.layer1.0.conv2.weight -> 8
  backbone.layer1.0.bn2.weight -> 8
  backbone.layer1.1.conv1.weight -> 8
  backbone.layer1.1.bn1.weight -> 8
  backbone.layer1.1.conv2.weight -> 8
  backbone.layer1.1.bn2.weight -> 8
  backbone.layer2.0.conv1.weight -> 8
  backbone.layer2.0.bn1.weight -> 8
  backbone.layer2.0.conv2.weig

Epoch 1/5:   0%|          | 0/566 [00:00<?, ?it/s]

Epoch 2/5:   0%|          | 0/566 [00:00<?, ?it/s]

Epoch 3/5:   0%|          | 0/566 [00:00<?, ?it/s]

Epoch 4/5:   0%|          | 0/566 [00:00<?, ?it/s]

Epoch 5/5:   0%|          | 0/566 [00:00<?, ?it/s]

Extracting features for val (1984 videos)...


  0%|          | 0/1984 [00:00<?, ?it/s]

Extracting features for train (1984 videos)...


  0%|          | 0/1984 [00:00<?, ?it/s]

Ensemble Acc: 97.80%
{8: 97.79989523310634}


In [None]:
import pandas as pd
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB

def save_ensemble_predictions(model, output_csv_path, local_img_dir='/content/all_unzipped'):

    model.eval()

    # --- Helper: Extract Features + Video IDs ---
    def get_features_and_ids(split_name, target_list):
        print(f"Extracting features for {split_name} ({len(target_list)} videos)...")

        # 1. Index local frames (Optimization)
        all_frames = glob.glob(os.path.join(local_img_dir, '**', '*.jpg'), recursive=True)
        frame_lookup = {}
        for f in all_frames:
            # Assumes filename format: videoID_frameX.jpg
            vid_id = os.path.basename(f).rsplit('_', 1)[0]
            if vid_id not in frame_lookup: frame_lookup[vid_id] = []
            frame_lookup[vid_id].append(f)

        # 2. Transform
        transform = T.Compose([
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        X, y, ids = [], [], []

        # 3. Extraction Loop
        with torch.no_grad():
            for vid_path, label in tqdm(target_list):
                # Clean video ID from the path provided in JSON/CSV
                vid_id = os.path.splitext(os.path.basename(vid_path))[0]

                if vid_id in frame_lookup:
                    batch_feats = []
                    for img_path in frame_lookup[vid_id]:
                        try:
                            img = Image.open(img_path).convert('RGB')
                            img_t = transform(img).unsqueeze(0).to(device)
                            # Model returns (logits, features) -> take features [1]
                            _, feat_512 = model(img_t)
                            batch_feats.append(feat_512.cpu().numpy().flatten())
                        except: continue

                    if batch_feats:
                        # Average pooling per video
                        X.append(np.mean(batch_feats, axis=0))
                        y.append(label)
                        ids.append(vid_id)

        return np.array(X), np.array(y), ids

    # We must re-extract Training features to fit the classifiers
    X_train, y_train, _ = get_features_and_ids('train', train_videos)

    X_val, y_val, val_ids = get_features_and_ids('val', val_videos)

    if len(X_train) == 0 or len(X_val) == 0:
        print("Error: Extraction failed (empty lists). Check image directory path.")
        return

    clfs = {
        'SVM': SVC(probability=True, kernel='rbf'),
        'DT': DecisionTreeClassifier(random_state=42),
        'KNN': KNeighborsClassifier(n_neighbors=5),
        'NB': GaussianNB()
    }

    preds_stack = []

    for name, clf in clfs.items():
        clf.fit(X_train, y_train)
        p = clf.predict(X_val)
        preds_stack.append(p)
        print(f"  {name} fitted and predicted.")


    # Stack predictions: Shape (4, N_samples)
    votes = np.array(preds_stack).sum(axis=0)
    # Ensemble Rule: If >= 2 classifiers say Fake (1), predict Fake
    ensemble_pred = (votes >= 2).astype(int)

    # Construct DataFrame
    df_results = pd.DataFrame({
        'video_id': val_ids,
        'true_label': y_val,
        'ensemble_pred': ensemble_pred,
        # Optional: Include individual classifier votes
        'svm': preds_stack[0],
        'dt': preds_stack[1],
        'knn': preds_stack[2],
        'nb': preds_stack[3]
    })

    # Save
    df_results.to_csv(output_csv_path, index=False)
    print(f"✅ CSV saved successfully to: {output_csv_path}")
    print(f"Total Validation Samples: {len(df_results)}")

    # Quick Accuracy Check
    acc = accuracy_score(y_val, ensemble_pred)
    print(f"Calculated Ensemble Accuracy: {acc*100:.2f}%")

# --- EXECUTION ---
output_file = os.path.join(DATA_ROOT, 'ensemble_predictions_lav.csv')

# Ensure 'qat_model' is the variable name of your trained model from the previous step
save_ensemble_predictions(qat_model, output_file)

--- Step 1: Extracting Training Features ---
Extracting features for train (1984 videos)...


  0%|          | 0/1984 [00:00<?, ?it/s]


--- Step 2: Extracting Validation Features ---
Extracting features for val (1984 videos)...


  0%|          | 0/1984 [00:00<?, ?it/s]


--- Step 3: Fitting Ensemble Classifiers ---
  SVM fitted and predicted.
  DT fitted and predicted.
  KNN fitted and predicted.
  NB fitted and predicted.

--- Step 4: Voting & Saving ---
✅ CSV saved successfully to: /content/drive/My Drive/Deep Fake Dataset/ensemble_predictions_lav.csv
Total Validation Samples: 1909
Calculated Ensemble Accuracy: 96.33%


In [None]:
import torch, os, math
from collections import OrderedDict, Counter

FP32_QAT_PATH = "/content/qat_model_bits_8.pth"   # the checkpoint you currently have (wrapped model)
quant_out_path = "/content/drive/My Drive/Deep Fake Dataset/trained_cnn/quantized_resnet_all8_mapped_lav.pth"

def canonicalize_state_dict(sd):
    canon = {}
    def score(k):
        if k.endswith(".module.weight"): return 3
        if k.endswith(".weight"): return 2
        if "w_q_cached" in k: return 1
        return 0
    for k, t in sd.items():
        if k.endswith(".module.weight"):
            can = k[:-len(".module.weight")] + ".weight"
        elif k.endswith("._w_q_cached"):
            can = k[:-len("._w_q_cached")] + ".weight"
        else:
            can = k
        s = score(k)
        prev = canon.get(can)
        if prev is None or s > prev[1]:
            canon[can] = (t.clone().to(torch.float32), s)
    return {k: v_s[0] for k, v_s in canon.items()}



if os.path.exists(FP32_QAT_PATH):
  raw = torch.load(FP32_QAT_PATH, map_location='cpu')
  sd = raw['state_dict'] if (isinstance(raw, dict) and 'state_dict' in raw) else raw
  canon_sd = canonicalize_state_dict(sd)
  save_quantized_checkpoint_mixedbits(canon_sd, bits_assignment_partial, quant_out_path,
                                       exclude_prefixes=('dense_128','classifier'), eps=1e-8)

    print("Saved mixed-bit quant file to:", quant_out_path)


Loaded state_dict with 166 entries.
Canonical params found: 145
Prepared quant_store entries: 41
Prepared fp32_fallback entries: 104
Bytes: quant ints+scales: 11455184 fallback fp32: 85080
Saved to: /content/drive/My Drive/Deep Fake Dataset/trained_cnn/quantized_resnet_all8_mapped_lav.pth
Original (canon) FP32 total bytes (approx): 45820504  (~43.70 MB)
Quant stored ints+scales bytes (sum): 11455184  (~11186.70 KB)
Fallback FP32 bytes stored: 85080  (~0.08 MB)
On-disk size (bytes): 11602507  (~11.07 MB)
Meta: {'num_quantized': 41, 'num_fallback_fp32': 104, 'total_quant_bytes': 11455184, 'total_fallback_bytes': 85080}

Sample dequant checks for a few largest quantized params:
 backbone.layer4.0.conv2.weight: elems=2359296, bits=8, dtype=int8, max_abs=1.918875e-03, rmse=5.646053e-04
 backbone.layer4.1.conv1.weight: elems=2359296, bits=8, dtype=int8, max_abs=2.025366e-03, rmse=5.137838e-04
 backbone.layer4.1.conv2.weight: elems=2359296, bits=8, dtype=int8, max_abs=1.779601e-03, rmse=4.734