In [8]:
import yaml
with open('../config/dcase20_proteacher.yaml') as f:
    cfg = yaml.safe_load(f)

In [6]:
ls .

 dcase19_baseline.yaml               dcase20_jointformer.yaml
'dcase19_jointformer copy.yaml'      dcase20_proteacher.yaml
 dcase19_jointformer_feat_rec.yaml   dcase21_baseline.yaml
 dcase19_jointformer.yaml            dcase21_jointformer.yaml
 dcase20_baseline.yaml               Untitled-1.ipynb


In [11]:
# %load src/models/pro_teacher.py
import torch
import math
from models.conformer.conformer_encoder import ConformerPromptedEncoder
from models.conformer.downsampler import CNNLocalDownsampler

    
class SEDModel(torch.nn.Module):
    def __init__(
        self,
        n_class,
        cnn_kwargs=None,
        encoder_kwargs=None,
        pooling="token",
        layer_init="pytorch",
    ):
        super(SEDModel, self).__init__()
        self.cnn_downsampler = CNNLocalDownsampler(n_in_channel=1, **cnn_kwargs)
        input_dim = self.cnn_downsampler.cnn.nb_filters[-1]
        adim = encoder_kwargs["adim"]
        self.pooling = pooling
        self.encoder = ConformerPromptedEncoder(input_dim, **encoder_kwargs)
        self.pred_head = torch.nn.Linear(adim, n_class)

        if self.pooling == "attention":
            self.dense = torch.nn.Linear(adim, n_class)
            self.sigmoid = torch.sigmoid
            self.softmax = torch.nn.Softmax(dim=-1)

        elif self.pooling == "token":
            # self.cls_token = torch.nn.Linear(1, input_dim)
            self.tag_token = torch.nn.Parameter(torch.zeros(1, 1, input_dim))
        self.dropout = torch.nn.Dropout(0.1)
        self.reset_parameters(layer_init)

    def forward(self, x, mask=None, prompt_tuning=True):
        x = self.cnn_downsampler(x)
        # x = x.squeeze(-1).permute(0, 2, 1)
        seq_len = x.size(1)
        if self.pooling == "token":
            x = torch.cat([self.tag_token.expand(x.size(0), -1, -1), x], dim=1)
            
        x, _ = self.encoder(x, mask, prompt_tuning)
        
        # clip head input
        x = torch.cat([x[:, 0:1], x[:, - seq_len:]], dim=1)
        if self.pooling == "attention":
            strong = self.pred_head(x)
            sof = self.dense(x)  # [bs, frames, nclass]
            sof = self.softmax(sof)
            sof = torch.clamp(sof, min=1e-7, max=1)
            weak = (torch.sigmoid(strong) * sof).sum(1) / sof.sum(1)  # [bs, nclass]
            # Convert to logit to calculate loss with bcelosswithlogits
            weak = torch.log(weak / (1 - weak))
            
        elif self.pooling == "token":
            x = self.pred_head(x)
            weak = x[:, 0, :]
            strong = x[:, 1:, :]
            
        return {"strong": strong, "weak": weak}

    def reset_parameters(self, initialization: str = "pytorch"):
        if initialization.lower() == "pytorch":
            return
        # weight init
        for p in self.parameters():
            if p.dim() > 1:
                if initialization.lower() == "xavier_uniform":
                    torch.nn.init.xavier_uniform_(p.data)
                elif initialization.lower() == "xavier_normal":
                    torch.nn.init.xavier_normal_(p.data)
                elif initialization.lower() == "kaiming_uniform":
                    torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
                elif initialization.lower() == "kaiming_normal":
                    torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
                else:
                    raise ValueError(f"Unknown initialization: {initialization}")
        # bias init
        for p in self.parameters():
            if p.dim() == 1:
                p.data.zero_()
        # reset some modules with default init
        for m in self.modules():
            if isinstance(m, (torch.nn.Embedding, LayerNorm)):
                m.reset_parameters()
                
    def get_masked_parameters(self, mask_param=None):
        if mask_param is not None:
            return self.parameters()

In [9]:
cfg

{'feature': {'audio_root': '/data0/gaolj/sed_data/DCASE2020/audio',
  'feat_root': '/data0/gaolj/sed_data/DCASE2020/features',
  'sample_rate': 16000,
  'gain': -3,
  'highpass': 10,
  'mel_spec': {'n_mels': 64, 'n_fft': 1024, 'hop_size': 323}},
 'model': {'cnn': {'activation': 'Relu',
   'conv_dropout': 0.1,
   'kernel_size': [3, 3, 3, 3],
   'padding': [1, 1, 1, 1],
   'stride': [1, 1, 1, 1],
   'nb_filters': [16, 32, 64, 128],
   'pooling': [[2, 4], [2, 2], [2, 2], [1, 2]],
   'patchsize': 8},
  'encoder_type': 'Conformer',
  'encoder': {'adim': 144,
   'aheads': 4,
   'dropout_rate': 0.1,
   'elayers': 3,
   'eunits': 576,
   'kernel_size': 7,
   'prompt_nums': 10,
   'prompt_layers': 1},
  'decoder': {'idim': 144,
   'adim': 144,
   'fdim': 64,
   'aheads': 4,
   'dropout_rate': 0.1,
   'elayers': 2,
   'eunits': 256,
   'kernel_size': 7,
   'cnn_upsampler': {'activation': 'Relu',
    'conv_dropout': 0.1,
    'kernel_size': [2, 2, 2],
    'padding': [0, 0, 0],
    'stride': [2, 2,

In [10]:
model = SEDModel(n_class=10, cnn_kwargs=cfg["model"]["cnn"],
                     encoder_kwargs=cfg["model"]["encoder"],)

In [14]:
x = torch.ones(2, 486, 64)
_ = model(x)

RuntimeError: shape '[2, 60, 1, 8, 64]' is invalid for input of size 62208