In [None]:
import torch
from torch import nn
import torchvision.transforms as T
from presnet import PResNet
from hybrid_encoder import HybridEncoder
from rtdetrv2_decoder import RTDETRTransformerv2
from rtdetr_postprocessor import RTDETRPostProcessor
from matcher import HungarianMatcher
import cv2
import matplotlib.pyplot as plt
import time

In [None]:
# Setup Feature Extractor
presnet = PResNet(depth= 34,
                  variant = 'd',
                  freeze_at = -1,
                  return_idx = [1, 2, 3],
                  num_stages = 4,
                  freeze_norm = False,
                  pretrained = True )
    

In [None]:
encoder = HybridEncoder(  in_channels = [128, 256, 512],
                          feat_strides = [8, 16, 32],
                          # intra
                          hidden_dim = 256,
                          use_encoder_idx = [2],
                          num_encoder_layers = 1,
                          nhead = 8,
                          dim_feedforward = 1024,
                          dropout = 0.,
                          enc_act = 'gelu' ,
                          # cross
                          expansion = 0.5,
                          depth_mult = 1,
                          act = 'silu')

In [None]:
decoder = RTDETRTransformerv2(feat_channels = [256, 256, 256],
                              feat_strides = [8, 16, 32],
                              hidden_dim = 256,
                              num_levels = 3,
                              num_layers = 4,
                              num_queries = 300,
                              num_denoising = 100,
                              label_noise_ratio = 0.5,
                              box_noise_scale = 1.0, # 1.0 0.4
                              eval_idx = 2,
                              # NEW
                              num_points = [4, 4, 4], # [3,3,3] [2,2,2]
                              cross_attn_method = 'default', # default, discrete
                              query_select_method = 'agnostic' # default, agnostic 
                              )

In [None]:
postprocessor = RTDETRPostProcessor()

In [None]:
class Model(nn.Module):
    def __init__(self, ) -> None:
        super().__init__()
        self.backbone = presnet
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, images, targets = None):
        features = self.backbone(images)
        features = self.encoder(features)
        out = self.decoder(features, targets)
        return out


In [None]:
detr = Model()

checkpoint = torch.load('detr_checkpoint_desk.pth', map_location='cpu') 

state = checkpoint['model_state_dict']

# NOTE load train mode state -> convert to deploy mode
detr.load_state_dict(state)

# Model Ready for evaluation
detr.eval()
detr.cuda()

In [None]:
# Loading the dataset
from dataset import PersonDataset
from torch.utils.data import random_split, DataLoader

root_dir = '/media/enrique/Extreme SSD/person'
sequence_list = [f'person-{i}' for i in range(1, 21)]
sequence_list = ["person-9"]
dataset = PersonDataset(root_dir=root_dir, sequence_list=sequence_list, img_transform_size=(640, 640), template_transform_size=(256, 256), max_num_templates=10, max_detections = 300)

# Define the lengths for training and validation sets
train_size = int(0.8 * len(dataset))  # 80% for training
val_size = len(dataset) - train_size  # The rest for validation

# Split the dataset
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

batch_size = 1
# Optionally, create DataLoader objects for the training and validation sets
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Get a single batch from the DataLoader
data_iter = iter(train_loader)
data = next(data_iter)

In [None]:
img = data["img"].cuda()

In [None]:
with torch.no_grad():
    output = detr(img)

In [None]:
orig_target_sizes = torch.Tensor([640, 640]).cuda()
procesed_output = postprocessor(output, orig_target_sizes)

In [None]:
procesed_output