In [8]:
import torch
import torch.nn as nn
import cv2
import numpy as np
from PIL import Image
import json
import os
from pathlib import Path
import pandas as pd
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns
from torchvision import transforms
from transformers import ViTModel
import warnings

warnings.filterwarnings("ignore")

In [9]:
class TTI_ViT_Classifier(nn.Module):
    """ViT-based TTI Classifier matching the actual saved model structure"""

    def __init__(self, num_classes=2):
        super(TTI_ViT_Classifier, self).__init__()

        # Input processing layers (matching saved model naming)
        self.first = nn.Conv2d(
            5, 3, kernel_size=1, bias=False
        )  # Maps to "first.weight"
        self.pre_conv = nn.Conv2d(
            5, 3, kernel_size=1, bias=False
        )  # Maps to "pre_conv.weight"

        # ViT backbone (matches "backbone.*" in saved model)
        self.backbone = ViTModel.from_pretrained("google/vit-base-patch16-224")

        # Classification head (matches "fc.*" in saved model)
        self.fc = nn.Linear(768, num_classes)  # Direct linear layer

    def forward(self, x):
        # Convert 5-channel input to 3-channel using pre_conv
        x = self.pre_conv(x)

        # ViT forward pass
        vit_outputs = self.backbone(pixel_values=x)

        # Use [CLS] token representation
        cls_output = vit_outputs.last_hidden_state[:, 0]

        # Classification
        logits = self.fc(cls_output)
        return logits


class DepthEstimator:
    """Simplified depth estimation (replace with actual Depth-Anything-V2-Small if available)"""

    def __init__(self):
        # Placeholder for actual depth model
        self.transform = transforms.Compose(
            [transforms.Resize((224, 224)), transforms.ToTensor()]
        )

    def estimate_depth(self, rgb_image):
        """
        Estimate depth from RGB image
        Returns normalized depth map [0, 1]
        """
        # Convert to grayscale as simple depth proxy
        if len(rgb_image.shape) == 3:
            gray = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2GRAY)
        else:
            gray = rgb_image

        # Apply gaussian blur to simulate depth
        depth = cv2.GaussianBlur(gray, (15, 15), 0)

        # Normalize to [0, 1]
        depth = depth.astype(np.float32) / 255.0

        return depth


class TTI_VideoInference:
    """TTI Model Inference Pipeline for Videos"""

    def __init__(
        self, model_path, device="cuda" if torch.cuda.is_available() else "cpu"
    ):
        self.device = device
        self.model = self.load_model(model_path)
        self.depth_estimator = DepthEstimator()

        # Preprocessing transforms
        self.transform = transforms.Compose(
            [transforms.Resize((224, 224)), transforms.ToTensor()]
        )

        # Results storage
        self.results = {
            "video_results": [],
            "frame_predictions": [],
            "summary_stats": {},
        }

    def load_model(self, model_path):
        """Load the trained ViT model with flexible architecture matching"""
        model = TTI_ViT_Classifier(num_classes=2)

        try:
            checkpoint = torch.load(model_path, map_location=self.device)

            # Try to load the state dict directly
            try:
                model.load_state_dict(checkpoint)
                print(f"✓ Loaded model from {model_path}")
            except RuntimeError as e:
                print(f"Direct loading failed, trying flexible loading...")

                # Get model's state dict
                model_dict = model.state_dict()

                # Filter out layers that don't match
                filtered_dict = {}
                for k, v in checkpoint.items():
                    if k in model_dict and model_dict[k].shape == v.shape:
                        filtered_dict[k] = v
                    else:
                        print(f"Skipping layer {k} (shape mismatch or not found)")

                # Load the filtered state dict
                model_dict.update(filtered_dict)
                model.load_state_dict(model_dict, strict=False)
                print(
                    f"✓ Loaded model with {len(filtered_dict)}/{len(checkpoint)} layers matched"
                )

        except Exception as e:
            print(f"Error loading model: {e}")
            print("Attempting to use one of the EfficientNet models instead...")

            # Fallback to EfficientNet-B3 (best performing according to docs)
            return self.load_efficientnet_model(model_path.replace("ViT", "EffNet_B3"))

        model.to(self.device)
        model.eval()
        return model

    def load_efficientnet_model(self, model_path):
        """Fallback to load EfficientNet model"""
        try:
            from torchvision import models

            # Create EfficientNet-B3 model
            model = models.efficientnet_b3(pretrained=False)

            # Modify for 5-channel input and 2-class output
            model.features[0][0] = nn.Conv2d(
                5, 40, kernel_size=3, stride=2, padding=1, bias=False
            )
            model.classifier[1] = nn.Linear(model.classifier[1].in_features, 2)

            checkpoint = torch.load(model_path, map_location=self.device)
            model.load_state_dict(checkpoint, strict=False)
            print(f"✓ Loaded EfficientNet-B3 fallback model from {model_path}")

            model.to(self.device)
            model.eval()
            return model

        except Exception as e:
            print(f"Failed to load fallback model: {e}")
            return None

    def extract_roi_from_frame(self, frame, tool_mask=None, tissue_mask=None):
        """
        Extract ROI based on tool-tissue interaction
        For now, using the entire frame as ROI (modify based on actual mask availability)
        """
        if tool_mask is not None and tissue_mask is not None:
            # Calculate intersection
            intersection_mask = cv2.bitwise_and(tool_mask, tissue_mask)

            # Find bounding box of intersection
            contours, _ = cv2.findContours(
                intersection_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
            )

            if contours:
                # Get largest contour
                largest_contour = max(contours, key=cv2.contourArea)
                x, y, w, h = cv2.boundingRect(largest_contour)

                # Ensure minimum size (64x64 as per documentation)
                w = max(w, 64)
                h = max(h, 64)

                # Extract ROI
                roi = frame[y : y + h, x : x + w]
                mask_roi = intersection_mask[y : y + h, x : x + w]
            else:
                # Fallback to center crop
                roi = self.center_crop(frame, 224)
                mask_roi = np.zeros((224, 224), dtype=np.uint8)
        else:
            # Use entire frame as ROI
            roi = frame
            mask_roi = np.ones((frame.shape[0], frame.shape[1]), dtype=np.uint8) * 255

        return roi, mask_roi

    def center_crop(self, image, size):
        """Center crop image to specified size"""
        h, w = image.shape[:2]
        start_h = max(0, (h - size) // 2)
        start_w = max(0, (w - size) // 2)

        cropped = image[start_h : start_h + size, start_w : start_w + size]

        # Pad if necessary
        if cropped.shape[0] < size or cropped.shape[1] < size:
            cropped = cv2.resize(cropped, (size, size))

        return cropped

    def prepare_5channel_input(self, roi, mask_roi):
        """
        Prepare 5-channel input as per documentation:
        Channels 0-2: RGB values [0, 255]
        Channel 3: Depth values [0, 1]
        Channel 4: Interaction mask [0, 255]
        """
        # Resize ROI to 224x224
        roi_resized = cv2.resize(roi, (224, 224))
        mask_resized = cv2.resize(mask_roi, (224, 224))

        # Ensure RGB format
        if len(roi_resized.shape) == 3 and roi_resized.shape[2] == 3:
            rgb = roi_resized
        else:
            rgb = cv2.cvtColor(roi_resized, cv2.COLOR_GRAY2RGB)

        # Estimate depth
        depth = self.depth_estimator.estimate_depth(rgb)

        # Normalize inputs
        rgb_normalized = rgb.astype(np.float32) / 255.0
        depth_normalized = depth.astype(np.float32)  # Already [0, 1]
        mask_normalized = mask_resized.astype(np.float32) / 255.0

        # Stack channels: RGB + Depth + Mask
        five_channel = np.zeros((5, 224, 224), dtype=np.float32)
        five_channel[0:3] = np.transpose(rgb_normalized, (2, 0, 1))  # RGB channels
        five_channel[3] = depth_normalized  # Depth channel
        five_channel[4] = mask_normalized  # Mask channel

        return torch.tensor(five_channel).unsqueeze(0).to(self.device)

    def predict_frame(self, frame, frame_idx, tool_mask=None, tissue_mask=None):
        """Predict TTI for a single frame"""
        if self.model is None:
            return None

        # Extract ROI
        roi, mask_roi = self.extract_roi_from_frame(frame, tool_mask, tissue_mask)

        # Prepare 5-channel input
        input_tensor = self.prepare_5channel_input(roi, mask_roi)

        # Model inference
        with torch.no_grad():
            logits = self.model(input_tensor)
            probabilities = torch.softmax(logits, dim=1)
            prediction = torch.argmax(logits, dim=1).item()
            confidence = probabilities[0, prediction].item()

        result = {
            "frame_idx": frame_idx,
            "prediction": prediction,  # 0: No-TTI, 1: TTI
            "confidence": confidence,
            "probabilities": {
                "No_TTI": probabilities[0, 0].item(),
                "TTI": probabilities[0, 1].item(),
            },
        }

        return result

    def process_video(self, video_path, sample_rate=30):
        """Process entire video and return predictions"""
        print(f"\nProcessing video: {video_path}")

        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            print(f"Error opening video: {video_path}")
            return None

        # Get video properties
        fps = int(cap.get(cv2.CAP_PROP_FPS))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        duration = total_frames / fps

        print(
            f"Video properties: {fps} FPS, {total_frames} frames, {duration:.2f}s duration"
        )

        frame_predictions = []
        frame_idx = 0

        # Process frames at specified sample rate
        frame_interval = max(1, fps // sample_rate) if fps > sample_rate else 1

        while True:
            ret, frame = cap.read()
            if not ret:
                break

            if frame_idx % frame_interval == 0:
                # Convert BGR to RGB
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

                # Predict
                prediction = self.predict_frame(frame_rgb, frame_idx)
                if prediction:
                    prediction["timestamp"] = frame_idx / fps
                    frame_predictions.append(prediction)

                    if len(frame_predictions) % 100 == 0:
                        print(f"Processed {len(frame_predictions)} frames...")

            frame_idx += 1

        cap.release()

        # Compile video results
        video_result = self.compile_video_results(
            video_path, frame_predictions, fps, total_frames, duration
        )

        return video_result

    def compile_video_results(
        self, video_path, frame_predictions, fps, total_frames, duration
    ):
        """Compile results for a single video"""
        if not frame_predictions:
            return None

        # Convert to DataFrame for easier analysis
        df = pd.DataFrame(frame_predictions)

        # Calculate statistics
        tti_predictions = df[df["prediction"] == 1]
        no_tti_predictions = df[df["prediction"] == 0]

        total_predictions = len(df)
        tti_count = len(tti_predictions)
        no_tti_count = len(no_tti_predictions)

        tti_percentage = (tti_count / total_predictions) * 100
        avg_tti_confidence = (
            tti_predictions["confidence"].mean() if tti_count > 0 else 0
        )
        avg_no_tti_confidence = (
            no_tti_predictions["confidence"].mean() if no_tti_count > 0 else 0
        )

        # Time-based analysis
        tti_duration = tti_count * (duration / total_predictions)

        video_result = {
            "video_path": video_path,
            "video_name": Path(video_path).name,
            "video_properties": {
                "fps": fps,
                "total_frames": total_frames,
                "duration_seconds": duration,
                "frames_processed": total_predictions,
            },
            "predictions": {
                "total_predictions": total_predictions,
                "tti_count": tti_count,
                "no_tti_count": no_tti_count,
                "tti_percentage": tti_percentage,
                "no_tti_percentage": 100 - tti_percentage,
            },
            "confidence_stats": {
                "avg_tti_confidence": avg_tti_confidence,
                "avg_no_tti_confidence": avg_no_tti_confidence,
                "overall_avg_confidence": df["confidence"].mean(),
            },
            "temporal_analysis": {
                "tti_duration_seconds": tti_duration,
                "tti_duration_percentage": (tti_duration / duration) * 100,
            },
            "frame_predictions": frame_predictions,
        }

        return video_result

    def process_multiple_videos(self, video_folder, output_folder="results"):
        """Process multiple videos and generate comprehensive report"""
        video_folder = Path(video_folder)
        output_folder = Path(output_folder)
        output_folder.mkdir(exist_ok=True)

        # Find video files
        video_extensions = [".mp4", ".avi", ".mov", ".mkv"]
        video_files = []
        for ext in video_extensions:
            video_files.extend(video_folder.glob(f"*{ext}"))

        print(f"Found {len(video_files)} video files")

        all_results = []

        # Process each video
        for video_file in video_files:
            try:
                result = self.process_video(str(video_file))
                if result:
                    all_results.append(result)
                    print(f"✓ Processed: {video_file.name}")
            except Exception as e:
                print(f"✗ Error processing {video_file.name}: {e}")

        # Generate comprehensive report
        if all_results:
            self.generate_comprehensive_report(all_results, output_folder)

        return all_results

    def generate_comprehensive_report(self, all_results, output_folder):
        """Generate detailed analysis report"""
        output_folder = Path(output_folder)

        # Create summary DataFrame
        summary_data = []
        for result in all_results:
            summary_data.append(
                {
                    "Video Name": result["video_name"],
                    "Duration (s)": result["video_properties"]["duration_seconds"],
                    "Total Frames Processed": result["predictions"][
                        "total_predictions"
                    ],
                    "TTI Count": result["predictions"]["tti_count"],
                    "TTI Percentage": result["predictions"]["tti_percentage"],
                    "Avg TTI Confidence": result["confidence_stats"][
                        "avg_tti_confidence"
                    ],
                    "TTI Duration (s)": result["temporal_analysis"][
                        "tti_duration_seconds"
                    ],
                    "TTI Time Percentage": result["temporal_analysis"][
                        "tti_duration_percentage"
                    ],
                }
            )

        summary_df = pd.DataFrame(summary_data)

        # Save summary to CSV
        summary_df.to_csv(output_folder / "video_analysis_summary.csv", index=False)

        # Generate detailed report
        report_path = output_folder / "detailed_analysis_report.txt"

        with open(report_path, "w") as f:
            f.write("TTI MODEL INFERENCE RESULTS - BILE DUCT INJURY VIDEOS\n")
            f.write("=" * 60 + "\n\n")
            f.write(
                f"Report Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n"
            )
            f.write(f"Total Videos Processed: {len(all_results)}\n\n")

            # Overall Statistics
            f.write("OVERALL STATISTICS\n")
            f.write("-" * 20 + "\n")
            f.write(
                f"Total Processing Time: {summary_df['Duration (s)'].sum():.2f} seconds\n"
            )
            f.write(
                f"Total Frames Processed: {summary_df['Total Frames Processed'].sum():,}\n"
            )
            f.write(
                f"Average TTI Percentage: {summary_df['TTI Percentage'].mean():.2f}%\n"
            )
            f.write(
                f"TTI Percentage Range: {summary_df['TTI Percentage'].min():.2f}% - {summary_df['TTI Percentage'].max():.2f}%\n"
            )
            f.write(
                f"Average TTI Confidence: {summary_df['Avg TTI Confidence'].mean():.3f}\n\n"
            )

            # Per-video detailed results
            f.write("PER-VIDEO DETAILED RESULTS\n")
            f.write("-" * 30 + "\n")

            for i, result in enumerate(all_results, 1):
                f.write(f"\n{i}. {result['video_name']}\n")
                f.write(
                    f"   Duration: {result['video_properties']['duration_seconds']:.2f}s\n"
                )
                f.write(
                    f"   Frames Processed: {result['predictions']['total_predictions']:,}\n"
                )
                f.write(
                    f"   TTI Interactions: {result['predictions']['tti_count']} ({result['predictions']['tti_percentage']:.2f}%)\n"
                )
                f.write(
                    f"   No-TTI: {result['predictions']['no_tti_count']} ({result['predictions']['no_tti_percentage']:.2f}%)\n"
                )
                f.write(
                    f"   Average TTI Confidence: {result['confidence_stats']['avg_tti_confidence']:.3f}\n"
                )
                f.write(
                    f"   TTI Duration: {result['temporal_analysis']['tti_duration_seconds']:.2f}s ({result['temporal_analysis']['tti_duration_percentage']:.2f}%)\n"
                )

        # Generate visualizations
        self.generate_visualizations(all_results, summary_df, output_folder)

        print(f"\n✓ Comprehensive report saved to: {output_folder}")
        print(f"✓ Summary CSV: {output_folder / 'video_analysis_summary.csv'}")
        print(f"✓ Detailed report: {report_path}")

    def generate_visualizations(self, all_results, summary_df, output_folder):
        """Generate visualization plots"""
        plt.style.use("default")

        # 1. TTI Percentage Distribution
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))

        # TTI Percentage by Video
        axes[0, 0].bar(range(len(summary_df)), summary_df["TTI Percentage"])
        axes[0, 0].set_title("TTI Percentage by Video")
        axes[0, 0].set_xlabel("Video Index")
        axes[0, 0].set_ylabel("TTI Percentage (%)")
        axes[0, 0].tick_params(axis="x", rotation=45)

        # TTI Confidence Distribution
        axes[0, 1].hist(summary_df["Avg TTI Confidence"].dropna(), bins=15, alpha=0.7)
        axes[0, 1].set_title("TTI Confidence Distribution")
        axes[0, 1].set_xlabel("Average TTI Confidence")
        axes[0, 1].set_ylabel("Frequency")

        # TTI Duration vs Video Duration
        axes[1, 0].scatter(summary_df["Duration (s)"], summary_df["TTI Duration (s)"])
        axes[1, 0].set_title("TTI Duration vs Total Video Duration")
        axes[1, 0].set_xlabel("Video Duration (s)")
        axes[1, 0].set_ylabel("TTI Duration (s)")

        # TTI Time Percentage
        axes[1, 1].bar(range(len(summary_df)), summary_df["TTI Time Percentage"])
        axes[1, 1].set_title("TTI Time Percentage by Video")
        axes[1, 1].set_xlabel("Video Index")
        axes[1, 1].set_ylabel("TTI Time Percentage (%)")

        plt.tight_layout()
        plt.savefig(
            output_folder / "analysis_overview.png", dpi=300, bbox_inches="tight"
        )
        plt.close()

        # 2. Temporal Analysis for each video
        for result in all_results[:3]:  # Show first 3 videos to avoid overcrowding
            self.plot_temporal_analysis(result, output_folder)

    def plot_temporal_analysis(self, video_result, output_folder):
        """Plot temporal analysis for individual video"""
        df = pd.DataFrame(video_result["frame_predictions"])

        plt.figure(figsize=(15, 8))

        # Plot predictions over time
        plt.subplot(2, 1, 1)
        colors = ["blue" if p == 0 else "red" for p in df["prediction"]]
        plt.scatter(df["timestamp"], df["prediction"], c=colors, alpha=0.6, s=10)
        plt.title(f'TTI Predictions Over Time - {video_result["video_name"]}')
        plt.ylabel("Prediction (0: No-TTI, 1: TTI)")
        plt.ylim(-0.1, 1.1)

        # Plot confidence over time
        plt.subplot(2, 1, 2)
        plt.plot(df["timestamp"], df["confidence"], alpha=0.7, linewidth=0.8)
        plt.title("Prediction Confidence Over Time")
        plt.xlabel("Time (seconds)")
        plt.ylabel("Confidence")
        plt.ylim(0, 1)

        plt.tight_layout()
        video_name_clean = video_result["video_name"].replace(".", "_")
        plt.savefig(
            output_folder / f"temporal_analysis_{video_name_clean}.png",
            dpi=300,
            bbox_inches="tight",
        )
        plt.close()

In [None]:
model_candidates = [
    "./models/ViT/best_model.pt",
    # "./EffNet_B3/best_model.pt",
    # "./EffNet_B0/best_model.pt",
    # "./EffNet_B1/best_model.pt",
]

video_folder = "./bdi_videos"
output_folder = "./bdi_vit_results"

print("TTI Model Inference on Bile Duct Injury Videos")
print("=" * 50)

# Try loading models in order of preference
inference_pipeline = None
for model_path in model_candidates:
    if os.path.exists(model_path):
        print(f"\nAttempting to load: {model_path}")
        try:
            inference_pipeline = TTI_VideoInference(model_path)
            if inference_pipeline.model is not None:
                print(f"✓ Successfully loaded model: {model_path}")
                break
        except Exception as e:
            print(f"✗ Failed to load {model_path}: {e}")
            continue
    else:
        print(f"✗ Model not found: {model_path}")

if inference_pipeline is None or inference_pipeline.model is None:
    print("❌ Could not load any model. Please check model paths and files.")
    exit()

# Process all videos
results = inference_pipeline.process_multiple_videos(video_folder, output_folder)

if results:
    print(f"\n✅ Successfully processed {len(results)} videos")
    print(f"📊 Results saved to: {output_folder}/")

    # Print quick summary
    total_tti = sum(r["predictions"]["tti_count"] for r in results)
    total_frames = sum(r["predictions"]["total_predictions"] for r in results)
    avg_tti_percentage = (total_tti / total_frames) * 100 if total_frames > 0 else 0

    print(f"\n📈 QUICK SUMMARY:")
    print(f"   Total frames analyzed: {total_frames:,}")
    print(f"   Total TTI detections: {total_tti:,}")
    print(f"   Overall TTI percentage: {avg_tti_percentage:.2f}%")

else:
    print("❌ No videos were successfully processed")

TTI Model Inference on Bile Duct Injury Videos

Attempting to load: ./models/ViT/best_model.pt


Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Direct loading failed, trying flexible loading...
Skipping layer first.weight (shape mismatch or not found)
Skipping layer pre_conv.weight (shape mismatch or not found)
✓ Loaded model with 202/204 layers matched
✓ Successfully loaded model: ./models/ViT/best_model.pt
Found 1 video files

Processing video: bdi_videos/V15_Trimmed.mp4
Video properties: 25 FPS, 3547 frames, 141.88s duration
Processed 100 frames...
Processed 200 frames...
Processed 300 frames...
Processed 400 frames...
Processed 500 frames...
Processed 600 frames...
Processed 700 frames...
Processed 800 frames...
Processed 900 frames...
Processed 1000 frames...
Processed 1100 frames...
Processed 1200 frames...
Processed 1300 frames...
Processed 1400 frames...
Processed 1500 frames...
Processed 1600 frames...
Processed 1700 frames...
Processed 1800 frames...
Processed 1900 frames...
Processed 2000 frames...
Processed 2100 frames...
Processed 2200 frames...
Processed 2300 frames...
Processed 2400 frames...
Processed 2500 fram