# F. Quantization for ResNet using the LAV-DF Dataset

## Helper Functions and Classes

In [None]:
from google.colab import drive
import os
import glob
import random
import numpy as np

drive.mount('/content/drive')
DATA_ROOT = '/content/drive/My Drive/Deep Fake Dataset'
PROCESSED_ROOT = os.path.join(DATA_ROOT, 'processed_faces')

Mounted at /content/drive


In [None]:
!cp -r "/content/drive/My Drive/Deep Fake Dataset/all.zip" "/content/"

In [None]:
!unzip /content/all.zip -d /content/all_unzipped

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import os
import glob
import json
from PIL import Image

# 1. Model Definition (ResNet-18 Version)
class DeepFakeHybridModel(nn.Module):
    def __init__(self):
        super(DeepFakeHybridModel, self).__init__()

        # Base: ResNet-18
        self.backbone = models.resnet18(pretrained=True)
        num_features = self.backbone.fc.in_features  # 512

        # Remove original classifier
        self.backbone.fc = nn.Identity()

        # Custom Head (Maintains 512-dim latent requirement)
        self.dense_512 = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5)
        )

        # Final Softmax (2 classes)
        self.classifier = nn.Linear(512, 2)

    def forward(self, x):
        x = self.backbone(x)         # 512
        feat_512 = self.dense_512(x) # 512 (For Ensemble)
        out = self.classifier(feat_512)     # Classification (For Training)
        return out, feat_512

class FakeQuantWrapper(nn.Module):
    def __init__(self, module: nn.Module, weight_bits:int=8, act_bits:int=8, eps:float=1e-8):
        super().__init__()
        self.module = module
        self.weight_bits = int(weight_bits)
        self.act_bits = int(act_bits)
        self.eps = float(eps)
        self.register_buffer('_w_q_cached', torch.empty(0))
        self.register_buffer('_scale_cached', torch.empty(0))
        self._cache_initialized = False

    def update_cached_weight(self, device=None):
        w = self.module.weight.data
        if device is not None: w = w.to(device)
        w_q, scale = quantize_tensor_symmetric_perchannel(w, self.weight_bits, eps=self.eps, channel_dim=0)
        self._w_q_cached = w_q.detach().to(self.module.weight.device)
        self._scale_cached = scale.detach().to(self.module.weight.device)
        self._cache_initialized = True

    def forward(self, x):
        if not self._cache_initialized or self._w_q_cached.numel() == 0:
            self.update_cached_weight()

        w_q_cached = self._w_q_cached
        if isinstance(self.module, nn.Conv2d):
            w_q_for_backward = (w_q_cached - self.module.weight).detach() + self.module.weight
            out = F.conv2d(x, w_q_for_backward, self.module.bias, self.module.stride,
                           self.module.padding, self.module.dilation, self.module.groups)
        elif isinstance(self.module, nn.Linear):
            w_q_for_backward = (w_q_cached - self.module.weight).detach() + self.module.weight
            out = F.linear(x, w_q_for_backward, self.module.bias)
        else:
            out = self.module(x)

        if self.act_bits < 32:
            qmax = 2 ** (self.act_bits - 1) - 1
            max_abs = out.abs().amax()
            eps_t = torch.tensor(self.eps, device=out.device, dtype=out.dtype)
            scale = (max_abs / qmax).clamp(min=eps_t)
            q = torch.round(out / scale).clamp(-qmax, qmax)
            out_q = q * scale
            out = (out_q - out).detach() + out
        return out

# 2. Data Loading (JSON Split Based)
split_json_path = os.path.join(DATA_ROOT, 'video_splits.json')
print(f"Loading fixed splits from {split_json_path}...")

if os.path.exists(split_json_path):
    with open(split_json_path, 'r') as f:
        splits = json.load(f)

    # Load directly from JSON
    train_videos = splits['train_videos']
    val_videos = splits['val_videos']

    print(f"Data Split Loaded: {len(train_videos)} Train, {len(val_videos)} Val")

else:
    raise FileNotFoundError(f"Could not find {split_json_path}")

# Custom Dataset for Local Frames
class DeepFakeFrameDataset(Dataset):
    def __init__(self, video_list, image_dir, transform=None):
        self.transform = transform
        self.samples = []

        # 1. Index local frames once
        print(f"Indexing frames in {image_dir}...")
        all_frames = glob.glob(os.path.join(image_dir, '**', '*.jpg'), recursive=True)
        frame_lookup = {}
        for f in all_frames:
            # key = '014471' from '014471_frame0.jpg' or '.../014471/frame.jpg'
            fname = os.path.basename(f)
            vid_id = fname.rsplit('_', 1)[0] if '_' in fname else fname
            if vid_id not in frame_lookup: frame_lookup[vid_id] = []
            frame_lookup[vid_id].append(f)

        # 2. Match JSON list to local frames
        print("Matching videos to frames...")
        for vid_path, label in video_list:
            # Extract ID from full path (e.g., .../train/034843.mp4 -> 034843)
            vid_id = os.path.splitext(os.path.basename(vid_path))[0]

            if vid_id in frame_lookup:
                for f_path in frame_lookup[vid_id]:
                    self.samples.append((f_path, label))

        print(f"Dataset ready: {len(self.samples)} frames.")

    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert('RGB')
        if self.transform: img = self.transform(img)
        return img, label

# Setup Loaders
train_tf = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

TRAIN_DIR_LOCAL = '/content/all_unzipped' # Frames must be here
train_dataset = DeepFakeFrameDataset(train_videos, TRAIN_DIR_LOCAL, transform=train_tf)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Loading fixed splits from /content/drive/My Drive/Deep Fake Dataset/video_splits.json...
Data Split Loaded: 1984 Train, 1984 Val
Indexing frames in /content/all_unzipped...
Matching videos to frames...
Dataset ready: 18112 frames.


In [None]:
import glob
import numpy as np
import os
from PIL import Image
from tqdm.auto import tqdm
import torch
import torchvision.transforms as T
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score, classification_report

calib_loader = train_loader
loss_fn = nn.CrossEntropyLoss()

def extract_features_from_disk(split_name, model):
    """
    Extracts features using the CSV-defined split lists (train_videos/val_videos).
    Looks for frames in /content/all.
    """
    model.eval()

    # Select the correct video list based on the requested split
    if split_name == 'train':
        target_list = train_videos
    elif split_name == 'val':
        target_list = val_videos
    else:
        print(f"Unknown split {split_name}")
        return np.array([]), np.array([])

    local_img_dir = '/content/all_unzipped'
    print(f"Extracting features for {split_name} ({len(target_list)} videos)...")

    # Pre-index frames (Optimization)
    all_frames = glob.glob(os.path.join(local_img_dir, '**', '*.jpg'), recursive=True)
    frame_lookup = {}
    for f in all_frames:
        vid_id = os.path.basename(f).rsplit('_', 1)[0]
        if vid_id not in frame_lookup: frame_lookup[vid_id] = []
        frame_lookup[vid_id].append(f)

    transform = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    final_X, final_y = [], []

    with torch.no_grad():
        for vid_path, label in tqdm(target_list):
            vid_id = os.path.splitext(os.path.basename(vid_path))[0]

            if vid_id in frame_lookup:
                batch_feats = []
                for img_path in frame_lookup[vid_id]:
                    try:
                        img = Image.open(img_path).convert('RGB')
                        img_t = transform(img).unsqueeze(0).to(device)
                        _, feat_512 = model(img_t) # Second output is feature
                        batch_feats.append(feat_512.cpu().numpy().flatten())
                    except: continue

                if batch_feats:
                    final_X.append(np.mean(batch_feats, axis=0))
                    final_y.append(label)

    return np.array(final_X), np.array(final_y)

def named_weight_params(model):
    for name, p in model.named_parameters():
        if p.requires_grad and p.ndim >= 1 and 'bias' not in name:
            yield name, p

def hutchinson_diag(model, loss_fn, dataloader, device='cuda', num_vectors=8, max_batches=2):
    model.to(device)
    model.eval()
    param_items = list(named_weight_params(model))
    param_list = [p for _, p in param_items]
    param_names = [n for n, _ in param_items]
    diag_accum = [torch.zeros_like(p, device=device) for p in param_list]
    total_draws = 0

    for batch_idx, batch in enumerate(dataloader):
        if max_batches is not None and batch_idx >= max_batches: break
        model.zero_grad()
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)

        # Handle Hybrid Tuple Output
        out = model(inputs)
        logits = out[0] if isinstance(out, tuple) else out
        loss = loss_fn(logits, targets)

        grads = torch.autograd.grad(loss, param_list, create_graph=True)
        for _ in range(num_vectors):
            vs = [(torch.randint_like(p, 0, 2, device=device).float()*2-1) for p in param_list]
            inner = sum((g * v).sum() for g, v in zip(grads, vs))
            Hv = torch.autograd.grad(inner, param_list, retain_graph=True)
            for i, (v, hvi) in enumerate(zip(vs, Hv)):
                diag_accum[i] += (v * hvi).detach()
            total_draws += 1
        for g in grads: del g

    if total_draws == 0: raise RuntimeError("No batches processed.")
    return {name: (d / total_draws).cpu() for name, d in zip(param_names, diag_accum)}

def empirical_fisher_diag(model, loss_fn, dataloader, device='cuda', max_batches=5):
    model.to(device)
    model.eval()
    param_items = list(named_weight_params(model))
    accum = [torch.zeros_like(p, device='cpu') for _, p in param_items]
    count = 0

    for batch_idx, batch in enumerate(dataloader):
        if max_batches is not None and batch_idx >= max_batches: break
        model.zero_grad()
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)

        out = model(inputs)
        logits = out[0] if isinstance(out, tuple) else out
        loss = loss_fn(logits, targets)

        grads = torch.autograd.grad(loss, [p for _, p in param_items], retain_graph=False)
        for i, g in enumerate(grads):
            accum[i] += (g.detach().cpu() ** 2)
        count += 1

    if count == 0: raise RuntimeError("No batches processed.")
    return {n: (a/count + 1e-12).cpu() for (n, _), a in zip(param_items, accum)}

def get_diag_estimates(model, loss_fn, calib_loader, device='cuda', try_hutchinson=True,
                       hutchinson_vectors=8, hutchinson_batches=2, fisher_batches=5):
    try:
        if try_hutchinson:
            print("Running Hutchinson Hessian diag estimator...")
            return hutchinson_diag(model, loss_fn, calib_loader, device, num_vectors=hutchinson_vectors, max_batches=hutchinson_batches)
    except Exception as e:
        print("Hutchinson failed:", repr(e))
    print("Falling back to empirical Fisher.")
    return empirical_fisher_diag(model, loss_fn, calib_loader, device, max_batches=fisher_batches)

In [None]:
import glob
import numpy as np
import os
from PIL import Image
from tqdm.auto import tqdm
import torch
import torchvision.transforms as T
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score, classification_report

calib_loader = train_loader
loss_fn = nn.CrossEntropyLoss()

def extract_features_from_disk(split_name, model):
    """
    Extracts features using the CSV-defined split lists (train_videos/val_videos).
    Looks for frames in /content/all.
    """
    model.eval()

    # Select the correct video list based on the requested split
    if split_name == 'train':
        target_list = train_videos
    elif split_name == 'val':
        target_list = val_videos
    else:
        print(f"Unknown split {split_name}")
        return np.array([]), np.array([])

    local_img_dir = '/content/all_unzipped'
    print(f"Extracting features for {split_name} ({len(target_list)} videos)...")

    # Pre-index frames (Optimization)
    all_frames = glob.glob(os.path.join(local_img_dir, '**', '*.jpg'), recursive=True)
    frame_lookup = {}
    for f in all_frames:
        vid_id = os.path.basename(f).rsplit('_', 1)[0]
        if vid_id not in frame_lookup: frame_lookup[vid_id] = []
        frame_lookup[vid_id].append(f)

    transform = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    final_X, final_y = [], []

    with torch.no_grad():
        for vid_path, label in tqdm(target_list):
            vid_id = os.path.splitext(os.path.basename(vid_path))[0]

            if vid_id in frame_lookup:
                batch_feats = []
                for img_path in frame_lookup[vid_id]:
                    try:
                        img = Image.open(img_path).convert('RGB')
                        img_t = transform(img).unsqueeze(0).to(device)
                        _, feat_512 = model(img_t) # Second output is feature
                        batch_feats.append(feat_512.cpu().numpy().flatten())
                    except: continue

                if batch_feats:
                    final_X.append(np.mean(batch_feats, axis=0))
                    final_y.append(label)

    return np.array(final_X), np.array(final_y)

def named_weight_params(model):
    for name, p in model.named_parameters():
        if p.requires_grad and p.ndim >= 1 and 'bias' not in name:
            yield name, p

def hutchinson_diag(model, loss_fn, dataloader, device='cuda', num_vectors=8, max_batches=2):
    model.to(device)
    model.eval()
    param_items = list(named_weight_params(model))
    param_list = [p for _, p in param_items]
    param_names = [n for n, _ in param_items]
    diag_accum = [torch.zeros_like(p, device=device) for p in param_list]
    total_draws = 0

    for batch_idx, batch in enumerate(dataloader):
        if max_batches is not None and batch_idx >= max_batches: break
        model.zero_grad()
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)

        # Handle Hybrid Tuple Output
        out = model(inputs)
        logits = out[0] if isinstance(out, tuple) else out
        loss = loss_fn(logits, targets)

        grads = torch.autograd.grad(loss, param_list, create_graph=True)
        for _ in range(num_vectors):
            vs = [(torch.randint_like(p, 0, 2, device=device).float()*2-1) for p in param_list]
            inner = sum((g * v).sum() for g, v in zip(grads, vs))
            Hv = torch.autograd.grad(inner, param_list, retain_graph=True)
            for i, (v, hvi) in enumerate(zip(vs, Hv)):
                diag_accum[i] += (v * hvi).detach()
            total_draws += 1
        for g in grads: del g

    if total_draws == 0: raise RuntimeError("No batches processed.")
    return {name: (d / total_draws).cpu() for name, d in zip(param_names, diag_accum)}

def empirical_fisher_diag(model, loss_fn, dataloader, device='cuda', max_batches=5):
    model.to(device)
    model.eval()
    param_items = list(named_weight_params(model))
    accum = [torch.zeros_like(p, device='cpu') for _, p in param_items]
    count = 0

    for batch_idx, batch in enumerate(dataloader):
        if max_batches is not None and batch_idx >= max_batches: break
        model.zero_grad()
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)

        out = model(inputs)
        logits = out[0] if isinstance(out, tuple) else out
        loss = loss_fn(logits, targets)

        grads = torch.autograd.grad(loss, [p for _, p in param_items], retain_graph=False)
        for i, g in enumerate(grads):
            accum[i] += (g.detach().cpu() ** 2)
        count += 1

    if count == 0: raise RuntimeError("No batches processed.")
    return {n: (a/count + 1e-12).cpu() for (n, _), a in zip(param_items, accum)}

def get_diag_estimates(model, loss_fn, calib_loader, device='cuda', try_hutchinson=True,
                       hutchinson_vectors=8, hutchinson_batches=2, fisher_batches=5):
    try:
        if try_hutchinson:
            print("Running Hutchinson Hessian diag estimator...")
            return hutchinson_diag(model, loss_fn, calib_loader, device, num_vectors=hutchinson_vectors, max_batches=hutchinson_batches)
    except Exception as e:
        print("Hutchinson failed:", repr(e))
    print("Falling back to empirical Fisher.")
    return empirical_fisher_diag(model, loss_fn, calib_loader, device, max_batches=fisher_batches)

In [None]:
def quantize_tensor_symmetric_perchannel(w: torch.Tensor, nbits:int, eps:float = 1e-8, channel_dim:int = 0):
    if nbits < 1: raise ValueError("nbits must be >= 1")
    qmax = 2 ** (nbits - 1) - 1
    with torch.no_grad():
        if w.ndim >= 2:
            reduce_dims = [i for i in range(w.ndim) if i != channel_dim]
            max_abs = w.abs().amax(dim=reduce_dims, keepdim=True)
            scale = (max_abs / qmax).clamp(min=eps)
            w_q = (torch.round(w / scale).clamp(-qmax, qmax)) * scale
        else:
            scale = (w.abs().amax() / qmax).clamp(min=eps)
            w_q = (torch.round(w / scale).clamp(-qmax, qmax)) * scale
        if not torch.isfinite(w_q).all(): return w.detach().clone(), scale
    return w_q, scale

def estimated_delta_loss_from_quant(w, diag_h, w_q):
    delta = (w_q - w).detach().cpu()
    h = diag_h.cpu()
    if h.shape != delta.shape: h = h.expand_as(delta)
    return 0.5 * (h * (delta ** 2)).sum().item()

def compute_costs_filtered(model, diag_h_dict, include_names=None, base_bits=4, max_bits=8):
    costs = {}
    for name, p in named_weight_params(model):
        if include_names is not None and name not in include_names: continue
        w = p.data
        diag = diag_h_dict.get(name, torch.ones_like(w.cpu()) * 1e-6)
        costs[name] = {}
        for b in range(base_bits, max_bits+1):
            w_q, _ = quantize_tensor_symmetric_perchannel(w, b)
            costs[name][b] = estimated_delta_loss_from_quant(w, diag, w_q)
    return costs

def allocate_bits_greedy(costs, base_bits=4, max_bits=8, bit_budget=None):
    names = list(costs.keys())
    assignment = {name: base_bits for name in names}
    if bit_budget is None:
        for name in names: assignment[name] = min(range(base_bits, max_bits+1), key=lambda b: costs[name][b])
        return assignment

    total_allowed = int(bit_budget * len(names)) if bit_budget <= max_bits else int(bit_budget)
    total_bits = sum(assignment.values())
    if total_bits >= total_allowed: return assignment

    while total_bits < total_allowed:
        best_gain, best_name = -float('inf'), None
        for name in names:
            cur = assignment[name]
            if cur >= max_bits: continue
            gain = costs[name][cur] - costs[name][cur+1]
            if gain > best_gain: best_gain, best_name = gain, name
        if best_name is None: break
        assignment[best_name] += 1
        total_bits += 1
    return assignment

def update_all_fakequant_caches(model, device=None):
    for m in model.modules():
        if isinstance(m, FakeQuantWrapper): m.update_cached_weight(device)

def wrap_model_with_fakequant(model, bits_assignment, act_bits=8):
    name_to_parent = {}
    for p_name, p in model.named_parameters(recurse=True):
        if p_name in bits_assignment and p_name.endswith('weight'):
            parts = p_name.split('.')
            parent = model
            for part in parts[:-1]: parent = getattr(parent, part)
            name_to_parent[p_name] = (parent, parts[:-1])
    wrapped = set()
    for p_name, (parent, path) in name_to_parent.items():
        if not path: continue
        grandparent = model
        for part in path[:-1]: grandparent = getattr(grandparent, part)
        attr = path[-1]
        orig = getattr(grandparent, attr)
        if id(orig) in wrapped: continue
        if isinstance(orig, (nn.Conv2d, nn.Linear)):
            wbits = bits_assignment.get(p_name, 8)
            wm = FakeQuantWrapper(orig, weight_bits=wbits, act_bits=act_bits)
            setattr(grandparent, attr, wm)
            wrapped.add(id(wm))
    return model

def qat_train(model, train_loader, val_loader, device, epochs=6, lr=1e-4):
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    for epoch in range(epochs):
        model.train()
        running_loss, n = 0.0, 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for xb, yb in pbar:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad()
            out = model(xb)
            logits = out[0] if isinstance(out, tuple) else out
            loss = F.cross_entropy(logits, yb)
            if torch.isfinite(loss):
                loss.backward()
                opt.step()
                try: update_all_fakequant_caches(model, device)
                except: pass
                running_loss += loss.item()
                n += 1
                pbar.set_postfix({'loss': f"{running_loss/max(1,n):.4f}"})

    # Eval
    try:
        X_val, y_val = extract_features_from_disk('val', model)
        if len(X_val) > 0:
            clf = SVC(probability=True, kernel='rbf')
            X_train, y_train = extract_features_from_disk('train', model)
            clf.fit(X_train, y_train)
            p = clf.predict(X_val)
            acc = accuracy_score(y_val, p) * 100
            print(f"Ensemble Acc: {acc:.2f}%")
            return model, acc
    except Exception as e: print("Eval failed:", e)
    return model, 0.0

def is_weight_param(name, param):
    return param.requires_grad and param.ndim >= 1 and 'bias' not in name

def quantizable_param_names(model, exclude_prefixes=('dense_512', 'classifier')):
    names = []
    for n,p in model.named_parameters(recurse=True):
        if is_weight_param(n,p):
            if not any(n.startswith(pref) for pref in exclude_prefixes):
                names.append(n)
    return names

from collections import OrderedDict
def save_quantized_checkpoint(fp32_ckpt_path, bits_assignment, out_path, exclude_prefixes=('dense_512','classifier'), eps=1e-8):
    ckpt = torch.load(fp32_ckpt_path, map_location='cpu')
    state_dict = ckpt['state_dict'] if isinstance(ckpt, dict) and 'state_dict' in ckpt else ckpt
    quant_store, fp32_fallback = {}, {}

    for name, t in state_dict.items():
        is_bias = name.lower().endswith('bias')
        if any(name.startswith(p) for p in exclude_prefixes) or is_bias:
            fp32_fallback[name] = t.clone().float()
            continue
        bits = bits_assignment.get(name, None)
        if bits is None:
            fp32_fallback[name] = t.clone().float()
        else:
            w = t.clone().float()
            # Inline quant logic for saving
            qmax = 2**(bits-1)-1
            if w.ndim >= 2:
                max_abs = w.abs().amax(dim=[i for i in range(w.ndim) if i!=0], keepdim=True)
            else: max_abs = w.abs().amax()
            scale = (max_abs/qmax).clamp(min=eps)
            q_int = (torch.round(w/scale).clamp(-qmax, qmax)).to(torch.int8 if bits<=8 else torch.int16)
            quant_store[name] = {'q': q_int, 'scale': scale.float(), 'bits': int(bits), 'shape': tuple(w.shape)}

    torch.save({'quant': quant_store, 'fp32_fallback': fp32_fallback}, out_path)
    print(f"Saved to {out_path}")

##Start Training

In [None]:
CHECKPOINT_PATH = os.path.join(DATA_ROOT, 'trained_cnn', 'resnet18_lav_df.pth')

exclude_prefixes = ('dense_128', 'classifier')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
try_hutchinson = True            # try true Hessian diag first
hutch_vectors = 16                # 6..12 on Colab, 12..16 on V100
hutch_batches = 2                # small number of batches for speed
fisher_batches = 6               # fallback batch count
base_bits = 6                    # minimum bits considered
max_bits = 8                     # maximum bits to consider (keep 8)
activation_bits = 8              # activation fake-quant bits (common across layers)
qat_epochs = 5                   # fine-tune epochs (2-8)
qat_lr = 2e-5                    # learning rate for QAT

bit_budgets = [8]

val_accs = {}
for bit_budget in bit_budgets:
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  efficientnet_model = DeepFakeHybridModel().to(device)

# Logic to load or initialize blank
  if os.path.exists(CHECKPOINT_PATH):
      print(f"Loading weights from {CHECKPOINT_PATH}...")
      efficientnet_model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=device))
      print("Model loaded successfully.")
  diag_h = get_diag_estimates(efficientnet_model, loss_fn, calib_loader, device=device,
                            try_hutchinson=try_hutchinson,
                            hutchinson_vectors=hutch_vectors,
                            hutchinson_batches=hutch_batches,
                            fisher_batches=fisher_batches)

  # 2) build a list of param names we want to quantize (backbone + dense_512)
  include_names = quantizable_param_names(efficientnet_model, exclude_prefixes=exclude_prefixes)
  print("Will quantize these params (sample):", include_names[:10])

  # 3) compute costs only for the included params
  costs = compute_costs_filtered(efficientnet_model, diag_h, include_names=include_names,
                                base_bits=base_bits, max_bits=max_bits)

  # 4) allocate bits greedily for only these params (this returns dict for included names)
  bits_assignment_partial = allocate_bits_greedy(costs, base_bits=base_bits, max_bits=max_bits, bit_budget=bit_budget)

  # 5) If you'd like, explicitly set excluded params to full precision by not including them.
  # bits_assignment_partial now contains only quantized params; wrap only those.
  print("Assigned bits (sample):")
  for i,(k,v) in enumerate(bits_assignment_partial.items()):
      if i >= 25: break
      print(f"  {k} -> {v}")

  # 6) Wrap model: wrap_model_with_fakequant will only wrap modules that match names in bits_assignment
  qat_model = efficientnet_model.train()
  qat_model = wrap_model_with_fakequant(qat_model, bits_assignment_partial, act_bits=activation_bits)
  qat_model.to(device)

  # 7) Create optimizer AFTER wrapping so it includes both FP params (head) and wrapped params
  opt = torch.optim.AdamW(qat_model.parameters(), lr=qat_lr, weight_decay=1e-4)

  # 4) Run QAT fine-tuning
  qat_model, final_val_acc = qat_train(qat_model, train_loader, calib_loader, device, epochs=qat_epochs, lr=qat_lr)
  torch.save(qat_model.state_dict(), f"qat_model_bits_{bit_budget}.pth")
  val_accs[bit_budget] = final_val_acc

  print(val_accs)


#fp32_ckpt_path = "/content/qat_model_bits_8.pth"
bits_assignment = bits_assignment_partial
out_path = os.path.join(DATA_ROOT, 'trained_cnn', 'quantized8_model_10epoch_lav.pth')
# save_quantized_checkpoint(fp32_ckpt_path, bits_assignment, out_path, exclude_prefixes=('dense_128','classifier'), eps=1e-8)





Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 244MB/s]


Loading weights from /content/drive/My Drive/Deep Fake Dataset/trained_cnn/resnet18_lav_df.pth...
Model loaded successfully.
Running Hutchinson Hessian diag estimator...
Will quantize these params (sample): ['backbone.conv1.weight', 'backbone.bn1.weight', 'backbone.layer1.0.conv1.weight', 'backbone.layer1.0.bn1.weight', 'backbone.layer1.0.conv2.weight', 'backbone.layer1.0.bn2.weight', 'backbone.layer1.1.conv1.weight', 'backbone.layer1.1.bn1.weight', 'backbone.layer1.1.conv2.weight', 'backbone.layer1.1.bn2.weight']
Assigned bits (sample):
  backbone.conv1.weight -> 8
  backbone.bn1.weight -> 8
  backbone.layer1.0.conv1.weight -> 8
  backbone.layer1.0.bn1.weight -> 8
  backbone.layer1.0.conv2.weight -> 8
  backbone.layer1.0.bn2.weight -> 8
  backbone.layer1.1.conv1.weight -> 8
  backbone.layer1.1.bn1.weight -> 8
  backbone.layer1.1.conv2.weight -> 8
  backbone.layer1.1.bn2.weight -> 8
  backbone.layer2.0.conv1.weight -> 8
  backbone.layer2.0.bn1.weight -> 8
  backbone.layer2.0.conv2.weig

Epoch 1/5:   0%|          | 0/566 [00:00<?, ?it/s]

Epoch 2/5:   0%|          | 0/566 [00:00<?, ?it/s]

Epoch 3/5:   0%|          | 0/566 [00:00<?, ?it/s]

Epoch 4/5:   0%|          | 0/566 [00:00<?, ?it/s]

Epoch 5/5:   0%|          | 0/566 [00:00<?, ?it/s]

Extracting features for val (1984 videos)...


  0%|          | 0/1984 [00:00<?, ?it/s]

Extracting features for train (1984 videos)...


  0%|          | 0/1984 [00:00<?, ?it/s]

Ensemble Acc: 97.80%
{8: 97.79989523310634}


In [None]:
import pandas as pd
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB

def save_ensemble_predictions(model, output_csv_path, local_img_dir='/content/all_unzipped'):
    """
    Extracts features, trains ensemble on TRAIN set, predicts on VAL set,
    and saves the results to CSV.
    """
    model.eval()

    # --- Helper: Extract Features + Video IDs ---
    def get_features_and_ids(split_name, target_list):
        print(f"Extracting features for {split_name} ({len(target_list)} videos)...")

        # 1. Index local frames (Optimization)
        all_frames = glob.glob(os.path.join(local_img_dir, '**', '*.jpg'), recursive=True)
        frame_lookup = {}
        for f in all_frames:
            # Assumes filename format: videoID_frameX.jpg
            vid_id = os.path.basename(f).rsplit('_', 1)[0]
            if vid_id not in frame_lookup: frame_lookup[vid_id] = []
            frame_lookup[vid_id].append(f)

        # 2. Transform
        transform = T.Compose([
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        X, y, ids = [], [], []

        # 3. Extraction Loop
        with torch.no_grad():
            for vid_path, label in tqdm(target_list):
                # Clean video ID from the path provided in JSON/CSV
                vid_id = os.path.splitext(os.path.basename(vid_path))[0]

                if vid_id in frame_lookup:
                    batch_feats = []
                    for img_path in frame_lookup[vid_id]:
                        try:
                            img = Image.open(img_path).convert('RGB')
                            img_t = transform(img).unsqueeze(0).to(device)
                            # Model returns (logits, features) -> take features [1]
                            _, feat_512 = model(img_t)
                            batch_feats.append(feat_512.cpu().numpy().flatten())
                        except: continue

                    if batch_feats:
                        # Average pooling per video
                        X.append(np.mean(batch_feats, axis=0))
                        y.append(label)
                        ids.append(vid_id)

        return np.array(X), np.array(y), ids

    # --- 1. Extraction Phase ---
    # We must re-extract Training features to fit the classifiers
    print("--- Step 1: Extracting Training Features ---")
    X_train, y_train, _ = get_features_and_ids('train', train_videos)

    print("\n--- Step 2: Extracting Validation Features ---")
    X_val, y_val, val_ids = get_features_and_ids('val', val_videos)

    if len(X_train) == 0 or len(X_val) == 0:
        print("Error: Extraction failed (empty lists). Check image directory path.")
        return

    # --- 2. Training Phase ---
    print("\n--- Step 3: Fitting Ensemble Classifiers ---")
    clfs = {
        'SVM': SVC(probability=True, kernel='rbf'),
        'DT': DecisionTreeClassifier(random_state=42),
        'KNN': KNeighborsClassifier(n_neighbors=5),
        'NB': GaussianNB()
    }

    preds_stack = []

    for name, clf in clfs.items():
        clf.fit(X_train, y_train)
        p = clf.predict(X_val)
        preds_stack.append(p)
        print(f"  {name} fitted and predicted.")

    # --- 3. Prediction Phase ---
    print("\n--- Step 4: Voting & Saving ---")
    # Stack predictions: Shape (4, N_samples)
    votes = np.array(preds_stack).sum(axis=0)
    # Ensemble Rule: If >= 2 classifiers say Fake (1), predict Fake
    ensemble_pred = (votes >= 2).astype(int)

    # Construct DataFrame
    df_results = pd.DataFrame({
        'video_id': val_ids,
        'true_label': y_val,
        'ensemble_pred': ensemble_pred,
        # Optional: Include individual classifier votes
        'svm': preds_stack[0],
        'dt': preds_stack[1],
        'knn': preds_stack[2],
        'nb': preds_stack[3]
    })

    # Save
    df_results.to_csv(output_csv_path, index=False)
    print(f"✅ CSV saved successfully to: {output_csv_path}")
    print(f"Total Validation Samples: {len(df_results)}")

    # Quick Accuracy Check
    acc = accuracy_score(y_val, ensemble_pred)
    print(f"Calculated Ensemble Accuracy: {acc*100:.2f}%")

# --- EXECUTION ---
output_file = os.path.join(DATA_ROOT, 'ensemble_predictions_lav.csv')

# Ensure 'qat_model' is the variable name of your trained model from the previous step
save_ensemble_predictions(qat_model, output_file)

--- Step 1: Extracting Training Features ---
Extracting features for train (1984 videos)...


  0%|          | 0/1984 [00:00<?, ?it/s]


--- Step 2: Extracting Validation Features ---
Extracting features for val (1984 videos)...


  0%|          | 0/1984 [00:00<?, ?it/s]


--- Step 3: Fitting Ensemble Classifiers ---
  SVM fitted and predicted.
  DT fitted and predicted.
  KNN fitted and predicted.
  NB fitted and predicted.

--- Step 4: Voting & Saving ---
✅ CSV saved successfully to: /content/drive/My Drive/Deep Fake Dataset/ensemble_predictions_lav.csv
Total Validation Samples: 1909
Calculated Ensemble Accuracy: 96.33%


In [None]:
# Paste & run this verification cell (adjust paths if needed)
import os, torch, math
from collections import OrderedDict

fp32_ckpt_path = "/content/qat_model_bits_8.pth"   # original FP32 checkpoint you loaded before quantizing
quant_path = "/content/drive/My Drive/Deep Fake Dataset/trained_cnn/quantized8_resnet_5epoch.pth"  # saved quant file

# 1) Load original FP32 state_dict and compute true total FP32 bytes
orig = torch.load(fp32_ckpt_path, map_location='cpu')
# if a wrapper dict, try to find state_dict
if isinstance(orig, dict) and 'state_dict' in orig:
    orig_sd = orig['state_dict']
else:
    orig_sd = orig

total_orig_bytes = 0
orig_param_counts = {}
for n, t in orig_sd.items():
    if not isinstance(t, torch.Tensor):
        continue
    total_orig_bytes += t.numel() * 4  # float32 bytes
    orig_param_counts[n] = t.numel()

print("Original checkpoint params:", len(orig_sd), "Total FP32 bytes (approx):", total_orig_bytes, " (~{:.2f} MB)".format(total_orig_bytes/1024**2))

# 2) Load quant file and compute stored bytes (sum of integer storage + scales + fp32 fallback)
qf = torch.load(quant_path, map_location='cpu')
quant_map = qf.get('quant', {})
fp32_fallback = qf.get('fp32_fallback', {})

total_quant_storage_bytes = 0
for name, info in quant_map.items():
    q = info['q']
    scale = info['scale']
    # integer dtype size
    dtype = info.get('dtype', None)
    if dtype is None:
        # infer from tensor dtype
        if q.dtype == torch.int8: dtype = 'int8'
        elif q.dtype == torch.int16: dtype = 'int16'
        elif q.dtype == torch.int32: dtype = 'int32'
        else: dtype = str(q.dtype)
    byte_per_val = 1 if dtype=='int8' else (2 if dtype=='int16' else 4)
    total_quant_storage_bytes += q.numel() * byte_per_val
    total_quant_storage_bytes += scale.numel() * 4  # scales stored float32

# count fallback fp32 bytes
total_fallback_bytes = 0
for name, t in fp32_fallback.items():
    total_fallback_bytes += t.numel() * 4

print("Quantized params stored:", len(quant_map), "Fallback FP32 params stored:", len(fp32_fallback))
print("Total quant storage (ints + scales) bytes:", total_quant_storage_bytes, " (~{:.2f} KB)".format(total_quant_storage_bytes/1024))
print("Total fallback FP32 bytes:", total_fallback_bytes, " (~{:.2f} MB)".format(total_fallback_bytes/1024**2))

# on-disk file size
disk_bytes = os.path.getsize(quant_path)
print("Quant file on-disk size:", disk_bytes, " bytes (~{:.2f} KB)".format(disk_bytes/1024))

# 3) compute corrected compression ratio: orig_total_bytes / (quant_storage + fallback_bytes)
stored_total_bytes = total_quant_storage_bytes + total_fallback_bytes
if stored_total_bytes == 0:
    print("Warning: stored_total_bytes=0 (unexpected).")
else:
    ratio = total_orig_bytes / float(stored_total_bytes)
    print("Corrected compression ratio (original fp32 bytes / stored quant bytes): {:.2f}x".format(ratio))

# 4) Sanity-check dequantization: rebuild a fp32 state_dict from quant file and compare a few stats
def dequantize_to_state_dict(qfile):
    out = OrderedDict()
    # add fallback
    for n, t in qfile.get('fp32_fallback', {}).items():
        out[n] = t.to(torch.float32).clone()
    for n, info in qfile.get('quant', {}).items():
        q = info['q'].to(torch.float32)
        scale = info['scale'].to(torch.float32)
        w_fp = q * scale
        out[n] = w_fp.clone()
    return out

deq_sd = dequantize_to_state_dict(qf)

# Compare norms and per-parameter error stats for overlapping names
names_compared = []
max_abs_err = 0.0
sq_err_sum = 0.0
count = 0
for n, orig_t in orig_sd.items():
    if n not in deq_sd:
        continue
    o = orig_t.to(torch.float32)
    d = deq_sd[n].to(torch.float32)
    if o.shape != d.shape:
        print("Shape mismatch for", n, o.shape, d.shape)
        continue
    diff = (o - d)
    mabs = diff.abs().max().item()
    mse = (diff**2).mean().item()
    max_abs_err = max(max_abs_err, mabs)
    sq_err_sum += mse * o.numel()
    count += o.numel()
    names_compared.append((n, float(mabs), float(mse)))
# global RMSE
rmse = math.sqrt(sq_err_sum / count) if count>0 else float('nan')
print("Compared params:", len(names_compared))
print("Max absolute error across compared params:", max_abs_err)
print("Global RMSE (over all elements compared):", rmse)

# print a few largest-per-parameter errors
names_compared_sorted = sorted(names_compared, key=lambda x: -x[1])[:10]
print("\nTop per-param max-abs errors (name, max_abs, mse):")
for n,mabs,mse in names_compared_sorted:
    print(f"  {n}: max_abs={mabs:.6e}, mse={mse:.6e}")


Original checkpoint params: 840 Total FP32 bytes (approx): 92304000  (~88.03 MB)
Quantized params stored: 78 Fallback FP32 params stored: 762
Total quant storage (ints + scales) bytes: 52296  (~51.07 KB)
Total fallback FP32 bytes: 92129408  (~87.86 MB)
Quant file on-disk size: 92485743  bytes (~90318.11 KB)
Corrected compression ratio (original fp32 bytes / stored quant bytes): 1.00x
Compared params: 840
Max absolute error across compared params: 1.1524181365966797
Global RMSE (over all elements compared): 0.008874270953472556

Top per-param max-abs errors (name, max_abs, mse):
  backbone.features.1.0.block.2.1.weight: max_abs=1.152418e+00, mse=4.513062e-01
  backbone.features.5.0.block.1.1.weight: max_abs=1.121962e+00, mse=5.544320e-01
  backbone.features.2.1.block.3.1.weight: max_abs=1.099119e+00, mse=4.166234e-01
  backbone.features.4.3.block.1.1.weight: max_abs=1.004616e+00, mse=3.634406e-01
  backbone.features.2.2.block.3.1.weight: max_abs=9.232476e-01, mse=3.142465e-01
  backbone

In [None]:
# Inspect current quant file and original state_dict mapping
import torch, os
from collections import Counter

quant_path = "/content/drive/My Drive/Deep Fake Dataset/trained_cnn/quantized8_model_10epoch.pth"
fp32_ckpt_path = "/content/qat_model_bits_8.pth"

qf = torch.load(quant_path, map_location='cpu')
quant_map = qf.get('quant', {})
fp32_fallback = qf.get('fp32_fallback', {})

print("Quantized params:", len(quant_map))
print("Fallback FP32 params:", len(fp32_fallback))

# top fallback by size
fallback_sizes = []
for name, t in fp32_fallback.items():
    fallback_sizes.append((name, t.numel(), t.numel()*4))
fallback_sizes.sort(key=lambda x: x[1], reverse=True)
print("\nTop 15 fallback params by element count (name, elems, bytes):")
for n,e,b in fallback_sizes[:15]:
    print(f"  {n}: elems={e}, bytes={b} (~{b/1024:.1f} KB)")

# check quant dtype distribution
dtype_counts = Counter()
for name, info in quant_map.items():
    dt = info.get('dtype', None)
    if dt is None:
        # infer
        q = info['q']
        if q.dtype == torch.int8: dt = 'int8'
        elif q.dtype == torch.int16: dt = 'int16'
        elif q.dtype == torch.int32: dt = 'int32'
        else: dt = str(q.dtype)
    dtype_counts[dt] += 1
print("\nQuant dtype counts:", dict(dtype_counts))

# check if names in bits_assignment match state_dict names (common mismatch)
# You should set bits_assignment to the mapping you used earlier (in the notebook)
try:
    missing_from_assignment = []
    present_count = 0
    for name in torch.load(fp32_ckpt_path, map_location='cpu').keys():
        if name in quant_map:
            present_count += 1
        else:
            missing_from_assignment.append(name)
    print("\nState-dict names present in quant_map:", present_count)
except Exception as e:
    print("Could not compare to FP32 ckpt automatically (structure mismatch):", e)

# report any 'module.' prefix issues
sample_names = list(fp32_fallback.keys())[:20] + list(quant_map.keys())[:20]
has_module = any(n.startswith('module.') for n in sample_names)
print("\nAny 'module.' prefixes in sample names?:", has_module)


Quantized params: 78
Fallback FP32 params: 762

Top 15 fallback params by element count (name, elems, bytes):
  backbone.features.7.1.block.0.0._w_q_cached: elems=884736, bytes=3538944 (~3456.0 KB)
  backbone.features.7.1.block.0.0.module.weight: elems=884736, bytes=3538944 (~3456.0 KB)
  backbone.features.7.1.block.3.0._w_q_cached: elems=884736, bytes=3538944 (~3456.0 KB)
  backbone.features.7.1.block.3.0.module.weight: elems=884736, bytes=3538944 (~3456.0 KB)
  dense_512.0._w_q_cached: elems=786432, bytes=3145728 (~3072.0 KB)
  dense_512.0.module.weight: elems=786432, bytes=3145728 (~3072.0 KB)
  backbone.features.8.0._w_q_cached: elems=589824, bytes=2359296 (~2304.0 KB)
  backbone.features.8.0.module.weight: elems=589824, bytes=2359296 (~2304.0 KB)
  backbone.features.7.0.block.3.0._w_q_cached: elems=534528, bytes=2138112 (~2088.0 KB)
  backbone.features.7.0.block.3.0.module.weight: elems=534528, bytes=2138112 (~2088.0 KB)
  backbone.features.6.1.block.0.0._w_q_cached: elems=322944,

In [None]:
# --- Cell: remap wrappers -> canonical names then quantize & save compact checkpoint ---
import torch, os, math
from collections import OrderedDict, Counter

# EDIT THESE
fp32_ckpt_path = "/content/qat_model_bits_8.pth"   # the checkpoint you currently have (wrapped model)
out_path = "/content/drive/My Drive/Deep Fake Dataset/trained_cnn/quantized_resnet_all8_mapped_lav.pth"
exclude_prefixes = ('dense_128', 'classifier')    # keep these in FP32 fallback
default_bits = 8
eps = 1e-8

# load the wrapped model state_dict
raw = torch.load(fp32_ckpt_path, map_location='cpu')
if isinstance(raw, dict) and 'state_dict' in raw:
    sd = raw['state_dict']
else:
    sd = raw  # assume raw is state_dict

print("Loaded state_dict with", len(sd), "entries.")

# Build canonical name map:
# Strategy:
#  - If key ends with '.module.weight' -> canonical = key.replace('.module.weight', '.weight')
#  - Else if key ends with '._w_q_cached' -> canonical = key.replace('._w_q_cached', '.weight')
#  - Else keep as-is.
# We'll prefer actual underlying tensors (module.weight or weight) and skip the cached buffers.
canonical_map = {}   # canonical_name -> (best_key_in_sd, tensor)
conflicts = {}

for k, t in sd.items():
    # compute candidate canonical name
    if k.endswith('.module.weight'):
        can = k[:-len('.module.weight')] + '.weight'
    elif k.endswith('._w_q_cached'):
        can = k[:-len('._w_q_cached')] + '.weight'
    else:
        can = k

    # prefer direct '...weight' or '...module.weight' keys (actual parameter tensors) over caches
    # Score preference: prefer keys that contain '.module.weight' or endwith '.weight' (not cached)
    score = 0
    if k.endswith('.module.weight'):
        score = 3
    elif k.endswith('.weight'):
        score = 2
    elif k.endswith('._w_q_cached'):
        score = 1
    else:
        # other buffer/param, give default score 0
        score = 0

    # choose best available tensor for canonical name
    if can not in canonical_map:
        canonical_map[can] = (k, t.clone().to(torch.float32), score)
    else:
        # compare score and replace if current is better
        _, _, prev_score = canonical_map[can]
        if score > prev_score:
            canonical_map[can] = (k, t.clone().to(torch.float32), score)
        elif score == prev_score and k != canonical_map[can][0]:
            # record conflict (non-fatal)
            conflicts.setdefault(can, []).append(k)

# finalize mapping: strip scores, keep canonical->tensor
canon_to_tensor = {can: info[1] for can, info in canonical_map.items()}

print("Canonical params found:", len(canon_to_tensor))
if conflicts:
    print("Conflicts (canonical name -> alternative keys):", {k: v for k, v in list(conflicts.items())[:10]})
    # you can inspect conflicts if desired

# Now choose which canonical params to quantize:
# bits_assignment (from your allocator) typically uses canonical names like '...weight'
try:
    ba = bits_assignment_partial  # use your in-memory mapping (original names)
except NameError:
    print("bits_assignment_partial not found. Using default_bits for all quantizable params.")
    ba = {}

quant_store = {}
fp32_fallback = {}
total_quant_bytes = 0
total_fallback_bytes = 0

def _quantize_tensor_symmetric_perchannel_to_int(t: torch.Tensor, bits: int, eps: float = 1e-8, channel_dim:int = 0):
    qmax = 2 ** (bits - 1) - 1
    if t.ndim >= 2:
        reduce_dims = [i for i in range(t.ndim) if i != channel_dim]
        max_abs = t.abs().amax(dim=reduce_dims, keepdim=True)
        scale = (max_abs / qmax).clamp(min=eps)
    else:
        max_abs = t.abs().amax()
        scale = (max_abs / qmax).clamp(min=eps)
    q = torch.round(t / scale).clamp(-qmax, qmax)
    if bits <= 8:
        q_int = q.to(torch.int8); dtype_str = 'int8'
    elif bits <= 16:
        q_int = q.to(torch.int16); dtype_str = 'int16'
    else:
        q_int = q.to(torch.int32); dtype_str = 'int32'
    return q_int, scale.to(torch.float32), dtype_str

# iterate canonical params
for can_name, tensor in canon_to_tensor.items():
    # skip non-weight like things that are not trainable weights (conservative heuristic)
    if not can_name.endswith('.weight'):
        # keep these in fallback
        fp32_fallback[can_name] = tensor.clone()
        total_fallback_bytes += tensor.numel() * 4
        continue

    # skip excluded prefixes (your classifier layers)
    if any(can_name.startswith(pref) for pref in exclude_prefixes):
        fp32_fallback[can_name] = tensor.clone()
        total_fallback_bytes += tensor.numel() * 4
        continue

    # decide bits: prefer bits from ba, else default_bits
    bits = ba.get(can_name, default_bits)
    # quantize
    q_int, scale, dtype_str = _quantize_tensor_symmetric_perchannel_to_int(tensor, bits, eps=eps, channel_dim=0)
    quant_store[can_name] = {'q': q_int, 'scale': scale, 'bits': int(bits), 'dtype': dtype_str, 'shape': tuple(tensor.shape)}
    total_quant_bytes += q_int.numel() * (1 if dtype_str=='int8' else (2 if dtype_str=='int16' else 4))
    total_quant_bytes += scale.numel() * 4

print("Prepared quant_store entries:", len(quant_store))
print("Prepared fp32_fallback entries:", len(fp32_fallback))
print("Bytes: quant ints+scales:", total_quant_bytes, "fallback fp32:", total_fallback_bytes)

# Save compact file
out = {'quant': quant_store, 'fp32_fallback': fp32_fallback, 'meta': {'num_quantized': len(quant_store), 'num_fallback_fp32': len(fp32_fallback), 'total_quant_bytes': total_quant_bytes, 'total_fallback_bytes': total_fallback_bytes}}
torch.save(out, out_path)
disk_bytes = os.path.getsize(out_path)
orig_bytes = sum(t.numel()*4 for t in canon_to_tensor.values())
print("Saved to:", out_path)
print("Original (canon) FP32 total bytes (approx):", orig_bytes, " (~{:.2f} MB)".format(orig_bytes/1024**2))
print("Quant stored ints+scales bytes (sum):", total_quant_bytes, " (~{:.2f} KB)".format(total_quant_bytes/1024))
print("Fallback FP32 bytes stored:", total_fallback_bytes, " (~{:.2f} MB)".format(total_fallback_bytes/1024**2))
print("On-disk size (bytes):", disk_bytes, " (~{:.2f} MB)".format(disk_bytes/1024**2))
print("Meta:", out['meta'])

# Quick integrity check: dequantize a few large layers and print norm diffs vs original canon tensor
def dequantize_entry(info):
    qi = info['q'].to(torch.float32)
    sc = info['scale'].to(torch.float32)
    return (qi * sc).to(torch.float32)

print("\nSample dequant checks for a few largest quantized params:")
top_quant = sorted(quant_store.items(), key=lambda kv: -kv[1]['q'].numel())[:8]
for name, info in top_quant:
    deq = dequantize_entry(info)
    orig_t = canon_to_tensor[name]
    max_abs = (orig_t - deq).abs().max().item()
    rmse = ((orig_t - deq)**2).mean().sqrt().item()
    print(f" {name}: elems={info['q'].numel()}, bits={info['bits']}, dtype={info['dtype']}, max_abs={max_abs:.6e}, rmse={rmse:.6e}")

print("\nDone. If disk size still large, inspect fp32_fallback list above and adjust exclude_prefixes or bits_assignment.")


Loaded state_dict with 166 entries.
Canonical params found: 145
Prepared quant_store entries: 41
Prepared fp32_fallback entries: 104
Bytes: quant ints+scales: 11455184 fallback fp32: 85080
Saved to: /content/drive/My Drive/Deep Fake Dataset/trained_cnn/quantized_resnet_all8_mapped_lav.pth
Original (canon) FP32 total bytes (approx): 45820504  (~43.70 MB)
Quant stored ints+scales bytes (sum): 11455184  (~11186.70 KB)
Fallback FP32 bytes stored: 85080  (~0.08 MB)
On-disk size (bytes): 11602507  (~11.07 MB)
Meta: {'num_quantized': 41, 'num_fallback_fp32': 104, 'total_quant_bytes': 11455184, 'total_fallback_bytes': 85080}

Sample dequant checks for a few largest quantized params:
 backbone.layer4.0.conv2.weight: elems=2359296, bits=8, dtype=int8, max_abs=1.918875e-03, rmse=5.646053e-04
 backbone.layer4.1.conv1.weight: elems=2359296, bits=8, dtype=int8, max_abs=2.025366e-03, rmse=5.137838e-04
 backbone.layer4.1.conv2.weight: elems=2359296, bits=8, dtype=int8, max_abs=1.779601e-03, rmse=4.734