[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kovacsdotgergo/szakdolgozat/blob/feature%2Fcolab_bringup/esc_notebook.ipynb)

In [None]:
!git clone https://github.com/kovacsdotgergo/szakdolgozat.git
%cd szakdolgozat
!pip install wget torch torchvision torchaudio matplotlib pandas numpy timm==0.4.5

In [None]:
#TODO: tmp for branch
!git checkout feature/colab_bringup

In [2]:
import utils
esc_path, save_path, workspace_path = utils.setup_env()

Running in local


In [6]:
from src.models import ASTModel
import torch
import torch.nn as nn
import esc_dataset
import trainer

have_cuda = torch.cuda.is_available()

## Model
INPUT_TDIM = 512
audio_model = ASTModel(label_dim=50, input_tdim=INPUT_TDIM, imagenet_pretrain=True, audioset_pretrain=True)
audio_model = torch.nn.DataParallel(audio_model, device_ids=[0])
audio_model = audio_model.to(torch.device("cuda:0" if have_cuda else 'cpu'))
audio_model.eval()

## Dataset
dataset = esc_dataset.ESCdataset(esc_path, n_fft=1024, hop_length=256,
                     n_mels=128, augment=False,  log_mel=True,
                     use_kaldi=True, target_len=INPUT_TDIM, resample_rate=22500)

#dividing the dataset randomly, 80% train, 10% validation, 10% test
numtrain = int(0.8*len(dataset))
numval = (len(dataset) - numtrain) // 2
numtest = len(dataset) - numtrain - numval
split_dataset = torch.utils.data.random_split(dataset, [numtrain, numval, numtest])
#using augment on the training data
#split_dataset[0].augment = True

## DataLoader
BATCHSIZE = 16
trainloader = torch.utils.data.DataLoader(split_dataset[0], batch_size=BATCHSIZE,
                         shuffle=True, num_workers=2)
valloader = torch.utils.data.DataLoader(split_dataset[1], batch_size=BATCHSIZE, shuffle=True)
testloader = torch.utils.data.DataLoader(split_dataset[2], batch_size=BATCHSIZE, shuffle=True)

## Trainer
trainer = trainer.Trainer(audio_model, have_cuda, criterion=nn.CrossEntropyLoss)

## Inference
spect, label = dataset[0]
print(f'trainer inference: {dataset.get_class_name(trainer.inference(spect, ret_index=True).item())}, '
    f'true label: {dataset.get_class_name(label)}')

## Training
#lrs = np.logspace(-4, -6, num=10)
#params = trainer.hyperparameter_plotting(lrs, trainloader, valloader, train_epochs=5)
#print(params)
save_name = 'tmp.pth'
trainer.train(trainloader, valloader, optimizer=torch.optim.AdamW, train_epochs=3,
              val_interval=25, lr=5e-06, save_best_model=True, env_save_path=save_path + save_name
             )
trainer.plot_train_proc('30 epoch training')
print(f'test accuracy: {trainer.test(testloader)}')

---------------AST Model Summary---------------
ImageNet pretraining: True, AudioSet pretraining: True
frequncey stride=10, time stride=10
number of patches=600


: 

: 