In [1]:
# -*- coding: utf-8 -*-
"""
Created on Tuesday Feb 25 

@author: Dr. Benjamin Vien

"""

## TO INSTALL
# pip install torch torchvision imageio matplotlib opencv-python
!pip install torch torchvision
!pip install imageio
!pip install matplotlib
!pip install opencv-python-headless
!pip install imageio[ffmpeg]

Collecting opencv-python-headless
  Using cached opencv_python_headless-4.11.0.86-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Using cached opencv_python_headless-4.11.0.86-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (50.0 MB)
Installing collected packages: opencv-python-headless
Successfully installed opencv-python-headless-4.11.0.86
Collecting imageio-ffmpeg (from imageio[ffmpeg])
  Using cached imageio_ffmpeg-0.6.0-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Using cached imageio_ffmpeg-0.6.0-py3-none-manylinux2014_x86_64.whl (29.5 MB)
Installing collected packages: imageio-ffmpeg
Successfully installed imageio-ffmpeg-0.6.0


In [2]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

import cv2
import numpy as np
import torch
import imageio.v3 as iio
import scipy.io as sio
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import time
from syntheticdatageneration_utils import SyntheticDataset,print_model_summary,train_finetune_updated,fetch_optimizer,generate_interpolated_video

print("CUDA available:", torch.cuda.is_available())
print("CUDA device count:", torch.cuda.device_count())
print("CUDA device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU detected")
torch.cuda.empty_cache()

#print(torch.cuda.memory_summary(device=None, abbreviated=False))

CUDA available: True
CUDA device count: 4
CUDA device name: NVIDIA A10G


In [8]:
# -----------------------------
# Main Fine-Tuning Script with Synthetic Data
# -----------------------------
def main():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print("Training on device:", device)
    
    # User-defined parameters for the video clip.
    clip_start = 170
    clip_end = 300
    step = 2
    num_aug_frames = 4
    # Number of synthetic frames per sequence.
    
    video_path = "./Input Files/IMG_7296.mp4"
    mat_path = "./matlab_files/saved_objC_ordered_NaN_fixed.mat"
    
    print("Loading video...")
    frames = iio.imread(video_path, plugin='FFMPEG')
    frames_clip = frames[clip_start:clip_end:step]
    print("Video loaded. Using frames {} to {} with step {}.".format(clip_start, clip_end, step))
    
    # Use the first frame of the clip as the base image.
    base_image = frames_clip[0]
    
    print("Loading MATLAB query points...")
    mat_data = sio.loadmat(mat_path)
    saved_objC = mat_data['data']  # Expected shape: (N, 2, Total_Frames)
    saved_objC_tensor = torch.tensor(saved_objC, dtype=torch.float32)
    saved_objC_tensor[:, :2, :] = saved_objC_tensor[:, :2, :] - 1
    print("Original MATLAB query points shape:", saved_objC_tensor.shape)
    
    # Clip query points to match video clip.
    saved_objC_tensor_clipped = saved_objC_tensor[:, :, clip_start:clip_end:step]
    print("Clipped MATLAB query points shape:", saved_objC_tensor_clipped.shape)
    
    # Extract base query points for the first frame (shape: (N,2)).
    base_query_points = saved_objC_tensor_clipped[:, :, 0].cpu().numpy()
    print("Base query points shape:", base_query_points.shape)
    
    # Generate synthetic dataset.
    num_synthetic_samples = 100
    synth_dataset = SyntheticDataset(base_image, base_query_points,
                                     num_samples=num_synthetic_samples,
                                     num_frames=num_aug_frames)
    # Set batch_size to a value >1 if desired.
    synth_loader = DataLoader(synth_dataset, batch_size=3, shuffle=False)
    
    # --- Fine-Tuning Section ---
    model = torch.hub.load('./co-tracker', 'cotracker3_offline', source='local').to(device)
    print("Model Loaded!")
    
    reload_model = True
    if reload_model:
        state_dict = torch.load("./model_saved/cotracker3_finetuned_AWS_v1SYN_fullFNET_100samplesEPOCH20.pth", map_location=device)
        model.load_state_dict(state_dict)
        print("Reloaded previous fine-tuned model.")
    
    model.train()
    # Freeze all parameters first.
    for param in model.parameters():
        param.requires_grad = True
        
    # Unfreeze parameters starting from index 44.
    #for i, param in enumerate(model.parameters()):
    #    param.requires_grad = (i >= 44)
    #    # param.requires_grad = (i >= 10)
    
    print_model_summary(model)
    
    # Control scheduler usage here.
    use_scheduler = False  # Set to True to enable scheduler, False to disable.
    optimizer, scheduler = fetch_optimizer(model, lr=1e-5, weight_decay=1e-5, num_steps=200000, use_scheduler=use_scheduler)
    num_epochs = 5
    
    use_autocast = True
    use_gradscaler = True
    
    print("Starting Fine-Tuning...")
    tic = time.time()
    model = train_finetune_updated(model, synth_loader, optimizer, scheduler, device, num_epochs=num_epochs,
                                   use_autocast=use_autocast, use_gradscaler=use_gradscaler, early_stop_patience=200)
    toc = time.time()
    print(f"Elapsed time: {toc - tic:.6f} seconds")
    print("--------------------------------------------------------------")
    
    torch.save(model.state_dict(), "./model_saved/cotracker3_finetuned_AWS_v1SYN_fullFNET_100samplesEPOCH25.pth")
    print("Fine-tuning completed and model saved.")

if __name__ == '__main__':
    main()

Training on device: cuda
Loading video...
Video loaded. Using frames 170 to 300 with step 2.
Loading MATLAB query points...
Original MATLAB query points shape: torch.Size([1024, 2, 483])
Clipped MATLAB query points shape: torch.Size([1024, 2, 65])
Base query points shape: (1024, 2)
Model Loaded!
Reloaded previous fine-tuned model.
Total number of top-level layers: 1
Layer 0: model
Starting Fine-Tuning...

Epoch: 1


  state_dict = torch.load("./model_saved/cotracker3_finetuned_AWS_v1SYN_fullFNET_100samplesEPOCH20.pth", map_location=device)


  [Autocast] [Batch 1] pred_tracks.dtype: torch.float32
  [Autocast] [Batch 2] pred_tracks.dtype: torch.float32
  [Autocast] [Batch 3] pred_tracks.dtype: torch.float32
  [Autocast] [Batch 4] pred_tracks.dtype: torch.float32
  [Autocast] [Batch 5] pred_tracks.dtype: torch.float32
  [Autocast] [Batch 6] pred_tracks.dtype: torch.float32
  [Autocast] [Batch 7] pred_tracks.dtype: torch.float32
  [Autocast] [Batch 8] pred_tracks.dtype: torch.float32
  [Autocast] [Batch 9] pred_tracks.dtype: torch.float32
  [Autocast] [Batch 10] pred_tracks.dtype: torch.float32
  [Autocast] [Batch 11] pred_tracks.dtype: torch.float32
  [Autocast] [Batch 12] pred_tracks.dtype: torch.float32
  [Autocast] [Batch 13] pred_tracks.dtype: torch.float32
  [Autocast] [Batch 14] pred_tracks.dtype: torch.float32
  [Autocast] [Batch 15] pred_tracks.dtype: torch.float32
  [Autocast] [Batch 16] pred_tracks.dtype: torch.float32
  [Autocast] [Batch 17] pred_tracks.dtype: torch.float32
  [Autocast] [Batch 18] pred_tracks.dtyp