In [33]:
import os
import json
from pathlib import Path
from collections import defaultdict, Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import timm
from helpers.helper import Helper
from torch.optim.lr_scheduler import CosineAnnealingLR
from PIL import Image
from datetime import datetime
import pandas as pd

import cv2
import numpy as np


class ProtoInfer:
    def __init__(self, model, device, proto_ckpt_path, class_mapping_path=None):
        """
        class_mapping_path: path to class_mappings.json (optional)
        Supports:
          - dict index->name  (keys int or numeric strings)
          - list              (index->name)
          - dict protoLabel->name (keys are strings that match checkpoint proto_labels)
        """
        self.model = model.eval().to(device)
        self.device = device
        ckpt = torch.load(proto_ckpt_path, map_location=device)

        self.prototypes  = ckpt["prototypes"].to(device)   # [K,D] L2-normalized
        self.proto_labels = ckpt["proto_labels"]           # list[str]
        self.tau = 4.0#float(ckpt["tau"])
        print("tau : ",self.tau)
        emb_dim = ckpt.get("embedding_dim", None)

        if emb_dim is not None and self.prototypes.shape[1] != emb_dim:
            raise ValueError(
                f"Embedding dim mismatch: ckpt={emb_dim}, prototypes D={self.prototypes.shape[1]}"
            )

        # ---- Load optional class mapping
        self.idx_to_name = None         # list or dict[int]->str
        self.label_to_name = None       # dict[str]->str

        if class_mapping_path is not None:
            with open(class_mapping_path, "r", encoding="utf-8") as f:
                mapping = json.load(f)

            # list -> index mapping
            if isinstance(mapping, list):
                self.idx_to_name = list(mapping)

            # dict -> either index->name or protoLabel->name
            elif isinstance(mapping, dict):
                # detect if keys are index-like (ints or numeric strings)
                def _is_int_like(k):
                    if isinstance(k, int): return True
                    if isinstance(k, str) and k.isdigit(): return True
                    return False

                if all(_is_int_like(k) for k in mapping.keys()):
                    # build int-keyed dict (normalize str keys -> int)
                    self.idx_to_name = {int(k): v for k, v in mapping.items()}
                else:
                    # assume keys are proto label strings
                    self.label_to_name = dict(mapping)

        # Precompute a set for quick membership
        self._have_idx_map = self.idx_to_name is not None
        self._have_lbl_map = self.label_to_name is not None

    @torch.no_grad()
    def predict_stack(self, images_tensors, tau=None):
        """
        images_tensors: list[Tensor[C,H,W]] transformed like val_transform
        returns: list[[class_name, count], ...] (includes 'unknown')
        """
        if tau is None:
            tau = self.tau

        x = torch.stack(images_tensors).to(self.device)  # [B,C,H,W]
        feats = get_embeddings(self.model, x)            # [B,D]
        sims  = feats @ self.prototypes.T                # [B,K]
        max_sims, idxs = sims.max(dim=1)

        labels = []
        for s, i in zip(max_sims.tolist(), idxs.tolist()):
            if s >= tau:
                # try index mapping first
                if self._have_idx_map and i in self.idx_to_name:
                    name = self.idx_to_name[i]
                else:
                    # try proto-label mapping
                    proto_lbl = self.proto_labels[i]
                    if self._have_lbl_map and proto_lbl in self.label_to_name:
                        name = self.label_to_name[proto_lbl]
                    else:
                        # fallback to original proto label
                        name = proto_lbl
                labels.append(name)
            else:
                labels.append("unknown")

        counts = Counter(labels)
        return [[k, v] for k, v in counts.items()]


# =========================
# EMBEDDINGS / PROTOTYPES
# =========================
@torch.no_grad()
def get_embeddings(model, x):
    """
    Get L2-normalized pre-logits embeddings from ConvNeXtV2.
    timm convnextv2: forward_features -> forward_head(pre_logits=True).
    """
    feats = model.forward_features(x)                  # pre-classifier features
    emb  = model.forward_head(feats, pre_logits=True)  # pooled pre-logits
    return F.normalize(emb, dim=1)  
@torch.no_grad()
def extract_foreground_mask(x, pixel_thresh=4):
    """
    x: [B, 3, H, W] normalized tensor
    returns: [B, 1, H, W] boolean mask
    """
    t = (pixel_thresh / 255.0) / 0.229
    mag = x.abs().sum(dim=1, keepdim=True)
    return mag > t


@torch.no_grad()
def get_region_embeddings_masked(
    model,
    x,
    num_regions=3,
    min_fg_ratio=0.12
):
    feats = model.forward_features(x)   # [B, C, Hf, Wf]
    B, C, Hf, Wf = feats.shape

    fg_mask = extract_foreground_mask(x)
    fg_mask = F.interpolate(
        fg_mask.float(),
        size=(Hf, Wf),
        mode="nearest"
    ).bool()

    feat_splits = torch.chunk(feats, num_regions, dim=2)
    mask_splits = torch.chunk(fg_mask, num_regions, dim=2)

    region_embs = []
    valid_regions = []

    for f, m in zip(feat_splits, mask_splits):
        fg_pixels = m.sum(dim=[2, 3])
        total_pixels = m.shape[2] * m.shape[3]
        fg_ratio = fg_pixels / total_pixels

        valid = (fg_ratio.squeeze(1) >= min_fg_ratio)
        valid_regions.append(valid)

        pooled = (f * m).sum(dim=[2, 3]) / (fg_pixels + 1e-6)
        pooled = F.normalize(pooled, dim=1)

        region_embs.append(pooled)

    return (
        torch.stack(region_embs, dim=1),   # [B, R, D]
        torch.stack(valid_regions, dim=1)  # [B, R]
    )


class CATTLE_IDENTIFICATION():

    data_transforms = {
        'train': transforms.Compose([
    
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'validation': transforms.Compose([
            #transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    def get_files_from_folder_scan_sort(self,path, limitX=20,Need_ReCheck = False):
    # Use os.scandir() for faster directory scanning
    
        with os.scandir(path) as entries:
        # Collect all filenames into a list
            files = [entry.name for entry in entries if entry.is_file()]
            # Sort the files based on the numeric part of the filenames (assuming filenames contain numbers)
            files.sort(key=lambda x: int(x.split('_')[0]), reverse=False)  # Adjust as needed for your filenames
            RECHECK_FILES = []
            if Need_ReCheck: #file name that include RC_TRUE at the end
                RECHECK_FILES = [f for f in files if 'RC_TRUE' in f][-50:]
            # Return the top 'limitX' files
            return files [:limitX] + RECHECK_FILES
    

    #SAVE_PATH = Path("identification_models_KNP") / "KNP_CONVNEXT_ARCFACE_22_Jan_2026_fixed" #KNP_CONVNEXT_ARCFACE_22_Jan_2026_fixed KNP_CONVNEXT_ARCFACE_Jan_2026_fixed
    SAVE_PATH = Path("identification_models_KNP") / "KNP_CONVNEXT_ARCFACE_22_Jan_2026_fixed_02" #KNP_CONVNEXT_ARCFACE_22_Jan_2026_fixed KNP_CONVNEXT_ARCFACE_Jan_2026_fixed
   
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    class_mapping_path = f'{SAVE_PATH}\\class_mappings.json'
    
    # Load model + weights (ALIGNED WITH TRAINING)
    NUM_CLASSES = len(json.load(open(SAVE_PATH / "class_mappings.json")))  # dynamic
    
    # Backbone only, NO classifier (matches training num_classes=0)
    model = timm.create_model('convnextv2_base.fcmae_ft_in22k_in1k',
                            pretrained=False, num_classes=0).to(device)

    # Load backbone weights (matches training's LAST_MODEL_SAVE_PATH)
    state = torch.load(SAVE_PATH / "best_backbone.pth", map_location=device)
    model.load_state_dict(state, strict=True)  # strict=True since architecture matches
    model.eval()
    
    # Load prototypes (matches training's LAST_PROTOS_PATH)
    infer = ProtoInfer(model, device, proto_ckpt_path=SAVE_PATH / "best_prototypes.pt", class_mapping_path=class_mapping_path)

    val_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                            [0.229, 0.224, 0.225]),
    ])

    def batch_identification(self,save_dir,tracking_id,FORCE_MISSED_DETECTION = False,path = None, batch_size = 50, showCon = False, Need_ReCheck = False, has_len_restriction = True):
        start_time = datetime.now()
        
        save_path = path
        #global batch_size 
        if path is None:
            save_path = f"{save_dir}//{tracking_id}"
        
        images = self.get_files_from_folder_scan_sort(save_path,batch_size,Need_ReCheck)
        if len(images) < 20 and has_len_restriction:
            return []
        stacked_images = []
        
        for image in images:
            #if counter < 1:
                #break
            #print(image)
            img_path = os.path.join(save_path, image)
            frame = cv2.imread(img_path)
         
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            input_tensor = self.data_transforms["validation"](frame).unsqueeze(0)  # Add batch dimension
            input_tensor = input_tensor.to(self.device)

            # Append the processed image tensor to the list
            stacked_images.append(input_tensor[0])  # Append the first element to keep it as (C, H, W)
            #counter -= 1
        
        return self.infer.predict_stack(stacked_images,0.3)
    
begin_identification = CATTLE_IDENTIFICATION()

tau :  4.0


In [None]:

batch_result = begin_identification.batch_identification(
    #r"H:\Output\runs\KUNNEPPU_identification\ConvNext_ArcFace 22 Jan 2026\KNP\Base_rtx8000_10_August_2025_20000_v1\Ch007_Camera 007_E2\2025-12-24_0023_part 1",    
    r"H:\Output\runs\KUNNEPPU_identification\ConvNext_ArcFace 22 Jan 2026\KNP\Base_rtx8000_10_August_2025_20000_v1\Ch002_Camera 002_E2\2025-12-24_0220_part 1",    
    "11", batch_size = 50,showCon = False, 
    Need_ReCheck=True, has_len_restriction = False)

max_result = max(batch_result, key=lambda x:x[1])
print(batch_result)
print(max_result)

#begin_identification.batch_identification(r'G:\Output\runs\KUNNEPPU_identification\KNP_Merge_and_Batch_identification\1 Sept Report 2025\reuse tetris_Sumiyoshi-44\Base_rtx8000_June_15000_v2\testris_Ch007_Camera 007_E2\2025-08-27_0514_part 2',"2", batch_size = 100)

[['4567', 29], ['3875', 1], ['4990', 1], ['A18', 7], ['A10', 12]]
['4567', 29]


: 

In [2]:
#base_dir = r'E:\Output\runs\KUNNEPPU_identification\EF_vector 8 Jan 2026\KNP\Base_rtx8000_December_2025_20000_v2'
base_dir = r'H:\Output\runs\KUNNEPPU_identification\ConvNext_ArcFace 22 Jan 2026\KNP\Base_rtx8000_10_August_2025_20000_v1'
cameras = [
    '/Ch001_Camera 001_E2/2025-12-24_0023_part 1/',
     '/Ch002_Camera 002_E2/2025-12-24_0023_part 1/',
     '/Ch003_Camera 003_E2/2025-12-24_0023_part 1/',
    '/Ch004_Camera 004_E2/2025-12-24_0023_part 1/',
     '/Ch005_Camera 005_E2/2025-12-24_0023_part 1/',
     '/Ch006_Camera 006_E2/2025-12-24_0023_part 1/',
     '/Ch007_Camera 007_E2/2025-12-24_0023_part 1/',
    ]
for camera in cameras:
        
    path = base_dir + camera
    result = []
    for subfolder in os.listdir(path):
        subfolder_path = os.path.join(path, subfolder)
        if os.path.isdir(subfolder_path):
            # Process each subfolder
            print(f"Processing subfolder: {subfolder}")
            # Call your function here with the subfolder path
            #begin_identification.batch_identification(subfolder_path)
            #break
            try:

                batch_result = begin_identification.batch_identification(path,subfolder,batch_size = 50,Need_ReCheck=True) #for RC_TRUE
                print(batch_result)
                if(len(batch_result)>0):
                    maxResult = max(batch_result,key=lambda x:x[1])
                    print(maxResult)
                    result.append([subfolder,maxResult[0]]) 

            except Exception as e:
                print(f"Error processing subfolder {subfolder}: {e}")
        else:
            print(f"{subfolder} is not a directory.")

    def resultToCsv(result):
        # Create a DataFrame from the result list
        df = pd.DataFrame(result, columns=['TrackingID', 'Cattle ID'])
        
        # Save the DataFrame to a CSV file
        
        csv_path = os.path.join(path, 'ConvNextV2_Arcface_Result.csv')
        df.to_csv(csv_path, index=False)
        
        print(f"Results saved to {csv_path}")


    resultToCsv(result)


0023_identification.mp4 is not a directory.
Processing subfolder: 1
[['4979', 52]]
['4979', 52]
Processing subfolder: 100
[['L7', 51], ['4413', 1]]
['L7', 51]
Processing subfolder: 103
[['4554', 1], ['4307', 1], ['L5', 50]]
['L5', 50]
Processing subfolder: 105
[['4525', 12], ['L11', 6], ['4979', 7], ['4413', 2]]
['4525', 12]
Processing subfolder: 106
[['4413', 9], ['4525', 42], ['4979', 11], ['4803', 3], ['4161', 1]]
['4525', 42]
Processing subfolder: 107
[['4979', 2], ['L11', 1], ['4416', 1], ['4413', 44], ['4990', 1], ['4525', 1]]
['4413', 44]
Processing subfolder: 108
[['L5', 20]]
['L5', 20]
Processing subfolder: 109
[['4993', 52]]
['4993', 52]
Processing subfolder: 110
[]
Processing subfolder: 111
[['4979', 52]]
['4979', 52]
Processing subfolder: 112
[['1374', 45], ['L6', 5], ['L7', 6]]
['1374', 45]
Processing subfolder: 114
[]
Processing subfolder: 115
[['4815', 50]]
['4815', 50]
Processing subfolder: 119
[['4628', 1], ['L9', 2], ['4416', 1], ['4627', 1], ['4381', 19], ['4413', 7]

In [3]:
import csv
import os
from datetime import timedelta, datetime

FPS = 5
DAY_START = datetime.strptime("00:00:00", "%H:%M:%S")
NIGHT_START = datetime.strptime("22:00:00", "%H:%M:%S")
FIVE_HOURS = timedelta(hours=6)

def frame_to_time(frame_number):
    seconds = frame_number / FPS
    return timedelta(seconds=seconds)

def adjust_time_window(duration):
    if duration <= FIVE_HOURS:
        return DAY_START + duration
    else:
        return NIGHT_START + (duration - FIVE_HOURS)

def process_tracking_folder(base_dir, tracking_id):
    folder = os.path.join(base_dir, str(tracking_id))

    filenames = sorted(os.listdir(folder))

    #files = [entry.name for entry in entries if entry.is_file()]
            # Sort the files based on the numeric part of the filenames (assuming filenames contain numbers)
    filenames.sort(key=lambda x: int(x.split('_')[0]), reverse=False)
    #print(filenames[0])
    #print(filenames[-1])
    try:

        start_frame = int(filenames[0].split("_")[2])
        end_frame = int(filenames[-1].split("_")[2])
    except Exception as ex:
        print(filenames[0])
        raise ex
    #print(start_frame, end_frame, ' start : end frame for tracking id ', tracking_id)
    #print(start_frame, end_frame, ' start end frame for tracking id ', tracking_id)
    start_time = adjust_time_window(frame_to_time(start_frame))
    end_time = adjust_time_window(frame_to_time(end_frame))

    duration = end_time - start_time
    return start_time, end_time, duration

def process_csv(csv_path, base_dir, output_path):
    csv_path = base_dir + csv_path
    output_path = base_dir + output_path
    with open(csv_path, newline='') as f:
        reader = csv.DictReader(f)
        rows = list(reader)

    results = []

    for row in rows:
        tracking_id = row["TrackingID"]
        cattle_raw = row["Cattle ID"]

        # extract GT ID from list format "['1234', 12]"
        gt_id = cattle_raw

        # predicted id = gt id (your requirement)
        predicted_id = gt_id

        start_time, end_time, duration = process_tracking_folder(base_dir, tracking_id)

        results.append({
            "TrackingID": tracking_id,
            "Ground-Truth": gt_id,
            "Predicted ID": predicted_id,
            "StartTime": start_time.strftime("%H:%M:%S"),
            "EndTime": end_time.strftime("%H:%M:%S"),
            "Duration": str(duration)
        })

    # write output CSV
    with open(output_path, "w", newline='') as f:
        writer = csv.DictWriter(
            f, 
            fieldnames=["TrackingID", "Ground-Truth", "Predicted ID", "StartTime", "EndTime", "Duration"]
        )
        writer.writeheader()
        writer.writerows(results)

    print("Done:", output_path)


# base_dir = r'H:\Output\runs\KUNNEPPU_identification\ConvNext_ArcFace 22 Jan 2026\KNP\Base_rtx8000_10_August_2025_20000_v1'
# cameras = [
#     '/Ch001_Camera 001_E2/2025-12-26_0023_part 2/',
#     #  '/Ch002_Camera 002_E2/2025-12-26_0023_part 1/',
#     #  '/Ch003_Camera 003_E2/2025-12-26_0023_part 1/',
#     # '/Ch004_Camera 004_E2/2025-12-26_0023_part 1/',
#     #  '/Ch005_Camera 005_E2/2025-12-26_0023_part 1/',
#     #  '/Ch006_Camera 006_E2/2025-12-26_0023_part 1/',
#     #  '/Ch007_Camera 007_E2/2025-12-26_0023_part 1/',
#     ]
#base_dir = r'E:\Output\runs\KUNNEPPU_identification\EF_vector 8 Jan 2026\KNP\Base_rtx8000_December_2025_20000_v2'
base_dir = r'H:\Output\runs\KUNNEPPU_identification\ConvNext_ArcFace 22 Jan 2026\KNP\Base_rtx8000_10_August_2025_20000_v1'
cameras = [
    '/Ch001_Camera 001_E2/2025-12-24_0023_part 1/',
     '/Ch002_Camera 002_E2/2025-12-24_0023_part 1/',
     '/Ch003_Camera 003_E2/2025-12-24_0023_part 1/',
    '/Ch004_Camera 004_E2/2025-12-24_0023_part 1/',
     '/Ch005_Camera 005_E2/2025-12-24_0023_part 1/',
     '/Ch006_Camera 006_E2/2025-12-24_0023_part 1/',
     '/Ch007_Camera 007_E2/2025-12-24_0023_part 1/',
    ]

csv_path = "ConvNextV2_Arcface_Result.csv"
output_path = "ConvNextV2_Arcface_Result.csv"
for camera in cameras:
    camName = camera.split('/')[1].split(' ')[1]
    output_path = f"{camName}_ConvNextV2_Arcface_Result.csv"
    #print("Processing camera: ", output_path)
    process_csv(csv_path, base_dir + camera, output_path)

Done: H:\Output\runs\KUNNEPPU_identification\ConvNext_ArcFace 22 Jan 2026\KNP\Base_rtx8000_10_August_2025_20000_v1/Ch001_Camera 001_E2/2025-12-24_0023_part 1/001_E2_ConvNextV2_Arcface_Result.csv
Done: H:\Output\runs\KUNNEPPU_identification\ConvNext_ArcFace 22 Jan 2026\KNP\Base_rtx8000_10_August_2025_20000_v1/Ch002_Camera 002_E2/2025-12-24_0023_part 1/002_E2_ConvNextV2_Arcface_Result.csv
Done: H:\Output\runs\KUNNEPPU_identification\ConvNext_ArcFace 22 Jan 2026\KNP\Base_rtx8000_10_August_2025_20000_v1/Ch003_Camera 003_E2/2025-12-24_0023_part 1/003_E2_ConvNextV2_Arcface_Result.csv
Done: H:\Output\runs\KUNNEPPU_identification\ConvNext_ArcFace 22 Jan 2026\KNP\Base_rtx8000_10_August_2025_20000_v1/Ch004_Camera 004_E2/2025-12-24_0023_part 1/004_E2_ConvNextV2_Arcface_Result.csv
Done: H:\Output\runs\KUNNEPPU_identification\ConvNext_ArcFace 22 Jan 2026\KNP\Base_rtx8000_10_August_2025_20000_v1/Ch005_Camera 005_E2/2025-12-24_0023_part 1/005_E2_ConvNextV2_Arcface_Result.csv
Done: H:\Output\runs\KUNN

In [None]:
resultToCsv(result)