#Making DataLoader

In [None]:
import torch
import os
import os
import json
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
class PoseTrackDataset(Dataset):
  def __init__(self, main_folder, json_folder):
      self.main_folder = main_folder
      self.json_folder = json_folder
      self.subdirectories = sorted(next(os.walk(main_folder))[1])

  def __len__(self):
      return len(self.subdirectories)

  def __getitem__(self, idx):
    subdir = self.subdirectories[idx]
    img_path = os.path.join(self.main_folder, subdir)
    anno_path = os.path.join(self.json_folder, f"{subdir}.json")
    return img_path, anno_path

In [None]:
train_folder = '/content/drive/MyDrive/PoseTrack2/d1/images/train'
train_json_folder = '/content/drive/MyDrive/PoseTrack2/d1/PoseTrack21/posetrack_data/train'
val_folder = '/content/drive/MyDrive/PoseTrack2/d1/images/val'
val_json_folder = '/content/drive/MyDrive/PoseTrack2/d1/PoseTrack21/posetrack_data/val'
train_dataset = PoseTrackDataset(train_folder, train_json_folder)
train_dataloader = DataLoader(train_dataset, batch_size=1)
val_dataset = PoseTrackDataset(val_folder, val_json_folder)
val_dataloader = DataLoader(val_dataset, batch_size=1)

# Creating Model

In [None]:
!git clone https://github.com/facebookresearch/co-tracker
%cd co-tracker
!pip install -e .
!pip install opencv-python einops timm matplotlib moviepy flow_vis
!mkdir checkpoints
%cd checkpoints
!wget https://huggingface.co/facebook/cotracker/resolve/main/cotracker2.pth

fatal: destination path 'co-tracker' already exists and is not an empty directory.
/content/co-tracker
Obtaining file:///content/co-tracker
  Preparing metadata (setup.py) ... [?25l[?25hdone
Installing collected packages: cotracker
  Attempting uninstall: cotracker
    Found existing installation: cotracker 2.0
    Uninstalling cotracker-2.0:
      Successfully uninstalled cotracker-2.0
  Running setup.py develop for cotracker
Successfully installed cotracker-2.0
mkdir: cannot create directory ‘checkpoints’: File exists
/content/co-tracker/checkpoints
--2024-06-08 17:50:13--  https://huggingface.co/facebook/cotracker/resolve/main/cotracker2.pth
Resolving huggingface.co (huggingface.co)... 18.172.134.88, 18.172.134.4, 18.172.134.24, ...
Connecting to huggingface.co (huggingface.co)|18.172.134.88|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/99/80/9980d78abf629546bc23aee3918967b68e7f1fe3a1bbbb105ecb0e930cf49b5b/362f5

In [None]:
%cd /content/co-tracker
import os
import torch

from base64 import b64encode
from cotracker.utils.visualizer import Visualizer, read_video_from_path
from IPython.display import HTML

/content/co-tracker


In [None]:
from cotracker.predictor1 import CoTrackerPredictor

model = CoTrackerPredictor(
    checkpoint=os.path.join(
        './checkpoints/cotracker2.pth'
    )
)

# Loading Annotations and Video

In [None]:
def load_video(subdir_path):
  images = sorted([img for img in os.listdir(subdir_path) if img.endswith(".jpg")])

  image_arrays = []
  for img in images:
      img_path = os.path.join(subdir_path, img)
      img_array = cv2.imread(img_path)
      image_arrays.append(img_array)

  image_arrays_np = np.array(image_arrays)
  video = torch.from_numpy(image_arrays_np).permute(0, 3, 1, 2)[None].float()[:, :, [2, 1, 0], :, :]

  return video

In [None]:
def load_anno(json_path):

  def create_input_tensor(annotation):
    keypoints = annotation['keypoints']
    processed_keypoints = []
    non_mask = []
    frame_no = annotation['image_id']%1000
    for i in range(0, len(keypoints), 3):
      if(keypoints[i+2] == 1):
        x = keypoints[i]
        # y = 1920 - keypoints[i + 1] # Change 1920 for other datapoints
        y = keypoints[i + 1]
        non_mask.append(i)
        processed_keypoints.append([frame_no, x, y])


    return torch.tensor(processed_keypoints), non_mask

  def create_keypoints_tensor(annotation, non_mask):
    keypoints = annotation['keypoints']
    processed_keypoints = []
    frame_no = annotation['image_id']%1000
    for i in non_mask:
        x = keypoints[i]
        y = keypoints[i + 1]
        vis = keypoints[i+2]
        processed_keypoints.append([x, y, vis])

    return torch.tensor(processed_keypoints)

  with open(json_path, 'r') as file:
    data = json.load(file)
  persons = {}
  frames = {}
  inputs = {}
  mask = {}
  output = {}
  for i in data['annotations']:
    # new_annot = create_keypoints_tensor1(i, False)
    frame_num = i['image_id']%1000
    if i['person_id'] in persons:
      new_annot = create_keypoints_tensor(i,mask[i['person_id']])
      persons[i['person_id']] = torch.cat((persons[i['person_id']], new_annot))
      frames[i['person_id']].append(frame_num)
      # valids[i['person_id']] = torch.cat((persons[i['person_id']], new_annot))
    else:
      inputs[i['person_id']], mask[i['person_id']] = create_input_tensor(i)
      persons[i['person_id']] = create_keypoints_tensor(i,mask[i['person_id']])
      frames[i['person_id']] = [frame_num]

  return (inputs, persons, frames)


# Training Model

In [None]:
def loss_fn(actual, pred):
  # Compute the errors
  errors = torch.norm(actual[:, :2] - pred, dim=-1)

  # Get the visibility flags
  vis = actual[:, 2].reshape((-1))

  # Filter the errors for visible points
  visible_errors = errors[vis > 0]

  # Compute the mean error for visible points
  error = torch.mean(visible_errors)

  return error

In [None]:
import torch
import numpy as np

EPS = 1e-6

def reduce_masked_mean(x, mask, dim=None, keepdim=False):
    """
    Compute the mean of x considering only the valid points defined by mask.

    Args:
    x (torch.Tensor): Data tensor.
    mask (torch.Tensor): Mask tensor, should be the same shape as x.
    dim (int or tuple of int, optional): Dimension(s) to reduce.
    keepdim (bool, optional): Whether to keep the dimensions of reduction.

    Returns:
    torch.Tensor: The mean value considering only valid points.
    """
    # for (a, b) in zip(x.size(), mask.size()):
    #     assert a == b  # Ensure shapes match
    prod = x * mask
    if dim is None:
        numer = torch.sum(prod)
        denom = EPS + torch.sum(mask)
    else:
        numer = torch.sum(prod, dim=dim, keepdim=keepdim)
        denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim)

    mean = numer / denom
    return mean

def evaluate_trajectories(trajs_e, trajs_g, valids, W, H):
    """
    Evaluate the predicted trajectories against ground truth trajectories.

    Args:
    trajs_e (torch.Tensor): Predicted trajectories of shape (S, 2).
    trajs_g (torch.Tensor): Ground truth trajectories of shape (S, 2).
    valids (torch.Tensor): Validity mask of shape (S) with 1 for valid and 0 for invalid.
    W (float): Width of the evaluation space.
    H (float): Height of the evaluation space.

    Returns:
    dict: Metrics containing distance thresholds and average distance.
    """
    # Distance thresholds
    thrs = [1, 2, 4, 8, 16]
    d_sum = 0.0
    metrics = {}

    # Scaling factors
    sx_ = W / 256.0
    sy_ = H / 256.0
    sc_py = np.array([sx_, sy_]).reshape([1, 2])
    sc_pt = torch.from_numpy(sc_py).float()

    for thr in thrs:
        # Calculate the L2 norm (Euclidean distance) and apply threshold
        d_ = (torch.norm(trajs_e[1:] / sc_pt - trajs_g[1:] / sc_pt, dim=-1) < thr).float()  # Shape: (S-1)

        # Reduce masked mean considering only valid points
        d_ = reduce_masked_mean(d_, valids[1:]).item() * 100.0

        # Accumulate the distance metrics
        d_sum += d_

        # Store individual threshold metrics
        metrics['d_%d' % thr] = d_

    # Calculate average distance metric
    d_avg = d_sum / len(thrs)
    metrics['d_avg'] = d_avg

    return metrics


In [None]:
optimizer = torch.optim.AdamW(model.model.parameters(), lr=5e-4)


In [None]:
def compute_val_loss(model, val_dataloader, val_batches):
  model.model.eval()
  losses = []
  metrics = {}
  d1 = []
  d2 = []
  d4 = []
  d8 = []
  d16 = []
  d_avg = []
  with torch.no_grad():
    for i in range(val_batches):
      for batch_idx, k in enumerate(train_dataloader):
        videos = []
        annos = []
        for video, anno in k:
          if os.path.exists(video) and os.path.exists(anno):
            videos.append(load_video(video))
            annos.append(load_anno(anno))
        total_loss = 0
        for(video,y) in zip(videos, annos):
          input = y[0]
          output = y[1]
          frames = y[2]
          for i in input:
            d1_temp = []
            d2_temp = []
            d4_temp = []
            d8_temp = []
            d16_temp = []
            davg_temp = []
            W = video.shape[-1]
            H = video.shape[-2]
            if torch.cuda.is_available():
              model = model.cuda()
              video = video.cuda()
              input[i] = input[i].cuda()
              output[i] = output[i].cuda()
            pred_tracks, visibility = model(video, queries = input[i][None])
            reshaped_output = pred_tracks[0][frames[i]] # Selecting the frames from output for calculating error
            reshaped_output = reshaped_output.reshape((-1,2))
            res_dict = evaluate_trajectories(reshaped_output, output[i][:,:2], output[i][:,2], W,H)
            d1_temp.append(res_dict['d_1'])
            d2_temp.append(res_dict['d_2'])
            d4_temp.append(res_dict['d_4'])
            d8_temp.append(res_dict['d_8'])
            d16_temp.append(res_dict['d_16'])
            davg_temp.append(res_dict['d_avg'])
            total_loss += loss_fn(output[i], reshaped_output)
          losses.append(total_loss.item()/len(input))
          d1.append(sum(d1_temp)/len(d1))
          d2.append(sum(d2_temp)/len(d2))
          d4.append(sum(d4_temp)/len(d4))
          d8.append(sum(d8_temp)/len(d8))
          d16.append(sum(d16_temp)/len(d16))
          d_avg.append(sum(davg_temp)/len(d_avg))

    final_loss = sum(losses)/len(losses)
    metrics['d_1'] = sum(d1)/len(d1)
    metrics['d_2'] = sum(d2)/len(d2)
    metrics['d_4'] = sum(d4)/len(d4)
    metrics['d_8'] = sum(d8)/len(d8)
    metrics['d_16'] = sum(d16)/len(d16)
    metrics['d_avg'] = sum(d_avg)/len(d_avg)

    print(f"Validation loss: {final_loss:.4f}")
    for i in metrics:
      print(f"{i}: {metrics[i]:.4f}")

    model.model.train()
    return

In [None]:
def train_model(model,optimizer,train_dataloader, val_freq, save_freq, num_epochs=10, ckpt_dir, use_augs = True):
  step = 0
  for epoch in range(num_epochs):
    for batch_idx, k in enumerate(train_dataloader):
      videos = []
      annos = []
      for video, anno in k:
        if os.path.exists(video) and os.path.exists(anno):
          videos.append(load_video(video))
          annos.append(load_anno(anno))
      total_loss = 0
      for(video,y) in zip(videos, annos):
        input = y[0]
        output = y[1]
        frames = y[2]
        if use_augs and np.random.rand() < 0.5: # rot90 aug
          video = video.permute(0,1,2,4,3) # swap xy
          output = output.flip([3])
        for i in input:
          if torch.cuda.is_available():
            model = model.cuda()
            video = video.cuda()
            input[i] = input[i].cuda()
            output[i] = output[i].cuda()
            temp_loss = 0
          # print(video.shape)
          pred_tracks, visibility = model(video, queries = input[i][None])
          reshaped_output = pred_tracks[0][frames[i]] # Selecting the frames from output for calculating error
          reshaped_output = reshaped_output.reshape((-1,2))
          temp_loss += loss_fn(output[i], reshaped_output)
          print("Processed person")
        temp_loss = temp_loss/len(input)
        total_loss += temp_loss
      total_loss = total_loss/len(videos)
      if torch.isnan(total_loss):
        print('nan in loss; quitting')
        return False
      optimizer.zero_grad()
      total_loss.backward()
      optimizer.step()
      print('Optimized')

      if batch_idx % val_freq == 0:
        compute_val_loss(model, val_dataloader)

      if batch_idx % save_freq == 0:
        torch.save(model.model.state_dict(), f'ckpt_dir/model_{batch_idx}_{step}.pth')
        step+=1

      break
    break
  return model


