In [1]:
import h5py
import numpy as np
from collections import Counter
from sklearn.model_selection import train_test_split
import imblearn
from imblearn.over_sampling import SMOTE
from imblearn.over_sampling import RandomOverSampler
import os
import warnings
import yaml
import argparse
import pandas as pd
import csv
import os
import pandas as pd
from Feature_extract import feature_transform
from Datagenerator import Datagen
import numpy as np
import torch
from torch.utils.data import DataLoader
from Model import ProtoNet,ResNet
from tqdm import tqdm
from collections import Counter
from batch_sampler import EpisodicBatchSampler
from torch.nn import functional as F
from util import prototypical_loss as loss_fn
from util import evaluate_prototypes
from glob import glob
import hydra
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
from omegaconf import DictConfig, OmegaConf
import pathlib
import pprint
import SepTr
hydra.core.global_hydra.GlobalHydra.instance().clear()
import gc
gc.collect()
torch.cuda.empty_cache()

config_dir = pathlib.Path('.')
hydra.initialize(config_path=config_dir)

conf = hydra.compose(config_name='config_online.yaml', overrides=["set.train=True",])

def init_seed():
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)

# In this notebook we will
- make a dataloader which loads raw audio
- apply transformations to the audio
- make dataloader which uses transformations
- ?

# Load balanced raw audio dataset
- how to load audio
- how to balance it?

In [2]:
import os
import librosa
import h5py
import pandas as pd
import numpy as np
from scipy import signal
from glob import glob
from itertools import chain

pd.options.mode.chained_assignment = None



def create_dataset(df_pos,pcen,glob_cls_name,file_name,hf,seg_len,hop_seg,fps):

    '''Chunk the time-frequecy representation to segment length and store in h5py dataset

    Args:
        -df_pos : dataframe
        -log_mel_spec : log mel spectrogram
        -glob_cls_name: Name of the class used in audio files where only one class is present
        -file_name : Name of the csv file
        -hf: h5py object
        -seg_len : fixed segment length
        -fps: frame per second
    Out:
        - label_list: list of labels for the extracted mel patches'''

    label_list = []
    if len(hf['features'][:]) == 0:
        file_index = 0
    else:
        file_index = len(hf['features'][:])


    start_time,end_time = time_2_frame(df_pos,fps)


    'For csv files with a column name Call, pick up the global class name'

    if 'CALL' in df_pos.columns:
        cls_list = [glob_cls_name] * len(start_time)
    else:
        cls_list = [df_pos.columns[(df_pos == 'POS').loc[index]].values for index, row in df_pos.iterrows()]
        cls_list = list(chain.from_iterable(cls_list))

    assert len(start_time) == len(end_time)
    assert len(cls_list) == len(start_time)

    for index in range(len(start_time)):

        str_ind = start_time[index]
        end_ind = end_time[index]
        label = cls_list[index]

        'Extract segment and move forward with hop_seg'

        if end_ind - str_ind > seg_len:
            shift = 0
            while end_ind - (str_ind + shift) > seg_len:

                pcen_patch = pcen[int(str_ind + shift):int(str_ind + shift + seg_len)]

                hf['features'].resize((file_index + 1, pcen_patch.shape[0], pcen_patch.shape[1]))
                hf['features'][file_index] = pcen_patch
                label_list.append(label)
                file_index += 1
                shift = shift + hop_seg

            pcen_patch_last = pcen[end_ind - seg_len:end_ind]



            hf['features'].resize((file_index+1 , pcen_patch.shape[0], pcen_patch.shape[1]))
            hf['features'][file_index] = pcen_patch_last
            label_list.append(label)
            file_index += 1
        else:

            'If patch length is less than segment length then tile the patch multiple times till it reaches the segment length'

            pcen_patch = pcen[str_ind:end_ind]
            if pcen_patch.shape[0] == 0:
                print(pcen_patch.shape[0])
                print("The patch is of 0 length")
                continue

            repeat_num = int(seg_len / (pcen_patch.shape[0])) + 1
            pcen_patch_new = np.tile(pcen_patch, (repeat_num, 1))
            pcen_patch_new = pcen_patch_new[0:int(seg_len)]
            hf['features'].resize((file_index+1, pcen_patch_new.shape[0], pcen_patch_new.shape[1]))
            hf['features'][file_index] = pcen_patch_new
            label_list.append(label)
            file_index += 1

    
    print("Total files created : {}".format(file_index))
    return label_list

class Feature_Extractor():

       def __init__(self, conf):
           self.sr =conf.features.sr
           self.n_fft = conf.features.n_fft
           self.hop = conf.features.hop_mel
           self.n_mels = conf.features.n_mels
           self.fmax = conf.features.fmax
           #self.win_length = conf.features.win_length
       def extract_feature(self,audio):

           mel_spec = librosa.feature.melspectrogram(audio,sr=self.sr, n_fft=self.n_fft,
                                                     hop_length=self.hop,n_mels=self.n_mels,fmax=self.fmax)
           pcen = librosa.core.pcen(mel_spec,sr=22050)
           pcen = pcen.astype(np.float32)

           return pcen

def extract_feature(audio_path,feat_extractor,conf):

    y,fs = librosa.load(audio_path,sr=conf.features.sr)

    'Scaling audio as per suggestion in librosa documentation'

    y = y * (2**32)
    pcen = feat_extractor.extract_feature(y)
    return pcen.T



def time_2_frame(df,fps):


    'Margin of 25 ms around the onset and offsets'

    df.loc[:,'Starttime'] = df['Starttime'] - 0.025
    df.loc[:,'Endtime'] = df['Endtime'] + 0.025

    'Converting time to frames'

    start_time = [int(np.floor(start * fps)) for start in df['Starttime']]

    end_time = [int(np.floor(end * fps)) for end in df['Endtime']]

    return start_time,end_time

def feature_transform(conf=None,mode=None):
    '''
       Training:
          Extract mel-spectrogram/PCEN and slice each data sample into segments of length conf.seg_len.
          Each segment inherits clip level label. The segment length is kept same across training
          and validation set.
       Evaluation:
           Currently using the validation set for evaluation.
           
           For each audio file, extract time-frequency representation and create 3 subsets:
           a) Positive set - Extract segments based on the provided onset-offset annotations.
           b) Negative set - Since there is no negative annotation provided, we consider the entire
                         audio file as the negative class and extract patches of length conf.seg_len
           c) Query set - From the end time of the 5th annotation to the end of the audio file.
                          Onset-offset prediction is made on this subset.

       Args:
       - config: config object
       - mode: train/valid

       Out:
       - Num_extract_train/Num_extract_valid - Number of samples in training/validation set
                                                                                              '''


    label_tr = []
    pcen_extractor = Feature_Extractor(conf)

    fps =  conf.features.sr / conf.features.hop_mel
    'Converting fixed segment legnth to frames'

    seg_len = int(round(conf.features.seg_len * fps))
    hop_seg = int(round(conf.features.hop_seg * fps))
    extension = "*.csv"


    if mode == 'train':

        print("=== Processing training set ===")
        meta_path = conf.path.train_dir
        all_csv_files = [file
                         for path_dir, subdir, files in os.walk(meta_path)
                         for file in glob(os.path.join(path_dir, extension))]
        all_csv_files = all_csv_files[:100]
        hdf_tr = os.path.join(conf.path.feat_train,'Mel_train.h5')
        hf = h5py.File(hdf_tr,'w')
        hf.create_dataset('features', shape=(0, seg_len, conf.features.n_mels),
                          maxshape=(None, seg_len, conf.features.n_mels))
        num_extract = 0
        for file in all_csv_files:

            split_list = file.split('/')
            glob_cls_name = split_list[split_list.index('Training_Set') + 1]
            file_name = split_list[split_list.index('Training_Set') + 2]
            df = pd.read_csv(file, header=0, index_col=False)
            audio_path = file.replace('csv', 'wav')
            print("Processing file name {}".format(audio_path))
            pcen = extract_feature(audio_path, pcen_extractor,conf)
            df_pos = df[(df == 'POS').any(axis=1)]
            label_list = create_dataset(df_pos,pcen,glob_cls_name,file_name,hf,seg_len,hop_seg,fps)
            label_tr.append(label_list)
        print(" Feature extraction for training set complete")
        num_extract = len(hf['features'])
        flat_list = [item for sublist in label_tr for item in sublist]
        hf.create_dataset('labels', data=[s.encode() for s in flat_list], dtype='S20')
        data_shape = hf['features'].shape
        hf.close()
        return num_extract,data_shape

    else:

        print("=== Processing Validation set ===")

        meta_path = conf.path.eval_dir

        all_csv_files = [file
                         for path_dir, subdir, files in os.walk(meta_path)
                         for file in glob(os.path.join(path_dir, extension))]

        num_extract_eval = 0

        for file in all_csv_files:

            idx_pos = 0
            idx_neg = 0
            start_neg = 0
            hop_neg = 0
            idx_query = 0
            hop_query = 0
            strt_index = 0

            split_list = file.split('/')
            name = str(split_list[-1].split('.')[0])
            feat_name = name + '.h5'
            audio_path = file.replace('csv', 'wav')
            feat_info = []
            hdf_eval = os.path.join(conf.path.feat_eval,feat_name)
            hf = h5py.File(hdf_eval,'w')
            

            df_eval = pd.read_csv(file, header=0, index_col=False)
            Q_list = df_eval['Q'].to_numpy()

            start_time,end_time = time_2_frame(df_eval,fps)

            index_sup = np.where(Q_list == 'POS')[0][:conf.train.n_shot]

            difference = []
            for index in index_sup:
                difference.append(end_time[index] - start_time[index])
            
            # Adaptive segment length based on the audio file. 
            max_len = max(difference)
            
            # Choosing the segment length based on the maximum size in the 5-shot.
            # Logic was based on fitment on 12GB GPU since some segments are quite long. 
            if max_len < 100:

                seg_len = max_len
            elif max_len > 100 and max_len < 500 :
                seg_len = max_len//4
            else:
                seg_len = max_len//8
                

            
            print(f"Segment length for file is {seg_len}")
            hop_seg = seg_len//2

            hf.create_dataset('feat_pos', shape=(0, seg_len, conf.features.n_mels),
                              maxshape= (None, seg_len, conf.features.n_mels))
            hf.create_dataset('feat_query',shape=(0,seg_len,conf.features.n_mels),maxshape=(None,seg_len,conf.features.n_mels))
            hf.create_dataset('feat_neg',shape=(0,seg_len,conf.features.n_mels),maxshape=(None,seg_len,conf.features.n_mels))
            hf.create_dataset('start_index_query',shape=(1,),maxshape=(None))

            

            
            hf.create_dataset('seg_len',shape=(1,), maxshape=(None))
            hf.create_dataset('hop_seg',shape=(1,), maxshape=(None))
            pcen = extract_feature(audio_path, pcen_extractor,conf)
            mean = np.mean(pcen)
            std = np.mean(pcen)
            hf['seg_len'][:] = seg_len
            hf['hop_seg'][:] = hop_seg

            strt_indx_query = end_time[index_sup[-1]]
            end_idx_neg = pcen.shape[0] - 1
            hf['start_index_query'][:] = strt_indx_query

            print("Creating negative dataset")

            while end_idx_neg - (strt_index + hop_neg) > seg_len:

                patch_neg = pcen[int(strt_index + hop_neg):int(strt_index + hop_neg + seg_len)]

                hf['feat_neg'].resize((idx_neg + 1, patch_neg.shape[0], patch_neg.shape[1]))
                hf['feat_neg'][idx_neg] = patch_neg
                idx_neg += 1
                hop_neg += hop_seg

            last_patch = pcen[end_idx_neg - seg_len:end_idx_neg]
            hf['feat_neg'].resize((idx_neg + 1, last_patch.shape[0], last_patch.shape[1]))
            hf['feat_neg'][idx_neg] = last_patch

            print("Creating Positive dataset")
            for index in index_sup:

                str_ind = int(start_time[index])
                end_ind = int(end_time[index])

                if end_ind - str_ind > seg_len:

                    shift = 0
                    while end_ind - (str_ind + shift) > seg_len:

                        patch_pos = pcen[int(str_ind + shift):int(str_ind + shift + seg_len)]

                        hf['feat_pos'].resize((idx_pos + 1, patch_pos.shape[0], patch_pos.shape[1]))
                        hf['feat_pos'][idx_pos] = patch_pos
                        idx_pos += 1
                        shift += hop_seg
                    last_patch_pos = pcen[end_ind - seg_len:end_ind]
                    hf['feat_pos'].resize((idx_pos + 1, patch_pos.shape[0], patch_pos.shape[1]))
                    hf['feat_pos'][idx_pos] = last_patch_pos
                    idx_pos += 1

                else:
                    patch_pos = pcen[str_ind:end_ind]

                    if patch_pos.shape[0] == 0:
                        print(patch_pos.shape[0])
                        print("The patch is of 0 length")
                        continue
                    repeat_num = int(seg_len / (patch_pos.shape[0])) + 1

                    patch_new = np.tile(patch_pos, (repeat_num, 1))
                    patch_new = patch_new[0:int(seg_len)]
                    hf['feat_pos'].resize((idx_pos + 1, patch_new.shape[0], patch_new.shape[1]))
                    hf['feat_pos'][idx_pos] = patch_new
                    idx_pos += 1



            print("Creating query dataset")

            while end_idx_neg - (strt_indx_query + hop_query) > seg_len:

                patch_query = pcen[int(strt_indx_query + hop_query):int(strt_indx_query + hop_query + seg_len)]
                hf['feat_query'].resize((idx_query + 1, patch_query.shape[0], patch_query.shape[1]))
                hf['feat_query'][idx_query] = patch_query
                idx_query += 1
                hop_query += hop_seg


            last_patch_query = pcen[end_idx_neg - seg_len:end_idx_neg]

            hf['feat_query'].resize((idx_query + 1, last_patch_query.shape[0], last_patch_query.shape[1]))
            hf['feat_query'][idx_query] = last_patch_query
            num_extract_eval += len(hf['feat_query'])

            hf.close()

        return num_extract_eval


In [3]:
# chunk up the original and store in dataset

In [4]:
# training
if conf.set.train:
    if not os.path.isdir(conf.path.Model):
        os.makedirs(conf.path.Model)

    init_seed()

    gen_train = Datagen(conf)
    X_train,Y_train,X_val,Y_val = gen_train.generate_train()
    X_tr = torch.tensor(X_train)
    Y_tr = torch.LongTensor(Y_train)
    X_val = torch.tensor(X_val)
    Y_val = torch.LongTensor(Y_val)

    samples_per_cls =  conf.train.n_shot * 2

    batch_size_tr = samples_per_cls * conf.train.k_way
    batch_size_vd = batch_size_tr

    if conf.train.num_episodes is not None:

        num_episodes_tr = conf.train.num_episodes
        num_episodes_vd = conf.train.num_episodes

    else:

        num_episodes_tr = len(Y_train)//batch_size_tr
        num_episodes_vd = len(Y_val)//batch_size_vd

    samplr_train = EpisodicBatchSampler(Y_train,num_episodes_tr,conf.train.k_way,samples_per_cls)
    samplr_valid = EpisodicBatchSampler(Y_val,num_episodes_vd,conf.train.k_way,samples_per_cls)

    train_dataset = torch.utils.data.TensorDataset(X_tr,Y_tr)
    valid_dataset = torch.utils.data.TensorDataset(X_val,Y_val)

    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_sampler=samplr_train,num_workers=0,pin_memory=True,shuffle=False)
    valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset,batch_sampler=samplr_valid,num_workers=0,pin_memory=True,shuffle=False)
    if conf.train.encoder == 'Resnet':
        encoder  = ResNet()
    else:
        encoder = ProtoNet()


    best_acc,model = train_protonet(encoder,train_loader,valid_loader,conf,num_episodes_tr,num_episodes_vd)
    print("Best accuracy of the model on training set is {}".format(best_acc))

NameError: name 'train_protonet' is not defined

In [5]:
def train_protonet(encoder,train_loader,valid_loader,conf,num_batches_tr,num_batches_vd):

    '''Model training
    Args:
    -model: Model
    -train_laoder: Training loader
    -valid_load: Valid loader
    -conf: configuration object
    -num_batches_tr: number of training batches
    -num_batches_vd: Number of validation batches
    Out:
    -best_val_acc: Best validation accuracy
    -model
    -best_state: State dictionary for the best validation accuracy
    '''

    if conf.train.device == 'cuda':
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    
    
    optim = torch.optim.Adam([{'params':encoder.parameters()}] ,lr=conf.train.lr_rate)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optim, gamma=conf.train.scheduler_gamma,
                                                   step_size=conf.train.scheduler_step_size)
    num_epochs = conf.train.epochs

    best_model_path = conf.path.best_model
    last_model_path = conf.path.last_model
    train_loss = []
    val_loss = []
    train_acc = []
    val_acc = []
    best_val_acc = 0.0
    encoder.to(device)
    

    for epoch in range(num_epochs):

        print("Epoch {}".format(epoch))
        train_iterator = iter(train_loader)
        for batch in tqdm(train_iterator):
            optim.zero_grad()
            encoder.train()
            
            x, y = batch
            x = x.to(device)
            y = y.to(device)
            x_out = encoder(x)
            tr_loss,tr_acc = loss_fn(x_out,y,conf.train.n_shot)
            train_loss.append(tr_loss.item())
            train_acc.append(tr_acc.item())

            tr_loss.backward()
            optim.step()

        avg_loss_tr = np.mean(train_loss[-num_batches_tr:])
        avg_acc_tr = np.mean(train_acc[-num_batches_tr:])
        print('Average train loss: {}  Average training accuracy: {}'.format(avg_loss_tr,avg_acc_tr))
        lr_scheduler.step()
        encoder.eval()
        
        val_iterator = iter(valid_loader)

        for batch in tqdm(val_iterator):
            x,y = batch
            x = x.to(device)
            x_val = encoder(x)
            valid_loss, valid_acc = loss_fn(x_val, y, conf.train.n_shot)
            val_loss.append(valid_loss.item())
            val_acc.append(valid_acc.item())
        avg_loss_vd = np.mean(val_loss[-num_batches_vd:])
        avg_acc_vd = np.mean(val_acc[-num_batches_vd:])

        print ('Epoch {}, Validation loss {:.4f}, Validation accuracy {:.4f}'.format(epoch,avg_loss_vd,avg_acc_vd))
        if avg_acc_vd > best_val_acc:
            print("Saving the best model with valdation accuracy {}".format(avg_acc_vd))
            best_val_acc = avg_acc_vd
            #best_state = model.state_dict()
            torch.save({'encoder':encoder.state_dict()},best_model_path)
    torch.save({'encoder':encoder.state_dict()},last_model_path)

    return best_val_acc,encoder
# train_protonet(encoder,train_loader,valid_loader,conf,num_episodes_tr,num_episodes_vd)

In [6]:
gen_train = Datagen(conf)
X_train,Y_train,X_val,Y_val = gen_train.generate_train()
X_tr = torch.tensor(X_train)
Y_tr = torch.LongTensor(Y_train)
X_val = torch.tensor(X_val)
Y_val = torch.LongTensor(Y_val)


In [7]:
# hdf_path = '/home/asalimi/dcase-few-shot-bioacoustic/baselines/deep_learning/outputs/2022-06-06/11-10-25/Features/feat_train/Mel_train.h5'
# hdf_train = h5py.File(hdf_path, 'r+')
# hdf_train['features'][:]

In [9]:
x, y = next(iter(train_loader))
x = x.to("cpu")
y = y.to("cpu")
# model = ResNet()


In [10]:
x.shape

torch.Size([50, 17, 128])

# speratable transformer test

In [18]:
# This code is released under the CC BY-SA 4.0 license.

import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange


# helpers
def pair(t):
    return t if isinstance(t, tuple) else (t, t)


# classes
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x


class Scale(nn.Module):
    def __init__(self, val):
        super().__init__()
        self.val = val

    def forward(self, x):
        return x * self.val


class SepTrBlock(nn.Module):
    def __init__(self, channels, input_size, heads=3, mlp_dim=128, dim_head=32,
                 down_sample_input=None, project=False, reconstruct=False, dim=128, dropout_tr=0.0):
        super().__init__()
        patch_height, patch_width = pair(input_size)
        self.avg_pool = nn.Identity()
        self.upsample = nn.Identity()
        self.projection = nn.Identity()
        self.reconstruction = nn.Identity()

        if down_sample_input is not None:
            patch_height = patch_height // down_sample_input[0]
            patch_width = patch_width // down_sample_input[1]

            self.avg_pool = nn.AvgPool2d(kernel_size=down_sample_input)
            self.upsample = nn.UpsamplingNearest2d(scale_factor=down_sample_input)

        if project:
            self.projection = nn.Linear(channels, dim)
        if reconstruct:
            self.reconstruction = nn.Sequential(
                nn.Linear(dim, channels),
                Scale(dim ** -0.5)
            )

        self.rearrange_patches_h = Rearrange('b c h w -> b w h c')
        self.rearrange_patches_w = Rearrange('b c h w -> b h w c')

        self.rearrange_in_tr = Rearrange('b c h w -> (b c) h w')
        self.rearrange_out_tr_h = Rearrange('(b c) h w -> b w h c', c=patch_width)
        self.rearrange_out_tr_w = Rearrange('(b c) h w -> b w c h', c=patch_height)

        self.pos_embedding_w = nn.Parameter(torch.randn(1, 1, patch_width + 1, dim))
        self.pos_embedding_h = nn.Parameter(torch.randn(1, 1, patch_height + 1, dim))
        self.transformer_w = Transformer(dim, 1, heads, dim_head, mlp_dim, dropout_tr)
        self.transformer_h = Transformer(dim, 1, heads, dim_head, mlp_dim, dropout_tr)

    def forward(self, x, cls_token):
        x = self.avg_pool(x)

        # H inference
        h = self.rearrange_patches_h(x)
        h = self.projection(h)

        dim1, dim2, _, _ = h.shape
        if cls_token.shape[0] == 1:
            cls_token = repeat(cls_token, '() () n d -> b w n d', b=dim1, w=dim2)
        else:
            cls_token = repeat(cls_token, 'b () n d -> b w n d', w=dim2)

        h = torch.cat((cls_token, h), dim=2)
        h += self.pos_embedding_h

        h = self.rearrange_in_tr(h)
        h = self.transformer_h(h)
        h = self.rearrange_out_tr_h(h)

        # W inference
        w = self.rearrange_patches_w(h[:, :, 1:, :])

        cls_token = h[:, :, 0, :].unsqueeze(2)
        cls_token = repeat(cls_token.mean((-1, -2)).unsqueeze(1).unsqueeze(1), 'b () d2 e -> b d1 d2 e', d1=w.shape[1])

        w = torch.cat((cls_token, w), dim=2)
        w += self.pos_embedding_w

        w = self.rearrange_in_tr(w)
        w = self.transformer_w(w)
        w = self.rearrange_out_tr_w(w)

        x = self.upsample(w[:, :, :, 1:])
        x = self.reconstruction(x)

        cls_token = w[:, :, :, 0].mean(2).unsqueeze(1).unsqueeze(1)
        return x, cls_token

class SeparableTr(nn.Module):
    def __init__(self, channels=1, input_size=(128, 128), num_classes=35, depth=3, heads=5, mlp_dim=256, dim_head=256,
                 down_sample_input=None, dim=256):
        super().__init__()
        inner_channels = channels

        self.transformer = nn.ModuleList()

        if depth < 1:
            raise Exception("Depth cannot be smaller than 1!")

        self.transformer.append(
            SepTrBlock(channels=inner_channels, input_size=input_size, heads=heads, mlp_dim=mlp_dim,
                       dim_head=dim_head, down_sample_input=down_sample_input, dim=dim, project=True)
        )

        for i in range(1, depth):
            self.transformer.append(
                SepTrBlock(channels=inner_channels, input_size=input_size, heads=heads, mlp_dim=mlp_dim,
                           dim_head=dim_head, down_sample_input=down_sample_input, dim=dim, project=False)
            )

        self.cls_token = nn.Parameter(torch.randn(1, 1, 1, dim))
        self.fc = nn.Linear(dim, num_classes)

    def forward(self, x):
        x, cls_token = self.transformer[0](x, self.cls_token)

        for i in range(0, len(self.transformer)):
            x, cls_token = self.transformer[i](x, cls_token)

        cls_token = cls_token[:, 0, 0, :]
        x = self.fc(cls_token)
        return x

In [29]:
model = SeparableTr(input_size=(17,128),dim=1)
x.unsqueeze(-3).shape

torch.Size([50, 1, 17, 128])

In [None]:
model(x.unsqueeze(-3))