<a href="https://colab.research.google.com/github/kapoor1309/BH-25/blob/main/test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import json
import torch
from PIL import Image
import numpy as np
import cv2
from torchvision import models, transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import torch.nn as nn


In [None]:
class MultiTaskModel(nn.Module):
    def __init__(self, num_triplets, num_tools, num_verbs, num_targets):
        super(MultiTaskModel, self).__init__()
        self.backbone = models.resnet50(pretrained=True)
        self.backbone.fc = nn.Identity()
        backbone_output_dim = 2048

        self.triplet_head = nn.Linear(backbone_output_dim, num_triplets)
        self.tool_head = nn.Linear(backbone_output_dim, num_tools)
        self.verb_head = nn.Linear(backbone_output_dim, num_verbs)
        self.target_head = nn.Linear(backbone_output_dim, num_targets)

    def forward(self, x):
        features = self.backbone(x)
        triplet_preds = self.triplet_head(features)
        tool_preds = self.tool_head(features)
        verb_preds = self.verb_head(features)
        target_preds = self.target_head(features)
        return triplet_preds, tool_preds, verb_preds, target_preds

def process_single_image(image_path, model_1, model_2, transform, cam_extractor, device, frame_id):
    """Process a single image and return formatted predictions"""
    print(f"\nProcessing frame: {frame_id}")

    image = Image.open(image_path).convert("RGB")
    original_image = np.array(image)

    # Get frame-level predictions (recognition)
    input_tensor = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        triplet_preds, _, _, _ = model_2(input_tensor)
    recognition_probs = torch.sigmoid(triplet_preds).cpu().numpy()[0].tolist()

    # Generate CAM for detection
    grayscale_cam = cam_extractor(input_tensor=input_tensor, targets=[ClassifierOutputTarget(0)])
    cam = grayscale_cam[0]
    cam_resized = cv2.resize(cam, (original_image.shape[1], original_image.shape[0]))

    # Generate bounding boxes
    threshold = 0.5
    binary_mask = (cam_resized >= threshold).astype(np.uint8) * 255
    contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Process each detected region
    detections = []
    for contour in contours:
        x, y, w, h = cv2.boundingRect(contour)

        # Process cropped region
        cropped_image = original_image[y:y+h, x:x+w]
        if cropped_image.size == 0:
            continue

        cropped_pil = Image.fromarray(cropped_image)
        transformed_image = transform(cropped_pil).unsqueeze(0).to(device)

        with torch.no_grad():
            triplet_preds, tool_preds, _, _ = model_2(transformed_image)

        # Get highest probability triplet and tool
        triplet_probs = torch.sigmoid(triplet_preds).cpu().numpy()[0]
        tool_probs = torch.sigmoid(tool_preds).cpu().numpy()[0]

        triplet_id = int(np.argmax(triplet_probs))
        tool_id = int(np.argmax(tool_probs))
        tool_prob = float(tool_probs[tool_id])

        detection = {
            "triplet": triplet_id,
            "instrument": [tool_id, tool_prob, x, y, w, h]
        }
        detections.append(detection)

    frame_results = {
        "recognition": recognition_probs,
        "detection": detections
    }

    # Print progress
    print(f"Completed frame {frame_id} - Found {len(detections)} detections")

    return frame_results

def process_video_folder(video_folder, model_1, model_2, transform, cam_extractor, device):
    """Process all frames in a video folder"""
    video_folder_name = os.path.basename(video_folder)
    print(f"\nProcessing {video_folder_name}")
    video_predictions = {}

    image_files = [f for f in os.listdir(video_folder) if f.endswith(('.png', '.jpg', '.jpeg'))]
    image_files.sort()
    total_frames = len(image_files)

    print(f"Found {total_frames} frames in {video_folder_name}")

    for idx, image_file in enumerate(image_files, 1):
        frame_id = os.path.splitext(image_file)[0]
        image_path = os.path.join(video_folder, image_file)

        try:
            frame_results = process_single_image(
                image_path, model_1, model_2, transform, cam_extractor, device, frame_id
            )
            video_predictions[frame_id] = frame_results
            print(f"Progress: {idx}/{total_frames} frames processed in {video_folder_name}")
        except Exception as e:
            print(f"Error processing {image_file}: {str(e)}")
            continue

    return video_predictions

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    base_dir = "/content/CholecT50/videos"
    video_ids = ["VID92", "VID96", "VID103", "VID110", "VID111"]

    transform = transforms.Compose([
        transforms.Resize((256, 448)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    print("Loading models...")
    # Task 1 model
    model_1 = models.resnet50(pretrained=True)
    model_1.fc = torch.nn.Linear(model_1.fc.in_features, 6)
    model_1.load_state_dict(torch.load("/content/trained_model_fold_.pth"))
    model_1 = model_1.to(device)
    model_1.eval()

    # Task 2 model
    model_2 = MultiTaskModel(num_triplets=100, num_tools=6, num_verbs=10, num_targets=15)
    model_2.load_state_dict(torch.load("/content/multi_task_model.pth", map_location=device))
    model_2 = model_2.to(device)
    model_2.eval()

    # Initialize GradCAM
    cam_extractor = GradCAM(model=model_1, target_layers=[model_1.layer4[-1]])

    # Process all video folders
    all_predictions = {}
    for video_id in video_ids:
        video_folder_path = os.path.join(base_dir, video_id)
        if not os.path.exists(video_folder_path):
            print(f"Warning: {video_id} not found, skipping...")
            continue

        video_predictions = process_video_folder(
            video_folder_path, model_1, model_2, transform, cam_extractor, device
        )
        all_predictions[video_id] = video_predictions

    # Save results
    output_file = os.path.join(base_dir, 'all_video_predictions.json')
    with open(output_file, 'w') as f:
        json.dump(all_predictions, f)

    print(f"\nProcessing complete! Results saved to: {output_file}")

if __name__ == "__main__":
    main()

Loading models...


  model_1.load_state_dict(torch.load("/content/trained_model_fold_.pth"))
  model_2.load_state_dict(torch.load("/content/multi_task_model.pth", map_location=device))


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Completed frame 000896 - Found 4 detections
Progress: 897/2146 frames processed in VID111

Processing frame: 000897
Completed frame 000897 - Found 4 detections
Progress: 898/2146 frames processed in VID111

Processing frame: 000898
Completed frame 000898 - Found 3 detections
Progress: 899/2146 frames processed in VID111

Processing frame: 000899
Completed frame 000899 - Found 5 detections
Progress: 900/2146 frames processed in VID111

Processing frame: 000900
Completed frame 000900 - Found 6 detections
Progress: 901/2146 frames processed in VID111

Processing frame: 000901
Completed frame 000901 - Found 4 detections
Progress: 902/2146 frames processed in VID111

Processing frame: 000902
Completed frame 000902 - Found 3 detections
Progress: 903/2146 frames processed in VID111

Processing frame: 000903
Completed frame 000903 - Found 3 detections
Progress: 904/2146 frames processed in VID111

Processing frame: 000904
Complet

In [None]:
import json

def convert_frame_ids(filepath):
    # Read the existing JSON file
    with open(filepath, 'r') as f:
        data = json.load(f)

    # Create new dictionary with converted frame IDs
    converted_data = {}
    for video_id, frames in data.items():
        converted_data[video_id] = {}
        for frame_id, content in frames.items():
            # Convert frame_id from string to integer, handle empty string
            new_frame_id = int(frame_id.lstrip('0')) if frame_id.lstrip('0') else 0
            converted_data[video_id][new_frame_id] = content

    # Save the converted JSON back to file
    with open(filepath, 'w') as f:
        json.dump(converted_data, f)

    print("Frame IDs have been converted to integers successfully!")

# Run the conversion
filepath = "/content/all_video_predictions.json"
convert_frame_ids(filepath)

Frame IDs have been converted to integers successfully!
