In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import numpy as np
from ls.config.loader import load_config
import IPython.display as ipd
import torch

In [3]:
# --- 1. load config ---
cfg = load_config("../configs/config.yaml")

print("Dataset config:", cfg.dataset)
print("Audio config:", cfg.audio)

Dataset config: {'name': 'icbhi', 'data_folder': '/home/AIoT04/Datasets/icbhi_dataset', 'cycle_metadata_path': '/home/AIoT04/Datasets/icbhi_dataset/icbhi_metadata.csv', 'class_split': 'lungsound', 'split_strategy': 'official', 'test_fold': 0, 'multi_label': True, 'n_cls': 4, 'weighted_sampler': True, 'batch_size': 8, 'num_workers': 0, 'h': 128, 'w': 1024}
Audio config: {'sample_rate': 16000, 'desired_length': 10.0, 'remove_dc': True, 'normalize': False, 'pad_type': 'repeat', 'use_fade': True, 'fade_samples_ratio': 64, 'n_mels': 128, 'frame_length': 40, 'frame_shift': 10, 'low_freq': 100, 'high_freq': 5000, 'window_type': 'hanning', 'use_energy': False, 'dither': 0.0, 'mel_norm': 'mit', 'resz': 1.0, 'raw_augment': 1, 'wave_aug': [{'type': 'Crop', 'sampling_rate': 16000, 'zone': [0.0, 1.0], 'coverage': 1.0, 'p': 0.0}, {'type': 'Noise', 'color': 'white', 'p': 0.1}, {'type': 'Speed', 'factor': [0.9, 1.1], 'p': 0.1}, {'type': 'Loudness', 'factor': [0.5, 2.0], 'p': 0.1}, {'type': 'VTLP', 'sa

In [4]:
# Regular training
from ls.data.dataloaders import build_dataloaders

train_loader, test_loader = build_dataloaders(cfg.dataset, cfg.audio)

[Transforms] Input spectrogram resize factor: 1.0, target size: (128, 1024)
[Transforms] Input spectrogram resize factor: 1.0, target size: (128, 1024)
[ICBHI] Loaded cycle metadata TSV: 6898 rows
[ICBHI] #Sites=7, #Devices=4
[ICBHI] Sites Found: {'Al': 0, 'Ar': 1, 'Ll': 2, 'Lr': 3, 'Pl': 4, 'Pr': 5, 'Tc': 6}
[ICBHI] Devices Found: {'AKGC417L': 0, 'Litt3200': 1, 'LittC2SE': 2, 'Meditron': 3}


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


[ICBHI] Extracted 4142 cycles from 539 recordings
[ICBHI] Metadata join missing: 0 (strict join; should be 0)
[ICBHI] Input spectrogram shape: (997, 128, 1)
[ICBHI] 4142 cycles
  Class 0: 2063 (49.8%)
  Class 1: 1215 (29.3%)
  Class 2: 501 (12.1%)
  Class 3: 363 (8.8%)
[ICBHI] Loaded cycle metadata TSV: 6898 rows
[ICBHI] #Sites=7, #Devices=4
[ICBHI] Sites Found: {'Al': 0, 'Ar': 1, 'Ll': 2, 'Lr': 3, 'Pl': 4, 'Pr': 5, 'Tc': 6}
[ICBHI] Devices Found: {'AKGC417L': 0, 'Litt3200': 1, 'LittC2SE': 2, 'Meditron': 3}


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


[ICBHI] Extracted 2756 cycles from 381 recordings
[ICBHI] Metadata join missing: 0 (strict join; should be 0)
[ICBHI] Input spectrogram shape: (997, 128, 1)
[ICBHI] 2756 cycles
  Class 0: 1579 (57.3%)
  Class 1: 649 (23.5%)
  Class 2: 385 (14.0%)
  Class 3: 143 (5.2%)


In [5]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
DEVICE

device(type='cuda')

In [6]:
batch = next(iter(train_loader))

x = batch["input_values"].to(DEVICE)      # (B, 1, F, T)
device_id = batch["device_id"].to(DEVICE) # (B,)
site_id   = batch["site_id"].to(DEVICE)   # (B,)
m_rest    = batch["m_rest"].to(DEVICE)    # (B, rest_dim)
y         = batch["label"].to(DEVICE)     # (B,2) for multilabel

print("x:", x.shape, x.dtype)
print("device_id:", device_id.shape, device_id.dtype)
print("site_id:", site_id.shape, site_id.dtype)
print("m_rest:", m_rest.shape, m_rest.dtype)
print("y:", y.shape, y.dtype)

x: torch.Size([8, 1, 128, 1024]) torch.float32
device_id: torch.Size([8]) torch.int64
site_id: torch.Size([8]) torch.int64
m_rest: torch.Size([8, 3]) torch.float32
y: torch.Size([8, 2]) torch.float32


### Projected added metadata

In [7]:
from ls.models.ast_fus import ASTMetaProj

In [8]:
import torch.nn as nn

ast_kwargs = dict(
    label_dim=2,          # unused since backbone_only=True
    fstride=10,
    tstride=10,
    input_fdim=128,
    input_tdim=1024,
    imagenet_pretrain=True,
    audioset_pretrain=True,
    audioset_ckpt_path='/home/AIoT04/Dev/pretrained_models/audioset_10_10_0.4593.pth',
    model_size='base384',
    verbose=True,
)

# If your dataset returns: m_rest = [sex, age, bmi, duration, bmi_missing]
num_devices = 4
num_sites = 7
rest_dim = train_loader.dataset[0]["m_rest"].numel()

meta_model = ASTMetaProj(
    ast_kwargs=ast_kwargs,
    num_devices=num_devices,
    num_sites=num_sites,
    dev_emb_dim=4,
    site_emb_dim=4,
    rest_dim=rest_dim,
    hidden_dim=64,
    dropout_p=0.3,
    num_labels=2
).to(DEVICE)

print(meta_model)

---------------AST Model Summary---------------
ImageNet pretraining: True, AudioSet pretraining: True
Loading AudioSet pretrained model from /home/AIoT04/Dev/pretrained_models/audioset_10_10_0.4593.pth
No mismatch for key: v.cls_token
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.pos_embed
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.dist_token
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.patch_embed.proj.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.patch_embed.proj.bias
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.norm1.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.norm1.bias
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.attn.qkv.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.attn.qkv.bias

In [9]:
logits = meta_model(x, device_id, site_id, m_rest)
print("logits:", logits.shape, logits.dtype)  # (B,2)

logits: torch.Size([8, 2]) torch.float32


### FiLM: Metadata conditioning inside the Transformer

In [10]:
from ls.models.ast_film import ASTFiLM

In [11]:
film = ASTFiLM(
    ast_kwargs=ast_kwargs,
    num_devices=num_devices,
    num_sites=num_sites,
    rest_dim=rest_dim,
    conditioned_layers=(10,11,12),
).to(DEVICE)

# Forward
logits_film = film(x, device_id, site_id, m_rest)

print(logits_film.shape)  # (B,2)

---------------AST Model Summary---------------
ImageNet pretraining: True, AudioSet pretraining: True
Loading AudioSet pretrained model from /home/AIoT04/Dev/pretrained_models/audioset_10_10_0.4593.pth
No mismatch for key: v.cls_token
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.pos_embed
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.dist_token
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.patch_embed.proj.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.patch_embed.proj.bias
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.norm1.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.norm1.bias
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.attn.qkv.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.attn.qkv.bias

In [12]:
from ls.models.ast_pp import ASTFiLMPlusPlus

In [13]:
num_devices = 4
num_sites = 7
rest_dim = train_loader.dataset[0]["m_rest"].numel()

In [14]:
filmpp = ASTFiLMPlusPlus(
    ast_kwargs=ast_kwargs,
    num_devices=num_devices,
    num_sites=num_sites,
    rest_dim=rest_dim,
    D_dev=128,
    D_site=128,
    conditioned_layers=(10,11,12),
).to(DEVICE)
logits_pp   = filmpp(x, device_id, site_id, m_rest)

print(logits_pp.shape)  # (B,2)

---------------AST Model Summary---------------
ImageNet pretraining: True, AudioSet pretraining: True
Loading AudioSet pretrained model from /home/AIoT04/Dev/pretrained_models/audioset_10_10_0.4593.pth
No mismatch for key: v.cls_token
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.pos_embed
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.dist_token
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.patch_embed.proj.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.patch_embed.proj.bias
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.norm1.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.norm1.bias
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.attn.qkv.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.attn.qkv.bias