In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import numpy as np
from ls.config.loader import load_config
import IPython.display as ipd
import torch

In [3]:
# --- 1. load config ---
cfg = load_config("../configs/config.yaml")

print("Dataset config:", cfg.dataset)
print("Audio config:", cfg.audio)

Dataset config: {'name': 'icbhi', 'data_folder': '/Users/gkont/Downloads/Datasets/icbhi_dataset', 'cycle_metadata_path': '/Users/gkont/Downloads/Datasets/icbhi_dataset/icbhi_metadata.csv', 'class_split': 'lungsound', 'split_strategy': 'official', 'test_fold': 0, 'multi_label': True, 'n_cls': 4, 'weighted_sampler': True, 'batch_size': 8, 'num_workers': 0, 'h': 128, 'w': 1024}
Audio config: {'sample_rate': 16000, 'desired_length': 10.0, 'remove_dc': True, 'normalize': False, 'pad_type': 'repeat', 'use_fade': True, 'fade_samples_ratio': 64, 'n_mels': 128, 'frame_length': 40, 'frame_shift': 10, 'low_freq': 100, 'high_freq': 5000, 'window_type': 'hanning', 'use_energy': False, 'dither': 0.0, 'mel_norm': 'mit', 'resz': 1.0, 'raw_augment': 1, 'wave_aug': [{'type': 'Crop', 'sampling_rate': 16000, 'zone': [0.0, 1.0], 'coverage': 1.0, 'p': 0.0}, {'type': 'Noise', 'color': 'white', 'p': 0.1}, {'type': 'Speed', 'factor': [0.9, 1.1], 'p': 0.1}, {'type': 'Loudness', 'factor': [0.5, 2.0], 'p': 0.1}, 

In [4]:
# Regular training
from ls.data.dataloaders import build_dataloaders

train_loader, test_loader = build_dataloaders(cfg.dataset, cfg.audio)

[Transforms] Input spectrogram resize factor: 1.0, target size: (128, 1024)
[Transforms] Input spectrogram resize factor: 1.0, target size: (128, 1024)
[ICBHI] Loaded cycle metadata TSV: 6898 rows
[ICBHI] #Sites=7, #Devices=4
[ICBHI] Sites Found: {'Al': 0, 'Ar': 1, 'Ll': 2, 'Lr': 3, 'Pl': 4, 'Pr': 5, 'Tc': 6}
[ICBHI] Devices Found: {'AKGC417L': 0, 'Litt3200': 1, 'LittC2SE': 2, 'Meditron': 3}




[ICBHI] Extracted 4142 cycles from 539 recordings
[ICBHI] Metadata join missing: 0 (strict join; should be 0)
[ICBHI] Input spectrogram shape: (997, 128, 1)
[ICBHI] 4142 cycles
  Class 0: 2063 (49.8%)
  Class 1: 1215 (29.3%)
  Class 2: 501 (12.1%)
  Class 3: 363 (8.8%)
[ICBHI] Loaded cycle metadata TSV: 6898 rows
[ICBHI] #Sites=7, #Devices=4
[ICBHI] Sites Found: {'Al': 0, 'Ar': 1, 'Ll': 2, 'Lr': 3, 'Pl': 4, 'Pr': 5, 'Tc': 6}
[ICBHI] Devices Found: {'AKGC417L': 0, 'Litt3200': 1, 'LittC2SE': 2, 'Meditron': 3}




[ICBHI] Extracted 2756 cycles from 381 recordings
[ICBHI] Metadata join missing: 0 (strict join; should be 0)
[ICBHI] Input spectrogram shape: (997, 128, 1)
[ICBHI] 2756 cycles
  Class 0: 1579 (57.3%)
  Class 1: 649 (23.5%)
  Class 2: 385 (14.0%)
  Class 3: 143 (5.2%)


In [6]:
DEVICE = torch.device("cpu")
DEVICE

device(type='cpu')

In [7]:
batch = next(iter(train_loader))

x = batch["input_values"].to(DEVICE)      # (B, 1, F, T)
device_id = batch["device_id"].to(DEVICE) # (B,)
site_id   = batch["site_id"].to(DEVICE)   # (B,)
m_rest    = batch["m_rest"].to(DEVICE)    # (B, rest_dim)
y         = batch["label"].to(DEVICE)     # (B,2) for multilabel

print("x:", x.shape, x.dtype)
print("device_id:", device_id.shape, device_id.dtype)
print("site_id:", site_id.shape, site_id.dtype)
print("m_rest:", m_rest.shape, m_rest.dtype)
print("y:", y.shape, y.dtype)



x: torch.Size([8, 1, 128, 1024]) torch.float32
device_id: torch.Size([8]) torch.int64
site_id: torch.Size([8]) torch.int64
m_rest: torch.Size([8, 3]) torch.float32
y: torch.Size([8, 2]) torch.float32


## Baseline, Naive Concatenation & Projection Fusion

In [8]:
import torch.nn as nn
from ls.models.ast import ASTModel

In [19]:
class ASTBaseline(nn.Module):
    """
    Baseline: AST encoder -> MLP -> 2 logits (crackle, wheeze).
    Matches Section 2.1 (Baseline Model) in your Methods.
    """
    def __init__(
        self,
        ast_kwargs: dict,
        hidden_dim: int = 64,
        dropout_p: float = 0.3,
        num_labels: int = 2,      # crackle, wheeze
    ):
        super().__init__()
        # AST as backbone-only encoder (returns h_CLS)
        self.ast = ASTModel(
            backbone_only=True,
            **ast_kwargs
        )
        
        D = self.ast.original_embedding_dim

        self.classifier = nn.Sequential(
            nn.LayerNorm(D),
            nn.Dropout(dropout_p),
            nn.Linear(D, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_p),
            nn.Linear(hidden_dim, num_labels)   # logits (no sigmoid)
        )

    def forward(self, x):
        """
        x: (B, T, F) or (B, 1, F, T)
        returns: (B, 2) logits (crackle, wheeze)
        """
        h_cls = self.ast.forward_features(x)           # (B, D), your h_CLS
        print(f"h_cls shape: {h_cls.shape}")  # (B, D)
        logits = self.classifier(h_cls)
        print(f"logits shape: {logits.shape}")  # (B, 2)
        return logits

In [20]:
ast_kwargs = dict(
    label_dim=2,          # unused since backbone_only=True
    fstride=10,
    tstride=10,
    input_fdim=128,
    input_tdim=1024,
    imagenet_pretrain=True,
    audioset_pretrain=True,
    audioset_ckpt_path='/Users/gkont/Documents/Code/pretrained_models/audioset_10_10_0.4593.pth',
    model_size='base384',
    verbose=True,
)

baseline_model = ASTBaseline(ast_kwargs, hidden_dim=64, dropout_p=0.3, num_labels=2).to(DEVICE)

---------------AST Model Summary---------------
ImageNet pretraining: True, AudioSet pretraining: True
Loading AudioSet pretrained model from /Users/gkont/Documents/Code/pretrained_models/audioset_10_10_0.4593.pth
No mismatch for key: v.cls_token
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.pos_embed
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.dist_token
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.patch_embed.proj.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.patch_embed.proj.bias
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.norm1.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.norm1.bias
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.attn.qkv.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.at

In [21]:
print(baseline_model)

ASTBaseline(
  (ast): ASTModel(
    (reg_dropout): Dropout(p=0.3, inplace=False)
    (v): DistilledVisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10))
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (blocks): ModuleList(
        (0-11): 12 x Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU(approximate='none')
            (fc2): Linear(in_features=3072, out_fea

In [22]:
x.shape

torch.Size([8, 1, 128, 1024])

In [23]:
out = baseline_model(x)  # (B, 2)
print(out.shape)  # torch.Size([2, 2])

: 

### Projected added metadata

In [14]:
class ASTWithMetadataProjection(nn.Module):
    """
    Metadata projection fusion:
        m = [E_dev(device_id), E_site(site_id), m_rest]
        m' = W m
        h_tilde = h_CLS + m'
    """
    def __init__(
        self,
        ast_kwargs: dict,
        num_devices: int,
        num_sites: int,
        dev_emb_dim: int = 4,
        site_emb_dim: int = 4,
        rest_dim: int = 5,
        hidden_dim: int = 64,
        dropout_p: float = 0.3,
        num_labels: int = 2,
    ):
        super().__init__()

        self.ast = ASTModel(backbone_only=True, **ast_kwargs)
        D = self.ast.original_embedding_dim

        # categorical encoders
        self.dev_emb  = nn.Embedding(num_devices, dev_emb_dim)
        self.site_emb = nn.Embedding(num_sites, site_emb_dim)

        meta_dim = dev_emb_dim + site_emb_dim + rest_dim
        self.metadata_proj = nn.Linear(meta_dim, D)

        self.classifier = nn.Sequential(
            nn.LayerNorm(D),
            nn.Dropout(dropout_p),
            nn.Linear(D, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_p),
            nn.Linear(hidden_dim, num_labels)
        )

    def forward(self, x, device_id, site_id, m_rest):
        """
        x:          (B, 1, F, T)
        device_id:  (B,) long
        site_id:    (B,) long
        m_rest:     (B, rest_dim) float
        """
        h_cls = self.ast(x)                    # (B, D)

        dev  = self.dev_emb(device_id)         # (B, d_dev)
        site = self.site_emb(site_id)          # (B, d_site)

        m = torch.cat([dev, site, m_rest], dim=-1)  # (B, meta_dim)
        m_prime = self.metadata_proj(m)              # (B, D)

        h_tilde = h_cls + m_prime
        logits = self.classifier(h_tilde)
        return logits

In [15]:
train_loader.dataset[0]["m_rest"].numel()

3

In [17]:
import torch.nn as nn

ast_kwargs = dict(
    label_dim=2,          # unused since backbone_only=True
    fstride=10,
    tstride=10,
    input_fdim=128,
    input_tdim=1024,
    imagenet_pretrain=True,
    audioset_pretrain=True,
    audioset_ckpt_path='/Users/gkont/Documents/Code/pretrained_models/audioset_10_10_0.4593.pth',
    model_size='base384',
    verbose=False,
)

# If your dataset returns: m_rest = [sex, age, bmi, duration, bmi_missing]
num_devices = 4
num_sites = 7
rest_dim = train_loader.dataset[0]["m_rest"].numel()

meta_model = ASTWithMetadataProjection(
    ast_kwargs=ast_kwargs,
    num_devices=num_devices,
    num_sites=num_sites,
    dev_emb_dim=4,
    site_emb_dim=4,
    rest_dim=rest_dim,
    hidden_dim=64,
    dropout_p=0.3,
    num_labels=2
).to(DEVICE)

print(meta_model)


ASTWithMetadataProjection(
  (ast): ASTModel(
    (reg_dropout): Dropout(p=0.3, inplace=False)
    (v): DistilledVisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(1, 768, kernel_size=(16, 16), stride=(10, 10))
      )
      (pos_drop): Dropout(p=0.0, inplace=False)
      (blocks): ModuleList(
        (0-11): 12 x Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=768, out_features=768, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (act): GELU(approximate='none')
            (fc2): Linear(in_features

In [18]:
logits = meta_model(x, device_id, site_id, m_rest)
print("logits:", logits.shape, logits.dtype)  # (B,2)

RuntimeError: The size of tensor a (2) must match the size of tensor b (768) at non-singleton dimension 1

In [None]:
class ASTWithMetadataProjection(nn.Module):
    """
    Metadata projection fusion:
        m -> m' in R^D
        h_tilde = h_CLS + m'
        h_tilde -> MLP -> logits
    """
    def __init__(
        self,
        ast_kwargs: dict,
        metadata_dim: int,
        hidden_dim: int = 64,
        dropout_p: float = 0.3,
        num_labels: int = 2,
    ):
        super().__init__()
        self.ast = ASTModel(
            backbone_only=True,
            **ast_kwargs
        )
        D = self.ast.original_embedding_dim
        self.metadata_dim = metadata_dim

        # Linear projection m -> m' in R^D
        self.metadata_proj = nn.Linear(metadata_dim, D)

        # Same style classifier as baseline
        self.classifier = nn.Sequential(
            nn.LayerNorm(D),
            nn.Dropout(dropout_p),
            nn.Linear(D, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_p),
            nn.Linear(hidden_dim, num_labels)
        )

    def forward(self, x, m):
        """
        x: (B, T, F) or (B, 1, F, T)
        m: (B, M)
        """
        h_cls = self.ast(x)               # (B, D)
        m_prime = self.metadata_proj(m)   # (B, D)
        h_tilde = h_cls + m_prime         # (B, D)  -- Eq. (11)
        logits = self.classifier(h_tilde)
        return logits

In [None]:
astproj = ASTWithMetadataProjection(ast_kwargs, metadata_dim=10, hidden_dim=64, dropout_p=0.3, num_labels=2).to(DEVICE)
print(astproj)
out = astproj(dummy_input, metadata)  # (2, 2)
print(out.shape)  # torch.Size([2, 2])

### FiLM: Metadata conditioning inside the Transformer

In [None]:
class ASTFiLM(ASTModel):
    """
    AST with FiLM conditioning on selected Transformer layers.

    - Uses ASTModel's __init__ to build the DeiT backbone.
    - Adds a metadata encoder g(m).
    - For a set of layers B, generates layer-specific gamma_l, beta_l.
    - Applies FiLM after the MHSA residual, before the FFN (as per your equations).
    """
    def __init__(
        self,
        metadata_dim: int,
        conditioned_layers=(0, 1, 2, 3),   # indices into self.v.blocks
        metadata_hidden_dim: int = 64,
        film_hidden_dim: int = 64,
        num_labels: int = 2,
        dropout_p: float = 0.3,
        ast_kwargs: dict = None,
    ):
        ast_kwargs = ast_kwargs or {}
        super().__init__(backbone_only=True, **ast_kwargs)

        self.metadata_dim = metadata_dim
        self.conditioned_layers = sorted(list(conditioned_layers))
        self.conditioned_layers_set = set(self.conditioned_layers)

        D = self.original_embedding_dim
        self.num_layers = len(self.v.blocks)

        # Metadata encoder h_m = g(m)
        self.metadata_encoder = nn.Sequential(
            nn.LayerNorm(metadata_dim),
            nn.Linear(metadata_dim, metadata_hidden_dim),
            nn.ReLU(),
            nn.Linear(metadata_hidden_dim, film_hidden_dim),
            nn.ReLU(),
        )

        # One FiLM generator per conditioned layer:
        # f_l: h_m -> [gamma_l || beta_l] in R^{2D}
        self.film_generators = nn.ModuleDict()
        for l in self.conditioned_layers:
            self.film_generators[str(l)] = nn.Linear(film_hidden_dim, 2 * D)

        # Classification head (same as baseline)
        self.classifier = nn.Sequential(
            nn.LayerNorm(D),
            nn.Dropout(dropout_p),
            nn.Linear(D, 64),
            nn.ReLU(),
            nn.Dropout(dropout_p),
            nn.Linear(64, num_labels)
        )

    def forward_features(self, x, m):
        """
        x: (B, 1, F, T)  (we'll handle 3D in forward())
        m: (B, M) metadata
        returns: h_CLS in R^D (FiLM-conditioned)
        """
        B = x.shape[0]

        # Patch embedding: (B, N, D)
        x = self.v.patch_embed(x)

        # CLS + dist tokens and positional embeddings
        cls_tokens = self.v.cls_token.expand(B, -1, -1)
        dist_token = self.v.dist_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, dist_token, x), dim=1)
        x = x + self.v.pos_embed
        x = self.v.pos_drop(x)

        # Encode metadata once, then generate FiLM params per layer
        h_m = self.metadata_encoder(m)   # (B, film_hidden_dim)

        gamma = {}
        beta = {}
        for l in self.conditioned_layers:
            film = self.film_generators[str(l)](h_m)  # (B, 2D)
            g_l, b_l = film.chunk(2, dim=-1)          # (B, D), (B, D)
            gamma[l] = g_l
            beta[l] = b_l

        # Manually unroll each Transformer block
        for layer_idx, blk in enumerate(self.v.blocks):
            # MHSA sublayer with pre-norm
            attn_out = blk.attn(blk.norm1(x))
            x = x + blk.drop_path(attn_out)   # Z_tilde_l

            # Apply FiLM after MHSA for selected layers
            if layer_idx in self.conditioned_layers_set:
                print(f"Applying FiLM at layer {layer_idx}")
                g_l = gamma[layer_idx].unsqueeze(1)   # (B, 1, D) -> broadcast over tokens
                b_l = beta[layer_idx].unsqueeze(1)    # (B, 1, D)
                x = g_l * x + b_l                     # Eq. (14â€“15): hat{Z}_l

            # FFN + residual
            x = x + blk.drop_path(blk.mlp(blk.norm2(x)))

        x = self.v.norm(x)
        h_cls = (x[:, 0] + x[:, 1]) / 2  # pooled token
        return h_cls

    def forward(self, x, m):
        """
        x: (B, T, F) or (B, 1, F, T)
        m: (B, M) metadata
        """
        if x.dim() == 3:
            x = x.unsqueeze(1).transpose(2, 3)  # (B, 1, F, T)

        h_cls = self.forward_features(x, m)
        logits = self.classifier(h_cls)
        return logits

In [None]:
astfilm = ASTFiLM(
    ast_kwargs=ast_kwargs, 
    metadata_dim=10,
    conditioned_layers=range(12), # all layers
    metadata_hidden_dim=64, 
    film_hidden_dim=64,
    dropout_p=0.3, 
    num_labels=2
).to(DEVICE)
print(astfilm)
metadata  = torch.randn(2, 10).to(DEVICE)  # (B, M)
out = astfilm(dummy_input, metadata)  # (2, 2)
print(out.shape)  # torch.Size([2, 2])

In [None]:
class ASTFiLMPlusPlus(ASTModel):
    """
    FiLM++: factor-aligned grouped FiLM with K=3 groups:
      - device group    (D_dev)
      - site group      (D_site)
      - rest group      (D_rest)

    Metadata inputs are passed separately as (m_dev, m_site, m_rest).
    """
    def __init__(
        self,
        dev_metadata_dim: int,
        site_metadata_dim: int,
        rest_metadata_dim: int,
        D_dev: int,
        D_site: int,
        conditioned_layers=(0, 1, 2, 3),
        metadata_hidden_dim: int = 64,
        film_hidden_dim: int = 64,
        num_labels: int = 2,
        dropout_p: float = 0.3,
        ast_kwargs: dict = None,
    ):
        ast_kwargs = ast_kwargs or {}
        super().__init__(backbone_only=True, **ast_kwargs)

        self.conditioned_layers = sorted(list(conditioned_layers))
        self.conditioned_layers_set = set(self.conditioned_layers)

        D_total = self.original_embedding_dim
        assert D_dev + D_site <= D_total, "D_dev + D_site must be <= total embedding dim"
        D_rest = D_total - D_dev - D_site

        self.D_dev = D_dev
        self.D_site = D_site
        self.D_rest = D_rest
        self.D_total = D_total

        # --- Metadata encoders for each factor ---
        self.dev_encoder = nn.Sequential(
            nn.LayerNorm(dev_metadata_dim),
            nn.Linear(dev_metadata_dim, metadata_hidden_dim),
            nn.ReLU(),
            nn.Linear(metadata_hidden_dim, film_hidden_dim),
            nn.ReLU(),
        )
        self.site_encoder = nn.Sequential(
            nn.LayerNorm(site_metadata_dim),
            nn.Linear(site_metadata_dim, metadata_hidden_dim),
            nn.ReLU(),
            nn.Linear(metadata_hidden_dim, film_hidden_dim),
            nn.ReLU(),
        )
        self.rest_encoder = nn.Sequential(
            nn.LayerNorm(rest_metadata_dim),
            nn.Linear(rest_metadata_dim, metadata_hidden_dim),
            nn.ReLU(),
            nn.Linear(metadata_hidden_dim, film_hidden_dim),
            nn.ReLU(),
        )

        # --- FiLM generators per group and per conditioned layer ---
        self.dev_film = nn.ModuleDict()
        self.site_film = nn.ModuleDict()
        self.rest_film = nn.ModuleDict()
        for l in self.conditioned_layers:
            self.dev_film[str(l)] = nn.Linear(film_hidden_dim, 2 * D_dev)
            self.site_film[str(l)] = nn.Linear(film_hidden_dim, 2 * D_site)
            self.rest_film[str(l)] = nn.Linear(film_hidden_dim, 2 * D_rest)

        # Classification head (same style as baseline)
        self.classifier = nn.Sequential(
            nn.LayerNorm(D_total),
            nn.Dropout(dropout_p),
            nn.Linear(D_total, 64),
            nn.ReLU(),
            nn.Dropout(dropout_p),
            nn.Linear(64, num_labels),
        )

    def _apply_filmpp_grouped(self, x, gammas, betas):
        """
        x:      (B, T, D_total)
        gammas: dict with 'dev','site','rest' tensors (B, D_group)
        betas:  same as above

        Applies group-wise FiLM and returns (B, T, D_total).
        """
        B, T, D = x.shape
        D_dev, D_site, D_rest = self.D_dev, self.D_site, self.D_rest

        x_dev, x_site, x_rest = torch.split(
            x, [D_dev, D_site, D_rest], dim=-1
        )  # each (B, T, D_group)

        # Broadcast gammas/betas over tokens
        g_dev = gammas["dev"].unsqueeze(1)   # (B, 1, D_dev)
        b_dev = betas["dev"].unsqueeze(1)
        g_site = gammas["site"].unsqueeze(1) # (B, 1, D_site)
        b_site = betas["site"].unsqueeze(1)
        g_rest = gammas["rest"].unsqueeze(1) # (B, 1, D_rest)
        b_rest = betas["rest"].unsqueeze(1)

        x_dev_hat = g_dev * x_dev + b_dev
        x_site_hat = g_site * x_site + b_site
        x_rest_hat = g_rest * x_rest + b_rest

        x_hat = torch.cat([x_dev_hat, x_site_hat, x_rest_hat], dim=-1)  # (B, T, D_total)
        return x_hat

    def forward_features(self, x, m_dev, m_site, m_rest):
        """
        x:      (B, 1, F, T)
        m_dev:  (B, dev_metadata_dim)
        m_site: (B, site_metadata_dim)
        m_rest: (B, rest_metadata_dim)
        """
        B = x.shape[0]

        # Patch embedding
        x = self.v.patch_embed(x)
        cls_tokens = self.v.cls_token.expand(B, -1, -1)
        dist_token = self.v.dist_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, dist_token, x), dim=1)
        x = x + self.v.pos_embed
        x = self.v.pos_drop(x)

        # Encode metadata for each factor
        h_dev = self.dev_encoder(m_dev)
        h_site = self.site_encoder(m_site)
        h_rest = self.rest_encoder(m_rest)

        # Precompute gamma,beta for each conditioned layer
        gamma_dev, beta_dev = {}, {}
        gamma_site, beta_site = {}, {}
        gamma_rest, beta_rest = {}, {}

        for l in self.conditioned_layers:
            dev_params = self.dev_film[str(l)](h_dev)   # (B, 2*D_dev)
            site_params = self.site_film[str(l)](h_site) # (B, 2*D_site)
            rest_params = self.rest_film[str(l)](h_rest) # (B, 2*D_rest)

            g_dev, b_dev = dev_params.chunk(2, dim=-1)
            g_site, b_site = site_params.chunk(2, dim=-1)
            g_rest, b_rest = rest_params.chunk(2, dim=-1)

            gamma_dev[l], beta_dev[l] = g_dev, b_dev
            gamma_site[l], beta_site[l] = g_site, b_site
            gamma_rest[l], beta_rest[l] = g_rest, b_rest

        # Unroll ViT blocks with FiLM++ after MHSA
        for layer_idx, blk in enumerate(self.v.blocks):
            attn_out = blk.attn(blk.norm1(x))
            x = x + blk.drop_path(attn_out)   # (B, T, D_total)

            if layer_idx in self.conditioned_layers_set:
                gammas = {
                    "dev": gamma_dev[layer_idx],
                    "site": gamma_site[layer_idx],
                    "rest": gamma_rest[layer_idx],
                }
                betas = {
                    "dev": beta_dev[layer_idx],
                    "site": beta_site[layer_idx],
                    "rest": beta_rest[layer_idx],
                }
                x = self._apply_filmpp_grouped(x, gammas, betas)

            x = x + blk.drop_path(blk.mlp(blk.norm2(x)))

        x = self.v.norm(x)
        h_cls = (x[:, 0] + x[:, 1]) / 2
        return h_cls

    def forward(self, x, m_dev, m_site, m_rest):
        """
        x:      (B, T, F) or (B, 1, F, T)
        m_dev:  (B, dev_metadata_dim)
        m_site: (B, site_metadata_dim)
        m_rest: (B, rest_metadata_dim)
        """
        if x.dim() == 3:
            x = x.unsqueeze(1).transpose(2, 3)

        h_cls = self.forward_features(x, m_dev, m_site, m_rest)
        logits = self.classifier(h_cls)
        return logits

In [None]:
astfilmpp = ASTFiLMPlusPlus(
    dev_metadata_dim=4,
    site_metadata_dim=7,
    rest_metadata_dim=3,
    D_dev=128,
    D_site=128,
    ast_kwargs=ast_kwargs,
    conditioned_layers=(10, 11, 12), # last 3 layers
    metadata_hidden_dim=64, 
    film_hidden_dim=64,
    dropout_p=0.3, 
    num_labels=2
).to(DEVICE)
print(astfilmpp)

dev_metadata = torch.randn(2, 4).to(DEVICE)
site_metadata = torch.randn(2, 7).to(DEVICE)
rest_metadata = torch.randn(2, 3).to(DEVICE)
out = astfilmpp(dummy_input, dev_metadata, site_metadata, rest_metadata)  # (2, 2)
print(out.shape)  # torch.Size([2, 2])