# PrDiMP50 Training and Evaluation on GOT-10k

This notebook implements PrDiMP50 (klcedimpnet50) with ResNet50 backbone, trains it on GOT-10k dataset, and evaluates the results.

## ⚠️ Important: NumPy Compatibility

**If you encounter NumPy 2.x compatibility errors**, the notebook will automatically fix this by downgrading to NumPy <2.0. You may need to restart the kernel/runtime after the first cell runs.


## 0. Setup and Installation


In [None]:
# CRITICAL: Fix NumPy 2.x compatibility issue first
# NumPy 2.x causes import errors with matplotlib and other packages
# This must run BEFORE any other imports
import subprocess
import sys

def fix_numpy_version():
    """Ensure NumPy <2.0 is installed for compatibility"""
    try:
        import numpy as np
        numpy_version = np.__version__
        print(f"Detected NumPy version: {numpy_version}")
        
        if numpy_version.startswith('2.'):
            print("\n⚠️  NumPy 2.x detected - this causes compatibility issues!")
            print("Downgrading to NumPy <2.0...")
            subprocess.run([
                sys.executable, "-m", "pip", "install", 
                "--upgrade", "--force-reinstall", "numpy<2.0", "--no-deps"
            ], check=True)
            print("✅ NumPy downgraded. Please RESTART the kernel/runtime now!")
            print("   (In Colab: Runtime > Restart runtime)")
            print("   (In Kaggle: Click 'Restart' button)")
            return True
        else:
            print(f"✅ NumPy version {numpy_version} is compatible")
            return False
    except ImportError:
        print("Installing NumPy <2.0...")
        subprocess.run([
            sys.executable, "-m", "pip", "install", "numpy<2.0"
        ], check=True)
        return False

needs_restart = fix_numpy_version()
if needs_restart:
    raise RuntimeError(
        "NumPy was downgraded. Please RESTART the kernel/runtime and run this cell again.\n"
        "In Colab: Runtime > Restart runtime\n"
        "In Kaggle: Click the 'Restart' button"
    )


In [None]:
import os
import sys
import subprocess
import shutil
from pathlib import Path

# Check if we're in Kaggle
KAGGLE = os.path.exists('/kaggle')

if KAGGLE:
    # Kaggle paths
    BASE_DIR = Path('/kaggle/working')
    DATA_DIR = Path('/kaggle/input')
else:
    # Local paths
    BASE_DIR = Path.cwd()
    DATA_DIR = BASE_DIR / 'data'

print(f"Working directory: {BASE_DIR}")
print(f"Data directory: {DATA_DIR}")


In [None]:
# Clone the repository
REPO_URL = "https://github.com/visionml/pytracking.git"
REPO_DIR = BASE_DIR / "pytracking"

if not REPO_DIR.exists():
    print("Cloning pytracking repository...")
    subprocess.run(["git", "clone", REPO_URL, str(REPO_DIR)], check=True)
    
    # Initialize submodules
    print("Initializing submodules...")
    os.chdir(REPO_DIR)
    subprocess.run(["git", "submodule", "update", "--init"], check=True)
    os.chdir(BASE_DIR)
else:
    print("Repository already exists, skipping clone...")

print(f"Repository cloned to: {REPO_DIR}")


In [None]:
# Add to Python path
if str(REPO_DIR) not in sys.path:
    sys.path.insert(0, str(REPO_DIR))

# Change to repository directory
os.chdir(REPO_DIR)
print(f"Current working directory: {os.getcwd()}")


In [None]:
# Install required packages (for Kaggle, most are pre-installed)
# IMPORTANT: Fix NumPy compatibility issue - ensure NumPy <2.0 for compatibility
import importlib

# First, check and fix NumPy version if needed
try:
    import numpy as np
    numpy_version = np.__version__
    print(f"Current NumPy version: {numpy_version}")
    
    # Check if NumPy 2.x is installed (causes compatibility issues with matplotlib)
    if numpy_version.startswith('2.'):
        print("NumPy 2.x detected - downgrading to NumPy <2.0 for compatibility...")
        subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade", "--force-reinstall", "numpy<2.0"], check=True)
        print("NumPy downgraded successfully. Please restart the kernel/runtime after this cell.")
        import importlib
        importlib.reload(importlib.import_module('numpy'))
except ImportError:
    print("NumPy not found, installing NumPy <2.0...")
    subprocess.run([sys.executable, "-m", "pip", "install", "numpy<2.0"], check=True)

required_packages = [
    'torch', 'torchvision', 'opencv-python', 'Pillow',
    'matplotlib', 'pandas', 'scipy', 'tqdm'
]

missing_packages = []
for pkg in required_packages:
    try:
        importlib.import_module(pkg.replace('-', '_'))
    except ImportError:
        missing_packages.append(pkg)

if missing_packages:
    print(f"Installing missing packages: {missing_packages}")
    subprocess.run([sys.executable, "-m", "pip", "install", "-q"] + missing_packages)
else:
    print("All required packages are available")

# Verify NumPy version one more time
try:
    import numpy as np
    print(f"Final NumPy version: {np.__version__}")
    if np.__version__.startswith('2.'):
        print("WARNING: NumPy 2.x still detected. You may need to restart the kernel.")
except:
    pass


## 1. Environment Configuration


In [None]:
# Create local.py for environment settings (both ltr and pytracking)
LTR_LOCAL_PY_PATH = REPO_DIR / "ltr" / "admin" / "local.py"
PYTRACKING_LOCAL_PY_PATH = REPO_DIR / "pytracking" / "evaluation" / "local.py"

# Set paths - user will provide GOT-10k path
# For Kaggle, you can set this as an environment variable
GOT10K_PATH = os.environ.get('GOT10K_PATH', '')
if not GOT10K_PATH:
    # Try default locations
    if KAGGLE:
        # Try common Kaggle input locations
        possible_paths = [
            DATA_DIR / "got10k" / "train",
            DATA_DIR / "got-10k" / "train",
            Path("/kaggle/input/got10k/train"),
        ]
        for path in possible_paths:
            if path.exists():
                GOT10K_PATH = str(path)
                break
        if not GOT10K_PATH:
            GOT10K_PATH = str(DATA_DIR / "got10k" / "train")
            print(f"Warning: GOT-10k path not found. Using default: {GOT10K_PATH}")
            print("Please set GOT10K_PATH environment variable or update the path in this cell")
    else:
        # Local default
        GOT10K_PATH = str(BASE_DIR / "data" / "got10k" / "train")
        print(f"Using default path: {GOT10K_PATH}")
        print("To change this, set GOT10K_PATH environment variable or modify this cell")

WORKSPACE_DIR = str(BASE_DIR / "workspace")
os.makedirs(WORKSPACE_DIR, exist_ok=True)

# LTR local.py
ltr_local_py_content = f"""
class EnvironmentSettings:
    def __init__(self):
        self.workspace_dir = '{WORKSPACE_DIR}'
        self.tensorboard_dir = self.workspace_dir + '/tensorboard/'
        self.pretrained_networks = self.workspace_dir + '/pretrained_networks/'
        self.pregenerated_masks = ''
        self.lasot_dir = ''
        self.got10k_dir = '{GOT10K_PATH}'
        self.trackingnet_dir = ''
        self.coco_dir = ''
        self.lvis_dir = ''
        self.sbd_dir = ''
        self.imagenet_dir = ''
        self.imagenetdet_dir = ''
        self.ecssd_dir = ''
        self.hkuis_dir = ''
        self.msra10k_dir = ''
        self.davis_dir = ''
        self.youtubevos_dir = ''
        self.lasot_candidate_matching_dataset_path = ''
"""

with open(LTR_LOCAL_PY_PATH, 'w') as f:
    f.write(ltr_local_py_content)

# PyTracking evaluation local.py
pytracking_local_py_content = f"""
from pytracking.evaluation.environment import EnvSettings

def local_env_settings():
    settings = EnvSettings()
    
    # Set paths
    settings.got10k_path = '{GOT10K_PATH}'
    settings.results_path = '{WORKSPACE_DIR}/tracking_results/'
    settings.segmentation_path = '{WORKSPACE_DIR}/segmentation_results/'
    settings.network_path = '{WORKSPACE_DIR}/networks/'
    settings.result_plot_path = '{WORKSPACE_DIR}/result_plots/'
    settings.dataspec_path = '{str(REPO_DIR / "ltr" / "data_specs")}'
    
    return settings
"""

os.makedirs(PYTRACKING_LOCAL_PY_PATH.parent, exist_ok=True)
with open(PYTRACKING_LOCAL_PY_PATH, 'w') as f:
    f.write(pytracking_local_py_content)

print(f"Environment configured:")
print(f"  GOT-10k path: {GOT10K_PATH}")
print(f"  Workspace: {WORKSPACE_DIR}")


## 2.1. Fix PreciseRoI Pooling Extension


In [None]:
# Fix PreciseRoI Pooling compilation issue
# This extension is required for PrDiMP training
import os
import shutil
import torch
import sys
import subprocess
import time

def fix_prroi_pooling():
    """Clean cache and force rebuild of PreciseRoI Pooling extension"""
    print("=" * 60)
    print("Fixing PreciseRoI Pooling extension...")
    print("=" * 60)
    
    # Set environment variable to force rebuild
    os.environ['TORCH_EXTENSIONS_DIR'] = os.path.expanduser('~/.cache/torch_extensions')
    # Force PyTorch to rebuild by setting this
    os.environ['FORCE_CUDA'] = '1'
    
    # Clean PyTorch extension cache COMPLETELY to force rebuild
    cache_dir = os.path.expanduser('~/.cache/torch_extensions')
    
    # Remove the entire cache directory, not just the specific extension
    if os.path.exists(cache_dir):
        print(f"Removing entire PyTorch extension cache: {cache_dir}")
        try:
            shutil.rmtree(cache_dir)
            print("✅ Cache cleaned successfully")
            # Give it a moment to ensure deletion is complete
            time.sleep(1)
        except Exception as e:
            print(f"⚠️  Warning: Could not clean cache: {e}")
            # Try to remove just the prroi_pooling cache
            prroi_cache = os.path.join(cache_dir, 'py311_cu124', '_prroi_pooling')
            if os.path.exists(prroi_cache):
                try:
                    shutil.rmtree(prroi_cache)
                    print("✅ Removed prroi_pooling cache specifically")
                except:
                    pass
    
    # Install ninja if needed (required for compilation)
    try:
        import ninja
        print("✅ ninja is available")
    except ImportError:
        print("Installing ninja (required for compilation)...")
        subprocess.run([sys.executable, "-m", "pip", "install", "ninja", "-q"], check=False)
        print("✅ ninja installed")
    
    # Force compilation by importing functional and actually using it
    print("\nAttempting to compile PreciseRoI Pooling extension...")
    print("This may take 1-2 minutes on first run...")
    
    try:
        # Find the source files to potentially touch them
        try:
            prroi_src_path = os.path.join(REPO_DIR, 'ltr', 'external', 'PreciseRoIPooling', 'pytorch', 'prroi_pool', 'src')
            if os.path.exists(prroi_src_path):
                # Touch source files to force rebuild detection
                for root, dirs, files in os.walk(prroi_src_path):
                    for file in files:
                        if file.endswith(('.cpp', '.cu', '.cuh')):
                            file_path = os.path.join(root, file)
                            try:
                                # Touch the file to update modification time
                                os.utime(file_path, None)
                            except:
                                pass
        except:
            pass
        
        # Import functional module which handles compilation
        from ltr.external.PreciseRoIPooling.pytorch.prroi_pool import functional as prroi_func
        
        # CRITICAL: Reset the cached module to force recompilation
        if hasattr(prroi_func, '_prroi_pooling'):
            prroi_func._prroi_pooling = None
        
        # Also clear any module-level cache
        import importlib
        modules_to_clear = [
            'ltr.external.PreciseRoIPooling.pytorch.prroi_pool.functional',
            'ltr.external.PreciseRoIPooling.pytorch.prroi_pool.prroi_pool'
        ]
        for mod_name in modules_to_clear:
            if mod_name in sys.modules:
                del sys.modules[mod_name]
        
        # Now import again to trigger fresh compilation
        from ltr.external.PreciseRoIPooling.pytorch.prroi_pool import functional as prroi_func_new
        
        # Actually call the function to trigger compilation
        if torch.cuda.is_available():
            print("Testing compilation with CUDA tensors...")
            dummy_feat = torch.randn(1, 64, 10, 10, device='cuda')
            dummy_roi = torch.tensor([[0, 0, 0, 5, 5]], dtype=torch.float32, device='cuda')
            
            # This will trigger actual compilation - it may take a minute
            print("   Calling prroi_pool2d (this triggers compilation)...")
            result = prroi_func_new.prroi_pool2d(dummy_feat, dummy_roi, 2, 2, 1.0)
            print("   ✅ Function call succeeded!")
            
            # Verify the .so file was created - check multiple possible locations
            cache_dir_check = os.path.expanduser('~/.cache/torch_extensions')
            possible_locations = [
                os.path.join(cache_dir_check, 'py311_cu124', '_prroi_pooling', '_prroi_pooling.so'),
                os.path.join(cache_dir_check, f'py{sys.version_info.major}{sys.version_info.minor}_cu{torch.version.cuda.replace(".", "")}', '_prroi_pooling', '_prroi_pooling.so'),
            ]
            
            # Also search for any .so file in the cache
            so_file = None
            for loc in possible_locations:
                if os.path.exists(loc):
                    so_file = loc
                    break
            
            # If not found, search the cache directory
            if so_file is None and os.path.exists(cache_dir_check):
                for root, dirs, files in os.walk(cache_dir_check):
                    for file in files:
                        if file == '_prroi_pooling.so':
                            so_file = os.path.join(root, file)
                            break
                    if so_file:
                        break
            
            if so_file and os.path.exists(so_file):
                print(f"✅ PreciseRoI Pooling extension compiled and verified!")
                print(f"   Extension file: {so_file}")
                print(f"   File size: {os.path.getsize(so_file) / 1024 / 1024:.2f} MB")
                return True
            else:
                print(f"⚠️  Extension function worked but .so file location unclear")
                print("   This is okay - the extension is loaded in memory")
                print("   Training should work now. If it fails, restart kernel and try again.")
                return True
        else:
            print("⚠️  CUDA not available - extension requires CUDA")
            print("   Training will fail without CUDA. Please enable GPU.")
            return False
            
    except Exception as e:
        print(f"❌ Failed to compile PreciseRoI Pooling: {e}")
        import traceback
        print("\nFull error traceback:")
        traceback.print_exc()
        print("\nTroubleshooting steps:")
        print("1. Ensure CUDA is available: torch.cuda.is_available()")
        print("2. Check CUDA version: torch.version.cuda")
        print("3. Verify PyTorch CUDA compatibility")
        print("4. Try restarting the kernel and running this cell again")
        print("5. On Kaggle, ensure GPU accelerator is enabled")
        return False

# Fix the extension
print("CUDA Status:")
print(f"  Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"  Device: {torch.cuda.get_device_name(0)}")
    print(f"  CUDA Version: {torch.version.cuda}")

success = fix_prroi_pooling()
if not success:
    print("\n" + "=" * 60)
    print("⚠️  WARNING: PreciseRoI Pooling extension setup incomplete.")
    print("=" * 60)
    print("Training will attempt to compile it on first use.")
    print("If training fails with the same error:")
    print("  1. Restart the kernel")
    print("  2. Run all cells from the beginning")
    print("  3. Ensure GPU is enabled (Kaggle: Settings > Accelerator > GPU)")
else:
    print("\n" + "=" * 60)
    print("✅ PreciseRoI Pooling extension is ready!")
    print("=" * 60)


## 2.2. Import Required Modules


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image
import json
from pathlib import Path
from tqdm import tqdm

# PyTracking imports
from ltr.dataset import Got10k
from ltr.data import processing, sampler, LTRLoader
from ltr.models.tracking import dimpnet
import ltr.models.loss as ltr_losses
import ltr.models.loss.kl_regression as klreg_losses
import ltr.actors.tracking as tracking_actors
from ltr.trainers import LTRTrainer
import ltr.data.transforms as tfm
from ltr.admin.environment import env_settings

# Evaluation imports
from pytracking.evaluation.datasets import get_dataset
from pytracking.evaluation.tracker import Tracker
from pytracking.evaluation.running import run_dataset
from pytracking.analysis.plot_results import plot_results, print_results

print("All modules imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")


## 3. Load and Visualize Dataset Samples


In [None]:
# Load GOT-10k dataset for training
settings_env = env_settings()
got10k_train = Got10k(root=settings_env.got10k_dir, split='vottrain')

print(f"GOT-10k training sequences: {len(got10k_train.sequence_list)}")
print(f"Sample sequences: {got10k_train.sequence_list[:5]}")


In [None]:
def visualize_samples(dataset, num_samples=4, seq_ids=None):
    """Visualize dataset samples with bounding boxes"""
    if seq_ids is None:
        seq_ids = np.random.choice(len(dataset.sequence_list), num_samples, replace=False)
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 15))
    axes = axes.flatten()
    
    for idx, seq_id in enumerate(seq_ids):
        # Get sequence info
        seq_info = dataset.get_sequence_info(seq_id)
        
        # Get first frame
        frame_ids = [0]
        frames, anno, meta = dataset.get_frames(seq_id, frame_ids)
        
        # Get image and bounding box
        img = np.array(frames[0])
        bbox = anno['bbox'][0].numpy()  # [x, y, w, h]
        
        # Draw bounding box
        img_with_bbox = img.copy()
        x, y, w, h = bbox.astype(int)
        cv2.rectangle(img_with_bbox, (x, y), (x + w, y + h), (0, 255, 0), 3)
        
        # Add text
        seq_name = dataset.sequence_list[seq_id]
        class_name = dataset.get_class_name(seq_id)
        text = f"{seq_name}\n{class_name}"
        cv2.putText(img_with_bbox, text, (x, max(y-10, 20)), 
                   cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
        
        # Display
        axes[idx].imshow(img_with_bbox)
        axes[idx].set_title(f"Sequence {seq_id}: {seq_name}", fontsize=12)
        axes[idx].axis('off')
        
        print(f"Sequence {seq_id} ({seq_name}):")
        print(f"  Class: {class_name}")
        print(f"  BBox: [{x}, {y}, {w}, {h}]")
        print(f"  Image size: {img.shape}")
    
    plt.tight_layout()
    plt.show()

# Visualize 4 random samples
print("Visualizing dataset samples...")
visualize_samples(got10k_train, num_samples=4)


## 4. Create PrDiMP50 Model


In [None]:
# Create PrDiMP50 model (klcedimpnet50) with ResNet50 backbone
print("Creating PrDiMP50 model...")

# Model parameters (matching prdimp50 training settings)
filter_size = 4
output_sigma_factor = 1/4
feature_sz = 18
output_sz = feature_sz * 16
output_sigma = output_sigma_factor / 5.0  # search_area_factor = 5.0

net = dimpnet.klcedimpnet50(
    filter_size=filter_size,
    backbone_pretrained=True,  # Use pretrained ResNet50
    optim_iter=5,
    clf_feat_norm=True,
    clf_feat_blocks=0,
    final_conv=True,
    out_feature_dim=512,
    optim_init_step=1.0,
    optim_init_reg=0.05,
    optim_min_reg=0.05,
    gauss_sigma=output_sigma * feature_sz,
    alpha_eps=0.05,
    normalize_label=True,
    init_initializer='zero'
)

# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = net.to(device)

print(f"Model created and moved to {device}")
print(f"Total parameters: {sum(p.numel() for p in net.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in net.parameters() if p.requires_grad):,}")


## 5. Setup Training Configuration


In [None]:
# Training settings
class TrainingSettings:
    def __init__(self):
        self.description = 'PrDiMP50 training on GOT-10k'
        self.batch_size = 8 if torch.cuda.is_available() else 2  # Adjust for Kaggle GPU
        self.num_workers = 4 if KAGGLE else 8
        self.multi_gpu = False
        self.print_interval = 200
        self.normalize_mean = [0.485, 0.456, 0.406]
        self.normalize_std = [0.229, 0.224, 0.225]
        self.search_area_factor = 5.0
        self.output_sigma_factor = 1/4
        self.target_filter_sz = 4
        self.feature_sz = 18
        self.output_sz = self.feature_sz * 16
        self.center_jitter_factor = {'train': 3, 'test': 4.5}
        self.scale_jitter_factor = {'train': 0.25, 'test': 0.5}
        self.hinge_threshold = 0.05
        self.print_stats = ['Loss/total', 'Loss/bb_ce', 'ClfTrain/clf_ce']
        
        # Project path for saving checkpoints
        self.module_name = 'dimp'
        self.script_name = 'prdimp50_got10k'
        self.project_path = f'ltr/{self.module_name}/{self.script_name}'
        
        # Training epochs
        self.num_epochs = 50

settings = TrainingSettings()
print(f"Training settings configured:")
print(f"  Batch size: {settings.batch_size}")
print(f"  Workers: {settings.num_workers}")
print(f"  Epochs: {settings.num_epochs}")
print(f"  Log interval (batches): {settings.print_interval}")


In [None]:
# Data transforms
transform_joint = tfm.Transform(tfm.ToGrayscale(probability=0.05))

transform_train = tfm.Transform(
    tfm.ToTensorAndJitter(0.2),
    tfm.Normalize(mean=settings.normalize_mean, std=settings.normalize_std)
)

transform_val = tfm.Transform(
    tfm.ToTensor(),
    tfm.Normalize(mean=settings.normalize_mean, std=settings.normalize_std)
)

# Data processing
output_sigma = settings.output_sigma_factor / settings.search_area_factor
proposal_params = {
    'boxes_per_frame': 128,
    'gt_sigma': (0.05, 0.05),
    'proposal_sigma': [(0.05, 0.05), (0.5, 0.5)]
}
label_params = {
    'feature_sz': settings.feature_sz,
    'sigma_factor': output_sigma,
    'kernel_sz': settings.target_filter_sz
}
label_density_params = {
    'feature_sz': settings.feature_sz,
    'sigma_factor': output_sigma,
    'kernel_sz': settings.target_filter_sz,
    'normalize': True
}

data_processing_train = processing.KLDiMPProcessing(
    search_area_factor=settings.search_area_factor,
    output_sz=settings.output_sz,
    center_jitter_factor=settings.center_jitter_factor,
    scale_jitter_factor=settings.scale_jitter_factor,
    mode='sequence',
    proposal_params=proposal_params,
    label_function_params=label_params,
    label_density_params=label_density_params,
    transform=transform_train,
    joint_transform=transform_joint
)

data_processing_val = processing.KLDiMPProcessing(
    search_area_factor=settings.search_area_factor,
    output_sz=settings.output_sz,
    center_jitter_factor=settings.center_jitter_factor,
    scale_jitter_factor=settings.scale_jitter_factor,
    mode='sequence',
    proposal_params=proposal_params,
    label_function_params=label_params,
    label_density_params=label_density_params,
    transform=transform_val,
    joint_transform=transform_joint
)

print("Data processing configured")


In [None]:
# Create data loaders
env_settings_obj = env_settings()

# Training dataset - only GOT-10k
got10k_train_dataset = Got10k(root=env_settings_obj.got10k_dir, split='vottrain')

# Validation dataset
got10k_val = Got10k(root=env_settings_obj.got10k_dir, split='votval')

# Train sampler
dataset_train = sampler.DiMPSampler(
    [got10k_train_dataset],
    [1],  # Only GOT-10k
    samples_per_epoch=26000,
    max_gap=200,
    num_test_frames=3,
    num_train_frames=3,
    processing=data_processing_train
)

loader_train = LTRLoader(
    'train',
    dataset_train,
    training=True,
    batch_size=settings.batch_size,
    num_workers=settings.num_workers,
    shuffle=True,
    drop_last=True,
    stack_dim=1
)

# Validation sampler
dataset_val = sampler.DiMPSampler(
    [got10k_val],
    [1],
    samples_per_epoch=5000,
    max_gap=200,
    num_test_frames=3,
    num_train_frames=3,
    processing=data_processing_val
)

loader_val = LTRLoader(
    'val',
    dataset_val,
    training=False,
    batch_size=settings.batch_size,
    num_workers=settings.num_workers,
    shuffle=False,
    drop_last=True,
    epoch_interval=5,
    stack_dim=1
)

print(f"Training samples per epoch: {len(dataset_train)}")
print(f"Validation samples per epoch: {len(dataset_val)}")


In [None]:
# Create actor and optimizer
objective = {
    'bb_ce': klreg_losses.KLRegression(),
    'clf_ce': klreg_losses.KLRegressionGrid()
}

loss_weight = {
    'bb_ce': 0.0025,
    'clf_ce': 0.25,
    'clf_ce_init': 0.25,
    'clf_ce_iter': 1.0
}

actor = tracking_actors.KLDiMPActor(
    net=net,
    objective=objective,
    loss_weight=loss_weight
)

# Optimizer with different learning rates for different components
optimizer = optim.Adam([
    {'params': actor.net.classifier.parameters(), 'lr': 1e-3},
    {'params': actor.net.bb_regressor.parameters(), 'lr': 1e-3},
    {'params': actor.net.feature_extractor.parameters(), 'lr': 2e-5}
], lr=2e-4)

lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.2)

print("Actor and optimizer created")


## 6. Train the Model


In [None]:
# Create trainer
from ltr.admin.settings import Settings

train_settings = Settings()
train_settings.description = settings.description
train_settings.batch_size = settings.batch_size
train_settings.num_workers = settings.num_workers
train_settings.multi_gpu = settings.multi_gpu
train_settings.print_interval = settings.print_interval
train_settings.project_path = settings.project_path
train_settings.print_stats = settings.print_stats

trainer = LTRTrainer(
    actor,
    [loader_train, loader_val],
    optimizer,
    train_settings,
    lr_scheduler
)

print("Starting training...")
print(f"Will train for {settings.num_epochs} epochs")
print(f"Checkpoints will be saved to: {train_settings.project_path}")


In [None]:
# Train the model
trainer.train(settings.num_epochs, load_latest=True, fail_safe=True)

print("Training completed!")


## 7. Evaluate on GOT-10k


In [None]:
# Load the best model checkpoint
checkpoint_path = os.path.join(
    env_settings().workspace_dir,
    settings.project_path,
    'checkpoints',
    'checkpoint.pth.tar'
)

# If checkpoint doesn't exist, try latest
if not os.path.exists(checkpoint_path):
    checkpoint_path = os.path.join(
        env_settings().workspace_dir,
        settings.project_path,
        'checkpoints',
        'checkpoint_latest.pth.tar'
    )

print(f"Loading checkpoint: {checkpoint_path}")

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    net.load_state_dict(checkpoint['net'], strict=False)
    print("Model loaded from checkpoint")
else:
    print("Warning: Checkpoint not found, using current model state")


In [None]:
# Setup evaluation
from pytracking.parameter.dimp.prdimp50 import parameters

# Create tracker
tracker_params = parameters()
tracker_params.net = net  # Use our trained model

# Load GOT-10k test/val dataset for evaluation
got10k_eval_dataset = get_dataset('got10k_ltrval')  # or 'got10k_val' for official val

print(f"Evaluation dataset: {len(got10k_eval_dataset)} sequences")

# Create tracker
tracker = Tracker(
    'prdimp50',
    'prdimp50',
    run_id=None,
    tracker_params=tracker_params
)

trackers = [tracker]

print("Running evaluation...")
run_dataset(got10k_eval_dataset, trackers, debug=0, threads=0)

print("Evaluation completed!")


## 8. Plot Results and Metrics


In [None]:
# Plot results
report_name = 'prdimp50_got10k_eval'

print("Generating plots...")
plot_results(
    trackers,
    got10k_eval_dataset,
    report_name,
    merge_results=False,
    plot_types=('success', 'prec'),
    force_evaluation=False
)

plt.show()


In [None]:
# Print detailed results
print("\n" + "="*80)
print("DETAILED RESULTS")
print("="*80)
print_results(
    trackers,
    got10k_eval_dataset,
    report_name,
    merge_results=False,
    plot_types=('success', 'prec')
)


In [None]:
# Extract and display key metrics
from pytracking.analysis.extract_results import extract_results

results = extract_results(trackers, got10k_eval_dataset, report_name, force_evaluation=False)

if results:
    print("\n" + "="*80)
    print("KEY METRICS")
    print("="*80)
    
    # Success plot metrics
    if 'success' in results:
        success_data = results['success']
        print(f"\nSuccess Plot (AUC): {success_data.get('AUC', 'N/A')}")
        print(f"Success at 0.5 overlap: {success_data.get('OP50', 'N/A')}")
        print(f"Success at 0.75 overlap: {success_data.get('OP75', 'N/A')}")
    
    # Precision metrics
    if 'precision' in results:
        prec_data = results['precision']
        print(f"\nPrecision (20px threshold): {prec_data.get('precision_score', 'N/A')}")
    
    # Normalized precision
    if 'norm_precision' in results:
        norm_prec_data = results['norm_precision']
        print(f"\nNormalized Precision: {norm_prec_data.get('norm_precision_score', 'N/A')}")
else:
    print("Results extraction failed. Check if evaluation completed successfully.")


In [None]:
# Save summary to file
summary_path = os.path.join(env_settings().workspace_dir, 'training_summary.txt')

with open(summary_path, 'w') as f:
    f.write("PrDiMP50 Training Summary\n")
    f.write("="*80 + "\n\n")
    f.write(f"Model: PrDiMP50 (klcedimpnet50) with ResNet50 backbone\n")
    f.write(f"Dataset: GOT-10k\n")
    f.write(f"Training epochs: {settings.num_epochs}\n")
    f.write(f"Batch size: {settings.batch_size}\n")
    f.write(f"\nResults:\n")
    if results:
        f.write(json.dumps(results, indent=2))

print(f"Summary saved to: {summary_path}")


## Summary

This notebook has:
1. ✅ Cloned the pytracking repository
2. ✅ Set up the environment for Kaggle
3. ✅ Implemented PrDiMP50 (klcedimpnet50) with ResNet50 backbone
4. ✅ Visualized dataset samples with bounding boxes
5. ✅ Trained the model on GOT-10k dataset
6. ✅ Evaluated the model on GOT-10k
7. ✅ Generated plots and metrics

The trained model checkpoints are saved in the workspace directory and can be used for inference or further training.
