In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import os
import random
import numpy as np
import cv2  
import glob

# ==========================================
# 1. CẤU HÌNH
# ==========================================
CONFIG = {
    'root_dir': '/kaggle/input/finaldifaattack',
    'img_size': 224,
    'seq_len': 8,           # Số frame lấy từ video
    'batch_size': 4,        # Giảm batch size nếu video nặng
    'num_epochs': 20,
    'lr': 0.001,
    'imposter_rate': 0.3,   # Tỷ lệ tạo cặp sai người
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

# ==========================================
# 2. VIDEO LOADER
# ==========================================
def load_frames_from_video(video_path, seq_len=8):
    """
    Đọc video file và lấy ngẫu nhiên/đều seq_len frames.
    Trả về list các PIL Image.
    """
    cap = cv2.VideoCapture(video_path)
    frames = []
    if not cap.isOpened():
        return None
    
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    # Nếu video quá ngắn, lấy tất cả rồi lặp lại
    if total_frames < 1:
        cap.release()
        return None

    # Chọn indices để lấy frame
    indices = np.linspace(0, total_frames - 1, seq_len).astype(int)
    
    current_idx = 0
    collected_count = 0
    
    # Duyệt video để lấy đúng frame tại indices
    # (Cách này tối ưu hơn đọc hết vào RAM)
    indices_set = set(indices)
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        
        if current_idx in indices_set:
            # Convert BGR (OpenCV) -> RGB (PIL)
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(Image.fromarray(frame))
            collected_count += 1
        
        current_idx += 1
        if collected_count >= seq_len:
            break
            
    cap.release()
    
    # Padding nếu thiếu frame (trường hợp video lỗi header)
    while len(frames) < seq_len:
        if len(frames) > 0:
            frames.append(frames[-1])
        else:
            # Video hỏng hoàn toàn -> tạo ảnh đen
            frames.append(Image.new('RGB', (CONFIG['img_size'], CONFIG['img_size'])))
            
    return frames

# ==========================================
# 3. DATASET 
# ==========================================
class eKYCDataset(Dataset):
    def __init__(self, root_dir, transform=None, seq_len=8, imposter_rate=0.3):
        self.root_dir = root_dir
        self.transform = transform
        self.seq_len = seq_len
        self.imposter_rate = imposter_rate
        
        self.samples = []       # Danh sách các mẫu video để train
        self.person_id_map = {} # Dict lưu list ID của mỗi người: {'Person_001': ['id1.jpg', 'id2.jpg']}
        
        self._prepare_db()

    def _prepare_db(self):
        # 1. Quét toàn bộ thư mục ID trước để xây dựng database ID
        id_root = os.path.join(self.root_dir, 'ID')
        if os.path.exists(id_root):
            for person in os.listdir(id_root):
                p_path = os.path.join(id_root, person)
                if not os.path.isdir(p_path): continue
                
                # Lấy tất cả ảnh trong folder ID của người đó
                images = []
                for ext in ['*.jpg', '*.png', '*.jpeg']:
                    images.extend(glob.glob(os.path.join(p_path, ext)))
                
                if len(images) > 0:
                    self.person_id_map[person] = images

        # 2. Load Real Videos
        real_root = os.path.join(self.root_dir, 'Real')
        if os.path.exists(real_root):
            for person in os.listdir(real_root):
                if person not in self.person_id_map: continue # Bỏ qua nếu không có ID
                
                p_path = os.path.join(real_root, person)
                # Lấy tất cả video
                videos = []
                for ext in ['*.mp4', '*.avi', '*.mov']:
                    videos.extend(glob.glob(os.path.join(p_path, ext)))
                
                for vid in videos:
                    self.samples.append({
                        'type': 'real',
                        'video_path': vid,
                        'person_name': person,
                        # Map sang file depth tương ứng (giả sử depth cũng là video cùng tên)
                        'depth_path': os.path.join(self.root_dir, 'Real_Depth', person, os.path.basename(vid))
                    })

        # 3. Load Fake Videos
        fake_root = os.path.join(self.root_dir, 'Fake')
        attack_types = ['Print-attack', 'Video-replay']
        for att in attack_types:
            att_path = os.path.join(fake_root, att)
            if not os.path.exists(att_path): continue
            
            for person in os.listdir(att_path):
                # Fake phải attack vào người có ID
                if person not in self.person_id_map: continue 
                
                p_path = os.path.join(att_path, person)
                videos = []
                for ext in ['*.mp4', '*.avi', '*.mov']:
                    videos.extend(glob.glob(os.path.join(p_path, ext)))
                    
                for vid in videos:
                    self.samples.append({
                        'type': 'fake',
                        'video_path': vid,
                        'person_name': person,
                        'depth_path': None
                    })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        person_name = sample['person_name']
        
        # --- A. Lấy ID Image ---
        # Logic Imposter: Real thì có tỉ lệ bị tráo ID, Fake thì dùng đúng ID của victim
        is_imposter = False
        target_name = person_name
        
        if sample['type'] == 'real':
            # Nếu random trúng imposter, chọn ID của người khác
            if random.random() < self.imposter_rate and len(self.person_id_map) > 1:
                is_imposter = True
                possible_imposters = list(self.person_id_map.keys())
                possible_imposters.remove(person_name)
                target_name = random.choice(possible_imposters)
        
        # Chọn ngẫu nhiên 1 trong các ảnh ID của người được chọn (Target)
        # Việc random này giúp model không overfit vào 1 ảnh ID duy nhất
        id_path = random.choice(self.person_id_map[target_name])
        
        try:
            id_img = Image.open(id_path).convert('RGB')
            
            # --- B. Load Video Frames ---
            video_frames = load_frames_from_video(sample['video_path'], self.seq_len)
            if video_frames is None: 
                raise Exception("Video load error")

            # --- C. Load Depth ---
            w, h = video_frames[0].size
            if sample['type'] == 'real':
                # Load Depth từ video depth hoặc fallback tạo đen
                if os.path.exists(sample['depth_path']):
                    depth_frames = load_frames_from_video(sample['depth_path'], self.seq_len)
                    if depth_frames:
                        depth_frames = [d.convert('L') for d in depth_frames]
                    else:
                        depth_frames = [Image.new('L', (w, h), 0) for _ in range(self.seq_len)]
                else:
                    # Nếu không tìm thấy file video depth
                    depth_frames = [Image.new('L', (w, h), 0) for _ in range(self.seq_len)]
            else:
                # Fake -> Depth = 0
                depth_frames = [Image.new('L', (w, h), 0) for _ in range(self.seq_len)]

        except Exception as e:
            # print(f"Error loading {sample['video_path']}: {e}")
            return self.__getitem__((idx + 1) % len(self.samples))

        # --- D. Labels ---
        label_live = 1.0 if sample['type'] == 'real' else 0.0
        label_match = 0.0 if (sample['type'] == 'real' and is_imposter) else 1.0

        # --- E. Transform ---
        depth_trans = transforms.Compose([
            transforms.Resize((CONFIG['img_size'], CONFIG['img_size'])),
            transforms.ToTensor()
        ])
        
        if self.transform:
            id_tensor = self.transform(id_img)
            video_tensor = torch.stack([self.transform(img) for img in video_frames])
        else:
            t = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()])
            id_tensor = t(id_img)
            video_tensor = torch.stack([t(img) for img in video_frames])
            
        depth_tensor = torch.stack([depth_trans(d) for d in depth_frames])

        return {
            'id': id_tensor,
            'video': video_tensor,
            'depth': depth_tensor,
            'label_live': torch.tensor([label_live], dtype=torch.float32),
            'label_match': torch.tensor([label_match], dtype=torch.float32)
        }

# ==========================================
# 4. MODEL 
# ==========================================
class MobileKYCModel(nn.Module):
    def __init__(self):
        super().__init__()
        backbone = models.mobilenet_v3_small(pretrained=True)
        self.encoder = list(backbone.children())[0] 
        self.feature_dim = 576 
        
        self.depth_conv = nn.Conv2d(self.feature_dim, 128, kernel_size=1)
        self.depth_estimator = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=3, padding=1),
            nn.Sigmoid() 
        )
        self.reduce_dim = nn.Linear(self.feature_dim, 128)
        self.rnn = nn.LSTM(input_size=256, hidden_size=128, num_layers=1, batch_first=True)
        
        self.liveness_head = nn.Sequential(
            nn.Linear(128, 64), nn.ReLU(), nn.Dropout(0.5), nn.Linear(64, 1)
        )
        self.matching_head = nn.Sequential(
            nn.Linear(128, 64), nn.ReLU(), nn.Dropout(0.5), nn.Linear(64, 1)
        )

    def forward_backbone(self, x):
        feat_map = self.encoder(x)
        feat_vec = F.adaptive_avg_pool2d(feat_map, (1, 1)).flatten(1)
        return feat_map, feat_vec

    def forward(self, id_img, video_frames):
        batch_size, seq_len, c, h, w = video_frames.size()
        
        # ID Feat
        _, id_vec = self.forward_backbone(id_img)
        id_vec_reduced = self.reduce_dim(id_vec)
        
        # Video Feat
        video_reshaped = video_frames.view(batch_size * seq_len, c, h, w)
        vid_feat_map, vid_feat_vec = self.forward_backbone(video_reshaped)
        
        # Depth & Attention
        depth_input = self.depth_conv(vid_feat_map)
        pred_depth_small = self.depth_estimator(depth_input)
        depth_score = pred_depth_small.mean(dim=[2, 3]).view(batch_size * seq_len, 1)
        
        vid_feat_vec_attended = vid_feat_vec * (0.5 + depth_score)
        vid_vec_reduced = self.reduce_dim(vid_feat_vec_attended)
        
        # Differencing
        vid_vec_seq = vid_vec_reduced.view(batch_size, seq_len, -1)
        id_vec_expanded = id_vec_reduced.unsqueeze(1).expand(-1, seq_len, -1)
        diff_feat = torch.abs(vid_vec_seq - id_vec_expanded)
        
        # RNN
        rnn_input = torch.cat([vid_vec_seq, diff_feat], dim=2)
        rnn_out, _ = self.rnn(rnn_input)
        final_feat = rnn_out[:, -1, :]
        
        liveness_logit = self.liveness_head(final_feat)
        matching_logit = self.matching_head(final_feat)
        
        # Depth Loss Output
        pred_depth_full = F.interpolate(pred_depth_small, size=(h, w), mode='bilinear', align_corners=False)
        pred_depth_full = pred_depth_full.view(batch_size, seq_len, 1, h, w)
        
        return liveness_logit, matching_logit, pred_depth_full

# ==========================================
# 5. TRAINING LOOP 
# ==========================================
def train_model():
    transform = transforms.Compose([
        transforms.Resize((CONFIG['img_size'], CONFIG['img_size'])),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    print("Scanning dataset (Video Files)... This may take a moment.")
    dataset = eKYCDataset(root_dir=CONFIG['root_dir'], transform=transform, seq_len=CONFIG['seq_len'])
    
    if len(dataset) == 0:
        print("No samples found! Check directory structure.")
        return
    else:
        print(f"Found {len(dataset)} video samples.")
        print(f"Found IDs for {len(dataset.person_id_map)} people.")

    # Num_workers nên set thấp nếu đọc video vì tốn CPU/IO
    dataloader = DataLoader(dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=2)
    
    model = MobileKYCModel().to(CONFIG['device'])
    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['lr'])
    criterion_cls = nn.BCEWithLogitsLoss()
    criterion_depth = nn.MSELoss()
    
    print("Start Training...")
    
    for epoch in range(CONFIG['num_epochs']):
        model.train()
        running_loss = 0.0
        acc_live = 0
        acc_match = 0
        total_samples = 0
        
        for i, batch in enumerate(dataloader):
            id_img = batch['id'].to(CONFIG['device'])
            video = batch['video'].to(CONFIG['device'])
            gt_depth = batch['depth'].to(CONFIG['device'])
            lbl_live = batch['label_live'].to(CONFIG['device'])
            lbl_match = batch['label_match'].to(CONFIG['device'])
            
            optimizer.zero_grad()
            out_live, out_match, pred_depth = model(id_img, video)
            
            # Loss Calculation
            loss_live = criterion_cls(out_live, lbl_live)
            loss_match = criterion_cls(out_match, lbl_match)
            # Chỉ tính depth loss cho mẫu Real để tránh nhiễu, hoặc tính cả Fake (depth=0) cũng được
            loss_depth = criterion_depth(pred_depth, gt_depth)
            
            total_loss = loss_live + loss_match + 0.5 * loss_depth
            
            total_loss.backward()
            optimizer.step()
            
            running_loss += total_loss.item()
            
            # Accuracy
            acc_live += ((torch.sigmoid(out_live) > 0.5).float() == lbl_live).sum().item()
            acc_match += ((torch.sigmoid(out_match) > 0.5).float() == lbl_match).sum().item()
            total_samples += lbl_live.size(0)
            
            if i % 10 == 0:
                print(f"Ep {epoch+1} [{i}/{len(dataloader)}] Loss: {total_loss.item():.4f}")

        print(f"=== Epoch {epoch+1} ===")
        print(f"Loss: {running_loss/len(dataloader):.4f}")
        print(f"Liveness Acc: {acc_live/total_samples:.4f}")
        print(f"Matching Acc: {acc_match/total_samples:.4f}")
        
    torch.save(model.state_dict(), "ekyc_mobile_video_model.pth")
    print("Model Saved!")


In [3]:
import torch
import torchvision.transforms as transforms
from PIL import Image
from torch.utils import mobile_optimizer

# 1. Load Model đã train
device = 'cpu'
model = MobileKYCModel()
model.load_state_dict(torch.load("/kaggle/input/ekyc-training/ekyc_mobile_video_model.pth", map_location=device))
model.eval()

# 2. Tạo Wrapper để loại bỏ Depth Output và thêm Sigmoid
class InferenceWrapper(torch.nn.Module):
    def __init__(self, original_model):
        super().__init__()
        self.model = original_model
        
    def forward(self, id_img, video_frames):
        # Model gốc trả về: liveness_logit, matching_logit, pred_depth
        live_logit, match_logit, _ = self.model(id_img, video_frames)
        
        # Chuyển Logit -> Probability (0.0 đến 1.0)
        live_score = torch.sigmoid(live_logit)
        match_score = torch.sigmoid(match_logit)
        return live_score, match_score

wrapper_model = InferenceWrapper(model)
wrapper_model.eval()

# 3. Tạo Dummy Input (Dữ liệu giả để trace mô hình)
# ID: [1, 3, 224, 224]
dummy_id = torch.randn(1, 3, 224, 224)
# Video: [1, 8, 3, 224, 224] (Batch=1, Seq=8)
dummy_video = torch.randn(1, 8, 3, 224, 224)

# 4. Trace và Save (Quan trọng: dùng torch.jit.trace)
traced_script_module = torch.jit.trace(wrapper_model, (dummy_id, dummy_video))

# Tối ưu hóa cho Mobile
traced_script_module_optimized = torch.utils.mobile_optimizer.optimize_for_mobile(traced_script_module)

# Lưu file .ptl (PyTorch Lite)
traced_script_module_optimized._save_for_lite_interpreter("ekyc_model_mobile.ptl")

print("Convert thành công! File 'ekyc_model_mobile.ptl' đã sẵn sàng.")

Convert thành công! File 'ekyc_model_mobile.ptl' đã sẵn sàng.
