In [14]:
# train_and_proto.py
import os
import json
from pathlib import Path
from collections import defaultdict, Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import timm
from helpers.helper import Helper
from torch.optim.lr_scheduler import CosineAnnealingLR
from PIL import Image
from datetime import datetime
import pandas as pd

class ProtoInfer:
    def __init__(self, model, device, proto_ckpt_path, class_mapping_path=None):
        """
        class_mapping_path: path to class_mappings.json (optional)
        Supports:
          - dict index->name  (keys int or numeric strings)
          - list              (index->name)
          - dict protoLabel->name (keys are strings that match checkpoint proto_labels)
        """
        self.model = model.eval().to(device)
        self.device = device
        ckpt = torch.load(proto_ckpt_path, map_location=device)

        self.prototypes  = ckpt["prototypes"].to(device)   # [K,D] L2-normalized
        self.proto_labels = ckpt["proto_labels"]           # list[str]
        self.tau = float(ckpt["tau"])
        emb_dim = ckpt.get("embedding_dim", None)

        if emb_dim is not None and self.prototypes.shape[1] != emb_dim:
            raise ValueError(
                f"Embedding dim mismatch: ckpt={emb_dim}, prototypes D={self.prototypes.shape[1]}"
            )

        # ---- Load optional class mapping
        self.idx_to_name = None         # list or dict[int]->str
        self.label_to_name = None       # dict[str]->str

        if class_mapping_path is not None:
            with open(class_mapping_path, "r", encoding="utf-8") as f:
                mapping = json.load(f)

            # list -> index mapping
            if isinstance(mapping, list):
                self.idx_to_name = list(mapping)

            # dict -> either index->name or protoLabel->name
            elif isinstance(mapping, dict):
                # detect if keys are index-like (ints or numeric strings)
                def _is_int_like(k):
                    if isinstance(k, int): return True
                    if isinstance(k, str) and k.isdigit(): return True
                    return False

                if all(_is_int_like(k) for k in mapping.keys()):
                    # build int-keyed dict (normalize str keys -> int)
                    self.idx_to_name = {int(k): v for k, v in mapping.items()}
                else:
                    # assume keys are proto label strings
                    self.label_to_name = dict(mapping)

        # Precompute a set for quick membership
        self._have_idx_map = self.idx_to_name is not None
        self._have_lbl_map = self.label_to_name is not None

    @torch.no_grad()
    def predict_stack(self, images_tensors, tau=None):
        """
        images_tensors: list[Tensor[C,H,W]] transformed like val_transform
        returns: list[[class_name, count], ...] (includes 'unknown')
        """
        if tau is None:
            tau = self.tau

        x = torch.stack(images_tensors).to(self.device)  # [B,C,H,W]
        feats = get_embeddings(self.model, x)            # [B,D]
        sims  = feats @ self.prototypes.T                # [B,K]
        max_sims, idxs = sims.max(dim=1)

        labels = []
        for s, i in zip(max_sims.tolist(), idxs.tolist()):
            if s >= tau:
                # try index mapping first
                if self._have_idx_map and i in self.idx_to_name:
                    name = self.idx_to_name[i]
                else:
                    # try proto-label mapping
                    proto_lbl = self.proto_labels[i]
                    if self._have_lbl_map and proto_lbl in self.label_to_name:
                        name = self.label_to_name[proto_lbl]
                    else:
                        # fallback to original proto label
                        name = proto_lbl
                labels.append(name)
            else:
                labels.append("unknown")

        counts = Counter(labels)
        return [[k, v] for k, v in counts.items()]


# =========================
# EMBEDDINGS / PROTOTYPES
# =========================
@torch.no_grad()
def get_embeddings(model, x):
    """
    Get L2-normalized pre-logits embeddings from ConvNeXtV2.
    timm convnextv2: forward_features -> forward_head(pre_logits=True).
    """
    feats = model.forward_features(x)                  # pre-classifier features
    emb  = model.forward_head(feats, pre_logits=True)  # pooled pre-logits
    return F.normalize(emb, dim=1)  



In [15]:
import random
import cv2
import numpy as np

class CATTLE_IDENTIFICATION():

    data_transforms = {
        'train': transforms.Compose([
        #    transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(0.2),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'validation': transforms.Compose([
        #   transforms.Resize(256),
        #   transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    def get_files_from_folder_scan_sort(self,path, limitX=20):
    # Use os.scandir() for faster directory scanning
        with os.scandir(path) as entries:
        # Collect all filenames into a list
            files = [entry.name for entry in entries if entry.is_file()]
            
            # Sort the files based on the numeric part of the filenames (assuming filenames contain numbers)
            files.sort(key=lambda x: int(x.split('_')[0]), reverse=False)  # Adjust as needed for your filenames
            
            # Return the top 'limitX' files
            return files [:limitX]
    
    def enhance_ir_fast(self,frame_bgr, *,
                        bilateral_d=7, bilateral_sigma_color=35, bilateral_sigma_space=35,
                        clahe_clip=2.0, clahe_tile=8,
                        unsharp_sigma=1.2, unsharp_amount=0.9):
        # Enable OpenCV CPU optimizations
        cv2.setUseOptimized(True)
        cv2.setNumThreads(0)  # let OpenCV decide (or set to your core count)

        # --- Work in YCrCb so we touch only luminance ---
        ycrcb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2YCrCb)
        Y, Cr, Cb = cv2.split(ycrcb)

        # --- Fast denoise on Y (bilateral) ---
        # ~10â€“30x faster than NLM for typical params
        Y = cv2.bilateralFilter(Y, d=bilateral_d,
                                sigmaColor=bilateral_sigma_color,
                                sigmaSpace=bilateral_sigma_space)

        # Optional: if you have ximgproc, guided filter is even faster/cleaner on edges
        # from cv2.ximgproc import guidedFilter
        # Y = guidedFilter(guide=Y, src=Y, radius=4, eps=1e-4)

        # --- Local contrast (CLAHE) on Y ---
        clahe = cv2.createCLAHE(clipLimit=clahe_clip,
                                tileGridSize=(clahe_tile, clahe_tile))
        Y = clahe.apply(Y)

        # --- Merge back to color ---
        ycrcb = cv2.merge([Y, Cr, Cb])
        out = cv2.cvtColor(ycrcb, cv2.COLOR_YCrCb2BGR)

        # --- Unsharp mask (Gaussian is separable, very fast on CPU) ---
        blur = cv2.GaussianBlur(out, (0, 0), unsharp_sigma)
        out = cv2.addWeighted(out, 1.0 + unsharp_amount, blur, -unsharp_amount, 0)

        return out
    

    def unsharp_mask(self, bgr, radius=1.4, amount=1.0):
        blurred = cv2.GaussianBlur(bgr, (0,0), radius)
        sharp = cv2.addWeighted(bgr, 1+amount, blurred, -amount, 0)
        return np.clip(sharp, 0, 255).astype(np.uint8)

    def apply_clahe_ycrcb(self, bgr, clip=2.5, tile=8):
        ycrcb = cv2.cvtColor(bgr, cv2.COLOR_BGR2YCrCb)
        y, cr, cb = cv2.split(ycrcb)
        clahe = cv2.createCLAHE(clipLimit=clip, tileGridSize=(tile, tile))
        y2 = clahe.apply(y)
        merged = cv2.merge([y2, cr, cb])
        return cv2.cvtColor(merged, cv2.COLOR_YCrCb2BGR)

    def predict_stack_cattle_id(self, images, threshold=0.85):
        """
        Predict cattle IDs from a stack of images.
        
        Args:
            images: List of PIL Images or numpy arrays
            threshold: Confidence threshold for predictions
        
        Returns:
            List of [class_name, count] pairs for predictions above threshold
        """
        self.model.eval()

        # Preprocess all images
        #processed_images = [self.preprocess_image(img) for img in images]
        
        # Stack images into a single tensor
        image_tensor = torch.stack(images).to(self.device)

        with torch.no_grad():
            # Get model predictions
            outputs = self.model(image_tensor)
            probabilities = torch.softmax(outputs, dim=1)
            max_probs, preds = torch.max(probabilities, 1)

            # Filter predictions based on threshold
            filtered_classes = [
                self.class_names[str(pred.item())] 
                for pred, prob in zip(preds, max_probs) 
                if prob.item() >= threshold
            ]

            # Count occurrences of filtered classes
            class_counts = Counter(filtered_classes)
            filtered_counts = [[key, count] for key, count in class_counts.items()]

        return filtered_counts

    def predict_stack_cattle_id_conf(self, images, threshold=0.7):
        """
        Predict cattle IDs from a stack of images.

        Args:
            images: List of preprocessed torch tensors
            threshold: Confidence threshold for filtering final results

        Returns:
            filtered_counts: List of [class_name, count] pairs for predictions above threshold
            predictions: List of [class_name, confidence_score] for each image
        """
        self.model.eval()

        # Stack images into a single tensor
        image_tensor = torch.stack(images).to(self.device)

        with torch.no_grad():
            outputs = self.model(image_tensor)
            probabilities = torch.softmax(outputs, dim=1)
            max_probs, preds = torch.max(probabilities, 1)

            predictions = []
            filtered_classes = []

            for pred, prob in zip(preds, max_probs):
                class_name = self.class_names[str(pred.item())]
                confidence = prob.item()
                predictions.append([class_name, confidence])
                if confidence >= threshold:
                    filtered_classes.append(class_name)

            # Count occurrences of filtered classes
            class_counts = Counter(filtered_classes)
            filtered_counts = [[key, count] for key, count in class_counts.items()]

        print("Predictions:", predictions)

        return filtered_counts

    def _load_model(classLen, model_path,device):
        # Create model architecture
        # model = timm.create_model("tf_efficientnetv2_s", pretrained=False, num_classes=classLen)
        # model.load_state_dict(torch.load(model_path, map_location=device))
        # model.eval()
        
        model = timm.create_model('convnextv2_base.fcmae_ft_in22k_in1k', pretrained=False, num_classes=classLen)
         # Load the trained weights
        # torch.load handles map_location
        state_dict = torch.load(model_path, map_location=device)

        # Now load into model
        model.load_state_dict(state_dict)

        # model.load_state_dict(torch.load(model_path, map_location=device))
        # model.to(device)
        # model.eval()
        
        
        return model.to(device)

    def _load_model_resnet(classLen, model_path,device):
        # Create model architecture
        # model = timm.create_model('resnet101', pretrained=False)
        # # model.fc = nn.Sequential(
        # #     nn.Linear(model.fc.in_features, 512),
        # #     nn.ReLU(),
        # #     nn.Dropout(0.3),
        # #     nn.Linear(512, classLen)
        # # )
        model = timm.create_model('resnet101', pretrained=False, num_classes=classLen)
        state = torch.load(model_path, map_location=device)  # handles CPU/GPU safely
        model.load_state_dict(state, strict=True)
        model.eval()

        # Load trained weights
        model.load_state_dict(torch.load(model_path))
        model.eval()
        return model.to(device)


    SAVE_PATH = Path("identification_models_KNP") / "KNP_CONVNEXT_Nighttime_01"
    #SAVE_PATH = Path("identification_models_KNP") / "KNP_CONVNEXT_Gray_03"
    #SAVE_PATH = Path("identification_models_Sumi") / "SUMI_CONVNEXT_2_03"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    class_mapping_path = f'{SAVE_PATH}\\class_mappings.json'
    # Load model + weights
    NUM_CLASSES = len(json.load(open(SAVE_PATH / "class_mappings.json")))  # dynamic
    model = timm.create_model('convnextv2_base.fcmae_ft_in22k_in1k',
                            pretrained=False, num_classes=NUM_CLASSES).to(device)
    model.load_state_dict(torch.load(SAVE_PATH / "last_model.pth", map_location=device))
    model.eval()

    infer = ProtoInfer(model, device, proto_ckpt_path=SAVE_PATH / "prototypes.pt", class_mapping_path=class_mapping_path )

    val_transform = transforms.Compose([
        #transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                            [0.229, 0.224, 0.225]),
    ])

    def batch_identification(self,save_dir,tracking_id,FORCE_MISSED_DETECTION = False,path = None, batch_size = 50, showCon = False):
        start_time = datetime.now()
        
        save_path = path
        #global batch_size 
        if path is None:
            save_path = f"{save_dir}//{tracking_id}"
        
        images = self.get_files_from_folder_scan_sort(save_path,batch_size)
        if len(images) < 20:
            return []
        stacked_images = []
        
        for image in images:
            #if counter < 1:
                #break
            #print(image)
            img_path = os.path.join(save_path, image)
            frame = cv2.imread(img_path)
            #predict_cattle_id(frame)
            # Ensure the image is in RGB format
            #frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = self.apply_clahe_ycrcb(frame, clip=2.5, tile=8)  #B
            frame = self.unsharp_mask(frame, radius=1.4, amount=1.0)  #B

            # a = cv2.fastNlMeansDenoisingColored(frame, None, h=5, hColor=5, templateWindowSize=7, searchWindowSize=21)
            # a = self.apply_clahe_ycrcb(a, clip=2.0, tile=8)
            # a = self.unsharp_mask(a, radius=1.2, amount=0.9)
            #frame = self.enhance_ir_fast(frame)
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            # === Random mid-grey background replacement ===
            # lower_black = np.array([0, 0, 0], dtype=np.uint8)
            # upper_black = np.array([5, 5, 5], dtype=np.uint8)
            # grey_val = random.randint(110, 160)  # random shade
            # random_grey = (grey_val, grey_val, grey_val)  # BGR
            # mask_black = cv2.inRange(frame, lower_black, upper_black)
            # frame[mask_black > 0] = random_grey
            # ==============================================
            # Preprocess the image using the defined transforms
            input_tensor = self.data_transforms['validation'](frame).unsqueeze(0)  # Add batch dimension
            input_tensor = input_tensor.to(self.device)

            # Append the processed image tensor to the list
            stacked_images.append(input_tensor[0])  # Append the first element to keep it as (C, H, W)
            #counter -= 1
        
        # Call predict_stack_cattle_id with the stacked_images
        #return self.predict_stack_cattle_id(stacked_images)
        #if showCon:
        #   return self.predict_stack_cattle_id_conf(stacked_images)
        return self.infer.predict_stack(stacked_images,0.3)
        #return self.predict_by_arcface_batch(stacked_images)
    #model = _load_model_resnet(len(class_names),identification_model_path,device)
    #model = _load_arcface_model(len(class_names),identification_model_path,device)
    
    #print(class_names)
begin_identification = CATTLE_IDENTIFICATION()

#identification_model_path = "identification_model|s\\resnet50V5\\best_model.pth" 
#model = _load_model(len(class_names),identification_model_path,device)


In [23]:

begin_identification.batch_identification(r'E:\Output\runs\KUNNEPPU_identification\KNP_Merge_and_Batch_identification\29 Oct Report 2025\reuse tetris_Sumiyoshi-44\Base_rtx8000_10_August_2025_20000_v1\testris_Ch003_Camera 003_E2\2025-10-11_0023_part 2',
    "326", batch_size = 100,showCon = True)
#begin_identification.batch_identification(r'G:\Output\runs\KUNNEPPU_identification\KNP_Merge_and_Batch_identification\1 Sept Report 2025\reuse tetris_Sumiyoshi-44\Base_rtx8000_June_15000_v2\testris_Ch007_Camera 007_E2\2025-08-27_0514_part 2',"2", batch_size = 100)

[['4783', 5],
 ['4413', 2],
 ['4381', 1],
 ['4598', 57],
 ['4999', 3],
 ['3875', 31],
 ['4548', 1]]

In [6]:
path = r'G:\Output\runs\KUNNEPPU_identification\KNP_Merge_and_Batch_identification\IR collect Oct Report 2025\reuse tetris_Sumiyoshi-44\Base_rtx8000_10_August_2025_20000_v1\testris_Ch004_Camera 004_E2\2025-10-14_0023_part 1'
#path = r'E:\Output\runs\KUNNEPPU_identification\KNP_Merge_and_Batch_identification\26 Sept Report 2025\reuse tetris_Sumiyoshi-44\Base_rtx8000_10_August_2025_20000_v1\testris_Ch005_KP-FHDD-4-05\2025-09-20_0514_part 27'
result = []
for subfolder in os.listdir(path):
    subfolder_path = os.path.join(path, subfolder)
    if os.path.isdir(subfolder_path):
        # Process each subfolder
        print(f"Processing subfolder: {subfolder}")
        # Call your function here with the subfolder path
        #begin_identification.batch_identification(subfolder_path)
        #break
        try:

            batch_result = begin_identification.batch_identification(path,subfolder,batch_size = 100)
            print(batch_result)
            if(len(batch_result)>0):
                maxResult = max(batch_result,key=lambda x:x[1])
                print(maxResult)
                result.append([subfolder,maxResult]) 

        except Exception as e:
            print(f"Error processing subfolder {subfolder}: {e}")
    else:
        print(f"{subfolder} is not a directory.")

def resultToCsv(result):
    # Create a DataFrame from the result list
    df = pd.DataFrame(result, columns=['TrackingID', 'Cattle ID'])
    
    # Save the DataFrame to a CSV file
    
    csv_path = os.path.join(path, 'manual_result_feature.csv')
    df.to_csv(csv_path, index=False)
    
    print(f"Results saved to {csv_path}")


resultToCsv(result)


0023_identification.mp4 is not a directory.
Processing subfolder: 1
[['4533', 75], ['4842', 25]]
['4533', 75]
Processing subfolder: 10
[]
Processing subfolder: 100
[]
Processing subfolder: 1000
[['4783', 29], ['4367', 1], ['4838', 1], ['4350', 1]]
['4783', 29]
Processing subfolder: 1001
[['4815', 1], ['4635', 5], ['4987', 50], ['4554', 3], ['4548', 37], ['4527', 1], ['4142', 2], ['4551', 1]]
['4987', 50]
Processing subfolder: 1005
[['4142', 21], ['4815', 10], ['4527', 12], ['4548', 52], ['4553', 5]]
['4548', 52]
Processing subfolder: 1007
[]
Processing subfolder: 1008
[['1374', 99], ['4635', 1]]
['1374', 99]
Processing subfolder: 101
[]
Processing subfolder: 1017
[['4826', 6], ['4816', 2], ['4161', 16], ['4598', 30], ['4628', 1], ['4635', 1], ['4853', 1], ['9697', 1]]
['4598', 30]
Processing subfolder: 1018
[]
Processing subfolder: 102
[]
Processing subfolder: 1020
[]
Processing subfolder: 1021
[]
Processing subfolder: 1022
[]
Processing subfolder: 1023
[]
Processing subfolder: 1024
[]

In [None]:
import pandas as pd

In [None]:
resultToCsv(result)