In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttnBlock(nn.Module):
    """
    One block of cross-attention:
      - text queries {audio+vision}
      - audio queries {text+vision}
      - vision queries {text+audio}
    Then residual + FFN for each branch.
    Shapes (batch_first=True): (B, L, d_model)
    """
    def __init__(self, d_model=512, nhead=8, ffn_mult=4, dropout=0.1):
        super().__init__()
        self.mha_t = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.mha_a = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.mha_v = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)

        self.ln_t1 = nn.LayerNorm(d_model); self.ln_a1 = nn.LayerNorm(d_model); self.ln_v1 = nn.LayerNorm(d_model)
        self.ln_t2 = nn.LayerNorm(d_model); self.ln_a2 = nn.LayerNorm(d_model); self.ln_v2 = nn.LayerNorm(d_model)

        hidden = d_model * ffn_mult
        self.ffn_t = nn.Sequential(nn.Linear(d_model, hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden, d_model))
        self.ffn_a = nn.Sequential(nn.Linear(d_model, hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden, d_model))
        self.ffn_v = nn.Sequential(nn.Linear(d_model, hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden, d_model))

        self.drop = nn.Dropout(dropout)

    def forward(self, t, a, v, pad_t=None, pad_a=None, pad_v=None):
        """
        t,a,v : (B, Lt/La/Lv, d_model)
        pad_* : (B, L) boolean masks where True = PAD (will be ignored)
        """
        # --- attention masks: need [B, L_kv] -> [B * num_heads, L_q, L_kv] handled internally by PyTorch via key_padding_mask
        # Text queries others
        kv_t = torch.cat([a, v], dim=1)
        kpm_t = None
        if (pad_a is not None) or (pad_v is not None):
            pad_a = pad_a if pad_a is not None else torch.zeros(a.size()[:2], dtype=torch.bool, device=a.device)
            pad_v = pad_v if pad_v is not None else torch.zeros(v.size()[:2], dtype=torch.bool, device=v.device)
            kpm_t = torch.cat([pad_a, pad_v], dim=1)
        t2, _ = self.mha_t(query=t, key=kv_t, value=kv_t, key_padding_mask=kpm_t)
        t = self.ln_t1(t + self.drop(t2))
        t = self.ln_t2(t + self.drop(self.ffn_t(t)))

        # Audio queries others
        kv_a = torch.cat([t, v], dim=1)
        kpm_a = None
        if (pad_t is not None) or (pad_v is not None):
            pad_t = pad_t if pad_t is not None else torch.zeros(t.size()[:2], dtype=torch.bool, device=t.device)
            pad_v = pad_v if pad_v is not None else torch.zeros(v.size()[:2], dtype=torch.bool, device=v.device)
            kpm_a = torch.cat([pad_t, pad_v], dim=1)
        a2, _ = self.mha_a(query=a, key=kv_a, value=kv_a, key_padding_mask=kpm_a)
        a = self.ln_a1(a + self.drop(a2))
        a = self.ln_a2(a + self.drop(self.ffn_a(a)))

        # Vision queries others
        kv_v = torch.cat([t, a], dim=1)
        kpm_v = None
        if (pad_t is not None) or (pad_a is not None):
            pad_t = pad_t if pad_t is not None else torch.zeros(t.size()[:2], dtype=torch.bool, device=t.device)
            pad_a = pad_a if pad_a is not None else torch.zeros(a.size()[:2], dtype=torch.bool, device=a.device)
            kpm_v = torch.cat([pad_t, pad_a], dim=1)
        v2, _ = self.mha_v(query=v, key=kv_v, value=kv_v, key_padding_mask=kpm_v)
        v = self.ln_v1(v + self.drop(v2))
        v = self.ln_v2(v + self.drop(self.ffn_v(v)))

        return t, a, v


In [2]:
class TriModalCrossAttnClassifier(nn.Module):
    """
    Project (BERT, HuBERT, CLIP) to a shared space and run N cross-attn blocks.
    Accepts variable sequence lengths with padding masks.
    """
    def __init__(self, d_text, d_audio, d_vision, d_model=512, nhead=8, depth=2, num_classes=2, dropout=0.1, pool="cls_mean"):
        super().__init__()
        self.proj_t = nn.Linear(d_text, d_model)
        self.proj_a = nn.Linear(d_audio, d_model)
        self.proj_v = nn.Linear(d_vision, d_model)

        self.blocks = nn.ModuleList([CrossAttnBlock(d_model, nhead, dropout=dropout) for _ in range(depth)])

        self.pool = pool  # "cls_mean" or "mean"
        # optional learned CLS tokens for each modality
        self.cls_t = nn.Parameter(torch.randn(1, 1, d_model))
        self.cls_a = nn.Parameter(torch.randn(1, 1, d_model))
        self.cls_v = nn.Parameter(torch.randn(1, 1, d_model))

        self.norm = nn.LayerNorm(d_model * 3)
        self.head = nn.Sequential(
            nn.Linear(d_model * 3, d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, num_classes)
        )

    def _prepend_cls(self, x, cls, pad=None):
        B = x.size(0)
        cls_exp = cls.expand(B, -1, -1)  # (B,1,d)
        x = torch.cat([cls_exp, x], dim=1)
        if pad is not None:
            pad = torch.cat([torch.zeros(B,1, dtype=torch.bool, device=x.device), pad], dim=1)
        return x, pad

    def forward(self, text_emb, audio_emb, vision_emb,
                pad_t=None, pad_a=None, pad_v=None):
        """
        text_emb  : (B, Lt, d_text)  BERT token embeddings or last hidden states
        audio_emb : (B, La, d_audio) HuBERT frame features
        vision_emb: (B, Lv, d_vision) CLIP patch/token embeddings (or sequence of MFCC patches)
        pad_*     : (B, L) boolean masks (True for PAD). Optional.
        """
        t = self.proj_t(text_emb); a = self.proj_a(audio_emb); v = self.proj_v(vision_emb)

        # add CLS tokens for stable pooling
        t, pad_t = self._prepend_cls(t, self.cls_t, pad_t)
        a, pad_a = self._prepend_cls(a, self.cls_a, pad_a)
        v, pad_v = self._prepend_cls(v, self.cls_v, pad_v)

        for blk in self.blocks:
            t, a, v = blk(t, a, v, pad_t, pad_a, pad_v)

        # pool
        if self.pool == "cls_mean":
            ft = t[:, 0]  # CLS
            fa = a[:, 0]
            fv = v[:, 0]
        else:  # mean over real tokens (ignoring PAD)
            def masked_mean(x, pad):
                if pad is None: return x.mean(dim=1)
                w = (~pad).float().unsqueeze(-1)  # (B,L,1)
                return (x * w).sum(dim=1) / (w.sum(dim=1).clamp_min(1.0))
            ft = masked_mean(t, pad_t); fa = masked_mean(a, pad_a); fv = masked_mean(v, pad_v)

        z = torch.cat([ft, fa, fv], dim=-1)
        z = self.norm(z)
        logits = self.head(z)
        return logits, (ft, fa, fv), z  # return fused and per-modality reps too


I need to write this code for PD ReadText, HC Spontaneous and others

# File paths

In [3]:
import os
import numpy as np
import torch

# --- Folder paths ---
BERT_DIR   = "/mnt/d/Roshidat_Msc_Project/Audio_parkinson/pd&Hc_multi/Bert embedings/HC_Spontaneous_berts_feats_tokens_only"
HUBERT_DIR = "/mnt/d/Roshidat_Msc_Project/Audio_parkinson/pd&Hc_multi/Hubert embeddings/HC_Spontaneous_hubert_features"
CLIP_FEATS = "/mnt/d/Roshidat_Msc_Project/Audio_parkinson/pd&Hc_multi/Clip embeddings/HC_Spontaneous_Spectrogram_CLIP_features.npy"
CLIP_NAMES = "/mnt/d/Roshidat_Msc_Project/Audio_parkinson/pd&Hc_multi/Clip embeddings/HC_Spontaneous_Spectrogram_CLIP_filenames.txt"


# Load CLIP embeddings and filenames

In [4]:
# CLIP features: one large array (N, D)
clip_features = np.load(CLIP_FEATS)  # shape (num_samples, clip_dim)
with open(CLIP_NAMES, "r") as f:
    clip_filenames = [line.strip() for line in f.readlines()]

# create a mapping {IDxx: feature_row}
def extract_id(path):
    base = os.path.basename(path)
    return base.split("_")[0]  # e.g., "ID00"
    
clip_dict = {extract_id(fname): clip_features[i] for i, fname in enumerate(clip_filenames)}
print(f"Loaded {len(clip_dict)} CLIP embeddings.")


Loaded 21 CLIP embeddings.


# Load the BERT and HuBERT embeddings from folders

In [6]:
def get_id(filename):
    """
    Extracts the sample ID from filenames like:
    'ID00_hc_0_0_0s_tokens_for_selfattn.npz' → 'ID00'
    """
    base = os.path.basename(filename)
    return base.split("_")[0]


In [7]:
def load_np(path):
    try:
        data = np.load(path, allow_pickle=True)
    except Exception as e:
        print(f"Failed to load {path}: {e}")
        return None
    if isinstance(data, np.lib.npyio.NpzFile):
        key = list(data.keys())[0]
        data = data[key]
    return data

# Build dicts robustly (accept .npy and .npz, skip non-files, handle load errors).
# Reuse get_id defined earlier in the notebook (do not redefine it here).
bert_dict = {}
for f in os.listdir(BERT_DIR):
    p = os.path.join(BERT_DIR, f)
    if not os.path.isfile(p):
        continue
    if not (f.endswith(".npy") or f.endswith(".npz")):
        continue
    arr = load_np(p)
    if arr is None:
        continue
    bert_dict[get_id(f)] = arr

hubert_dict = {}
for f in os.listdir(HUBERT_DIR):
    p = os.path.join(HUBERT_DIR, f)
    if not os.path.isfile(p):
        continue
    if not (f.endswith(".npy") or f.endswith(".npz")):
        continue
    arr = load_np(p)
    if arr is None:
        continue
    hubert_dict[get_id(f)] = arr

common_ids = sorted(set(bert_dict) & set(hubert_dict) & set(clip_dict))
print(f"Found {len(common_ids)} samples across all three modalities: {common_ids[:20]}")


Found 21 samples across all three modalities: ['ID00', 'ID01', 'ID03', 'ID05', 'ID08', 'ID09', 'ID10', 'ID11', 'ID12', 'ID14', 'ID15', 'ID19', 'ID21', 'ID22hc', 'ID23', 'ID25', 'ID26', 'ID28', 'ID31', 'ID35']


# Convert to tensors and pad for Cross-Attention input

In [8]:
import torch
import torch.nn.functional as F

def pad_to_len(x, max_len=256):
    # Accept numpy arrays or torch tensors, and handle 1D/2D/ND inputs.
    # If input has more than 2 dims, collapse leading dims into the time dimension.
    if not torch.is_tensor(x):
        x = torch.tensor(x, dtype=torch.float32)
    else:
        x = x.float()
    if x.dim() == 1:  # 1D (CLIP)
        x = x.unsqueeze(0)  # add time dim
    elif x.dim() > 2:
        # flatten leading dims into a single time dimension: (..., L, D) -> (L_total, D)
        x = x.reshape(-1, x.size(-1))
    L, D = x.shape
    if L >= max_len:
        return x[:max_len]
    pad = torch.zeros(max_len - L, D, dtype=x.dtype, device=x.device)
    return torch.cat([x, pad], dim=0)

text_embs, audio_embs, vision_embs, labels = [], [], [], []

for sid in common_ids:
    t = bert_dict[sid]
    a = hubert_dict[sid]
    v = clip_dict[sid]

    # convert to tensors (if needed) and normalize along feature dim
    t = torch.tensor(t, dtype=torch.float32) if not torch.is_tensor(t) else t.float()
    a = torch.tensor(a, dtype=torch.float32) if not torch.is_tensor(a) else a.float()
    v = torch.tensor(v, dtype=torch.float32) if not torch.is_tensor(v) else v.float()

    t = F.normalize(t, dim=-1)
    a = F.normalize(a, dim=-1)
    v = F.normalize(v, dim=-1)

    # pad (only for sequence-based ones)
    t = pad_to_len(t)
    a = pad_to_len(a)
    v = pad_to_len(v)  # will just add one row if 1D

    text_embs.append(t.unsqueeze(0))
    audio_embs.append(a.unsqueeze(0))
    vision_embs.append(v.unsqueeze(0))
    labels.append(0 if "hc" in sid.lower() else 1)

text_embs  = torch.cat(text_embs, dim=0)
audio_embs = torch.cat(audio_embs, dim=0)
vision_embs= torch.cat(vision_embs, dim=0)
labels = torch.tensor(labels)


In [9]:
class FusionTokenCrossAttn(nn.Module):
    def __init__(self, d_text, d_audio, d_vision, d_model=512, nhead=8, depth=2, num_classes=2, dropout=0.1):
        super().__init__()
        self.proj_t = nn.Linear(d_text, d_model)
        self.proj_a = nn.Linear(d_audio, d_model)
        self.proj_v = nn.Linear(d_vision, d_model)
        self.fuse_token = nn.Parameter(torch.randn(1, 1, d_model))

        self.blocks = nn.ModuleList([
            nn.ModuleDict({
                "mha": nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True),
                "ln1": nn.LayerNorm(d_model),
                "ffn": nn.Sequential(nn.Linear(d_model, d_model*4), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model*4, d_model)),
                "ln2": nn.LayerNorm(d_model),
            }) for _ in range(depth)
        ])
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, num_classes))

    def forward(self, t, a, v, pad_t=None, pad_a=None, pad_v=None):
        t = self.proj_t(t); a = self.proj_a(a); v = self.proj_v(v)
        kv = torch.cat([t, a, v], dim=1)
        if any(m is not None for m in (pad_t, pad_a, pad_v)):
            pad_t = pad_t if pad_t is not None else torch.zeros(t.size()[:2], dtype=torch.bool, device=t.device)
            pad_a = pad_a if pad_a is not None else torch.zeros(a.size()[:2], dtype=torch.bool, device=a.device)
            pad_v = pad_v if pad_v is not None else torch.zeros(v.size()[:2], dtype=torch.bool, device=v.device)
            kpm = torch.cat([pad_t, pad_a, pad_v], dim=1)
        else:
            kpm = None

        B = t.size(0)
        q = self.fuse_token.expand(B, 1, -1)  # (B,1,d)

        for blk in self.blocks:
            attn_out, _ = blk["mha"](q, kv, kv, key_padding_mask=kpm)
            q = blk["ln1"](q + attn_out)
            q = blk["ln2"](q + blk["ffn"](q))

        logits = self.head(q.squeeze(1))
        return logits, q.squeeze(1)


In [10]:
from torch import nn

# Assuming you have the FusionTokenCrossAttn class defined (from earlier)
model = FusionTokenCrossAttn(
    d_text=text_embs.size(-1),
    d_audio=audio_embs.size(-1),
    d_vision=vision_embs.size(-1),
    d_model=512,
    nhead=8,
    depth=2,
    num_classes=2
)

logits, fused = model(text_embs, audio_embs, vision_embs)
preds = torch.argmax(logits, dim=-1)

print("Predictions:", preds)
print("Shape of fused embedding:", fused.shape)


Predictions: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
Shape of fused embedding: torch.Size([21, 512])


In [11]:
import torch.nn.functional as F

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
loss = F.cross_entropy(logits, labels)
loss.backward()
optimizer.step()


In [14]:
import numpy as np
import torch, os

# Suppose you already have:
# fused: (B, 512)  logits: (B, 2)  labels: (B,)
# common_ids: list[str] length B, in the SAME order as the batch

save_dir = "/mnt/d/Roshidat_Msc_Project/Audio_parkinson/pd&Hc_multi/Fused_embeddings/"

# ensure target directory exists
os.makedirs(save_dir, exist_ok=True)

# Save NumPy compressed file (ids as ndarray)
np.savez_compressed(
    os.path.join(save_dir, "HC_Spontaneous_fused_embeddings.npz"),
    ids=np.array(common_ids, dtype='<U'),  # store as fixed-length unicode array
    fused=fused.detach().cpu().numpy(),
    labels=labels.detach().cpu().numpy()
)

# (optional) also save as torch .pt — store ids as numpy array for consistency
torch.save(
    {"ids": np.array(common_ids, dtype='<U'),
     "fused": fused.detach().cpu(),
     "labels": labels.detach().cpu()},
    os.path.join(save_dir, "HC_Spontaneous_fused_embeddings.pt")
)

print("Saved:",
      os.path.join(save_dir, "HC_Spontaneous_fused_embeddings.npz"),
      "and .pt")


Saved: /mnt/d/Roshidat_Msc_Project/Audio_parkinson/pd&Hc_multi/Fused_embeddings/HC_Spontaneous_fused_embeddings.npz and .pt


# HC ReadText Fusion

In [35]:
import os
import numpy as np
import torch

# --- Folder paths ---
HC_BERT_DIR   = "/mnt/d/Roshidat_Msc_Project/Audio_parkinson/pd&Hc_multi/Bert embedings/HC_transcript_berts_feats_tokens_only"
HC_HUBERT_DIR = "/mnt/d/Roshidat_Msc_Project/Audio_parkinson/pd&Hc_multi/Hubert embeddings/HC_ReadText_hubert_features"
HC_CLIP_FEATS = "/mnt/d/Roshidat_Msc_Project/Audio_parkinson/pd&Hc_multi/Clip embeddings/HC_ReadText_Spectrogram_CLIP_features.npy"
HC_CLIP_NAMES = "/mnt/d/Roshidat_Msc_Project/Audio_parkinson/pd&Hc_multi/Clip embeddings/HC_ReadText_Spectrogram_CLIP_filenames.txt"


In [36]:
# CLIP features: one large array (N, D)
hc_clip_features = np.load(HC_CLIP_FEATS)  # shape (num_samples, clip_dim)
with open(HC_CLIP_NAMES, "r") as f:
    hc_clip_filenames = [line.strip() for line in f.readlines()]

# create a mapping {IDxx: feature_row}
def extract_id(path):
    hc_base = os.path.basename(path)
    return hc_base.split("_")[0]  # e.g., "ID00"
    
hc_clip_dict = {extract_id(fname): hc_clip_features[i] for i, fname in enumerate(hc_clip_filenames)}
print(f"Loaded {len(hc_clip_dict)} CLIP embeddings.")


Loaded 21 CLIP embeddings.


In [37]:
def get_id(filename):
    """
    Extracts the sample ID from filenames like:
    'ID00_hc_0_0_0s_tokens_for_selfattn.npz' → 'ID00'
    """
    hc_base = os.path.basename(filename)
    return hc_base.split("_")[0]


In [38]:
def load_np(path):
    try:
        data = np.load(path, allow_pickle=True)
    except Exception as e:
        print(f"Failed to load {path}: {e}")
        return None
    if isinstance(data, np.lib.npyio.NpzFile):
        key = list(data.keys())[0]
        data = data[key]
    return data

# Build dicts robustly (accept .npy and .npz, skip non-files, handle load errors).
# Reuse get_id defined earlier in the notebook (do not redefine it here).
hc_bert_dict = {}
for f in os.listdir(HC_BERT_DIR):
    p = os.path.join(HC_BERT_DIR, f)
    if not os.path.isfile(p):
        continue
    if not (f.endswith(".npy") or f.endswith(".npz")):
        continue
    arr = load_np(p)
    if arr is None:
        continue
    hc_bert_dict[get_id(f)] = arr

hc_hubert_dict = {}
for f in os.listdir(HC_HUBERT_DIR):
    p = os.path.join(HC_HUBERT_DIR, f)
    if not os.path.isfile(p):
        continue
    if not (f.endswith(".npy") or f.endswith(".npz")):
        continue
    arr = load_np(p)
    if arr is None:
        continue
    hc_hubert_dict[get_id(f)] = arr

hc_common_ids = sorted(set(hc_bert_dict) & set(hc_hubert_dict) & set(hc_clip_dict))
print(f"Found {len(hc_common_ids)} samples across all three modalities: {hc_common_ids[:20]}")


Found 21 samples across all three modalities: ['ID00', 'ID01', 'ID03', 'ID05', 'ID08', 'ID09', 'ID10', 'ID11', 'ID12', 'ID14', 'ID15', 'ID19', 'ID21', 'ID22', 'ID23', 'ID25', 'ID26', 'ID28', 'ID31', 'ID35']


In [39]:
import torch
import torch.nn.functional as F

def pad_to_len(x, max_len=256):
    # Accept numpy arrays or torch tensors, and handle 1D/2D/ND inputs.
    # If input has more than 2 dims, collapse leading dims into the time dimension.
    if not torch.is_tensor(x):
        x = torch.tensor(x, dtype=torch.float32)
    else:
        x = x.float()
    if x.dim() == 1:  # 1D (CLIP)
        x = x.unsqueeze(0)  # add time dim
    elif x.dim() > 2:
        # flatten leading dims into a single time dimension: (..., L, D) -> (L_total, D)
        x = x.reshape(-1, x.size(-1))
    L, D = x.shape
    if L >= max_len:
        return x[:max_len]
    pad = torch.zeros(max_len - L, D, dtype=x.dtype, device=x.device)
    return torch.cat([x, pad], dim=0)

hc_text_embs, hc_audio_embs, hc_vision_embs, hc_labels = [], [], [], []

# Use the HC-specific common ids (hc_common_ids) rather than common_ids from the other dataset.
for sid in hc_common_ids:
    t = hc_bert_dict[sid]
    a = hc_hubert_dict[sid]
    v = hc_clip_dict[sid]

    # convert to tensors (if needed) and normalize along feature dim
    t = torch.tensor(t, dtype=torch.float32) if not torch.is_tensor(t) else t.float()
    a = torch.tensor(a, dtype=torch.float32) if not torch.is_tensor(a) else a.float()
    v = torch.tensor(v, dtype=torch.float32) if not torch.is_tensor(v) else v.float()

    t = F.normalize(t, dim=-1)
    a = F.normalize(a, dim=-1)
    v = F.normalize(v, dim=-1)

    # pad (only for sequence-based ones)
    t = pad_to_len(t)
    a = pad_to_len(a)
    v = pad_to_len(v)  # will just add one row if 1D

    hc_text_embs.append(t.unsqueeze(0))
    hc_audio_embs.append(a.unsqueeze(0))
    hc_vision_embs.append(v.unsqueeze(0))
    hc_labels.append(0 if "hc" in sid.lower() else 1)

hc_text_embs  = torch.cat(hc_text_embs, dim=0)
hc_audio_embs = torch.cat(hc_audio_embs, dim=0)
hc_vision_embs= torch.cat(hc_vision_embs, dim=0)
hc_labels = torch.tensor(hc_labels)


In [40]:
class FusionTokenCrossAttn(nn.Module):
    def __init__(self, hc_d_text, hc_d_audio, hc_d_vision, d_model=512, nhead=8, depth=2, num_classes=2, dropout=0.1):
        super().__init__()
        self.proj_t = nn.Linear(hc_d_text, d_model)
        self.proj_a = nn.Linear(hc_d_audio, d_model)
        self.proj_v = nn.Linear(hc_d_vision, d_model)
        self.fuse_token = nn.Parameter(torch.randn(1, 1, d_model))

        self.blocks = nn.ModuleList([
            nn.ModuleDict({
                "mha": nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True),
                "ln1": nn.LayerNorm(d_model),
                "ffn": nn.Sequential(nn.Linear(d_model, d_model*4), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_model*4, d_model)),
                "ln2": nn.LayerNorm(d_model),
            }) for _ in range(depth)
        ])
        self.head = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, num_classes))

    def forward(self, t, a, v, pad_t=None, pad_a=None, pad_v=None):
        t = self.proj_t(t); a = self.proj_a(a); v = self.proj_v(v)
        kv = torch.cat([t, a, v], dim=1)
        if any(m is not None for m in (pad_t, pad_a, pad_v)):
            pad_t = pad_t if pad_t is not None else torch.zeros(t.size()[:2], dtype=torch.bool, device=t.device)
            pad_a = pad_a if pad_a is not None else torch.zeros(a.size()[:2], dtype=torch.bool, device=a.device)
            pad_v = pad_v if pad_v is not None else torch.zeros(v.size()[:2], dtype=torch.bool, device=v.device)
            kpm = torch.cat([pad_t, pad_a, pad_v], dim=1)
        else:
            kpm = None

        B = t.size(0)
        q = self.fuse_token.expand(B, 1, -1)  # (B,1,d)

        for blk in self.blocks:
            attn_out, _ = blk["mha"](q, kv, kv, key_padding_mask=kpm)
            q = blk["ln1"](q + attn_out)
            q = blk["ln2"](q + blk["ffn"](q))

        logits = self.head(q.squeeze(1))
        return logits, q.squeeze(1)


In [41]:
from torch import nn

# Assuming you have the FusionTokenCrossAttn class defined (from earlier)
model = FusionTokenCrossAttn(
    hc_d_text=hc_text_embs.size(-1),
    hc_d_audio=hc_audio_embs.size(-1),
    hc_d_vision=hc_vision_embs.size(-1),
    d_model=512,
    nhead=8,
    depth=2,
    num_classes=2
)

hc_logits, fused = model(hc_text_embs, hc_audio_embs, hc_vision_embs)
hc_preds = torch.argmax(hc_logits, dim=-1)

print("Predictions:", hc_preds)
print("Shape of fused embedding:", fused.shape)


Predictions: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Shape of fused embedding: torch.Size([21, 512])


In [42]:
import torch.nn.functional as F

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
hc_loss = F.cross_entropy(hc_logits, hc_labels)
hc_loss.backward()
optimizer.step()


In [43]:
import numpy as np
import torch, os

# Suppose you already have:
# fused: (B, 512)  logits: (B, 2)  labels: (B,)
# common_ids: list[str] length B, in the SAME order as the batch

save_dir = "/mnt/d/Roshidat_Msc_Project/Audio_parkinson/pd&Hc_multi/HC_ReadText_Fused_embeddings/"

# ensure target directory exists
os.makedirs(save_dir, exist_ok=True)

# Save NumPy compressed file (ids as ndarray)
np.savez_compressed(
    os.path.join(save_dir, "HC_ReadText_fused_embeddings.npz"),
    ids=np.array(common_ids, dtype='<U'),  # store as fixed-length unicode array
    fused=fused.detach().cpu().numpy(),
    hc_labels=hc_labels.detach().cpu().numpy()
)

# (optional) also save as torch .pt — store ids as numpy array for consistency
torch.save(
    {"ids": np.array(common_ids, dtype='<U'),
     "fused": fused.detach().cpu(),
     "hc_labels": hc_labels.detach().cpu()},
    os.path.join(save_dir, "HC_ReadText_fused_embeddings.pt")
)

print("Saved:",
      os.path.join(save_dir, "HC_ReadText_fused_embeddings.npz"),
      "and .pt")


Saved: /mnt/d/Roshidat_Msc_Project/Audio_parkinson/pd&Hc_multi/HC_ReadText_Fused_embeddings/HC_ReadText_fused_embeddings.npz and .pt


# PD Spontaneous fusion


In [2]:
import os, re
import numpy as np
import torch
import torch.nn.functional as F

ID_RE = re.compile(r'(ID\d{2,})', re.IGNORECASE)

def extract_id_any(s: str) -> str:
    """Return ID like 'ID00' from any filename/path string."""
    base = os.path.basename(s)
    m = ID_RE.search(base)
    if not m:
        raise ValueError(f"Could not extract ID from: {s}")
    return m.group(1).upper()


In [3]:
def load_np(path: str):
    try:
        data = np.load(path, allow_pickle=True)
    except Exception as e:
        print(f"[skip] {path}: {e}")
        return None
    if isinstance(data, np.lib.npyio.NpzFile):
        # pick the first numeric array
        for k in data.files:
            arr = data[k]
            if isinstance(arr, np.ndarray) and np.issubdtype(arr.dtype, np.number):
                return arr
        return None
    return data  # .npy


In [5]:
def build_dict_from_dir(root: str):
    d = {}
    for f in os.listdir(root):
        p = os.path.join(root, f)
        if not os.path.isfile(p): 
            continue
        if not (f.endswith(".npy") or f.endswith(".npz")):
            continue
        arr = load_np(p)
        if arr is None: 
            continue
        sid = extract_id_any(f)
        d[sid] = arr
    return d

# --- Directories (your PD Spont example) ---
PD_BERT_DIR   = "/home/jovyan/Desktop/PD_LLM/codes/pd&Hc_multi/Bert embedings/PD_Spontaneous_berts_feats_tokens_only"
PD_HUBERT_DIR = "/home/jovyan/Desktop/PD_LLM/codes/pd&Hc_multi/Hubert embeddings/PD_Spontaneous_hubert_features"
PD_CLIP_FEATS = "//home/jovyan/Desktop/PD_LLM/codes/pd&Hc_multi/Clip embeddings/PD_Spontaneous_Spectrogram_CLIP_features.npy"
PD_CLIP_NAMES = "/home/jovyan/Desktop/PD_LLM/codes/pd&Hc_multi/Clip embeddings/PD_Spontaneous_Spectrogram_CLIP_filenames.txt"

# CLIP (names + big matrix)
clip_feats = np.load(PD_CLIP_FEATS)  # [N, D]
with open(PD_CLIP_NAMES, "r") as f:
    clip_names = [line.strip() for line in f]

clip_dict = {extract_id_any(name): clip_feats[i] for i, name in enumerate(clip_names)}
bert_dict = build_dict_from_dir(PD_BERT_DIR)
hub_dict  = build_dict_from_dir(PD_HUBERT_DIR)

common_ids = sorted(set(bert_dict) & set(hub_dict) & set(clip_dict))
print(f"[info] PD Spont common: {len(common_ids)} ids")


[info] PD Spont common: 15 ids


In [6]:
def pad_to_len(x, max_len=256):
    x = torch.as_tensor(x, dtype=torch.float32)
    if x.dim() == 1:  # e.g., CLIP single vector
        x = x.unsqueeze(0)
    elif x.dim() > 2:
        x = x.reshape(-1, x.size(-1))
    L, D = x.shape
    if L >= max_len:
        return x[:max_len]
    return torch.cat([x, torch.zeros(max_len - L, D)], dim=0)

# Build tensors
text_list, audio_list, vision_list, ids_kept, labels = [], [], [], [], []
for sid in common_ids:
    t = F.normalize(torch.as_tensor(bert_dict[sid], dtype=torch.float32), dim=-1)
    a = F.normalize(torch.as_tensor(hub_dict[sid],  dtype=torch.float32), dim=-1)
    v = F.normalize(torch.as_tensor(clip_dict[sid], dtype=torch.float32), dim=-1)

    t = pad_to_len(t)    # [256, d_t]
    a = pad_to_len(a)    # [256, d_a]
    v = pad_to_len(v)    # [256, d_v]

    text_list.append(t.unsqueeze(0))
    audio_list.append(a.unsqueeze(0))
    vision_list.append(v.unsqueeze(0))
    ids_kept.append(sid)

    # For PD set → label 1; for HC set → label 0 (don’t key off string 'hc' in sid)
    labels.append(1)

text = torch.cat(text_list, dim=0)
audio = torch.cat(audio_list, dim=0)
vision= torch.cat(vision_list, dim=0)
labels= torch.tensor(labels, dtype=torch.int64)
print(text.shape, audio.shape, vision.shape, len(ids_kept), labels.shape)


torch.Size([15, 256, 768]) torch.Size([15, 256, 768]) torch.Size([15, 256, 512]) 15 torch.Size([15])


In [9]:
import torch
from torch import nn

class FusionTokenCrossAttn(nn.Module):
    def __init__(self, pd_d_text, pd_d_audio, pd_d_vision,
                 d_model=512, nhead=8, depth=2, num_classes=2, dropout=0.1):
        super().__init__()
        # project each modality to same dimension
        self.proj_t = nn.Linear(pd_d_text, d_model)
        self.proj_a = nn.Linear(pd_d_audio, d_model)
        self.proj_v = nn.Linear(pd_d_vision, d_model)
        self.fuse_token = nn.Parameter(torch.randn(1, 1, d_model))

        self.blocks = nn.ModuleList([
            nn.ModuleDict({
                "mha": nn.MultiheadAttention(d_model, nhead,
                                             dropout=dropout, batch_first=True),
                "ln1": nn.LayerNorm(d_model),
                "ffn": nn.Sequential(
                    nn.Linear(d_model, d_model*4),
                    nn.GELU(),
                    nn.Dropout(dropout),
                    nn.Linear(d_model*4, d_model)
                ),
                "ln2": nn.LayerNorm(d_model),
            })
            for _ in range(depth)
        ])
        self.head = nn.Sequential(nn.LayerNorm(d_model),
                                  nn.Linear(d_model, num_classes))

    def forward(self, t, a, v, pad_t=None, pad_a=None, pad_v=None):
        # project
        t = self.proj_t(t)
        a = self.proj_a(a)
        v = self.proj_v(v)
        kv = torch.cat([t, a, v], dim=1)

        # build mask if any padding masks were passed
        if any(m is not None for m in (pad_t, pad_a, pad_v)):
            pad_t = pad_t if pad_t is not None else torch.zeros(t.size()[:2],
                                                                dtype=torch.bool,
                                                                device=t.device)
            pad_a = pad_a if pad_a is not None else torch.zeros(a.size()[:2],
                                                                dtype=torch.bool,
                                                                device=a.device)
            pad_v = pad_v if pad_v is not None else torch.zeros(v.size()[:2],
                                                                dtype=torch.bool,
                                                                device=v.device)
            kpm = torch.cat([pad_t, pad_a, pad_v], dim=1)
        else:
            kpm = None

        B = t.size(0)
        q = self.fuse_token.expand(B, 1, -1)
        for blk in self.blocks:
            attn_out, _ = blk["mha"](q, kv, kv, key_padding_mask=kpm)
            q = blk["ln1"](q + attn_out)
            q = blk["ln2"](q + blk["ffn"](q))
        logits = self.head(q.squeeze(1))
        return logits, q.squeeze(1)


In [11]:
model = FusionTokenCrossAttn(
    pd_d_text=text.size(-1),
    pd_d_audio=audio.size(-1),
    pd_d_vision=vision.size(-1),
    d_model=512,
    nhead=8,
    depth=2,
    num_classes=2
)
logits, fused = model(text, audio, vision)
print(f"Pred logits shape: {logits.shape}, fused shape: {fused.shape}")


Pred logits shape: torch.Size([15, 2]), fused shape: torch.Size([15, 512])


In [12]:
import numpy as np
out_path = "/home/jovyan/Desktop/PD_LLM/codes/pd&Hc_multi/Fused_embeddings/PD_Spontaneous_fused_embeddings_FIXED.npz"  # <-- set yours

ids_kept = common_ids                    # the exact order you looped
labels   = np.full((len(ids_kept),), 1, dtype=np.int64)   # PD = 1

np.savez_compressed(
    out_path,
    fused    = fused.detach().cpu().numpy(),   # [15, 512]
    ids      = np.array(ids_kept, dtype=object),
    labels   = labels,
    kept_idx = np.arange(len(ids_kept), dtype=np.int64),  # optional but handy
)
# quick sanity check
z = np.load(out_path, allow_pickle=True)
print(z["fused"].shape, len(z["ids"]), z["labels"].shape)
assert z["fused"].shape[0] == len(z["ids"]) == z["labels"].shape[0]


(15, 512) 15 (15,)


# Fusing PD ReadText

In [13]:
def build_dict_from_dir(root: str):
    d = {}
    for f in os.listdir(root):
        p = os.path.join(root, f)
        if not os.path.isfile(p): 
            continue
        if not (f.endswith(".npy") or f.endswith(".npz")):
            continue
        arr = load_np(p)
        if arr is None: 
            continue
        sid = extract_id_any(f)
        d[sid] = arr
    return d

# --- Directories (your PD Spont example) ---
PD_BERT_DIR   = "/home/jovyan/Desktop/PD_LLM/codes/pd&Hc_multi/Bert embedings/PD_ReadText_berts_feats_tokens_only"
PD_HUBERT_DIR = "/home/jovyan/Desktop/PD_LLM/codes/pd&Hc_multi/Hubert embeddings/PD_ReadText_hubert_features"
PD_CLIP_FEATS = "//home/jovyan/Desktop/PD_LLM/codes/pd&Hc_multi/Clip embeddings/PD_ReadText_Spectrogram_CLIP_features.npy"
PD_CLIP_NAMES = "/home/jovyan/Desktop/PD_LLM/codes/pd&Hc_multi/Clip embeddings/PD_ReadText_Spectrogram_CLIP_filenames.txt"

# CLIP (names + big matrix)
clip_feats = np.load(PD_CLIP_FEATS)  # [N, D]
with open(PD_CLIP_NAMES, "r") as f:
    clip_names = [line.strip() for line in f]

clip_dict = {extract_id_any(name): clip_feats[i] for i, name in enumerate(clip_names)}
bert_dict = build_dict_from_dir(PD_BERT_DIR)
hub_dict  = build_dict_from_dir(PD_HUBERT_DIR)

common_ids = sorted(set(bert_dict) & set(hub_dict) & set(clip_dict))
print(f"[info] PD Spont common: {len(common_ids)} ids")


[info] PD Spont common: 16 ids


In [14]:
def pad_to_len(x, max_len=256):
    x = torch.as_tensor(x, dtype=torch.float32)
    if x.dim() == 1:  # e.g., CLIP single vector
        x = x.unsqueeze(0)
    elif x.dim() > 2:
        x = x.reshape(-1, x.size(-1))
    L, D = x.shape
    if L >= max_len:
        return x[:max_len]
    return torch.cat([x, torch.zeros(max_len - L, D)], dim=0)

# Build tensors
text_list, audio_list, vision_list, ids_kept, labels = [], [], [], [], []
for sid in common_ids:
    t = F.normalize(torch.as_tensor(bert_dict[sid], dtype=torch.float32), dim=-1)
    a = F.normalize(torch.as_tensor(hub_dict[sid],  dtype=torch.float32), dim=-1)
    v = F.normalize(torch.as_tensor(clip_dict[sid], dtype=torch.float32), dim=-1)

    t = pad_to_len(t)    # [256, d_t]
    a = pad_to_len(a)    # [256, d_a]
    v = pad_to_len(v)    # [256, d_v]

    text_list.append(t.unsqueeze(0))
    audio_list.append(a.unsqueeze(0))
    vision_list.append(v.unsqueeze(0))
    ids_kept.append(sid)

    # For PD set → label 1; for HC set → label 0 (don’t key off string 'hc' in sid)
    labels.append(1)

text = torch.cat(text_list, dim=0)
audio = torch.cat(audio_list, dim=0)
vision= torch.cat(vision_list, dim=0)
labels= torch.tensor(labels, dtype=torch.int64)
print(text.shape, audio.shape, vision.shape, len(ids_kept), labels.shape)


torch.Size([16, 256, 768]) torch.Size([16, 256, 768]) torch.Size([16, 256, 512]) 16 torch.Size([16])


In [15]:
import torch
from torch import nn

class FusionTokenCrossAttn(nn.Module):
    def __init__(self, pd_d_text, pd_d_audio, pd_d_vision,
                 d_model=512, nhead=8, depth=2, num_classes=2, dropout=0.1):
        super().__init__()
        # project each modality to same dimension
        self.proj_t = nn.Linear(pd_d_text, d_model)
        self.proj_a = nn.Linear(pd_d_audio, d_model)
        self.proj_v = nn.Linear(pd_d_vision, d_model)
        self.fuse_token = nn.Parameter(torch.randn(1, 1, d_model))

        self.blocks = nn.ModuleList([
            nn.ModuleDict({
                "mha": nn.MultiheadAttention(d_model, nhead,
                                             dropout=dropout, batch_first=True),
                "ln1": nn.LayerNorm(d_model),
                "ffn": nn.Sequential(
                    nn.Linear(d_model, d_model*4),
                    nn.GELU(),
                    nn.Dropout(dropout),
                    nn.Linear(d_model*4, d_model)
                ),
                "ln2": nn.LayerNorm(d_model),
            })
            for _ in range(depth)
        ])
        self.head = nn.Sequential(nn.LayerNorm(d_model),
                                  nn.Linear(d_model, num_classes))

    def forward(self, t, a, v, pad_t=None, pad_a=None, pad_v=None):
        # project
        t = self.proj_t(t)
        a = self.proj_a(a)
        v = self.proj_v(v)
        kv = torch.cat([t, a, v], dim=1)

        # build mask if any padding masks were passed
        if any(m is not None for m in (pad_t, pad_a, pad_v)):
            pad_t = pad_t if pad_t is not None else torch.zeros(t.size()[:2],
                                                                dtype=torch.bool,
                                                                device=t.device)
            pad_a = pad_a if pad_a is not None else torch.zeros(a.size()[:2],
                                                                dtype=torch.bool,
                                                                device=a.device)
            pad_v = pad_v if pad_v is not None else torch.zeros(v.size()[:2],
                                                                dtype=torch.bool,
                                                                device=v.device)
            kpm = torch.cat([pad_t, pad_a, pad_v], dim=1)
        else:
            kpm = None

        B = t.size(0)
        q = self.fuse_token.expand(B, 1, -1)
        for blk in self.blocks:
            attn_out, _ = blk["mha"](q, kv, kv, key_padding_mask=kpm)
            q = blk["ln1"](q + attn_out)
            q = blk["ln2"](q + blk["ffn"](q))
        logits = self.head(q.squeeze(1))
        return logits, q.squeeze(1)


In [16]:
model = FusionTokenCrossAttn(
    pd_d_text=text.size(-1),
    pd_d_audio=audio.size(-1),
    pd_d_vision=vision.size(-1),
    d_model=512,
    nhead=8,
    depth=2,
    num_classes=2
)
logits, fused = model(text, audio, vision)
print(f"Pred logits shape: {logits.shape}, fused shape: {fused.shape}")


Pred logits shape: torch.Size([16, 2]), fused shape: torch.Size([16, 512])


In [19]:
import numpy as np
import os

out_path = "/home/jovyan/Desktop/PD_LLM/codes/pd&Hc_multi/Fixed_PD_Fused_embeddings/PD_ReadText_fused_embeddings_FIXED.npz"

ids_kept = common_ids                    # the exact order you looped
labels   = np.full((len(ids_kept),), 1, dtype=np.int64)   # PD = 1

# ensure directory exists
os.makedirs(os.path.dirname(out_path), exist_ok=True)

# prepare arrays for saving (detach and move to cpu first)
fused_np = fused.detach().cpu().numpy()
ids_arr = np.array(ids_kept, dtype='<U')  # store as unicode array
kept_idx = np.arange(len(ids_kept), dtype=np.int64)

np.savez_compressed(
    out_path,
    fused    = fused_np,   # [B, D]
    ids      = ids_arr,
    labels   = labels,
    kept_idx = kept_idx,
)

# quick sanity check
z = np.load(out_path, allow_pickle=True)
print(z["fused"].shape, len(z["ids"]), z["labels"].shape)
assert z["fused"].shape[0] == len(z["ids"]) == z["labels"].shape[0]


(16, 512) 16 (16,)


# HC Spontaneous

In [22]:
def build_dict_from_dir(root: str):
    d = {}
    for f in os.listdir(root):
        p = os.path.join(root, f)
        if not os.path.isfile(p): 
            continue
        if not (f.endswith(".npy") or f.endswith(".npz")):
            continue
        arr = load_np(p)
        if arr is None: 
            continue
        sid = extract_id_any(f)
        d[sid] = arr
    return d

# --- Directories (your PD Spont example) ---
HC_BERT_DIR   = "/home/jovyan/Desktop/PD_LLM/codes/pd&Hc_multi/Bert embedings/HC_Spontaneous_berts_feats_tokens_only"
HC_HUBERT_DIR = "/home/jovyan/Desktop/PD_LLM/codes/pd&Hc_multi/Hubert embeddings/HC_Spontaneous_hubert_features"
HC_CLIP_FEATS = "/home/jovyan/Desktop/PD_LLM/codes/pd&Hc_multi/Clip embeddings/HC_Spontaneous_Spectrogram_CLIP_features.npy"
HC_CLIP_NAMES = "/home/jovyan/Desktop/PD_LLM/codes/pd&Hc_multi/Clip embeddings/HC_Spontaneous_Spectrogram_CLIP_filenames.txt"

# CLIP (names + big matrix)
clip_feats = np.load(HC_CLIP_FEATS)  # [N, D]
with open(HC_CLIP_NAMES, "r") as f:
    clip_names = [line.strip() for line in f]

clip_dict = {extract_id_any(name): clip_feats[i] for i, name in enumerate(clip_names)}
bert_dict = build_dict_from_dir(HC_BERT_DIR)
hub_dict  = build_dict_from_dir(HC_HUBERT_DIR)

common_ids = sorted(set(bert_dict) & set(hub_dict) & set(clip_dict))
print(f"[info] HC Spont common: {len(common_ids)} ids")


[info] HC Spont common: 21 ids


In [23]:
def pad_to_len(x, max_len=256):
    x = torch.as_tensor(x, dtype=torch.float32)
    if x.dim() == 1:  # e.g., CLIP single vector
        x = x.unsqueeze(0)
    elif x.dim() > 2:
        x = x.reshape(-1, x.size(-1))
    L, D = x.shape
    if L >= max_len:
        return x[:max_len]
    return torch.cat([x, torch.zeros(max_len - L, D)], dim=0)

# Build tensors
text_list, audio_list, vision_list, ids_kept, labels = [], [], [], [], []
for sid in common_ids:
    t = F.normalize(torch.as_tensor(bert_dict[sid], dtype=torch.float32), dim=-1)
    a = F.normalize(torch.as_tensor(hub_dict[sid],  dtype=torch.float32), dim=-1)
    v = F.normalize(torch.as_tensor(clip_dict[sid], dtype=torch.float32), dim=-1)

    t = pad_to_len(t)    # [256, d_t]
    a = pad_to_len(a)    # [256, d_a]
    v = pad_to_len(v)    # [256, d_v]

    text_list.append(t.unsqueeze(0))
    audio_list.append(a.unsqueeze(0))
    vision_list.append(v.unsqueeze(0))
    ids_kept.append(sid)

    # For PD set → label 1; for HC set → label 0 (don’t key off string 'hc' in sid)
    labels.append(1)

text = torch.cat(text_list, dim=0)
audio = torch.cat(audio_list, dim=0)
vision= torch.cat(vision_list, dim=0)
labels= torch.tensor(labels, dtype=torch.int64)
print(text.shape, audio.shape, vision.shape, len(ids_kept), labels.shape)


torch.Size([21, 256, 768]) torch.Size([21, 256, 768]) torch.Size([21, 256, 512]) 21 torch.Size([21])


In [24]:
import torch
from torch import nn

class FusionTokenCrossAttn(nn.Module):
    def __init__(self, pd_d_text, pd_d_audio, pd_d_vision,
                 d_model=512, nhead=8, depth=2, num_classes=2, dropout=0.1):
        super().__init__()
        # project each modality to same dimension
        self.proj_t = nn.Linear(pd_d_text, d_model)
        self.proj_a = nn.Linear(pd_d_audio, d_model)
        self.proj_v = nn.Linear(pd_d_vision, d_model)
        self.fuse_token = nn.Parameter(torch.randn(1, 1, d_model))

        self.blocks = nn.ModuleList([
            nn.ModuleDict({
                "mha": nn.MultiheadAttention(d_model, nhead,
                                             dropout=dropout, batch_first=True),
                "ln1": nn.LayerNorm(d_model),
                "ffn": nn.Sequential(
                    nn.Linear(d_model, d_model*4),
                    nn.GELU(),
                    nn.Dropout(dropout),
                    nn.Linear(d_model*4, d_model)
                ),
                "ln2": nn.LayerNorm(d_model),
            })
            for _ in range(depth)
        ])
        self.head = nn.Sequential(nn.LayerNorm(d_model),
                                  nn.Linear(d_model, num_classes))

    def forward(self, t, a, v, pad_t=None, pad_a=None, pad_v=None):
        # project
        t = self.proj_t(t)
        a = self.proj_a(a)
        v = self.proj_v(v)
        kv = torch.cat([t, a, v], dim=1)

        # build mask if any padding masks were passed
        if any(m is not None for m in (pad_t, pad_a, pad_v)):
            pad_t = pad_t if pad_t is not None else torch.zeros(t.size()[:2],
                                                                dtype=torch.bool,
                                                                device=t.device)
            pad_a = pad_a if pad_a is not None else torch.zeros(a.size()[:2],
                                                                dtype=torch.bool,
                                                                device=a.device)
            pad_v = pad_v if pad_v is not None else torch.zeros(v.size()[:2],
                                                                dtype=torch.bool,
                                                                device=v.device)
            kpm = torch.cat([pad_t, pad_a, pad_v], dim=1)
        else:
            kpm = None

        B = t.size(0)
        q = self.fuse_token.expand(B, 1, -1)
        for blk in self.blocks:
            attn_out, _ = blk["mha"](q, kv, kv, key_padding_mask=kpm)
            q = blk["ln1"](q + attn_out)
            q = blk["ln2"](q + blk["ffn"](q))
        logits = self.head(q.squeeze(1))
        return logits, q.squeeze(1)


In [25]:
model = FusionTokenCrossAttn(
    pd_d_text=text.size(-1),
    pd_d_audio=audio.size(-1),
    pd_d_vision=vision.size(-1),
    d_model=512,
    nhead=8,
    depth=2,
    num_classes=2
)
logits, fused = model(text, audio, vision)
print(f"Pred logits shape: {logits.shape}, fused shape: {fused.shape}")


Pred logits shape: torch.Size([21, 2]), fused shape: torch.Size([21, 512])


In [26]:
import numpy as np
import os

out_path = "/home/jovyan/Desktop/PD_LLM/codes/pd&Hc_multi/Fixed_HC_Fused_embeddings/HC_Spontaneous_fused_embeddings_FIXED.npz"

ids_kept = common_ids                    # the exact order you looped
labels   = np.full((len(ids_kept),), 1, dtype=np.int64)   # PD = 1

# ensure directory exists
os.makedirs(os.path.dirname(out_path), exist_ok=True)

# prepare arrays for saving (detach and move to cpu first)
fused_np = fused.detach().cpu().numpy()
ids_arr = np.array(ids_kept, dtype='<U')  # store as unicode array
kept_idx = np.arange(len(ids_kept), dtype=np.int64)

np.savez_compressed(
    out_path,
    fused    = fused_np,   # [B, D]
    ids      = ids_arr,
    labels   = labels,
    kept_idx = kept_idx,
)

# quick sanity check
z = np.load(out_path, allow_pickle=True)
print(z["fused"].shape, len(z["ids"]), z["labels"].shape)
assert z["fused"].shape[0] == len(z["ids"]) == z["labels"].shape[0]


(21, 512) 21 (21,)


# HC ReadText

In [28]:
def build_dict_from_dir(root: str):
    d = {}
    for f in os.listdir(root):
        p = os.path.join(root, f)
        if not os.path.isfile(p): 
            continue
        if not (f.endswith(".npy") or f.endswith(".npz")):
            continue
        arr = load_np(p)
        if arr is None: 
            continue
        sid = extract_id_any(f)
        d[sid] = arr
    return d

# --- Directories (your PD Spont example) ---
HC_BERT_DIR   = "/home/jovyan/Desktop/PD_LLM/codes/pd&Hc_multi/Bert embedings/HC_transcript_berts_feats_tokens_only"
HC_HUBERT_DIR = "/home/jovyan/Desktop/PD_LLM/codes/pd&Hc_multi/Hubert embeddings/HC_ReadText_hubert_features"
HC_CLIP_FEATS = "/home/jovyan/Desktop/PD_LLM/codes/pd&Hc_multi/Clip embeddings/HC_ReadText_Spectrogram_CLIP_features.npy"
HC_CLIP_NAMES = "/home/jovyan/Desktop/PD_LLM/codes/pd&Hc_multi/Clip embeddings/HC_ReadText_Spectrogram_CLIP_filenames.txt"

# CLIP (names + big matrix)
clip_feats = np.load(HC_CLIP_FEATS)  # [N, D]
with open(HC_CLIP_NAMES, "r") as f:
    clip_names = [line.strip() for line in f]

clip_dict = {extract_id_any(name): clip_feats[i] for i, name in enumerate(clip_names)}
bert_dict = build_dict_from_dir(HC_BERT_DIR)
hub_dict  = build_dict_from_dir(HC_HUBERT_DIR)

common_ids = sorted(set(bert_dict) & set(hub_dict) & set(clip_dict))
print(f"[info] HC Spont common: {len(common_ids)} ids")


[info] HC Spont common: 21 ids


In [29]:
def pad_to_len(x, max_len=256):
    x = torch.as_tensor(x, dtype=torch.float32)
    if x.dim() == 1:  # e.g., CLIP single vector
        x = x.unsqueeze(0)
    elif x.dim() > 2:
        x = x.reshape(-1, x.size(-1))
    L, D = x.shape
    if L >= max_len:
        return x[:max_len]
    return torch.cat([x, torch.zeros(max_len - L, D)], dim=0)

# Build tensors
text_list, audio_list, vision_list, ids_kept, labels = [], [], [], [], []
for sid in common_ids:
    t = F.normalize(torch.as_tensor(bert_dict[sid], dtype=torch.float32), dim=-1)
    a = F.normalize(torch.as_tensor(hub_dict[sid],  dtype=torch.float32), dim=-1)
    v = F.normalize(torch.as_tensor(clip_dict[sid], dtype=torch.float32), dim=-1)

    t = pad_to_len(t)    # [256, d_t]
    a = pad_to_len(a)    # [256, d_a]
    v = pad_to_len(v)    # [256, d_v]

    text_list.append(t.unsqueeze(0))
    audio_list.append(a.unsqueeze(0))
    vision_list.append(v.unsqueeze(0))
    ids_kept.append(sid)

    # For PD set → label 1; for HC set → label 0 (don’t key off string 'hc' in sid)
    labels.append(1)

text = torch.cat(text_list, dim=0)
audio = torch.cat(audio_list, dim=0)
vision= torch.cat(vision_list, dim=0)
labels= torch.tensor(labels, dtype=torch.int64)
print(text.shape, audio.shape, vision.shape, len(ids_kept), labels.shape)


torch.Size([21, 256, 768]) torch.Size([21, 256, 768]) torch.Size([21, 256, 512]) 21 torch.Size([21])


In [30]:
import torch
from torch import nn

class FusionTokenCrossAttn(nn.Module):
    def __init__(self, pd_d_text, pd_d_audio, pd_d_vision,
                 d_model=512, nhead=8, depth=2, num_classes=2, dropout=0.1):
        super().__init__()
        # project each modality to same dimension
        self.proj_t = nn.Linear(pd_d_text, d_model)
        self.proj_a = nn.Linear(pd_d_audio, d_model)
        self.proj_v = nn.Linear(pd_d_vision, d_model)
        self.fuse_token = nn.Parameter(torch.randn(1, 1, d_model))

        self.blocks = nn.ModuleList([
            nn.ModuleDict({
                "mha": nn.MultiheadAttention(d_model, nhead,
                                             dropout=dropout, batch_first=True),
                "ln1": nn.LayerNorm(d_model),
                "ffn": nn.Sequential(
                    nn.Linear(d_model, d_model*4),
                    nn.GELU(),
                    nn.Dropout(dropout),
                    nn.Linear(d_model*4, d_model)
                ),
                "ln2": nn.LayerNorm(d_model),
            })
            for _ in range(depth)
        ])
        self.head = nn.Sequential(nn.LayerNorm(d_model),
                                  nn.Linear(d_model, num_classes))

    def forward(self, t, a, v, pad_t=None, pad_a=None, pad_v=None):
        # project
        t = self.proj_t(t)
        a = self.proj_a(a)
        v = self.proj_v(v)
        kv = torch.cat([t, a, v], dim=1)

        # build mask if any padding masks were passed
        if any(m is not None for m in (pad_t, pad_a, pad_v)):
            pad_t = pad_t if pad_t is not None else torch.zeros(t.size()[:2],
                                                                dtype=torch.bool,
                                                                device=t.device)
            pad_a = pad_a if pad_a is not None else torch.zeros(a.size()[:2],
                                                                dtype=torch.bool,
                                                                device=a.device)
            pad_v = pad_v if pad_v is not None else torch.zeros(v.size()[:2],
                                                                dtype=torch.bool,
                                                                device=v.device)
            kpm = torch.cat([pad_t, pad_a, pad_v], dim=1)
        else:
            kpm = None

        B = t.size(0)
        q = self.fuse_token.expand(B, 1, -1)
        for blk in self.blocks:
            attn_out, _ = blk["mha"](q, kv, kv, key_padding_mask=kpm)
            q = blk["ln1"](q + attn_out)
            q = blk["ln2"](q + blk["ffn"](q))
        logits = self.head(q.squeeze(1))
        return logits, q.squeeze(1)


In [31]:
model = FusionTokenCrossAttn(
    pd_d_text=text.size(-1),
    pd_d_audio=audio.size(-1),
    pd_d_vision=vision.size(-1),
    d_model=512,
    nhead=8,
    depth=2,
    num_classes=2
)
logits, fused = model(text, audio, vision)
print(f"Pred logits shape: {logits.shape}, fused shape: {fused.shape}")


Pred logits shape: torch.Size([21, 2]), fused shape: torch.Size([21, 512])


In [32]:
import numpy as np
import os

out_path = "/home/jovyan/Desktop/PD_LLM/codes/pd&Hc_multi/Fixed_HC_Fused_embeddings/HC_ReadText_fused_embeddings_FIXED.npz"

ids_kept = common_ids                    # the exact order you looped
labels   = np.full((len(ids_kept),), 1, dtype=np.int64)   # PD = 1

# ensure directory exists
os.makedirs(os.path.dirname(out_path), exist_ok=True)

# prepare arrays for saving (detach and move to cpu first)
fused_np = fused.detach().cpu().numpy()
ids_arr = np.array(ids_kept, dtype='<U')  # store as unicode array
kept_idx = np.arange(len(ids_kept), dtype=np.int64)

np.savez_compressed(
    out_path,
    fused    = fused_np,   # [B, D]
    ids      = ids_arr,
    labels   = labels,
    kept_idx = kept_idx,
)

# quick sanity check
z = np.load(out_path, allow_pickle=True)
print(z["fused"].shape, len(z["ids"]), z["labels"].shape)
assert z["fused"].shape[0] == len(z["ids"]) == z["labels"].shape[0]


(21, 512) 21 (21,)
