# Test the split
Notebook for testing SelectSplitData and prepare a new way for splitting the data

In [4]:
%load_ext autoreload
%autoreload 2

import sys

# setting path
sys.path.append('../')

from modules.TransformApplier import TransformApplier
from modules.Wav2Spec import Wav2Spec
from modules.SimpleDataset import SimpleDataset
from modules.PretrainedModel import *
from modules.OnlyXTransform import OnlyXTransform
from modules.SelectSplitData import *
from modules.SimpleAttention import *

import pandas as pd
import torch.nn as nn
import torch
import timm
import json
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np



In [72]:
def collate_fn(data):
    """
    Define how the DataLoaders should batch the data
    """
    max_dim = max([d[0].shape[-1] for d in data])
    pad_x = lambda x: torch.concat([x, torch.zeros((max_dim - x.shape[-1], ))])
    return torch.stack([pad_x(d[0]) for d in data], axis=0), torch.stack([torch.tensor(d[1]) for d in data])
    

In [89]:
DATA_PATH = '../data/'
metadata = pd.read_csv(f'{DATA_PATH}train_metadata.csv')[:4000]

with open(f'{DATA_PATH}all_birds.json') as f:
    birds = json.load(f)

In [90]:
tts = metadata.sample(frac=.05).index # train test split
df_val = metadata.iloc[tts]
df_train = metadata.drop(axis=0,index=tts)

In [144]:
df_val.shape

(200, 13)

In [91]:
train_data = SimpleDataset(df_train, DATA_PATH, mode='train', labels=birds)
val_data = SimpleDataset(df_val, DATA_PATH, mode='train', labels=birds)

In [147]:
len(val_data)/16

12.5

In [97]:
bs = 16
train_loader = DataLoader(train_data, batch_size=bs, num_workers=0, collate_fn=collate_fn, shuffle=True)
val_loader = DataLoader(val_data, batch_size=bs, num_workers=0, collate_fn=collate_fn)
# 1 row per batch: 47 rows for val, 882 for train


In [146]:
for x,y in val_loader:
    print(x.shape[1])

3876096
8476212
9706656
4801515
5250816
5108297
3375438
3180669
4635648
2907648
3778351
22062393
2644992


In [153]:
 inputs, classes = next(iter(val_loader))  

In [142]:
inputs.shape

torch.Size([16, 3876096])

In [116]:
x_v, y_v = data_pipeline_val((inputs.to(device),classes.to(device).float()))

In [126]:
inputs.shape[-1]/16000

242.256

In [132]:
30*16000

480000

In [154]:
duration = 30
n_splits = 5
sr = 16000
total_duration = inputs.shape[-1] / sr
max_offset = total_duration - duration 
offset = uniform(low=0.0, high=max_offset)
start = int(offset*sr)
stop = min([int((offset + duration)*sr), inputs.shape[-1]])
print(start,stop)
print(stop - start )

605763 1085763
480000


In [155]:
inputs = inputs[..., start:stop]

In [157]:
inputs.shape

torch.Size([16, 480000])

In [160]:
inputs.reshape((inputs.shape[0]*n_splits, *inputs.shape[1:-1], -1)).shape
# x_v

torch.Size([80, 96000])

In [120]:
x_v.shape

torch.Size([80, 96000])

In [162]:
min(0,2)

0

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
duration = 30
n_splits = 5
transforms1 = TransformApplier([nn.Identity(), SelectSplitData(duration, n_splits)])
data_pipeline_val = nn.Sequential(transforms1).to(device)


In [None]:
for x_v,y_v in val_loader:
    x_v, y_v = data_pipeline_val((x_v.to(device),y_v.to(device)).float)
