In [1]:
# Implementation of AST model, mostly adapted from https://github.com/YuanGongND/AST

import torch
import torch.nn as nn
import os
from timm.models.layers import to_2tuple, trunc_normal_
import timm

In [2]:
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# override the timm package to relax the input shape constraint.
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()

        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

In [4]:
image = torch.randn(1, 1, 128, 1024).to(DEVICE)

PE = PatchEmbed(patch_size=(16, 16), in_chans=1, embed_dim=768).to(DEVICE)
patches = PE(image)
print(patches.shape)  # torch.Size([1, 512, 768])

torch.Size([1, 512, 768])


In [5]:
class ASTModel(nn.Module):
    """
    The AST model.
    :param label_dim: the label dimension, i.e., the number of total classes, it is 527 for AudioSet, 50 for ESC-50, and 35 for speechcommands v2-35
    :param fstride: the stride of patch spliting on the frequency dimension, for 16*16 patchs, fstride=16 means no overlap, fstride=10 means overlap of 6
    :param tstride: the stride of patch spliting on the time dimension, for 16*16 patchs, tstride=16 means no overlap, tstride=10 means overlap of 6
    :param input_fdim: the number of frequency bins of the input spectrogram
    :param input_tdim: the number of time frames of the input spectrogram
    :param imagenet_pretrain: if use ImageNet pretrained model
    :param audioset_pretrain: if use full AudioSet and ImageNet pretrained model
    :param audioset_ckpt_path: path to AudioSet pretrained checkpoint
    :param model_size: the model size of AST, should be in [tiny224, small224, base224, base384]
    :param verbose: print model summary
    :param backbone_only: if True, do not create classification head
    :param dropout_p: dropout probability for regularization
    """
    def __init__(
            self, 
            label_dim=4, 
            fstride=10, 
            tstride=10, 
            input_fdim=128, 
            input_tdim=1024,
            imagenet_pretrain=True,
            audioset_pretrain=False,
            audioset_ckpt_path='',
            model_size='base384', 
            verbose=True, 
            backbone_only=False,
            dropout_p=0.3,
        ):

        super(ASTModel, self).__init__()
        assert timm.__version__ == '0.4.5', 'Please use timm == 0.4.5, the code might not be compatible with newer versions.'

        if verbose == True:
            print('---------------AST Model Summary---------------')
            print('ImageNet pretraining: {:s}, AudioSet pretraining: {:s}'.format(str(imagenet_pretrain),str(audioset_pretrain)))
        
        # override timm input shape restriction
        timm.models.vision_transformer.PatchEmbed = PatchEmbed
        self.label_dim = label_dim
        self.reg_dropout = nn.Dropout(dropout_p)
        
        # if AudioSet pretraining is not used (but ImageNet pretraining may still apply)
        if audioset_pretrain == False:
            if model_size == 'tiny224':
                self.v = timm.create_model('vit_deit_tiny_distilled_patch16_224', pretrained=imagenet_pretrain)
            elif model_size == 'small224':
                self.v = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained=imagenet_pretrain)
            elif model_size == 'base224':
                self.v = timm.create_model('vit_deit_base_distilled_patch16_224', pretrained=imagenet_pretrain)
            elif model_size == 'base384':
                self.v = timm.create_model('vit_deit_base_distilled_patch16_384', pretrained=imagenet_pretrain)
            else:
                raise Exception('Model size must be one of tiny224, small224, base224, base384.')
            
            if verbose:
                print('Vision transformer model size {:s} created.'.format(model_size))
            
            self.original_num_patches = self.v.patch_embed.num_patches
            self.oringal_hw = int(self.original_num_patches ** 0.5)
            self.original_embedding_dim = self.v.pos_embed.shape[2]

            self.mlp_head = None
            if not backbone_only:
                print('Creating classification head with hidden dim 64 and dropout p={:.2f}'.format(dropout_p))
                self.mlp_head = nn.Sequential(
                    nn.LayerNorm(self.original_embedding_dim), 
                    self.reg_dropout,
                    nn.Linear(self.original_embedding_dim, 64),  # Hidden layer
                    nn.ReLU(),
                    self.reg_dropout,
                    nn.Linear(64, self.label_dim)  # label_dim outputs, no activation here
                )

            # automatcially get the intermediate shape
            f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim)
            num_patches = f_dim * t_dim
            self.v.patch_embed.num_patches = num_patches
            if verbose == True:
                print('frequency stride={:d}, time stride={:d}'.format(fstride, tstride))
                print('number of patches={:d}'.format(num_patches))

            # the linear projection layer
            new_proj = torch.nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
            if imagenet_pretrain == True:
                new_proj.weight = torch.nn.Parameter(torch.sum(self.v.patch_embed.proj.weight, dim=1).unsqueeze(1))
                new_proj.bias = self.v.patch_embed.proj.bias
            self.v.patch_embed.proj = new_proj

            # the positional embedding
            if imagenet_pretrain == True:
                # get the positional embedding from deit model, skip the first two tokens (cls token and distillation token), reshape it to original 2D shape (24*24).
                new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, self.original_num_patches, self.original_embedding_dim).transpose(1, 2).reshape(1, self.original_embedding_dim, self.oringal_hw, self.oringal_hw)
                # cut (from middle) or interpolate the second dimension of the positional embedding
                if t_dim <= self.oringal_hw:
                    new_pos_embed = new_pos_embed[:, :, :, int(self.oringal_hw / 2) - int(t_dim / 2): int(self.oringal_hw / 2) - int(t_dim / 2) + t_dim]
                else:
                    new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(self.oringal_hw, t_dim), mode='bilinear')
                # cut (from middle) or interpolate the first dimension of the positional embedding
                if f_dim <= self.oringal_hw:
                    new_pos_embed = new_pos_embed[:, :, int(self.oringal_hw / 2) - int(f_dim / 2): int(self.oringal_hw / 2) - int(f_dim / 2) + f_dim, :]
                else:
                    new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear')
                # flatten the positional embedding
                new_pos_embed = new_pos_embed.reshape(1, self.original_embedding_dim, num_patches).transpose(1,2)
                # concatenate the above positional embedding with the cls token and distillation token of the deit model.
                self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))
            else:
                # if not use imagenet pretrained model, just randomly initialize a learnable positional embedding
                # TODO can use sinusoidal positional embedding instead
                new_pos_embed = nn.Parameter(torch.zeros(1, self.v.patch_embed.num_patches + 2, self.original_embedding_dim))
                self.v.pos_embed = new_pos_embed
                trunc_normal_(self.v.pos_embed, std=.02)

        # now load a model that is pretrained on both ImageNet and AudioSet
        elif audioset_pretrain == True:
            if audioset_pretrain == True and imagenet_pretrain == False:
                raise ValueError('currently model pretrained on only audioset is not supported, please set imagenet_pretrain = True to use audioset pretrained model.')
            if model_size != 'base384':
                raise ValueError('currently only has base384 AudioSet pretrained model.')
            
            if not os.path.exists(audioset_ckpt_path):
                raise FileNotFoundError(f"Pretrained AudioSet model not found at '{audioset_ckpt_path}'.")
            
            if verbose:
                print(f"Loading AudioSet pretrained model from {audioset_ckpt_path}")
            
            # Load checkpoint
            sd = torch.load(audioset_ckpt_path, map_location=DEVICE)
            
            # Create temporary model to load AudioSet weights
            audio_model = ASTModel(
                label_dim=label_dim,
                fstride=fstride, 
                tstride=tstride, 
                input_fdim=input_fdim, 
                input_tdim=input_tdim, 
                imagenet_pretrain=False, 
                audioset_pretrain=False,
                backbone_only=backbone_only,
                model_size='base384', 
                verbose=False
            )
            
            # Handle different checkpoint formats (with or without DataParallel)
            model_dict = audio_model.state_dict()
            pretrained_dict = {}
            
            for k, v in sd.items():
                # Remove various prefixes
                if k.startswith('module.'):
                    key = k[7:]  # Remove 'module.' prefix (DataParallel)
                else:
                    key = k
                
                # Keep trained weights only for the vision transformer backbone
                # Skip mlp_head weights since dimensions differ
                if key in model_dict:
                    if model_dict[key].shape == v.shape:
                        pretrained_dict[key] = v
                    elif verbose:
                        print(f"Shape mismatch for {key}: model={model_dict[key].shape}, ckpt={v.shape}")
                print("No mismatch for key:", key) if verbose else None
                print("No mlp_head weights loaded from AudioSet checkpoint.") if verbose else None
            if len(pretrained_dict) > 0:
                model_dict.update(pretrained_dict)
                audio_model.load_state_dict(model_dict)
                if verbose:
                    print(f"Loaded {len(pretrained_dict)}/{len(model_dict)} tensors from AudioSet checkpoint")
                    print(f"The remaining {len(model_dict) - len(pretrained_dict)} tensors are randomly initialized and will be trained from scratch.")
            else:
                raise RuntimeError("No matching weights found in checkpoint!")
            
            # Extract the vision transformer backbone
            self.v = audio_model.v
            self.original_num_patches = self.v.patch_embed.num_patches
            self.oringal_hw = int(self.original_num_patches ** 0.5)
            self.original_embedding_dim = self.v.pos_embed.shape[2]
            
            # Create new classification head for your task
            self.mlp_head = None
            if not backbone_only:
                print('Creating classification head with hidden dim 64 and dropout p={:.2f}'.format(dropout_p))
                self.mlp_head = nn.Sequential(
                    nn.LayerNorm(self.original_embedding_dim), 
                    self.reg_dropout,
                    nn.Linear(self.original_embedding_dim, 64),  # Hidden layer
                    nn.ReLU(),
                    self.reg_dropout,
                    nn.Linear(64, self.label_dim)  # 2 outputs (no sigmoid here, apply in loss)
                )

            # Get patch dimensions for your input
            f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim)
            num_patches = f_dim * t_dim
            self.v.patch_embed.num_patches = num_patches
            if verbose == True:
                print('frequency stride={:d}, time stride={:d}'.format(fstride, tstride))
                print('number of patches={:d}'.format(num_patches))

            # Adapt positional embeddings from AudioSet (12x101) to your dimensions
            new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, 1212, 768).transpose(1, 2).reshape(1, 768, 12, 101)
            # if the input sequence length is larger than the original audioset (10s), then cut the positional embedding
            if t_dim < 101:
                new_pos_embed = new_pos_embed[:, :, :, 50 - int(t_dim/2): 50 - int(t_dim/2) + t_dim]
            # otherwise interpolate
            else:
                new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(12, t_dim), mode='bilinear')
            if f_dim < 12:
                new_pos_embed = new_pos_embed[:, :, 6 - int(f_dim/2): 6 - int(f_dim/2) + f_dim, :]
            # otherwise interpolate
            elif f_dim > 12:
                new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear')
            new_pos_embed = new_pos_embed.reshape(1, 768, num_patches).transpose(1, 2)
            self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))

    def get_shape(self, fstride, tstride, input_fdim=128, input_tdim=1024):
        test_input = torch.randn(1, 1, input_fdim, input_tdim)
        test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
        test_out = test_proj(test_input)
        f_dim = test_out.shape[2]
        t_dim = test_out.shape[3]
        return f_dim, t_dim

    def forward_features(self, x):
        """
        Extract features from the backbone.
        :param x: input spectrogram, shape (B, 1, F, T)
        :return: feature embeddings
        """
        B = x.shape[0]
        x = self.v.patch_embed(x)
        print(f"Images size after patch embedding: {x.shape}")
        print(f"CLS token shape: {self.v.cls_token.shape}")
        cls_tokens = self.v.cls_token.expand(B, -1, -1)
        print(f"Expanded CLS token shape: {cls_tokens.shape}")
        dist_token = self.v.dist_token.expand(B, -1, -1)
        print(f"Expanded distillation token shape: {dist_token.shape}")
        x = torch.cat((cls_tokens, dist_token, x), dim=1)
        print(f"Shape after concatenating tokens: {x.shape}")
        x = x + self.v.pos_embed
        print(f"Positional embedding shape: {self.v.pos_embed.shape}")
        x = self.v.pos_drop(x)
        print(x.shape)

        for blk in self.v.blocks:
            x = blk(x)
        print(f"x.shape after blocks: {x.shape}")

        x = self.v.norm(x)
        print(f"x.shape after norm: {x.shape}") # (B, Patches+2, Embedding_dim)
        x = (x[:, 0, :] + x[:, 1, :]) / 2  # average CLS and distillation tokens -> (B, Embedding_dim)
        print(f"x.shape after averaging tokens: {x.shape}")
        return x

    @torch.amp.autocast(device_type=DEVICE.type)
    def forward(self, x):
        """
        :param x: input spectrogram
                  Can be either (B, T, F) [original format] or (B, 1, F, T) [preprocessed]
        :return: prediction logits or features if backbone_only
        """
        # Handle both input formats
        if x.dim() == 3:  # (B, T, F) - original format
            x = x.unsqueeze(1)  # (B, 1, T, F)
            x = x.transpose(2, 3)  # (B, 1, F, T)
        
        # Now x is (B, 1, F, T)
        x = self.forward_features(x) # (B, Embedding_dim)

        if self.mlp_head is not None:
            x = self.mlp_head(x)

        return x
    
    def freeze_backbone(self, until_block=None):
        """
        Freeze transformer blocks for fine-tuning.
        :param until_block: freeze blocks up to and including this index (None = freeze all)
        """
        # Freeze patch embedding
        for p in self.v.patch_embed.parameters():
            p.requires_grad = False
        
        # Freeze positional embeddings
        self.v.pos_embed.requires_grad = False
        self.v.cls_token.requires_grad = False
        self.v.dist_token.requires_grad = False
        
        # Freeze transformer blocks
        for i, blk in enumerate(self.v.blocks):
            if until_block is None or i <= until_block:
                for p in blk.parameters():
                    p.requires_grad = False

    def unfreeze_all(self):
        """Unfreeze all parameters."""
        for p in self.parameters():
            p.requires_grad = True

In [6]:
dummy_input = torch.randn(2, 1, 128, 1024).to(DEVICE)
model = ASTModel(
    label_dim=2,
    fstride=10,
    tstride=10,
    input_fdim=128,
    input_tdim=1024,
    imagenet_pretrain=True,
    audioset_pretrain=True,
    audioset_ckpt_path='/Users/gkont/Documents/Code/pretrained_models/audioset_10_10_0.4593.pth',
    model_size='base384',
    verbose=True,
    backbone_only=True,
    dropout_p=0.3,
).to(DEVICE)
output = model(dummy_input)
print(output.shape)  # torch.Size([2, 2])

---------------AST Model Summary---------------
ImageNet pretraining: True, AudioSet pretraining: True
Loading AudioSet pretrained model from /Users/gkont/Documents/Code/pretrained_models/audioset_10_10_0.4593.pth
No mismatch for key: v.cls_token
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.pos_embed
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.dist_token
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.patch_embed.proj.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.patch_embed.proj.bias
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.norm1.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.norm1.bias
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.attn.qkv.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.at

## Baseline, Naive Concatenation & Projection Fusion

In [7]:
class ASTBaseline(nn.Module):
    """
    Baseline: AST encoder -> MLP -> 2 logits (crackle, wheeze).
    Matches Section 2.1 (Baseline Model) in your Methods.
    """
    def __init__(
        self,
        ast_kwargs: dict,
        hidden_dim: int = 64,
        dropout_p: float = 0.3,
        num_labels: int = 2,      # crackle, wheeze
    ):
        super().__init__()
        # AST as backbone-only encoder (returns h_CLS)
        self.ast = ASTModel(
            backbone_only=True,
            **ast_kwargs
        )
        print(self.ast)
        D = self.ast.original_embedding_dim

        self.classifier = nn.Sequential(
            nn.LayerNorm(D),
            nn.Dropout(dropout_p),
            nn.Linear(D, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_p),
            nn.Linear(hidden_dim, num_labels)   # logits (no sigmoid)
        )

    def forward(self, x):
        """
        x: (B, T, F) or (B, 1, F, T)
        returns: (B, 2) logits (crackle, wheeze)
        """
        h_cls = self.ast(x)           # (B, D), your h_CLS
        print(f"h_cls shape: {h_cls.shape}")  # (B, D)
        logits = self.classifier(h_cls)
        print(f"logits shape: {logits.shape}")  # (B, 2)
        return logits
    
    

In [8]:
ast_kwargs = dict(
    label_dim=2,          # unused since backbone_only=True
    fstride=10,
    tstride=10,
    input_fdim=128,
    input_tdim=1024,
    imagenet_pretrain=True,
    audioset_pretrain=True,
    audioset_ckpt_path='/Users/gkont/Documents/Code/pretrained_models/audioset_10_10_0.4593.pth',
    model_size='base384',
    verbose=True,
)

baseline_model = ASTBaseline(ast_kwargs, hidden_dim=64, dropout_p=0.3, num_labels=2).to(DEVICE)


---------------AST Model Summary---------------
ImageNet pretraining: True, AudioSet pretraining: True
Loading AudioSet pretrained model from /Users/gkont/Documents/Code/pretrained_models/audioset_10_10_0.4593.pth
No mismatch for key: v.cls_token
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.pos_embed
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.dist_token
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.patch_embed.proj.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.patch_embed.proj.bias
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.norm1.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.norm1.bias
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.attn.qkv.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.at

In [9]:
out = baseline_model(dummy_input)  # (2, 2)
print(out.shape)  # torch.Size([2, 2])

Images size after patch embedding: torch.Size([2, 1212, 768])
CLS token shape: torch.Size([1, 1, 768])
Expanded CLS token shape: torch.Size([2, 1, 768])
Expanded distillation token shape: torch.Size([2, 1, 768])
Shape after concatenating tokens: torch.Size([2, 1214, 768])
Positional embedding shape: torch.Size([1, 1214, 768])
torch.Size([2, 1214, 768])
x.shape after blocks: torch.Size([2, 1214, 768])
x.shape after norm: torch.Size([2, 1214, 768])
x.shape after averaging tokens: torch.Size([2, 768])
h_cls shape: torch.Size([2, 768])
logits shape: torch.Size([2, 2])
torch.Size([2, 2])


## class ASTWithNaiveMetadataConcat(nn.Module)

In [None]:
class ASTWithNaiveMetadataConcat(nn.Module):
    """
    Early fusion: [h_CLS; m] -> MLP -> 2 logits.

    m is assumed to be a normalized, preprocessed metadata vector
    (e.g. z-scored continuous + one-hot categorical).
    """
    def __init__(
        self,
        ast_kwargs: dict,
        metadata_dim: int,
        hidden_dim: int = 128,
        dropout_p: float = 0.3,
        num_labels: int = 2,
    ):
        super().__init__()
        self.ast = ASTModel(
            backbone_only=True,
            **ast_kwargs
        )
        D = self.ast.original_embedding_dim
        self.metadata_dim = metadata_dim

        self.classifier = nn.Sequential(
            nn.LayerNorm(D + metadata_dim),
            nn.Dropout(dropout_p),
            nn.Linear(D + metadata_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_p),
            nn.Linear(hidden_dim, num_labels)
        )

    def forward(self, x, m):
        """
        x: (B, T, F) or (B, 1, F, T)
        m: (B, M) metadata vector
        """
        h_cls = self.ast(x)              # (B, D)
        z = torch.cat([h_cls, m], dim=-1)  # (B, D+M)
        logits = self.classifier(z)
        return logits

In [None]:
astnaive = ASTWithNaiveMetadataConcat(ast_kwargs, metadata_dim=10, hidden_dim=64, dropout_p=0.3, num_labels=2).to(DEVICE)
print(astnaive)

### Projected added metadata

In [None]:
class ASTWithMetadataProjection(nn.Module):
    """
    Metadata projection fusion:
        m -> m' in R^D
        h_tilde = h_CLS + m'
        h_tilde -> MLP -> logits
    """
    def __init__(
        self,
        ast_kwargs: dict,
        metadata_dim: int,
        hidden_dim: int = 64,
        dropout_p: float = 0.3,
        num_labels: int = 2,
    ):
        super().__init__()
        self.ast = ASTModel(
            backbone_only=True,
            **ast_kwargs
        )
        D = self.ast.original_embedding_dim
        self.metadata_dim = metadata_dim

        # Linear projection m -> m' in R^D
        self.metadata_proj = nn.Linear(metadata_dim, D)

        # Same style classifier as baseline
        self.classifier = nn.Sequential(
            nn.LayerNorm(D),
            nn.Dropout(dropout_p),
            nn.Linear(D, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_p),
            nn.Linear(hidden_dim, num_labels)
        )

    def forward(self, x, m):
        """
        x: (B, T, F) or (B, 1, F, T)
        m: (B, M)
        """
        h_cls = self.ast(x)               # (B, D)
        m_prime = self.metadata_proj(m)   # (B, D)
        h_tilde = h_cls + m_prime         # (B, D)  -- Eq. (11)
        logits = self.classifier(h_tilde)
        return logits

In [None]:
astproj = ASTWithMetadataProjection(ast_kwargs, metadata_dim=10, hidden_dim=64, dropout_p=0.3, num_labels=2).to(DEVICE)
print(astproj)
out = astproj(dummy_input, metadata)  # (2, 2)
print(out.shape)  # torch.Size([2, 2])

### FiLM: Metadata conditioning inside the Transformer

In [18]:
class ASTFiLM(ASTModel):
    """
    AST with FiLM conditioning on selected Transformer layers.

    - Uses ASTModel's __init__ to build the DeiT backbone.
    - Adds a metadata encoder g(m).
    - For a set of layers B, generates layer-specific gamma_l, beta_l.
    - Applies FiLM after the MHSA residual, before the FFN (as per your equations).
    """
    def __init__(
        self,
        metadata_dim: int,
        conditioned_layers=(0, 1, 2, 3),   # indices into self.v.blocks
        metadata_hidden_dim: int = 64,
        film_hidden_dim: int = 64,
        num_labels: int = 2,
        dropout_p: float = 0.3,
        ast_kwargs: dict = None,
    ):
        ast_kwargs = ast_kwargs or {}
        super().__init__(backbone_only=True, **ast_kwargs)

        self.metadata_dim = metadata_dim
        self.conditioned_layers = sorted(list(conditioned_layers))
        self.conditioned_layers_set = set(self.conditioned_layers)

        D = self.original_embedding_dim
        self.num_layers = len(self.v.blocks)

        # Metadata encoder h_m = g(m)
        self.metadata_encoder = nn.Sequential(
            nn.LayerNorm(metadata_dim),
            nn.Linear(metadata_dim, metadata_hidden_dim),
            nn.ReLU(),
            nn.Linear(metadata_hidden_dim, film_hidden_dim),
            nn.ReLU(),
        )

        # One FiLM generator per conditioned layer:
        # f_l: h_m -> [gamma_l || beta_l] in R^{2D}
        self.film_generators = nn.ModuleDict()
        for l in self.conditioned_layers:
            self.film_generators[str(l)] = nn.Linear(film_hidden_dim, 2 * D)

        # Classification head (same as baseline)
        self.classifier = nn.Sequential(
            nn.LayerNorm(D),
            nn.Dropout(dropout_p),
            nn.Linear(D, 64),
            nn.ReLU(),
            nn.Dropout(dropout_p),
            nn.Linear(64, num_labels)
        )

    def forward_features(self, x, m):
        """
        x: (B, 1, F, T)  (we'll handle 3D in forward())
        m: (B, M) metadata
        returns: h_CLS in R^D (FiLM-conditioned)
        """
        B = x.shape[0]

        # Patch embedding: (B, N, D)
        x = self.v.patch_embed(x)

        # CLS + dist tokens and positional embeddings
        cls_tokens = self.v.cls_token.expand(B, -1, -1)
        dist_token = self.v.dist_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, dist_token, x), dim=1)
        x = x + self.v.pos_embed
        x = self.v.pos_drop(x)

        # Encode metadata once, then generate FiLM params per layer
        h_m = self.metadata_encoder(m)   # (B, film_hidden_dim)

        gamma = {}
        beta = {}
        for l in self.conditioned_layers:
            film = self.film_generators[str(l)](h_m)  # (B, 2D)
            g_l, b_l = film.chunk(2, dim=-1)          # (B, D), (B, D)
            gamma[l] = g_l
            beta[l] = b_l

        # Manually unroll each Transformer block
        for layer_idx, blk in enumerate(self.v.blocks):
            # MHSA sublayer with pre-norm
            attn_out = blk.attn(blk.norm1(x))
            x = x + blk.drop_path(attn_out)   # Z_tilde_l

            # Apply FiLM after MHSA for selected layers
            if layer_idx in self.conditioned_layers_set:
                print(f"Applying FiLM at layer {layer_idx}")
                g_l = gamma[layer_idx].unsqueeze(1)   # (B, 1, D) -> broadcast over tokens
                b_l = beta[layer_idx].unsqueeze(1)    # (B, 1, D)
                x = g_l * x + b_l                     # Eq. (14â€“15): hat{Z}_l

            # FFN + residual
            x = x + blk.drop_path(blk.mlp(blk.norm2(x)))

        x = self.v.norm(x)
        h_cls = (x[:, 0] + x[:, 1]) / 2  # pooled token
        return h_cls

    def forward(self, x, m):
        """
        x: (B, T, F) or (B, 1, F, T)
        m: (B, M) metadata
        """
        if x.dim() == 3:
            x = x.unsqueeze(1).transpose(2, 3)  # (B, 1, F, T)

        h_cls = self.forward_features(x, m)
        logits = self.classifier(h_cls)
        return logits

In [20]:
astfilm = ASTFiLM(
    ast_kwargs=ast_kwargs, 
    metadata_dim=10,
    conditioned_layers=range(12), # all layers
    metadata_hidden_dim=64, 
    film_hidden_dim=64,
    dropout_p=0.3, 
    num_labels=2
).to(DEVICE)
print(astfilm)
metadata  = torch.randn(2, 10).to(DEVICE)  # (B, M)
out = astfilm(dummy_input, metadata)  # (2, 2)
print(out.shape)  # torch.Size([2, 2])

---------------AST Model Summary---------------
ImageNet pretraining: True, AudioSet pretraining: True
Loading AudioSet pretrained model from /Users/gkont/Documents/Code/pretrained_models/audioset_10_10_0.4593.pth
No mismatch for key: v.cls_token
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.pos_embed
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.dist_token
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.patch_embed.proj.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.patch_embed.proj.bias
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.norm1.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.norm1.bias
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.attn.qkv.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.at

In [21]:
class ASTFiLMPlusPlus(ASTModel):
    """
    FiLM++: factor-aligned grouped FiLM with K=3 groups:
      - device group    (D_dev)
      - site group      (D_site)
      - rest group      (D_rest)

    Metadata inputs are passed separately as (m_dev, m_site, m_rest).
    """
    def __init__(
        self,
        dev_metadata_dim: int,
        site_metadata_dim: int,
        rest_metadata_dim: int,
        D_dev: int,
        D_site: int,
        conditioned_layers=(0, 1, 2, 3),
        metadata_hidden_dim: int = 64,
        film_hidden_dim: int = 64,
        num_labels: int = 2,
        dropout_p: float = 0.3,
        ast_kwargs: dict = None,
    ):
        ast_kwargs = ast_kwargs or {}
        super().__init__(backbone_only=True, **ast_kwargs)

        self.conditioned_layers = sorted(list(conditioned_layers))
        self.conditioned_layers_set = set(self.conditioned_layers)

        D_total = self.original_embedding_dim
        assert D_dev + D_site <= D_total, "D_dev + D_site must be <= total embedding dim"
        D_rest = D_total - D_dev - D_site

        self.D_dev = D_dev
        self.D_site = D_site
        self.D_rest = D_rest
        self.D_total = D_total

        # --- Metadata encoders for each factor ---
        self.dev_encoder = nn.Sequential(
            nn.LayerNorm(dev_metadata_dim),
            nn.Linear(dev_metadata_dim, metadata_hidden_dim),
            nn.ReLU(),
            nn.Linear(metadata_hidden_dim, film_hidden_dim),
            nn.ReLU(),
        )
        self.site_encoder = nn.Sequential(
            nn.LayerNorm(site_metadata_dim),
            nn.Linear(site_metadata_dim, metadata_hidden_dim),
            nn.ReLU(),
            nn.Linear(metadata_hidden_dim, film_hidden_dim),
            nn.ReLU(),
        )
        self.rest_encoder = nn.Sequential(
            nn.LayerNorm(rest_metadata_dim),
            nn.Linear(rest_metadata_dim, metadata_hidden_dim),
            nn.ReLU(),
            nn.Linear(metadata_hidden_dim, film_hidden_dim),
            nn.ReLU(),
        )

        # --- FiLM generators per group and per conditioned layer ---
        self.dev_film = nn.ModuleDict()
        self.site_film = nn.ModuleDict()
        self.rest_film = nn.ModuleDict()
        for l in self.conditioned_layers:
            self.dev_film[str(l)] = nn.Linear(film_hidden_dim, 2 * D_dev)
            self.site_film[str(l)] = nn.Linear(film_hidden_dim, 2 * D_site)
            self.rest_film[str(l)] = nn.Linear(film_hidden_dim, 2 * D_rest)

        # Classification head (same style as baseline)
        self.classifier = nn.Sequential(
            nn.LayerNorm(D_total),
            nn.Dropout(dropout_p),
            nn.Linear(D_total, 64),
            nn.ReLU(),
            nn.Dropout(dropout_p),
            nn.Linear(64, num_labels),
        )

    def _apply_filmpp_grouped(self, x, gammas, betas):
        """
        x:      (B, T, D_total)
        gammas: dict with 'dev','site','rest' tensors (B, D_group)
        betas:  same as above

        Applies group-wise FiLM and returns (B, T, D_total).
        """
        B, T, D = x.shape
        D_dev, D_site, D_rest = self.D_dev, self.D_site, self.D_rest

        x_dev, x_site, x_rest = torch.split(
            x, [D_dev, D_site, D_rest], dim=-1
        )  # each (B, T, D_group)

        # Broadcast gammas/betas over tokens
        g_dev = gammas["dev"].unsqueeze(1)   # (B, 1, D_dev)
        b_dev = betas["dev"].unsqueeze(1)
        g_site = gammas["site"].unsqueeze(1) # (B, 1, D_site)
        b_site = betas["site"].unsqueeze(1)
        g_rest = gammas["rest"].unsqueeze(1) # (B, 1, D_rest)
        b_rest = betas["rest"].unsqueeze(1)

        x_dev_hat = g_dev * x_dev + b_dev
        x_site_hat = g_site * x_site + b_site
        x_rest_hat = g_rest * x_rest + b_rest

        x_hat = torch.cat([x_dev_hat, x_site_hat, x_rest_hat], dim=-1)  # (B, T, D_total)
        return x_hat

    def forward_features(self, x, m_dev, m_site, m_rest):
        """
        x:      (B, 1, F, T)
        m_dev:  (B, dev_metadata_dim)
        m_site: (B, site_metadata_dim)
        m_rest: (B, rest_metadata_dim)
        """
        B = x.shape[0]

        # Patch embedding
        x = self.v.patch_embed(x)
        cls_tokens = self.v.cls_token.expand(B, -1, -1)
        dist_token = self.v.dist_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, dist_token, x), dim=1)
        x = x + self.v.pos_embed
        x = self.v.pos_drop(x)

        # Encode metadata for each factor
        h_dev = self.dev_encoder(m_dev)
        h_site = self.site_encoder(m_site)
        h_rest = self.rest_encoder(m_rest)

        # Precompute gamma,beta for each conditioned layer
        gamma_dev, beta_dev = {}, {}
        gamma_site, beta_site = {}, {}
        gamma_rest, beta_rest = {}, {}

        for l in self.conditioned_layers:
            dev_params = self.dev_film[str(l)](h_dev)   # (B, 2*D_dev)
            site_params = self.site_film[str(l)](h_site) # (B, 2*D_site)
            rest_params = self.rest_film[str(l)](h_rest) # (B, 2*D_rest)

            g_dev, b_dev = dev_params.chunk(2, dim=-1)
            g_site, b_site = site_params.chunk(2, dim=-1)
            g_rest, b_rest = rest_params.chunk(2, dim=-1)

            gamma_dev[l], beta_dev[l] = g_dev, b_dev
            gamma_site[l], beta_site[l] = g_site, b_site
            gamma_rest[l], beta_rest[l] = g_rest, b_rest

        # Unroll ViT blocks with FiLM++ after MHSA
        for layer_idx, blk in enumerate(self.v.blocks):
            attn_out = blk.attn(blk.norm1(x))
            x = x + blk.drop_path(attn_out)   # (B, T, D_total)

            if layer_idx in self.conditioned_layers_set:
                gammas = {
                    "dev": gamma_dev[layer_idx],
                    "site": gamma_site[layer_idx],
                    "rest": gamma_rest[layer_idx],
                }
                betas = {
                    "dev": beta_dev[layer_idx],
                    "site": beta_site[layer_idx],
                    "rest": beta_rest[layer_idx],
                }
                x = self._apply_filmpp_grouped(x, gammas, betas)

            x = x + blk.drop_path(blk.mlp(blk.norm2(x)))

        x = self.v.norm(x)
        h_cls = (x[:, 0] + x[:, 1]) / 2
        return h_cls

    def forward(self, x, m_dev, m_site, m_rest):
        """
        x:      (B, T, F) or (B, 1, F, T)
        m_dev:  (B, dev_metadata_dim)
        m_site: (B, site_metadata_dim)
        m_rest: (B, rest_metadata_dim)
        """
        if x.dim() == 3:
            x = x.unsqueeze(1).transpose(2, 3)

        h_cls = self.forward_features(x, m_dev, m_site, m_rest)
        logits = self.classifier(h_cls)
        return logits

In [24]:
astfilmpp = ASTFiLMPlusPlus(
    dev_metadata_dim=4,
    site_metadata_dim=7,
    rest_metadata_dim=3,
    D_dev=128,
    D_site=128,
    ast_kwargs=ast_kwargs,
    conditioned_layers=(10, 11, 12), # last 3 layers
    metadata_hidden_dim=64, 
    film_hidden_dim=64,
    dropout_p=0.3, 
    num_labels=2
).to(DEVICE)
print(astfilmpp)

dev_metadata = torch.randn(2, 4).to(DEVICE)
site_metadata = torch.randn(2, 7).to(DEVICE)
rest_metadata = torch.randn(2, 3).to(DEVICE)
out = astfilmpp(dummy_input, dev_metadata, site_metadata, rest_metadata)  # (2, 2)
print(out.shape)  # torch.Size([2, 2])

---------------AST Model Summary---------------
ImageNet pretraining: True, AudioSet pretraining: True
Loading AudioSet pretrained model from /Users/gkont/Documents/Code/pretrained_models/audioset_10_10_0.4593.pth
No mismatch for key: v.cls_token
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.pos_embed
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.dist_token
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.patch_embed.proj.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.patch_embed.proj.bias
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.norm1.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.norm1.bias
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.attn.qkv.weight
No mlp_head weights loaded from AudioSet checkpoint.
No mismatch for key: v.blocks.0.at