In [None]:
# Install required packages
# !pip install torch torchvision scipy numpy pillow matplotlib opencv-python
# !pip install davis2017-evaluation
# !pip install --upgrade torch torchvision torchaudio
# !pip install torch torchvision

# # Download and extract DAVIS 2017 dataset
# !wget -O DAVIS-2017-trainval-480p.zip https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip
# !unzip -q DAVIS-2017-trainval-480p.zip -d /content/DAVIS

import os
import cv2
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.ops import MultiScaleRoIAlign
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import functional as F
from IPython.display import Image as IPImage, display
from google.colab import files
import shutil
import matplotlib.pyplot as plt
import random
from sklearn.metrics import accuracy_score, f1_score # Importing accuracy_score and f1_score


# Define the DAVIS Dataset class
class DAVISDataset(Dataset):
    def __init__(self, davis_root, subset='val', transform=None):
        self.davis_root = davis_root
        self.subset = subset
        self.transform = transform
        self.img_dir = os.path.join(davis_root, 'JPEGImages', '480p')
        self.anno_dir = os.path.join(davis_root, 'Annotations', '480p')
        with open(os.path.join(davis_root, 'ImageSets', '2017', f'{subset}.txt'), 'r') as f:
            self.sequences = f.read().splitlines()
        self.class_map = {
            'bear': 1, 'bmx-bumps': 2, 'boat': 3, 'boxing-fisheye': 4, 'breakdance-flare': 5,
            'bus': 6, 'car-turn': 7, 'cat-girl': 8, 'classic-car': 9, 'color-run': 10,
            'crossing': 11, 'dance-jump': 12, 'dancing': 13, 'disc-jockey': 14, 'dog-agility': 15,
            'dog-gooses': 16, 'dogs-scale': 17, 'drift-turn': 18, 'drone': 19, 'elephant': 20,
            'flamingo': 21, 'hike': 22, 'hockey': 23, 'horsejump-low': 24, 'kid-football': 25,
            'kite-walk': 26, 'koala': 27, 'lady-running': 28, 'lindy-hop': 29, 'longboard': 30,
            'lucia': 31, 'mallard-fly': 32, 'mallard-water': 33, 'miami-surf': 34, 'motocross-bumps': 35,
            'motorbike': 36, 'night-race': 37, 'paragliding': 38, 'planes-water': 39, 'rallye': 40,
            'rhino': 41, 'rollerblade': 42, 'schoolgirls': 43, 'scooter-board': 44, 'scooter-gray': 45,
            'sheep': 46, 'skate-park': 47, 'snowboard': 48, 'soccerball': 49, 'stroller': 50,
            'stunt': 51, 'surf': 52, 'swing': 53, 'tennis': 54, 'tractor-sand': 55,
            'train': 56, 'tuk-tuk': 57, 'upside-down': 58, 'varanus-cage': 59, 'walking': 60,
            'bike-packing': 61, 'blackswan': 62, 'bmx-trees': 63, 'breakdance': 64, 'camel': 65,
            'car-roundabout': 66, 'car-shadow': 67, 'cows': 68, 'dance-twirl': 69, 'dog': 70,
            'dogs-jump': 71, 'drift-chicane': 72, 'drift-straight': 73, 'goat': 74, 'gold-fish': 75,
            'horsejump-high': 76, 'india': 77, 'judo': 78, 'kite-surf': 79, 'lab-coat': 80,
            'libby': 81, 'loading': 82, 'mbike-trick': 83, 'motocross-jump': 84, 'paragliding-launch': 85,
            'parkour': 86, 'pigs': 87, 'scooter-black': 88, 'shooting': 89, 'soapbox': 90
        }
        self.class_names = {v: k for k, v in self.class_map.items()}
        self.samples = []
        for seq in self.sequences:
            img_seq_dir = os.path.join(self.img_dir, seq)
            anno_seq_dir = os.path.join(self.anno_dir, seq)
            img_files = sorted(os.listdir(img_seq_dir))
            for img_file in img_files:
                if img_file.endswith('.jpg'):
                    frame_num = img_file.split('.')[0]
                    anno_file = f"{frame_num}.png"
                    if os.path.exists(os.path.join(anno_seq_dir, anno_file)):
                        self.samples.append((
                            os.path.join(img_seq_dir, img_file),
                            os.path.join(anno_seq_dir, anno_file),
                            seq
                        ))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, mask_path, seq_name = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')
        image_np = np.array(image)
        mask_np = np.array(mask)
        obj_ids = np.unique(mask_np)[1:] if np.unique(mask_np).size > 1 else []
        if len(obj_ids) == 0:
            masks = []
        else:
            masks = mask_np == obj_ids[:, None, None]
        boxes = []
        labels = []
        class_label = self.class_map.get(seq_name, 1)
        for i, mask in enumerate(masks):
            pos = np.where(mask)
            if len(pos[0]) > 0:
                xmin = np.min(pos[1])
                xmax = np.max(pos[1])
                ymin = np.min(pos[0])
                ymax = np.max(pos[0])
                if xmax > xmin and ymax > ymin:
                    boxes.append([xmin, ymin, xmax, ymax])
                    labels.append(class_label)
        if len(boxes) == 0:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros(0, dtype=torch.int64)
        else:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
        target = {
            'boxes': boxes,
            'labels': labels,
            'image_id': image_id,
            'area': area,
            'iscrowd': iscrowd
        }
        if self.transform:
            image, target = self.transform(image, target)
        return image, target

# Define transformation classes
class Compose:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class ToTensor:
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target

class Normalize:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, image, target):
        image = F.normalize(image, mean=self.mean, std=self.std)
        return image, target

class RandomHorizontalFlip:
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            height, width = image.shape[-2:]
            image = F.hflip(image)
            bbox = target["boxes"]
            bbox[:, [0, 2]] = width - bbox[:, [2, 0]]
            target["boxes"] = bbox
        return image, target

# Define the Faster R-CNN model
def create_fasterrcnn_model(num_classes, pretrained_backbone=True):
    import torchvision
    backbone = torchvision.models.resnet50(weights='DEFAULT' if pretrained_backbone else None)
    backbone = nn.Sequential(*list(backbone.children())[:-2])
    backbone.out_channels = 2048
    anchor_generator = AnchorGenerator(
        sizes=((32, 64, 128, 256, 512),),
        aspect_ratios=((0.5, 1.0, 2.0),)
    )
    roi_pooler = MultiScaleRoIAlign(
        featmap_names=['0'],
        output_size=7,
        sampling_ratio=2
    )
    model = FasterRCNN(
        backbone=backbone,
        num_classes=num_classes,
        rpn_anchor_generator=anchor_generator,
        box_roi_pool=roi_pooler,
        min_size=800,
        max_size=1333
    )
    return model

# Enhanced training function with F1, accuracy, and loss tracking
def train_model(model, data_loader, optimizer, device, num_epochs=4):
    model.to(device)
    model.train()
    losses_history = []
    accuracy_history = []
    f1_history = []
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        running_loss = 0.0
        all_preds = []
        all_targets = []
        for i, (images, targets) in enumerate(data_loader):
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            optimizer.zero_grad()
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            losses.backward()
            optimizer.step()
            running_loss += losses.item()

            # Collect predictions and targets for metrics
            model.eval()
            with torch.no_grad():
                outputs = model(images)
            model.train()

            for output, target in zip(outputs, targets):
                # Get predictions and target labels for the current image
                pred_labels = output['labels'].cpu().numpy()
                true_labels = target['labels'].cpu().numpy()

                # Ensure predictions and targets have the same length for this image
                min_len = min(len(pred_labels), len(true_labels))
                pred_labels = pred_labels[:min_len]
                true_labels = true_labels[:min_len]

                all_preds.extend(pred_labels)
                all_targets.extend(true_labels)

            if (i + 1) % 10 == 0:
                print(f"Batch {i+1}/{len(data_loader)}, Loss: {running_loss/10:.4f}")
                running_loss = 0.0

        epoch_loss = running_loss / len(data_loader)
        losses_history.append(epoch_loss)

        acc = accuracy_score(all_targets, all_preds)
        f1 = f1_score(all_targets, all_preds, average='macro', zero_division=0)
        accuracy_history.append(acc)
        f1_history.append(f1)
        print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f} Accuracy: {acc:.4f} F1 Score: {f1:.4f}")

    # Plotting metrics
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 3, 1)
    plt.plot(range(1, num_epochs+1), losses_history, label='Loss')
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)

    plt.subplot(1, 3, 2)
    plt.plot(range(1, num_epochs+1), accuracy_history, label='Accuracy', color='green')
    plt.title('Training Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.grid(True)

    plt.subplot(1, 3, 3)
    plt.plot(range(1, num_epochs+1), f1_history, label='F1 Score', color='red')
    plt.title('Training F1 Score')
    plt.xlabel('Epoch')
    plt.ylabel('F1 Score')
    plt.grid(True)

    plt.tight_layout()
    plt.savefig('/content/training_metrics.png')
    plt.close()
    display(IPImage('/content/training_metrics.png'))
    return model

# Function to extract frames
def extract_frames(video_path, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print(f"Error: Could not open video {video_path}")
        return None
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_paths = []
    frame_count = 0
    try:
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            frame_count += 1
            frame_path = os.path.join(output_dir, f"frame_{frame_count:06d}.jpg")
            cv2.imwrite(frame_path, frame)
            frame_paths.append(frame_path)
            print(f"Extracting frame {frame_count}/{total_frames}", end='\r')
    except Exception as e:
        print(f"Error during frame extraction: {str(e)}")
        cap.release()
        return None
    cap.release()
    print(f"\nExtracted {frame_count} frames to {output_dir}")
    return frame_paths, width, height, fps

# Function to visualize detections
def visualize_detections(image, prediction, class_names, threshold=0.3):
    image_np = image.permute(1, 2, 0).cpu().numpy()
    image_np = (image_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])) * 255
    image_np = image_np.astype(np.uint8).copy()
    for box, label, score in zip(prediction['boxes'].cpu().numpy(),
                                prediction['labels'].cpu().numpy(),
                                prediction['scores'].cpu().numpy()):
        if score >= threshold:
            x1, y1, x2, y2 = box.astype(int)
            class_name = class_names.get(label, 'Unknown')
            cv2.rectangle(image_np, (x1, y1), (x2, y2), (255, 0, 0), 2)
            cv2.putText(image_np, f'{class_name}: {score:.2f}', (x1, y1 - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
    return image_np

# Function to process video frames
def process_video_frames(model, video_path, class_names, device, frames_dir='/content/frames', output_frames_dir='/content/output_frames', threshold=0.3):
    if os.path.exists(frames_dir):
        shutil.rmtree(frames_dir)
    if os.path.exists(output_frames_dir):
        shutil.rmtree(output_frames_dir)
    os.makedirs(frames_dir)
    os.makedirs(output_frames_dir)
    frame_paths, width, height, fps = extract_frames(video_path, frames_dir)
    if frame_paths is None:
        print("Error: Frame extraction failed.")
        return None
    total_frames = len(frame_paths)
    if total_frames == 0:
        print("Error: No frames extracted from video.")
        return None
    transform = Compose([
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    model.eval()
    sample_frames = []
    frame_count = 0
    try:
        with torch.no_grad():
            for frame_path in frame_paths:
                frame_count += 1
                print(f"Processing frame {frame_count}/{total_frames}", end='\r')
                frame_rgb = cv2.cvtColor(cv2.imread(frame_path), cv2.COLOR_BGR2RGB)
                image = Image.fromarray(frame_rgb)
                image_tensor, _ = transform(image, {})
                image_tensor = image_tensor.unsqueeze(0).to(device)
                outputs = model(image_tensor)[0]
                annotated_frame = visualize_detections(image_tensor[0], outputs, class_names, threshold)
                output_frame_path = os.path.join(output_frames_dir, f"annotated_frame_{frame_count:06d}.jpg")
                cv2.imwrite(output_frame_path, cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR))
                if frame_count % (total_frames // 5 + 1) == 0 and len(sample_frames) < 5:
                    sample_frames.append(annotated_frame)
                del image_tensor, outputs
                torch.cuda.empty_cache() if device.type == 'cuda' else None
    except Exception as e:
        print(f"Error during frame processing: {str(e)}")
        return None
    print(f"\nProcessed {frame_count} frames. Annotated frames saved to {output_frames_dir}")
    for i, frame in enumerate(sample_frames):
        sample_path = f'/content/sample_frame_{i}.png'
        cv2.imwrite(sample_path, cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
        display(IPImage(sample_path))
    shutil.rmtree(frames_dir)
    print("Cleaned up temporary frames directory")
    return output_frames_dir

# Main function
def main():
    torch.manual_seed(42)
    random.seed(42)
    np.random.seed(42)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    davis_root = '/content/DAVIS/DAVIS'
    transform = Compose([
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        RandomHorizontalFlip(prob=0.5)
    ])

    # Optional: Train the model (uncomment to train)
    train_dataset = DAVISDataset(davis_root, subset='train', transform=transform)
    train_loader = DataLoader(
        train_dataset,
        batch_size=2,
        shuffle=True,
        num_workers=2,
        collate_fn=lambda x: tuple(zip(*x))
    )
    num_classes = 91
    model = create_fasterrcnn_model(num_classes=num_classes)
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
    model = train_model(model, train_loader, optimizer, device, num_epochs=5)
    torch.save(model.state_dict(), '/content/fasterrcnn_resnet50_davis.pth')
    print("Training completed!")

    # Load dataset for class names
    dataset = DAVISDataset(davis_root, subset='val', transform=transform)
    class_names = dataset.class_names

    # Load the model
    num_classes = 91
    model = create_fasterrcnn_model(num_classes=num_classes)
    try:
        model.load_state_dict(torch.load('/content/fasterrcnn_resnet50_davis.pth', map_location=device))
    except FileNotFoundError:
        print("Error: Model weights file 'fasterrcnn_resnet50_davis.pth' not found.")
        print("Please upload the trained model or uncomment the training section to train the model.")
        print("To proceed without training, upload 'fasterrcnn_resnet50_davis.pth' now:")
        uploaded = files.upload()
        if 'fasterrcnn_resnet50_davis.pth' in uploaded:
            model.load_state_dict(torch.load('/content/fasterrcnn_resnet50_davis.pth', map_location=device))
        else:
            print("Error: Model weights not uploaded. Cannot proceed.")
            return
    model.to(device)

    # Upload video
    print("Please upload your video file (e.g., MP4 format):")
    uploaded = files.upload()
    if not uploaded:
        print("Error: No video file uploaded.")
        return
    video_path = list(uploaded.keys())[0]

    # Process video frames
    output_frames_dir = process_video_frames(
        model, video_path, class_names, device,
        frames_dir='/content/frames',
        output_frames_dir='/content/output_frames',
        threshold=0.3
    )

    if output_frames_dir:
        print(f"Processed frames saved to {output_frames_dir}")
        print("To download all frames, run the following cell to zip the output_frames directory:")
        print("Then download the zip file from the Colab file explorer.")
        # Zip the output frames for easy download
        !zip -r /content/output_frames.zip /content/output_frames
    else:
        print("Frame processing failed. Please check error messages.")

if __name__ == "__main__":
    main()

Using device: cuda
Epoch 1/5
Batch 10/2105, Loss: 3.2485
Batch 20/2105, Loss: 1.4067
Batch 30/2105, Loss: 0.8866
Batch 40/2105, Loss: 0.5841
Batch 50/2105, Loss: 0.5987
Batch 60/2105, Loss: 0.7248
Batch 70/2105, Loss: 0.5869
Batch 80/2105, Loss: 0.8629
Batch 90/2105, Loss: 0.7172
Batch 100/2105, Loss: 0.7735
Batch 110/2105, Loss: 0.6556
Batch 120/2105, Loss: 0.7140
Batch 130/2105, Loss: 0.5816
Batch 140/2105, Loss: 0.6279
Batch 150/2105, Loss: 0.7089
Batch 160/2105, Loss: 0.4773
Batch 170/2105, Loss: 0.8093
Batch 180/2105, Loss: 0.8451
Batch 190/2105, Loss: 0.8001
Batch 200/2105, Loss: 0.6046
Batch 210/2105, Loss: 0.6834
Batch 220/2105, Loss: 0.8027
Batch 230/2105, Loss: 0.4913
Batch 240/2105, Loss: 0.6495
Batch 250/2105, Loss: 0.7625
Batch 260/2105, Loss: 0.5276
Batch 270/2105, Loss: 0.6988
Batch 280/2105, Loss: 0.7230
Batch 290/2105, Loss: 0.6532
Batch 300/2105, Loss: 0.5596
Batch 310/2105, Loss: 0.7542
Batch 320/2105, Loss: 0.6159
Batch 330/2105, Loss: 0.4916
Batch 340/2105, Loss: 0