In [None]:
!pip install lightning transformers torchaudio nnAudio scikit-learn pandas tqdm

Collecting lightning
  Downloading lightning-2.6.0-py3-none-any.whl.metadata (44 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m44.9/44.9 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
Collecting nnAudio
  Downloading nnaudio-0.3.4-py3-none-any.whl.metadata (771 bytes)
Downloading lightning-2.6.0-py3-none-any.whl (845 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m846.0/846.0 kB[0m [31m20.3 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hDownloading nnaudio-0.3.4-py3-none-any.whl (43 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m43.8/43.8 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: nnAudio, lightning
Successfully installed lightning-2.6.0 nnAudio-0.3.4


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

import lightning as L

from transformers import AutoModel, Wav2Vec2FeatureExtractor

from typing import List, Tuple, Dict, Any, Union, Optional


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, stride, padding)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm1d(out_channels)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(x)
        out = self.relu(out)
        return out


class SiameseNet(nn.Module):
    def __init__(self, embedding_dim: int):
        super(SiameseNet, self).__init__()
        self.layer1 = ResidualBlock(3072, 512)
        self.layer2 = ResidualBlock(512, 256)
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(256, embedding_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.fc(x)
        return x

    def similarity_score(self, sample1, sample2, metric='euclidean'):
        if metric == 'euclidean':
            return torch.nn.functional.pairwise_distance(sample1, sample2)
        elif metric == 'cosine':
            return 1 - torch.nn.functional.cosine_similarity(sample1, sample2)

class PlagiarismDetectionSystem(L.LightningModule):
    def __init__(self, config: Dict):
        super().__init__()
        train_classifier_gap = config["train_classifier_gap"]
        embedding_dim = config["siamese_emb_dim"]

        # Feature extractor trained by triplet loss
        self.siamese_net = SiameseNet(embedding_dim=embedding_dim)

        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, 1),
        )

        # loss functions
        self.criterion_triplet = nn.TripletMarginLoss(margin=2, p=2)
        self.criterion_classification = nn.BCEWithLogitsLoss()

        # how often to train the classification head
        self.train_classifier_gap = train_classifier_gap

        # audio model for inference
        if not hasattr(self, "audio_processor"):
            self.audio_processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-95M")
        if not hasattr(self, "audio_model"):
            self.audio_model = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True).to(self.device)

    def forward_siamese_net(self, anchors, positives, negatives):
        triplet_embeddings = torch.stack([
                self.siamese_net(anchors),
                self.siamese_net(positives),
                self.siamese_net(negatives)
            ], dim=1
        )
        return triplet_embeddings

    def forward_classifier(self, triplet_embeddings):
        diff_same = torch.abs(triplet_embeddings[:,0] - triplet_embeddings[:,1])
        diff_diff = torch.abs(triplet_embeddings[:,0] - triplet_embeddings[:,2])
        logit_same = self.classifier(diff_same).squeeze()
        logit_diff = self.classifier(diff_diff).squeeze()
        return logit_same, logit_diff

    def training_step(self, batch, batch_idx):
        anchors, positives, negatives = batch
        B = anchors.shape[0]

        # train siamese net
        triplet_embeddings = self.forward_siamese_net(anchors, positives, negatives)
        loss_triplet = self.criterion_triplet(
            triplet_embeddings[:,0], triplet_embeddings[:,1], triplet_embeddings[:,2]
        ) # anchor_embeddings, positive_embeddings, negative_embeddings

        # train classifier
        train_classifier = False
        if self.train_classifier_gap is None:
            train_classifier = True
        elif self.global_step // self.train_classifier_gap == self.train_classifier_gap - 1:
            train_classifier = True

        if train_classifier:
            labels_same = torch.zeros(B).to(triplet_embeddings.device).float()
            labels_diff = torch.ones(B).to(triplet_embeddings.device).float()
            logit_same, logit_diff = self.forward_classifier(triplet_embeddings.detach())
            loss_same = self.criterion_classification(logit_same, labels_same.squeeze())
            loss_diff = self.criterion_classification(logit_diff, labels_diff.squeeze())
            loss_classification = (loss_same + loss_diff) / 2
        else:
            loss_classification = 0

        # final loss
        final_loss = loss_triplet + loss_classification
        self.log("triplet_loss", loss_triplet)
        self.log("classification_loss", loss_classification)
        self.log("total_loss", final_loss)
        return final_loss

    @torch.no_grad()
    def validation_step(self, batch, batch_idx):
        # B, three = batch.shape[0], batch.shape[1]
        # assert three == 3
        anchors, positives, negatives = batch
        B = anchors.shape[0]

        triplet_embeddings = self.forward_siamese_net(anchors, positives, negatives)
        loss_triplet = self.criterion_triplet(
            triplet_embeddings[:,0], triplet_embeddings[:,1], triplet_embeddings[:,2]
        ).cpu().item() # anchor_embeddings, positive_embeddings, negative_embeddings

        labels_same = torch.zeros(B).to(triplet_embeddings.device)
        labels_diff = torch.ones(B).to(triplet_embeddings.device)
        logit_same, logit_diff = self.forward_classifier(triplet_embeddings)
        loss_same = self.criterion_classification(logit_same, labels_same)
        loss_diff = self.criterion_classification(logit_diff, labels_diff)
        loss_classification = ((loss_same + loss_diff) / 2).cpu().item()

        # ">": normal decision (different song, large logit -> TRUE; same song, small logit -> FALSE)
        # If using "<", then it is reverting the decision

        preds_same = torch.sigmoid(logit_same) > 0.5
        preds_diff = torch.sigmoid(logit_diff) > 0.5
        # preds_same = self._inference_step(batch[:,0], batch[:,1]) > 0.5 # same operation
        # preds_diff = self._inference_step(batch[:,0], batch[:,2]) > 0.5
        preds = torch.cat([preds_same, preds_diff])
        labels = torch.cat([labels_same, labels_diff])

        accuracy = (preds.cpu() == labels.cpu()).float().mean()  # Batch accuracy = overall accuracy when batch_size = dataset_size
        accuracy_positive = (preds[:B].cpu() == labels_same.cpu()).float().mean()

        self.log("val_triplet_loss", loss_triplet, prog_bar=True)
        self.log("val_classification_loss", loss_classification, prog_bar=True)
        self.log("val_accuracy", accuracy, prog_bar=True)
        self.log("val_accuracy_positive", accuracy, prog_bar=True)

        return {
            "val_triplet_loss": loss_triplet,
            "val_classification_loss": loss_classification,
        }

    @torch.no_grad()
    def _inference_step(self, sample1:torch.Tensor, sample2:torch.Tensor):
        B1 = sample1.shape[0]
        B2 = sample2.shape[0]
        assert B1 == B2
        B = B1

        out_embs1 = self.siamese_net(sample1)
        out_embs2 = self.siamese_net(sample2)
        diff = torch.abs(out_embs1 - out_embs2)
        logit = self.classifier(diff).squeeze()
        scores = torch.sigmoid(logit)
        return scores

    @torch.no_grad()
    def inference_pairs(
        self,
        waveforms1:Union[List[torch.Tensor], torch.Tensor],
        waveforms2:Union[List[torch.Tensor], torch.Tensor],
    ):
        time_reduce = torch.nn.AvgPool1d(kernel_size=10, stride=10, count_include_pad=False).to(self.device)
        self.eval()

        if type(waveforms1) == list and type(waveforms2) == list:
            assert len(waveforms1) == len(waveforms2)
            waveforms1 = torch.stack(waveforms1).to(self.device)
            waveforms2 = torch.stack(waveforms2).to(self.device)
        elif torch.is_tensor(waveforms1) and torch.is_tensor(waveforms2):
            assert waveforms1.shape[0] == waveforms2.shape[0]
            assert waveforms1.dim() == 2 and waveforms2.dim() == 2
        else:
            assert 0

        # extract MERT features
        hidden_states1 = self.audio_model(waveforms1, output_hidden_states=True).hidden_states
        hidden_states2 = self.audio_model(waveforms2, output_hidden_states=True).hidden_states
        mert_features1 = torch.stack(
            [time_reduce(h.detach()[:, :, :].permute(0,2,1)).permute(0,2,1) for h in hidden_states1[2::3]], dim=1
        )
        mert_features2 = torch.stack(
            [time_reduce(h.detach()[:, :, :].permute(0,2,1)).permute(0,2,1) for h in hidden_states2[2::3]], dim=1
        )
        batch_num, num_layers, num_frames, layer_dim = mert_features1.shape
        mert_features1 = mert_features1.permute(0, 1, 3, 2) # [batch_num, num_layers=4, layer_dim=768, num_frames]
        mert_features2 = mert_features2.permute(0, 1, 3, 2) # [batch_num, num_layers=4, layer_dim=768, num_frames]
        assert mert_features1.shape[1] == 4 and mert_features1.shape[2] == 768
        # mert_features = mert_features.reshape(batch_num, num_layers * layer_dim, num_frames)
        mert_features1 = torch.cat([mert_features1[:,i] for i in range(mert_features1.shape[1])], dim=1)
        mert_features2 = torch.cat([mert_features2[:,i] for i in range(mert_features2.shape[1])], dim=1)

        # get scores for decisions
        # num_features = mert_features.shape[0] // 2
        scores = self._inference_step(mert_features1, mert_features2)

        return 1 - scores # similarity, the higher the more similar (distance smaller)

    def configure_optimizers(self):
        optimizer = optim.Adam(
            list(self.siamese_net.parameters()) + list(self.classifier.parameters()),
            lr=1e-3
        )
        return optimizer


print("‚úÖ System Classes Defined")

‚úÖ System Classes Defined


In [None]:
import torch
import numpy as np
import os
import random
from sklearn import metrics
from tqdm import tqdm
import glob

config = {
    "siamese_emb_dim": 128,
    "train_classifier_gap": None
}

def final_kaggle_evaluation():
    print("üöÄ Starting Evaluation (Robust Tensor Reshape)...")

    # Setup
    model_files = glob.glob("/kaggle/input/**/*.ckpt", recursive=True)
    if not model_files: return
    checkpoint_path = model_files[0]

    track_files = glob.glob("/kaggle/input/**/*.npy", recursive=True)
    if not track_files: return

    try:
        track_files.sort(key=lambda x: int(os.path.basename(x).replace('pair_', '').replace('.npy', '')))
    except: pass

    # Load Model
    print(f"üì• Loading Model: {checkpoint_path}")
    model = PlagiarismDetectionSystem.load_from_checkpoint(checkpoint_path, config=config, map_location='cuda', strict=False)
    model.eval().cuda()

    # Load Data
    all_track_data = [np.load(f) for f in track_files]

    #  Pairs
    positive_pairs = []
    negative_pairs = []
    num_tracks = len(all_track_data)

    for i in range(num_tracks):
        if all_track_data[i].shape[0] >= 2:
            positive_pairs.append((i, 0, i, 1))

    valid_indices = [i for i in range(num_tracks) if all_track_data[i].shape[0] >= 1]
    while len(negative_pairs) < len(positive_pairs):
        idx1 = random.choice(valid_indices)
        idx2 = random.choice(valid_indices)
        if idx1 != idx2:
            negative_pairs.append((idx1, 0, idx2, 0))

    all_pairs = positive_pairs + negative_pairs
    all_labels = [1] * len(positive_pairs) + [0] * len(negative_pairs)

    print(f"‚öñÔ∏è Evaluating on {len(all_pairs)} pairs...")

    # Distances
    print("Computing Distances...")
    best_acc = 0
    best_thresh = 0
    best_preds = []
    distances = []

    for t1, v1, t2, v2 in tqdm(all_pairs):
        seg1 = all_track_data[t1][v1]
        seg2 = all_track_data[t2][v2]

        # Robust Prepare Tensor
        def prepare_tensor(seg):
            t = torch.from_numpy(seg).float().cuda()

            if t.ndim == 4:
                # Segments: (Seg, Time, Layers, Feat)
                t = t.permute(0, 2, 1, 3)
                # Segments * Time -> Total Time
                # Layers * Feat -> Total Channels (3072)
                # Shape: (Total_Time, 4, 768)
                t = t.reshape(-1, 4, 768)
                #    Shape: (Total_Time, 3072)
                t = t.reshape(t.shape[0], -1)
                t = t.transpose(0, 1) # -> (3072, Total_Time)

            # Simple Case: (Layers, Time, Feat) -> [4, T, 768]
            elif t.ndim == 3 and t.shape[0] == 4 and t.shape[2] == 768:
                t = t.permute(0, 2, 1) # (4, 768, T)
                t = t.reshape(-1, t.shape[-1]) # (3072, T)

            if t.shape[0] != 3072 and t.shape[1] == 3072:
                t = t.transpose(0, 1)

            return t.unsqueeze(0)

        input1 = prepare_tensor(seg1)
        input2 = prepare_tensor(seg2)

        with torch.no_grad():
            emb1 = model.siamese_net(input1)
            emb2 = model.siamese_net(input2)
            dist = torch.dist(emb1, emb2, p=2).item()
            distances.append(dist)

    # Tuning
    print("\nTuning Threshold...")
    max_dist = max(distances) if distances else 1.0
    thresholds = np.linspace(0, max_dist, 100)

    for thresh in thresholds:
        preds = [1 if d < thresh else 0 for d in distances]
        acc = metrics.accuracy_score(all_labels, preds)
        if acc > best_acc:
            best_acc = acc
            best_thresh = thresh
            best_preds = preds

    print("\n" + "="*40)
    print(f"üèÜ KAGGLE FINAL RESULT")
    print(f"‚öôÔ∏è Best Distance Threshold: < {best_thresh:.4f}")
    print("="*40)
    print(metrics.classification_report(all_labels, best_preds, target_names=["Different", "Similar/Plagiarism"]))
    print(f"‚úÖ Max Accuracy: {best_acc:.2%}")
    print(f"Confusion Matrix:\n{metrics.confusion_matrix(all_labels, best_preds)}")

final_kaggle_evaluation()

üöÄ Starting Evaluation (Robust Tensor Reshape)...
üì• Loading Model: /kaggle/input/thesis-complete/best_model_continued-epoch7.ckpt
‚öñÔ∏è Evaluating on 96 pairs...
Computing Distances...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 96/96 [00:01<00:00, 53.36it/s]


Tuning Threshold...

üèÜ KAGGLE FINAL RESULT
‚öôÔ∏è Best Distance Threshold: < 1.1499
                    precision    recall  f1-score   support

         Different       0.79      0.79      0.79        48
Similar/Plagiarism       0.79      0.79      0.79        48

          accuracy                           0.79        96
         macro avg       0.79      0.79      0.79        96
      weighted avg       0.79      0.79      0.79        96

‚úÖ Max Accuracy: 79.17%
Confusion Matrix:
[[38 10]
 [10 38]]



