In [1]:
# Install required packages
print("🚀 Installing dependencies...")

!pip -q install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip -q install opencv-python numpy matplotlib pillow scikit-image tqdm

print("✅ Dependencies installed.")

🚀 Installing dependencies...
✅ Dependencies installed.


In [2]:
# Clone the TrackNetV3 repo
print("📂 Cloning TrackNetV3 repository...")

!git clone https://github.com/qaz812345/TrackNetV3.git
%cd TrackNetV3

print("✅ Repo cloned.")

📂 Cloning TrackNetV3 repository...
Cloning into 'TrackNetV3'...
remote: Enumerating objects: 240, done.[K
remote: Counting objects: 100% (111/111), done.[K
remote: Compressing objects: 100% (24/24), done.[K
remote: Total 240 (delta 99), reused 87 (delta 87), pack-reused 129 (from 1)[K
Receiving objects: 100% (240/240), 2.82 MiB | 24.85 MiB/s, done.
Resolving deltas: 100% (134/134), done.
/content/TrackNetV3
✅ Repo cloned.


In [33]:
!pip install -r requirements.txt

Collecting dash==2.5.1 (from -r requirements.txt (line 1))
  Downloading dash-2.5.1-py3-none-any.whl.metadata (11 kB)
Collecting numpy==1.22.4 (from -r requirements.txt (line 2))
  Downloading numpy-1.22.4.zip (11.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.5/11.5 MB[0m [31m107.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mGetting requirements to build wheel[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Getting requirements to build wheel ... [?25l[?25herror
[1;31merror[0m: [1msubprocess-exited-with-error[0m

[31m×[0m [32mGetting requirements to build wheel[0m did not run successfully.
[31m│[0m exit code: [1;36m1[0m
[31m╰─>[0m See above for output.

[1;35mnote

In [4]:
print("🔽 Downloading checkpoints.zip from Google Drive link...")

!wget -O checkpoints.zip "https://drive.usercontent.google.com/download?id=1CfzE87a0f6LhBp0kniSl1-89zaLCZ8cA&export=download&confirm=t"
print("✅ Downloaded checkpoints.zip")

🔽 Downloading checkpoints.zip from Google Drive link...
--2025-09-21 13:39:39--  https://drive.usercontent.google.com/download?id=1CfzE87a0f6LhBp0kniSl1-89zaLCZ8cA&export=download&confirm=t
Resolving drive.usercontent.google.com (drive.usercontent.google.com)... 173.194.212.132, 2607:f8b0:400c:c11::84
Connecting to drive.usercontent.google.com (drive.usercontent.google.com)|173.194.212.132|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 131653952 (126M) [application/octet-stream]
Saving to: ‘checkpoints.zip’


2025-09-21 13:39:41 (142 MB/s) - ‘checkpoints.zip’ saved [131653952/131653952]

✅ Downloaded checkpoints.zip


In [6]:
print("🔍 Checking contents of 'ckpts/'...")
!ls -la ckpts/

print("\n🔍 Checking contents of 'weights/'...")
!ls -la weights/

🔍 Checking contents of 'ckpts/'...
total 139128
drwxr-xr-x 2 root root      4096 Aug  8  2023 .
drwxr-xr-x 8 root root      4096 Sep 21 13:40 ..
-rw-r--r-- 1 root root   6264451 Aug  8  2023 InpaintNet_best.pt
-rw-r--r-- 1 root root 136190877 Aug  8  2023 TrackNet_best.pt

🔍 Checking contents of 'weights/'...
total 8
drwxr-xr-x 2 root root 4096 Sep 21 13:27 .
drwxr-xr-x 8 root root 4096 Sep 21 13:40 ..
-rw-r--r-- 1 root root    0 Sep 21 13:27 tracknet_weights.pth


In [12]:
print("🧠 Loading TrackNetV3 model architecture...")

%cd /content/TrackNetV3

from model import TrackNet
import torch

# Initialize the model with correct input and output dimensions based on the error
model = TrackNet(in_dim=27, out_dim=8)
print("✅ Model architecture loaded.")

# Load the checkpoint
checkpoint_path = 'ckpts/TrackNet_best.pt'
state_dict = torch.load(checkpoint_path, map_location='cpu')

# Extract only the model state dictionary
model_state_dict = state_dict['model']

# Fix keys if they have 'module.' prefix (common in DataParallel models)
model_state_dict = {k.replace('module.', ''): v for k, v in model_state_dict.items()}

model.load_state_dict(model_state_dict)
model.eval()

print("✅ Successfully loaded weights from ckpts/TrackNet_best.pt")

🧠 Loading TrackNetV3 model architecture...
/content/TrackNetV3
✅ Model architecture loaded.
✅ Successfully loaded weights from ckpts/TrackNet_best.pt


In [39]:
%%writefile /content/TrackNetV3/predict.py
# predict.py - Fixed for Colab Compatibility

import os
import argparse
import numpy as np
import cv2

from tqdm import tqdm
import torch
from torch.utils.data import DataLoader

# Assume these exist from the repo
from test import predict_location, get_ensemble_weight, generate_inpaint_mask
from dataset import Shuttlecock_Trajectory_Dataset, Video_IterableDataset
from utils.general import get_model, write_pred_csv, write_pred_video, generate_frames, to_img_format, to_img

# 🔧 Constants (likely defined elsewhere — now added here)
WIDTH = 640   # Input width for model
HEIGHT = 360  # Input height for model
COOR_TH = 0.01  # Threshold for coordinate validity

# 🛠️ Fix: Use XVID codec for reliable video output in Colab
def create_video_writer(save_file, fps, w, h):
    fourcc = cv2.VideoWriter_fourcc(*'XVID')  # Works reliably in Colab
    return cv2.VideoWriter(save_file, fourcc, fps, (w, h))


def predict(indices, y_pred=None, c_pred=None, img_scaler=(1, 1)):
    """ Predict coordinates from heatmap or inpainted coordinates. """
    pred_dict = {'Frame': [], 'X': [], 'Y': [], 'Visibility': []}

    batch_size, seq_len = indices.shape[0], indices.shape[1]
    indices = indices.detach().cpu().numpy() if torch.is_tensor(indices) else indices.numpy()

    if y_pred is not None:
        y_pred = (y_pred > 0.5).detach().cpu().numpy()
        y_pred = to_img_format(y_pred)

    if c_pred is not None:
        c_pred = c_pred.detach().cpu().numpy()

    prev_f_i = -1
    for n in range(batch_size):
        for f in range(seq_len):
            f_i = indices[n][f][1]
            if f_i != prev_f_i:
                if c_pred is not None:
                    c_p = c_pred[n][f]
                    cx_pred = int(c_p[0] * WIDTH * img_scaler[0])
                    cy_pred = int(c_p[1] * HEIGHT * img_scaler[1])
                elif y_pred is not None:
                    y_p = y_pred[n][f]
                    bbox_pred = predict_location(to_img(y_p))
                    cx_pred = int((bbox_pred[0] + bbox_pred[2] / 2) * img_scaler[0])
                    cy_pred = int((bbox_pred[1] + bbox_pred[3] / 2) * img_scaler[1])
                else:
                    raise ValueError('Invalid input')

                vis_pred = 0 if cx_pred == 0 and cy_pred == 0 else 1
                pred_dict['Frame'].append(int(f_i))
                pred_dict['X'].append(cx_pred)
                pred_dict['Y'].append(cy_pred)
                pred_dict['Visibility'].append(vis_pred)
                prev_f_i = f_i
            else:
                break
    return pred_dict


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--video_file', type=str, required=True, help='file path of the video')
    parser.add_argument('--tracknet_file', type=str, required=True, help='file path of the TrackNet model checkpoint')
    parser.add_argument('--inpaintnet_file', type=str, default='', help='file path of the InpaintNet model checkpoint')
    parser.add_argument('--batch_size', type=int, default=1, help='batch size for inference (Colab-friendly)')
    parser.add_argument('--eval_mode', type=str, default='weight', choices=['nonoverlap', 'average', 'weight'])
    parser.add_argument('--max_sample_num', type=int, default=1800)
    parser.add_argument('--video_range', type=str, default=None, help='start,end in seconds')
    parser.add_argument('--save_dir', type=str, default='prediction', help='output directory')
    parser.add_argument('--large_video', action='store_true', help='for long videos')
    parser.add_argument('--output_video', action='store_true', help='generate output video')
    parser.add_argument('--traj_len', type=int, default=8, help='length of trajectory trail')
    args = parser.parse_args()

    # 🔽 Force low num_workers to avoid Colab crash
    num_workers = 0  # Was: args.batch_size if ... → now safe

    video_file = args.video_file
    video_name = os.path.splitext(os.path.basename(video_file))[0]
    video_range = [int(x) for x in args.video_range.split(',')] if args.video_range else None
    large_video = args.large_video
    save_dir = args.save_dir

    out_csv_file = os.path.join(save_dir, f'{video_name}_ball.csv')
    out_video_file = os.path.join(save_dir, f'{video_name}.avi')  # 🔴 Use .avi for XVID

    os.makedirs(save_dir, exist_ok=True)

    # Load models
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    tracknet_ckpt = torch.load(args.tracknet_file, map_location=device)
    seq_len = tracknet_ckpt['param_dict']['seq_len']
    bg_mode = tracknet_ckpt['param_dict']['bg_mode']

    tracknet = get_model('TrackNet', seq_len, bg_mode).to(device)
    tracknet.load_state_dict(tracknet_ckpt['model'])
    tracknet.eval()

    inpaintnet = None
    if args.inpaintnet_file:
        inpaintnet_ckpt = torch.load(args.inpaintnet_file, map_location=device)
        inpaintnet = get_model('InpaintNet').to(device)
        inpaintnet.load_state_dict(inpaintnet_ckpt['model'])
        inpaintnet.eval()

    cap = cv2.VideoCapture(video_file)
    w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    cap.release()

    w_scaler, h_scaler = w / WIDTH, h / HEIGHT
    img_scaler = (w_scaler, h_scaler)

    # Test on TrackNet
    tracknet_pred_dict = {
        'Frame': [], 'X': [], 'Y': [], 'Visibility': [], 'Inpaint_Mask': [],
        'Img_scaler': (w_scaler, h_scaler), 'Img_shape': (w, h)
    }

    print("🚀 Starting inference...")

    if args.eval_mode == 'nonoverlap':
        if large_video:
            dataset = Video_IterableDataset(
                video_file, seq_len=seq_len, sliding_step=seq_len,
                bg_mode=bg_mode, max_sample_num=args.max_sample_num,
                video_range=video_range
            )
            data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, drop_last=False)
        else:
            frame_list = generate_frames(video_file)
            dataset = Shuttlecock_Trajectory_Dataset(
                seq_len=seq_len, sliding_step=seq_len, data_mode='heatmap',
                bg_mode=bg_mode, frame_arr=np.array(frame_list)[:, :, :, ::-1], padding=True
            )
            data_loader = DataLoader(
                dataset, batch_size=args.batch_size, shuffle=False,
                num_workers=num_workers, drop_last=False
            )

        for step, (i, x) in enumerate(tqdm(data_loader)):
            x = x.float().to(device)
            with torch.no_grad():
                y_pred = tracknet(x).detach().cpu()
            tmp_pred = predict(i, y_pred=y_pred, img_scaler=img_scaler)
            for k in tmp_pred:
                tracknet_pred_dict[k].extend(tmp_pred[k])

    # Write CSV
    pred_dict = tracknet_pred_dict.copy()
    if inpaintnet is not None:
        # InpaintNet logic can be added later
        pass

    write_pred_csv(pred_dict, save_file=out_csv_file)
    print(f"✅ CSV saved: {out_csv_file}")

    # Write Video
    if args.output_video:
        print("🎬 Generating output video...")
        writer = create_video_writer(out_video_file, 30, w, h)
        cap = cv2.VideoCapture(video_file)
        frame_idx = 0
        traj_points = []

        while True:
            ret, frame = cap.read()
            if not ret:
                break

            # Find prediction for this frame
            if frame_idx in pred_dict['Frame']:
                idx = pred_dict['Frame'].index(frame_idx)
                x_pos = pred_dict['X'][idx]
                y_pos = pred_dict['Y'][idx]
                visible = pred_dict['Visibility'][idx]

                if visible:
                    cv2.circle(frame, (x_pos, y_pos), 8, (0, 0, 255), -1)
                    cv2.putText(frame, f'{frame_idx}', (x_pos + 10, y_pos),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
                    traj_points.append((x_pos, y_pos))
                else:
                    traj_points.append(None)
            else:
                traj_points.append(None)

            # Draw trajectory
            for i in range(1, min(args.traj_len, len(traj_points))):
                if traj_points[-i-1] and traj_points[-i]:
                    cv2.line(frame, traj_points[-i-1], traj_points[-i], (0, 255, 0), 2)

            writer.write(frame)
            frame_idx += 1

        cap.release()
        writer.release()
        print(f"✅ Video saved: {out_video_file}")

    print("🎉 Done.")

Overwriting /content/TrackNetV3/predict.py


In [40]:
%cd /content/TrackNetV3

!python predict.py \
  --video_file /content/clip_343.mp4 \
  --tracknet_file ckpts/TrackNet_best.pt \
  --inpaintnet_file ckpts/InpaintNet_best.pt \
  --save_dir prediction \
  --output_video \
  --batch_size 1

/content/TrackNetV3
🚀 Starting inference...
✅ CSV saved: prediction/clip_343_ball.csv
🎬 Generating output video...
✅ Video saved: prediction/clip_343.avi
🎉 Done.
