In [None]:
import os
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
from glob import glob

In [None]:
class MelSpecDataset(Dataset):
    def __init__(self, data_dir):

        self.samples = []
        self.label_map = {'mother': 0, 'fetus': 1}

        for label_name in ['mother', 'fetus']:
            folder = os.path.join(data_dir, label_name)
            files = glob(os.path.join(folder, "*.npy"))
            for f in files:
                self.samples.append((f, self.label_map[label_name]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        mel = np.load(path)
        mel = torch.tensor(mel, dtype=torch.float32)

        mel = (mel - mel.mean()) / (mel.std() + 1e-6)

        return mel.unsqueeze(0), torch.tensor(label, dtype=torch.long)

In [None]:
from dataset import MelSpecDataset
from torch.utils.data import DataLoader

train_ds = MelSpecDataset("/content/drive/MyDrive/SUFHSDB/training_data")
test_ds = MelSpecDataset("/content/drive/MyDrive/SUFHSDB/testing_data")

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=16, shuffle=False)

### ViT

In [None]:
!pip install vit_pytorch

In [None]:
from vit_pytorch import SimpleViT

vit_model = SimpleViT(
    image_size = 128,
    patch_size = 32,
    num_classes = 2,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048
)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vit_model.parameters(), lr=3e-4)

In [None]:
def train_one_epoch(model, loader):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for x, y in tqdm(loader):
        x, y = x.to(device), y.to(device)

        x = torch.nn.functional.interpolate(x, size=(128,128), mode='bilinear')
        x = x.repeat(1,3,1,1)

        optimizer.zero_grad()
        outputs = model(x)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * x.size(0)
        preds = torch.argmax(outputs, dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

    avg_loss = running_loss / total
    accuracy = correct / total
    return avg_loss, accuracy

In [None]:
def evaluate(model, loader):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for x, y in tqdm(loader):
            x, y = x.to(device), y.to(device)
            x = torch.nn.functional.interpolate(x, size=(256,256), mode='bilinear')
            x = x.repeat(1,3,1,1)

            outputs = model(x)
            loss = criterion(outputs, y)

            running_loss += loss.item() * x.size(0)
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

    avg_loss = running_loss / total
    accuracy = correct / total
    return avg_loss, accuracy

In [None]:
num_epochs = 50

history = {
    "train_accuracy": [],
    "train_loss": [],
    "test_accuracy": [],
    "test_loss": []
}

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    train_loss, train_acc = train_one_epoch(vit_model, train_loader)
    test_loss, test_acc = evaluate(vit_model, test_loader)

    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")

    history["train_accuracy"].append(train_acc)
    history["train_loss"].append(train_loss)
    history["test_accuracy"].append(test_acc)
    history["test_loss"].append(test_loss)


In [None]:
df = pd.DataFrame(history)
df.to_csv("/content/drive/MyDrive/SUFHSDB/vit_with_specaugment.csv", index=False)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.metrics import confusion_matrix

In [None]:
vit_model.eval()
y_true = []
y_pred = []

with torch.no_grad():
    for x, y in tqdm(dataloader, desc="Evaluating"):
        x, y = x.to(device), y.to(device)

        x = torch.nn.functional.interpolate(x, size=(128,128), mode='bilinear')
        x = x.repeat(1,3,1,1)

        outputs =vit_model(x)
        preds = torch.argmax(outputs, dim=1)

        y_true.extend(y.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())

In [None]:
train_acc = df["train_accuracy"]
val_acc = df["test_accuracy"]
train_loss = df["train_loss"]
val_loss = df["test_loss"]

fig, axs = plt.subplots(1, 3, figsize=(12,4))

axs[0].plot(epochs, train_acc, label='Train Accuracy')
axs[0].plot(epochs, val_acc, label='Validation Accuracy')
axs[0].set_xlabel("Number of Epochs", weight='bold')
axs[0].set_ylabel("Accuracy", weight='bold')
axs[0].set_title("CaiT Model Accuracy")
axs[0].legend()
axs[0].grid(True)

axs[1].plot(epochs, train_loss, label='Train Loss')
axs[1].plot(epochs, val_loss, label='Validation Loss')
axs[1].set_xlabel("Number of Epochs", weight='bold')
axs[1].set_ylabel("Loss", weight='bold')
axs[1].set_title("Model Loss")
axs[1].legend()
axs[1].grid(True)


cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Purples',
            xticklabels=['Fetus', 'Mother'],
            yticklabels=['Fetus', 'Mother'],
            ax=axs[2])

axs[2].set_xlabel("Predicted labels", weight='bold')
axs[2].set_ylabel("True labels", weight='bold')
axs[2].set_title("Confusion Matrix")

plt.tight_layout()
plt.show()

### Deep-ViT

In [None]:
from vit_pytorch.deepvit import DeepViT

vit_model = DeepViT(
    image_size = 128,
    patch_size = 32,
    num_classes = 2,
    dim = 256,
    depth = 6,
    heads = 14,
    mlp_dim = 128,
    dropout = 0.3,
)

### CaiT

In [None]:
from vit_pytorch.cait import CaiT

vit_model = CaiT(
    image_size = 128,
    patch_size = 32,
    num_classes = 2,
    dim = 1024,
    depth = 12,
    cls_depth = 2,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.05,
)

### PiT

In [None]:
from vit_pytorch.pit import PiT

v = PiT(
    image_size = 128,
    patch_size = 32,
    dim = 1024,
    num_classes = 2,
    depth = (3, 3, 3),
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1
)

### Distillation Student Model

In [None]:
!pip install efficientnet_pytorch

In [None]:
from efficientnet_pytorch import EfficientNet

class StudentModel(nn.Module):
    def __init__(self, num_classes=2, proj_dim=256):
        super().__init__()

        effnet = EfficientNet.from_pretrained('efficientnet-b0')

        self.stem = effnet._conv_stem
        self.bn0 = effnet._bn0
        self.blocks = nn.Sequential(*effnet._blocks[:3])

        self.pool = nn.AdaptiveAvgPool2d(1)

        self.classifier = nn.Linear(effnet._blocks[2]._project_conv.out_channels, num_classes)

        self.proj = nn.Linear(effnet._blocks[2]._project_conv.out_channels, proj_dim)

    def forward(self, x):
        x = self.stem(x)
        x = self.bn0(x)
        x = self.blocks(x)
        pooled = self.pool(x).flatten(1)
        logits = self.classifier(pooled)
        features = self.proj(pooled)
        return logits, features


In [None]:
import torch
import torch.nn as nn
from vit_pytorch import DeepViT

class TeacherDeepViT(nn.Module):
    def __init__(self, num_classes=2, proj_dim=256):
        super().__init__()
        self.deepvit = DeepViT(
            image_size=128,
            patch_size=32,
            num_classes=num_classes,
            dim=1024,
            depth=6,
            heads=14,
            mlp_dim=2048
        )
        self.deepvit.to_logits = nn.Identity()
        self.projector = nn.Linear(1024, proj_dim)
        self.classifier = nn.Linear(proj_dim, num_classes)

    def forward(self, x):
        feat = self.deepvit(x)
        proj_feat = self.projector(feat)
        logits = self.classifier(proj_feat)
        return logits, proj_feat


In [None]:
def combined_loss(sm_logits, sm_feat, tm_logits, tm_pos_feat, tm_neg_feat, alpha=0.2, T=2.0):
    sm_log_probs = F.log_softmax(sm_logits, dim=1)
    tm_probs = F.softmax(tm_logits / T, dim=1)
    kl_loss = F.kl_div(sm_log_probs, tm_probs, reduction='batchmean')

    d_pos = F.pairwise_distance(sm_feat, tm_pos_feat)
    d_neg = F.pairwise_distance(sm_feat, tm_neg_feat)
    triplet_loss = torch.clamp(d_pos - d_neg + alpha, min=0.0).mean()

    return kl_loss + triplet_loss, kl_loss.item(), triplet_loss.item()


In [None]:
optimizer = optim.Adam(student_model.parameters(), lr=3e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-5)

In [None]:
tm_model = TeacherDeepViT()
tm_model.eval()
tm_model.to(device)

student_model = StudentModel().to(device)

In [None]:
from glob import glob

class TripletMelDataset(Dataset):
    def __init__(self, root_dir):
        self.anchor_samples = []
        self.class_to_paths = {'mother': [], 'fetus': []}

        for label in ['mother', 'fetus']:
            paths = glob(os.path.join(root_dir, label, "*.npy"))
            self.class_to_paths[label].extend(paths)
            for p in paths:
                self.anchor_samples.append( (p, label) )

    def __len__(self):
        return len(self.anchor_samples)

    def __getitem__(self, idx):
        anchor_path, anchor_class = self.anchor_samples[idx]
        other_class = 'fetus' if anchor_class == 'mother' else 'mother'

        anchor = np.load(anchor_path)
        anchor = torch.tensor(anchor, dtype=torch.float32).unsqueeze(0)
        anchor_label = 0 if anchor_class == 'mother' else 1

        pos_path = random.choice(self.class_to_paths[anchor_class])
        pos = np.load(pos_path)
        pos = torch.tensor(pos, dtype=torch.float32).unsqueeze(0)

        neg_path = random.choice(self.class_to_paths[other_class])
        neg = np.load(neg_path)
        neg = torch.tensor(neg, dtype=torch.float32).unsqueeze(0)

        return anchor, anchor_label, pos, neg


In [None]:
train_ds = TripletMelDataset('/content/drive/MyDrive/SUFHSDB/training_data')
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_ds = TripletMelDataset('/content/drive/MyDrive/SUFHSDB/testing_data')
test_loader = DataLoader(test_ds, batch_size=64, shuffle=False)

In [None]:
for epoch in range(50):
    student_model.train()
    running_loss = 0.0
    running_kl = 0.0
    running_triplet = 0.0

    for anchor, anchor_label, pos, neg in tqdm(train_loader):
        anchor = anchor.to(device)
        anchor_label = anchor_label.to(device)
        pos = pos.to(device)
        neg = neg.to(device)

        def resize(x):
            x = torch.nn.functional.interpolate(x, size=(256,256), mode='bilinear')
            x = x.repeat(1,3,1,1)
            return x
        anchor = resize(anchor)
        pos = resize(pos)
        neg = resize(neg)

        sm_logits, sm_feat = student_model(anchor)
        with torch.no_grad():
            tm_logits_pos, tm_pos_feat = teacher_model(pos)
            _, tm_neg_feat = teacher_model(neg)

        loss, kl, triplet = combined_loss(sm_logits, sm_feat, tm_logits_pos, tm_pos_feat, tm_neg_feat)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_kl += kl
        running_triplet += triplet

    scheduler.step()

    print(f"Epoch {epoch+1}: Loss={running_loss:.4f} KL={running_kl:.4f} Triplet={running_triplet:.4f}")