In [None]:
import pandas as pd
import os

from nemo.collections.asr.parts.numba.rnnt_loss.utils.rnnt_helper import threshold
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import random
import numpy as np
import pickle
from tqdm import tqdm


In [None]:
seed = 44

def set_seed(seed: int):
    """
    Makes process of training more deterministic
    and allows to get reproducible results
    Arguments:
        seed (int): Random seed to be used in fixing
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(seed)

# Data preprocessing

In [None]:
with open('embeddings.pkl', 'rb') as f:
    embeddings = pickle.load(f)
    
df = pd.read_csv('data.csv')

In [None]:
# Get rid of videos without embeddings

for i in range(len(df)):
    name = f"{df['source_video'][i]}_{df['source_start'][i]}_{df['source_end'][i]}"
    if name not in embeddings:
        df.drop(i, inplace=True)
        continue
    if df['is_duplicate'][i]:
        name = f"{df['target_video'][i]}_{df['target_start'][i]}_{df['target_end'][i]}"
        if name not in embeddings:
            df.drop(i, inplace=True)
            
df.reset_index(inplace=True, drop=True)

print(len(df))

In [None]:
df.head()

## Take all the unique groups and divide on train and test. Non-duplicate parts do not have a group -> divide them separately. 

In [None]:
# Step 1: Get unique groups from the 'group' column
unique_groups = df['group'].dropna().unique()

# Step 2: Split the groups into train and test (80% train, 20% test)
train_groups, test_groups = train_test_split(unique_groups, test_size=0.1, random_state=42)

# Step 3: Assign rows in the DataFrame to train and test sets based on the 'group' column
train_df = df[df['group'].isin(train_groups)]
test_df = df[(df['group'].isin(test_groups))]

# Find not duplicates
not_duplicate_df = df[~df['is_duplicate']]

# Split not duplicates into train and test
train_not_duplicates, test_not_duplicates = train_test_split(not_duplicate_df, test_size=0.2, random_state=42)

train_df = pd.concat([train_df, train_not_duplicates])
test_df = pd.concat([test_df, test_not_duplicates])

train_df.shape, test_df.shape, test_df.shape[0] / train_df.shape[0]

In [None]:
test_df.head()

### Retrieve connected components

In [None]:
def get_connected_components(df):
    connected_components = {}
    
    for _, row in df.iterrows():
        if row['is_duplicate']:
            source = f"{row['source_video']}_{row['source_start']}_{row['source_end']}"
            target = f"{row['target_video']}_{row['target_start']}_{row['target_end']}"
            group = int(row['group'])
            
            if group not in connected_components:
                connected_components[group] = set()
            
            connected_components[group].add(source)
            connected_components[group].add(target)
    
    return connected_components

test_connected_components = get_connected_components(test_df)
train_connected_components = get_connected_components(train_df)

## Prepare test data to easily run qdrant evaluation further

In [None]:
test_parts = []
for component in test_connected_components.values():
    for video in component:
        test_parts.append(video)

# iterate over the dataframe and check if the video is in the connected component

for _, row in test_df.iterrows():
    if not row['is_duplicate']:
        source = f"{row['source_video']}_{row['source_start']}_{row['source_end']}"
        test_parts.append(source)
        
test_parts[0]

## Triplets generation

In [None]:
def create_triplets(df, is_test):
    triplets = []
    
    duplicate_rows = df[df['is_duplicate']]
    
    bar = tqdm(total=len(duplicate_rows))
    for _, row in duplicate_rows.iterrows():
        anchor = f"{row['source_video']}_{row['source_start']}_{row['source_end']}"
        positive = f"{row['target_video']}_{row['target_start']}_{row['target_end']}"        

        negatives = df[df['group'] != row['group']].copy()
        
        if not is_test:
            negatives = negatives.sample(frac=0.1, random_state=seed)
        else:
            negatives = negatives.sample(frac=0.2, random_state=seed)
        negatives.reset_index(inplace=True, drop=True)
        
        for _, n in negatives.iterrows():
            if n['is_duplicate']:
                negative = f"{n['source_video']}_{n['source_start']}_{n['source_end']}"
                triplets.append((anchor, positive, negative))
                negative = f"{n['target_video']}_{n['target_start']}_{n['target_end']}"
                triplets.append((anchor, positive, negative))
            else:
                negative = f"{n['source_video']}_{n['source_start']}_{n['source_end']}"
                triplets.append((anchor, positive, negative))
    
        bar.update(1)
    bar.close()
    triplets_df = pd.DataFrame(triplets, columns=["anchor", "positive", "negative"])
    return triplets_df

train_triplets = create_triplets(train_df, False)
test_triplets = create_triplets(test_df, True)

In [None]:
for a, n, p in train_triplets.values:
    assert a in embeddings and n in embeddings and p in embeddings

In [None]:
import torch.utils
import torch.utils.data
from torchvision.transforms import v2

# 1. Define a Triplet Dataset with Late Fusion
class TripletDataset(Dataset):
    """
    Dateset for storing triplets
    """
    def __init__(self, data, embeddings):
        self.data = data
        self.embeddings = embeddings

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        """ 
        Gets triplet (anchor, positive, negative) by index
        Arguments:
            idx (int): Index to get triple by
        Returns:
            triplet ((torch.temsor, torch.tensor, torch.tensor)): Retrieved tripet
        """
        # If late fusion is not used, use the raw triplet data
        anchor, positive, negative = self.data.anchor[idx], self.data.positive[idx], self.data.negative[idx]

        return self.embeddings[anchor], self.embeddings[positive], self.embeddings[negative]
        

train_dataset = TripletDataset(train_triplets, embeddings)
test_dataset = TripletDataset(test_triplets, embeddings)

In [None]:
len(train_dataset), len(test_dataset)

# Model

In [None]:

# 2. Define a Simple Neural Network for Embedding Generation
class EmbeddingNet(nn.Module):
    def __init__(self):
        super(EmbeddingNet, self).__init__()
        self.fc1 = nn.Linear(1152, 1500)  # Assuming input images are 28x28
        self.fc2 = nn.Linear(1500, 1000)
        self.fc3 = nn.Linear(1000, 500)  # Output embedding of size 500

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)  # Output embedding
        x = F.normalize(x, p=2, dim=1)  # Normalize embeddings to have unit norm
        return x

# 3. Triplet Loss with Regularization (Custom)
class CustomTripletLoss(nn.Module):
    def __init__(self, margin=1.0, lambda_reg=1e-3):
        super(CustomTripletLoss, self).__init__()
        self.margin = margin
        self.lambda_reg = lambda_reg
        self.triplet_loss = nn.TripletMarginLoss(margin=self.margin)

    def forward(self, anchor, positive, negative, model_params):
        # Compute the triplet loss
        loss = self.triplet_loss(anchor, positive, negative)

        # L2 regularization on model parameters
        reg_loss = 0
        for param in model_params:
            reg_loss += torch.sum(param ** 2)

        reg_loss = self.lambda_reg * reg_loss
        total_loss = loss + reg_loss
        return total_loss


# 5. Training Loop
def train_triplet_model(train_loader: torch.utils.data.DataLoader, validation_loader: torch.utils.data.DataLoader, model: nn.Module, optimizer: torch.optim, criterion: nn.Module, num_epochs: int = 10, device: str = "cpu"):
    """
    Trains model to encode videos to embeddings
    Arguments:
        train_loader (torch.utils.data.DataLoader): dataloader for training
        model (nn.Module): model to train
        optimizer (torch.optim) optimizer for training
        criterion (nn.Module): loss function for optimization
        num_epochs (int): number of epoch to train the model for
        device (str): 'cuda' or 'cpu' depending on the machine and/or choice
    """
    
    min_loss = float('inf')
    
    model.train()

    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for (anchor, positive, negative) in tqdm(train_loader):
            # Move data to the correct device
            anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)
            # Forward pass: Compute embeddings
            anchor_emb = model(anchor)
            positive_emb = model(positive)
            negative_emb = model(negative)

            # Compute the triplet loss with regularization
            loss = criterion(anchor_emb, positive_emb, negative_emb, model.parameters())

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Accumulate batch loss
            epoch_loss += loss.item()

        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss / len(train_loader):.6f}')
        
        # validate
        val_loss = validate_triplet_model(validation_loader, model, criterion, device)
        
        if val_loss < min_loss:
            min_loss = val_loss
            torch.save(model.state_dict(), 'model.pt')

def validate_triplet_model(val_loader: torch.utils.data.DataLoader, model: nn.Module, criterion: nn.Module, device: str = "cpu"):
    """
    Validates the model by computing the average loss over the validation dataset.
    Arguments:
        val_loader (torch.utils.data.DataLoader): dataloader for validation
        model (nn.Module): trained model to validate
        criterion (nn.Module): loss function for evaluation
        device (str): 'cuda' or 'cpu' depending on the machine and/or choice
    Returns:
        float: Average validation loss
    """
    model.eval()  # Set model to evaluation mode
    val_loss = 0.0

    with torch.no_grad():  # No gradient computation
        for batch_idx, (anchor, positive, negative) in enumerate(val_loader):
            # Move data to the correct device
            anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)

            # Forward pass: Compute embeddings
            anchor_emb = model(anchor)
            positive_emb = model(positive)
            negative_emb = model(negative)

            # Compute the triplet loss
            loss = criterion(anchor_emb, positive_emb, negative_emb, model.parameters())

            # Accumulate batch loss
            val_loss += loss.item()

    # Compute average loss
    avg_val_loss = val_loss / len(val_loader)
    print(f'Validation Loss: {avg_val_loss:.6f}')
    return avg_val_loss


In [None]:
# Open dataset and create dataloader
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

# Initialize base model and late fusion model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model = EmbeddingNet().to(device)

In [34]:
# Loss function and optimizer
criterion = CustomTripletLoss(margin=1.0, lambda_reg=1e-3)
optimizer = optim.AdamW(base_model.parameters(), lr=1e-4, weight_decay=2e-5)

# Train the model
train_triplet_model(train_loader, test_loader, base_model, optimizer, criterion, num_epochs=6, device=device)


 19%|█▉        | 668/3496 [00:18<01:15, 37.63it/s][A
 19%|█▉        | 672/3496 [00:18<01:15, 37.52it/s][A
 19%|█▉        | 676/3496 [00:18<01:15, 37.50it/s][A
 19%|█▉        | 680/3496 [00:18<01:15, 37.50it/s][A
 20%|█▉        | 684/3496 [00:18<01:14, 37.50it/s][A
 20%|█▉        | 688/3496 [00:18<01:14, 37.52it/s][A
 20%|█▉        | 692/3496 [00:18<01:14, 37.53it/s][A
 20%|█▉        | 696/3496 [00:18<01:14, 37.53it/s][A
 20%|██        | 700/3496 [00:18<01:14, 37.48it/s][A
 20%|██        | 704/3496 [00:19<01:14, 37.53it/s][A
 20%|██        | 708/3496 [00:19<01:14, 37.52it/s][A
 20%|██        | 712/3496 [00:19<01:14, 37.54it/s][A
 20%|██        | 716/3496 [00:19<01:14, 37.51it/s][A
 21%|██        | 720/3496 [00:19<01:14, 37.42it/s][A
 21%|██        | 724/3496 [00:19<01:13, 37.50it/s][A
 21%|██        | 728/3496 [00:19<01:13, 37.47it/s][A
 21%|██        | 732/3496 [00:19<01:13, 37.49it/s][A
 21%|██        | 736/3496 [00:19<01:13, 37.48it/s][A
 21%|██        | 740/3496 [

Epoch [2/6], Loss: 0.020375
Validation Loss: 0.159474



  0%|          | 0/3496 [00:00<?, ?it/s][A
  0%|          | 1/3496 [00:00<08:41,  6.70it/s][A
  0%|          | 4/3496 [00:00<03:21, 17.30it/s][A
  0%|          | 8/3496 [00:00<02:18, 25.26it/s][A
  0%|          | 12/3496 [00:00<01:57, 29.77it/s][A
  0%|          | 16/3496 [00:00<01:46, 32.54it/s][A
  1%|          | 20/3496 [00:00<01:41, 34.17it/s][A
  1%|          | 24/3496 [00:00<01:38, 35.26it/s][A
  1%|          | 28/3496 [00:00<01:36, 35.95it/s][A
  1%|          | 32/3496 [00:01<01:35, 36.41it/s][A
  1%|          | 36/3496 [00:01<01:34, 36.75it/s][A
  1%|          | 40/3496 [00:01<01:33, 36.98it/s][A
  1%|▏         | 44/3496 [00:01<01:32, 37.13it/s][A
  1%|▏         | 48/3496 [00:01<01:32, 37.20it/s][A
  1%|▏         | 52/3496 [00:01<01:32, 37.32it/s][A
  2%|▏         | 56/3496 [00:01<01:32, 37.26it/s][A
  2%|▏         | 60/3496 [00:01<01:31, 37.43it/s][A
  2%|▏         | 64/3496 [00:01<01:31, 37.44it/s][A
  2%|▏         | 68/3496 [00:01<01:31, 37.44it/s][A
  2%

KeyboardInterrupt: 

In [44]:
base_model.load_state_dict(torch.load('model_77.pt', map_location=device))

      base_model.load_state_dict(torch.load('model_77.pt', map_location=device))
    


<All keys matched successfully>

# Model evaluation

In [45]:
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams
from qdrant_client.models import PointStruct
from datetime import datetime
from tqdm import tqdm

# Connect to Qdrant and create collection
client = QdrantClient(url="http://localhost:6333")

client.delete_collection('video')

video_emb_dim = 500
distance = Distance.EUCLID

client.create_collection(
    collection_name="video",
    vectors_config=VectorParams(size=video_emb_dim, distance=distance)
)

d = []

test_videos_parts = list(test_parts)
for i in tqdm(range(len(test_videos_parts))):
    name = test_parts[i]
    v = embeddings[name]
    v = base_model(v.unsqueeze(0).to(device)).detach().cpu().numpy()[0]

    # search for the closest vector
    search_result = client.query_points(
        collection_name="video",
        query=v,
        with_payload=True,
        limit=1
    ).points

    # if the db is not empty, put the closest found vector into array for further evaluation
    if len(search_result) > 0:
        id2 = search_result[0].id
        name2 = search_result[0].payload['name']
        d.append((name, name2, search_result[0].score))

    # insert the vector
    client.upsert(
        collection_name="video",
        points=[PointStruct(id=int(i), vector=v, payload={'name': name})]
    )
    # if i > 10:
    #     break
len(d)


  0%|          | 0/616 [00:00<?, ?it/s][A
  2%|▏         | 14/616 [00:00<00:04, 131.10it/s][A
  5%|▍         | 28/616 [00:00<00:04, 136.16it/s][A
  7%|▋         | 44/616 [00:00<00:03, 143.36it/s][A
 10%|▉         | 59/616 [00:00<00:03, 145.31it/s][A
 12%|█▏        | 74/616 [00:00<00:03, 144.55it/s][A
 14%|█▍        | 89/616 [00:00<00:03, 144.55it/s][A
 17%|█▋        | 104/616 [00:00<00:03, 144.13it/s][A
 19%|█▉        | 119/616 [00:00<00:03, 144.87it/s][A
 22%|██▏       | 134/616 [00:00<00:03, 145.45it/s][A
 24%|██▍       | 149/616 [00:01<00:03, 145.69it/s][A
 27%|██▋       | 165/616 [00:01<00:03, 147.54it/s][A
 29%|██▉       | 180/616 [00:01<00:02, 148.16it/s][A
 32%|███▏      | 195/616 [00:01<00:02, 147.64it/s][A
 34%|███▍      | 210/616 [00:01<00:02, 145.85it/s][A
 37%|███▋      | 225/616 [00:01<00:02, 144.50it/s][A
 39%|███▉      | 240/616 [00:01<00:02, 143.38it/s][A
 41%|████▏     | 255/616 [00:01<00:02, 143.08it/s][A
 44%|████▍     | 270/616 [00:01<00:02, 140.0

615

### For each test video part, find the proper component

In [46]:
right_components = []

for video in test_videos_parts:
    found = False
    # Find the component that contains the video
    for component in test_connected_components.keys():
        if video in test_connected_components[component]:
            right_components.append(component)
            found = True
            
    if not found:
        # if the video is not in any component then it is not a duplicate
        # print(f"Video {video} is not in any component")
        right_components.append(None)

In [47]:
def get_metrics(threshold, d, right_components):
    tp = 0
    fn = 0
    fp = 0
    tn = 0

    for i in range(1, len(test_videos_parts) -1):
        # If it has a pair
        if right_components[i] is not None:
            # If the threshold is passed
            if d[i-1][2] < threshold:
                # If they are actually in the same component
                if d[i-1][1] in test_connected_components[right_components[i]]:
                    tp += 1
        else:
            if d[i-1][2] > threshold:
                tn += 1
            else:
                fp += 1


    fn = len(test_videos_parts) -1 - tp - fp - tn
    accuracy = (tp + tn)/(tp + tn + fp + fn)
    recall = tp/(tp + fn)
    try:
        precision = tp/(tp + fp)
        f1 = 2 * (precision * recall) / (precision + recall)
    except:
        precision = 0
        f1 = 0
    

    return tp, fn, fp, tn, accuracy, recall, precision, f1

thresholds = np.linspace(0, 1, 100)

# find the best threshold on f1
best_threshold = 0
best_f1 = 0
for threshold in thresholds:
    tp, fn, fp, tn, accuracy, recall, precision, f1 = get_metrics(threshold, d, right_components)
    if f1 > best_f1:
        best_f1 = f1
        best_threshold = threshold

print(f"Best threshold: {best_threshold}, F1: {best_f1}")

tp, fn, fp, tn, accuracy, recall, precision, f1 = get_metrics(best_threshold, d, right_components)
print('                    actual')
print('               positive  negative')
print('predicted pos   ', tp, '    ', fp)
print('          neg   ',fn, '    ', tn)
print()


print(f"Accuracy: {accuracy:.4f}, Recall: {recall:.4f}, Precision: {precision:.4f}, F1: {f1:.4f}")

Best threshold: 0.32323232323232326, F1: 0.772348033373063
                    actual
               positive  negative
predicted pos    324      29
          neg    162      100

Accuracy: 0.6894, Recall: 0.6667, Precision: 0.9178, F1: 0.7723
