## Using RNN for feature extraction from audio input

The internet suggest we should use RNN on spectogram (as omer did with cnn) by considering each column of the spectogram as the current input (in the time dimension) and then using the recurrent network.

Working with the raw audio is not so simple as I originally thought because even if I would use LSTM (which can handle longer sequences than simple RNN), we are talking about sequences of length of ~1e-6 and I think we won't be able to train this well naively. It is possible to do Truncated backpropagation through time (TBPTT) and if time would allow us, we will try that as well but because using spectogram was suggested by the internet, we will go with it.

In [1]:
import sys; sys.path.append('..')
import torch
from torch.utils.data import random_split
import pandas as pd
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
import torchaudio
from src.RNN_utils.dataset import SoundDS

AUDIO_PATH = '../data/audio'

TENSOR_PATH = '../data/specs'

METADATA_PATH = '../data/metadata.csv'

SEED = 42

torch.manual_seed(SEED)

<torch._C.Generator at 0x23ab856fab0>

### Data processing:

#### Playing with the data:

In [None]:
from pychorus.helpers import find_and_output_chorus
import matplotlib.pyplot as plt
from IPython.display import Audio
import numpy as np
import librosa

For each song, we will focus only on the chorus. The idea behind this is both in term of performance and in term of computations. In terms of performance, the chorus contains the whole message of the song in just a few lines and also it will be the most powerful, highest energy, loudest, catchiest, and most memorable part of any song. Thus, it make sense that most tiktokers will choose this part for their video. In addition, in term of computation, working on shorter audio file (only the chorus compared to the whole song) / smaller spectogram will require less computations.

In order to do so, we will use pychorus library.

In [None]:
x, sr = librosa.load('../data/audio/0e3CM2Fm4cpDtxjzYkdLAr.mp3')
start = int(find_and_output_chorus(input_file='../data/audio/0e3CM2Fm4cpDtxjzYkdLAr.mp3', output_file=None, clip_length=20))

And now let's plot the predicted chorus and hear it:

In [None]:
plt.figure(figsize=(14, 5))
librosa.display.waveshow(x[start*sr:(start+30)*sr], sr=sr)

In [None]:
chorus = x[start*sr:(start+20)*sr]
Audio(data = chorus, rate=sr)

and it really sounds like the real chorus (cut in the middle because I limit the duration to be 30 seconds).

Let's try to plot the spectogram of the chorus. 

In [None]:
S = librosa.feature.melspectrogram(y=chorus, sr=sr) #n_fft=2048, hop_length=512 by default
fig, ax = plt.subplots()
S_dB = librosa.power_to_db(S, ref=np.max(S))
img = librosa.display.specshow(S_dB, x_axis='time',
                         y_axis='mel', sr=sr,
                         fmax=8000, ax=ax)
fig.colorbar(img, ax=ax, format='%+2.0f dB')
ax.set(title='Mel-frequency spectrogram')

It seems that the high frequencies as high dB, which might indicate more rhythmic song (in addition to the previous plot where we can see rapid changes in the signal). This in turn can indicate the virality of the song but we will let the model decide it.

Let's create the pipeline of the preprocessing.

#### Preprocessing pipeline:

The following is the basic pipeline:

                Raw audio -> calculate mean of channels -> extract chorus from audio -> create spectorgram from audio -> convert spectogram from amplitude to dB

In order to avoid redundant calculations and speed-up the training time, I will create all spectogram before the training and save them as files and only load them each epoch.

In [2]:
from torch.utils.data import random_split
from src.RNN_utils.audio_utils import rechannel, get_chorus, createSpect
import pandas as pd
import torch
from tqdm import tqdm
import torchaudio

AUDIO_PATH = '../data/audio'

TENSOR_PATH = '../data/specs'

METADATA_PATH = '../data/metadata.csv'

In [None]:
import os

os.mkdir(TENSOR_PATH)

In [3]:
df = pd.read_csv(METADATA_PATH)

Let's start by applying the pipeline on the training set and save the new tensors as files:

In [None]:
#for idx in tqdm(df.index):
for idx in tqdm(df.index):
    song_path = AUDIO_PATH + '/' + df.loc[idx,'id'] + '.mp3'
    #load the audio file
    aud = torchaudio.load(song_path)
    #convert the audio to mono audio
    aud = rechannel(aud,new_channel=1)
    #take only the part of the chorus from the signal
    aud = get_chorus(song_path, 20, aud)
    #create the mel-spectogram
    sgram = createSpect(aud, n_mels=64)
    torch.save(sgram,TENSOR_PATH + '/' + df.loc[idx,'id'] + '.pt')

Now, let's use the SoundDS class in order to create dataset from those tensors and then create dataloader for both the training and validation (test) sets:

In [4]:
from src.RNN_utils.dataset import SoundDS
myds = SoundDS(pd.read_csv('../data/metadata.csv'), '../data/specs/')

# Random split of 80:20 between training and validation
num_items = len(myds)
num_train = round(num_items * 0.8)
num_val = num_items - num_train
train_ds, val_ds = random_split(myds, [num_train, num_val])

# Create training and validation data loaders
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=16, shuffle=True)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=16, shuffle=False)

Next we will check that everything is working properly:

In [5]:
inputs, labels = next(iter(train_dl))

In [6]:
print(f'Batch input shape: {inputs.shape}')
print(f'Batch label shape: {labels.shape}')

Batch input shape: torch.Size([16, 2206, 64])
Batch label shape: torch.Size([16])


As we can see, each batch as 16 samples of shape (2206,64) - 2206 windows of time and 64 mel bins of frequencies. The number of channels is only one. Having the data loader, we can now move to the model part!

### The Model:

We will use RNN based model in this notebook.

Because our input is of length 2206 which is pretty long, we won't use the basic RNN unit but the LSTM (Long Short Term Memory). The advantage of LSTM on the basic RNN is the ability to "remember" information from far earlier inputs. In addition,it also handle the vanishing gradient problem which we might suffer from with the basic RNN because we have long sequence inputs.

In [2]:
import torch.nn as nn

In [3]:
class viralCls(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers=1, dropout=0.0, num_classes=2):
        super().__init__()
        self.feature_extractor = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout)
        self.clf = nn.Sequential(
            nn.Linear(hidden_size, 64),
            nn.LeakyReLU(),
            nn.Linear(64, 64),
            nn.LeakyReLU(),
            nn.Linear(64, num_classes),
            nn.Softmax(dim=1)
        )
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_classes = num_classes

    def forward(self,X,h0=None,c0=None):
        batch_size = X.shape[0]
        if h0 is None or c0 is None:
            h0 = torch.normal(mean=0.0,std=1.0,size=(self.num_layers,batch_size,self.hidden_size))
            c0 = torch.normal(mean=0.0,std=1.0,size=(self.num_layers,batch_size,self.hidden_size))
        
        #extracting the features from the spectogram.
        out, _ = self.feature_extractor(X, (h0, c0))

        #classifing according to the extracted features.
        prob = self.clf(out[:,-1,:])
        return prob

Let's see if the new classifier is working on random input:

In [18]:
model = viralCls(5,10)
X = torch.rand(10,20,5)
model(X).shape

torch.Size([10, 2])

The input is 10 samples, each is with length of 20 and 5 features for each time. The output is probability distribution over 2 classes for all 10 samples. Success!

### The training loop:

As before, I will first create the loaders of the data:

In [7]:
from src.RNN_utils.dataset import SoundDS
myds = SoundDS(pd.read_csv('../data/metadata.csv'), '../data/specs/')

# Random split of 80:20 between training and validation
num_items = len(myds)
num_train = round(num_items * 0.8)
num_val = num_items - num_train
train_ds, val_ds = random_split(myds, [num_train, num_val])

# Create training and validation data loaders
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=32, shuffle=True)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=32, shuffle=False)

In [8]:
b_size, seq_len, input_size = next(iter(train_dl))[0].shape
num_batches = len(train_dl)
hidden_size = 64

#### Overfitting the model:

Let's create the classification model:

In [6]:
model = viralCls(input_size, hidden_size)

We will use cross entropy loss and Adam optimizer:

In [27]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=5e-3)

epochs = 10

Let's start by train the model to overfitted to the first batch: 

In [28]:
(X,y) = next(iter(train_dl))
for epoch in range(epochs):
    optimizer.zero_grad()
    y_prob = model(X)
    loss = criterion(y_prob,y)
    loss.backward()
    optimizer.step()
    loss = loss.item()
    acc = torch.sum(torch.argmax(y_prob,dim=1)==y).item()/32
    #scheduler.step()
    print(f'Epoch #{epoch}: Loss - {loss}, Accuracy - {acc}')

Epoch #0: Loss - 0.6927495002746582, Accuracy - 0.53125
Epoch #1: Loss - 0.6708220839500427, Accuracy - 0.875
Epoch #2: Loss - 0.6415773630142212, Accuracy - 0.96875
Epoch #3: Loss - 0.5991092920303345, Accuracy - 0.9375
Epoch #4: Loss - 0.5533318519592285, Accuracy - 0.9375
Epoch #5: Loss - 0.5023677945137024, Accuracy - 0.96875
Epoch #6: Loss - 0.4437624514102936, Accuracy - 0.96875
Epoch #7: Loss - 0.38921135663986206, Accuracy - 1.0
Epoch #8: Loss - 0.35495197772979736, Accuracy - 1.0
Epoch #9: Loss - 0.3279470205307007, Accuracy - 1.0


#### Cross validation:

and now for the real training:

In [11]:
from src.RNN_utils.trainer import trainer
from src.RNN_utils.cross_val import crossValidate

model = viralCls(input_size, hidden_size)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3, weight_decay=3e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=20,gamma=0.5)

train_model = trainer(model,criterion,optimizer,scheduler)

epochs = 50

In [12]:
cv_obj = crossValidate(train_ds=train_ds)

In [13]:
cv_obj.runCV(train_model, epochs=10)

Fold #0:


100%|██████████| 73/73 [00:35<00:00,  2.06it/s]


Epoch #0: Loss - 47.92689663171768, Accuracy - 0.6383526383526383


100%|██████████| 19/19 [00:02<00:00,  8.27it/s]


Val results: Loss - 12.640501260757446, Accuracy - 0.6295025728987993


100%|██████████| 73/73 [00:40<00:00,  1.81it/s]


Epoch #1: Loss - 47.36117070913315, Accuracy - 0.6392106392106393


100%|██████████| 19/19 [00:02<00:00,  7.55it/s]


Val results: Loss - 12.471888303756714, Accuracy - 0.6295025728987993


100%|██████████| 73/73 [00:43<00:00,  1.66it/s]


Epoch #2: Loss - 46.93905144929886, Accuracy - 0.6392106392106393


100%|██████████| 19/19 [00:02<00:00,  7.32it/s]


Val results: Loss - 12.352991223335266, Accuracy - 0.6295025728987993


100%|██████████| 73/73 [00:52<00:00,  1.40it/s]


Epoch #3: Loss - 46.8009198307991, Accuracy - 0.6392106392106393


100%|██████████| 19/19 [00:02<00:00,  7.15it/s]


Val results: Loss - 12.555815577507019, Accuracy - 0.6295025728987993


100%|██████████| 73/73 [00:47<00:00,  1.55it/s]


Epoch #4: Loss - 46.5096520781517, Accuracy - 0.6422136422136422


100%|██████████| 19/19 [00:02<00:00,  7.31it/s]


Val results: Loss - 12.481537103652954, Accuracy - 0.6260720411663808


100%|██████████| 73/73 [00:45<00:00,  1.60it/s]


Epoch #5: Loss - 45.80665421485901, Accuracy - 0.6640926640926641


100%|██████████| 19/19 [00:02<00:00,  6.96it/s]


Val results: Loss - 12.31297242641449, Accuracy - 0.6226415094339622


100%|██████████| 73/73 [00:44<00:00,  1.63it/s]


Epoch #6: Loss - 45.71187913417816, Accuracy - 0.6670956670956671


100%|██████████| 19/19 [00:02<00:00,  6.98it/s]


Val results: Loss - 12.576425433158875, Accuracy - 0.6174957118353345


100%|██████████| 73/73 [00:40<00:00,  1.81it/s]


Epoch #7: Loss - 45.42164787650108, Accuracy - 0.6752466752466753


100%|██████████| 19/19 [00:02<00:00,  6.92it/s]


Val results: Loss - 12.748926341533661, Accuracy - 0.6072041166380789


100%|██████████| 73/73 [00:54<00:00,  1.35it/s]


Epoch #8: Loss - 44.56422120332718, Accuracy - 0.6894036894036895


100%|██████████| 19/19 [00:02<00:00,  7.07it/s]


Val results: Loss - 12.936583757400513, Accuracy - 0.6123499142367067


100%|██████████| 73/73 [00:54<00:00,  1.33it/s]


Epoch #9: Loss - 43.91723781824112, Accuracy - 0.700986700986701


100%|██████████| 19/19 [00:02<00:00,  6.83it/s]


Val results: Loss - 12.831981539726257, Accuracy - 0.6209262435677531
Fold #1:


100%|██████████| 73/73 [00:45<00:00,  1.61it/s]


Epoch #0: Loss - 48.4468719959259, Accuracy - 0.6314886314886314


100%|██████████| 19/19 [00:02<00:00,  6.81it/s]


Val results: Loss - 12.215252697467804, Accuracy - 0.6569468267581475


100%|██████████| 73/73 [00:42<00:00,  1.71it/s]


Epoch #1: Loss - 47.78664708137512, Accuracy - 0.6323466323466324


100%|██████████| 19/19 [00:02<00:00,  7.01it/s]


Val results: Loss - 12.236363708972931, Accuracy - 0.6569468267581475


100%|██████████| 73/73 [00:45<00:00,  1.62it/s]


Epoch #2: Loss - 47.53398811817169, Accuracy - 0.6323466323466324


100%|██████████| 19/19 [00:02<00:00,  6.87it/s]


Val results: Loss - 12.19677734375, Accuracy - 0.6569468267581475


100%|██████████| 73/73 [00:45<00:00,  1.59it/s]


Epoch #3: Loss - 47.25171595811844, Accuracy - 0.6323466323466324


100%|██████████| 19/19 [00:02<00:00,  7.07it/s]


Val results: Loss - 12.252007961273193, Accuracy - 0.6569468267581475


100%|██████████| 73/73 [00:42<00:00,  1.72it/s]


Epoch #4: Loss - 47.041736483573914, Accuracy - 0.6323466323466324


100%|██████████| 19/19 [00:02<00:00,  6.95it/s]


Val results: Loss - 12.138647735118866, Accuracy - 0.6569468267581475


100%|██████████| 73/73 [00:43<00:00,  1.68it/s]


Epoch #5: Loss - 46.66292345523834, Accuracy - 0.6323466323466324


100%|██████████| 19/19 [00:02<00:00,  6.84it/s]


Val results: Loss - 12.250790357589722, Accuracy - 0.6586620926243568


100%|██████████| 73/73 [00:53<00:00,  1.38it/s]


Epoch #6: Loss - 46.57298409938812, Accuracy - 0.6374946374946375


100%|██████████| 19/19 [00:02<00:00,  6.88it/s]


Val results: Loss - 12.272039830684662, Accuracy - 0.6535162950257289


100%|██████████| 73/73 [00:58<00:00,  1.25it/s]


Epoch #7: Loss - 45.93080198764801, Accuracy - 0.6602316602316602


100%|██████████| 19/19 [00:02<00:00,  6.97it/s]


Val results: Loss - 12.056087255477905, Accuracy - 0.6672384219554031


100%|██████████| 73/73 [00:53<00:00,  1.36it/s]


Epoch #8: Loss - 45.63599467277527, Accuracy - 0.6580866580866581


100%|██████████| 19/19 [00:02<00:00,  7.01it/s]


Val results: Loss - 12.608319997787476, Accuracy - 0.614065180102916


100%|██████████| 73/73 [00:59<00:00,  1.23it/s]


Epoch #9: Loss - 45.34006389975548, Accuracy - 0.6726726726726727


100%|██████████| 19/19 [00:02<00:00,  6.81it/s]


Val results: Loss - 12.023483216762543, Accuracy - 0.6552315608919382
Fold #2:


100%|██████████| 73/73 [00:45<00:00,  1.61it/s]


Epoch #0: Loss - 48.16758930683136, Accuracy - 0.6374946374946375


100%|██████████| 19/19 [00:02<00:00,  6.95it/s]


Val results: Loss - 12.351842999458313, Accuracy - 0.6432246998284734


100%|██████████| 73/73 [00:43<00:00,  1.67it/s]


Epoch #1: Loss - 47.55285292863846, Accuracy - 0.6357786357786358


100%|██████████| 19/19 [00:02<00:00,  6.86it/s]


Val results: Loss - 12.306352078914642, Accuracy - 0.6432246998284734


100%|██████████| 73/73 [00:40<00:00,  1.78it/s]


Epoch #2: Loss - 47.14934378862381, Accuracy - 0.6357786357786358


100%|██████████| 19/19 [00:02<00:00,  6.79it/s]


Val results: Loss - 12.296163380146027, Accuracy - 0.6432246998284734


100%|██████████| 73/73 [00:42<00:00,  1.70it/s]


Epoch #3: Loss - 46.72218769788742, Accuracy - 0.6392106392106393


100%|██████████| 19/19 [00:02<00:00,  6.75it/s]


Val results: Loss - 12.324399948120117, Accuracy - 0.6397941680960549


100%|██████████| 73/73 [00:48<00:00,  1.51it/s]


Epoch #4: Loss - 46.408556163311005, Accuracy - 0.6417846417846418


100%|██████████| 19/19 [00:02<00:00,  6.85it/s]


Val results: Loss - 12.421982944011688, Accuracy - 0.6483704974271012


100%|██████████| 73/73 [00:45<00:00,  1.61it/s]


Epoch #5: Loss - 46.37836191058159, Accuracy - 0.6482196482196482


100%|██████████| 19/19 [00:02<00:00,  6.94it/s]


Val results: Loss - 12.457918882369995, Accuracy - 0.6260720411663808


100%|██████████| 73/73 [00:46<00:00,  1.56it/s]


Epoch #6: Loss - 45.57149004936218, Accuracy - 0.6628056628056628


100%|██████████| 19/19 [00:02<00:00,  6.66it/s]


Val results: Loss - 12.481024503707886, Accuracy - 0.6174957118353345


100%|██████████| 73/73 [00:45<00:00,  1.60it/s]


Epoch #7: Loss - 44.823200047016144, Accuracy - 0.6821106821106822


100%|██████████| 19/19 [00:02<00:00,  7.06it/s]


Val results: Loss - 12.552795708179474, Accuracy - 0.6054888507718696


100%|██████████| 73/73 [00:56<00:00,  1.30it/s]


Epoch #8: Loss - 44.17198997735977, Accuracy - 0.6894036894036895


100%|██████████| 19/19 [00:02<00:00,  7.02it/s]


Val results: Loss - 12.628711462020874, Accuracy - 0.6106346483704974


100%|██████████| 73/73 [00:57<00:00,  1.28it/s]


Epoch #9: Loss - 44.32896310091019, Accuracy - 0.6855426855426855


100%|██████████| 19/19 [00:02<00:00,  6.88it/s]


Val results: Loss - 12.935725808143616, Accuracy - 0.5934819897084048
Fold #3:


100%|██████████| 73/73 [00:40<00:00,  1.78it/s]


Epoch #0: Loss - 47.95809131860733, Accuracy - 0.640926640926641


100%|██████████| 19/19 [00:02<00:00,  6.70it/s]


Val results: Loss - 12.693602740764618, Accuracy - 0.6174957118353345


100%|██████████| 73/73 [00:41<00:00,  1.75it/s]


Epoch #1: Loss - 47.21484726667404, Accuracy - 0.6422136422136422


100%|██████████| 19/19 [00:02<00:00,  7.01it/s]


Val results: Loss - 12.679374992847443, Accuracy - 0.6174957118353345


100%|██████████| 73/73 [00:39<00:00,  1.84it/s]


Epoch #2: Loss - 46.936330795288086, Accuracy - 0.6422136422136422


100%|██████████| 19/19 [00:02<00:00,  7.01it/s]


Val results: Loss - 12.677088975906372, Accuracy - 0.6174957118353345


100%|██████████| 73/73 [00:41<00:00,  1.74it/s]


Epoch #3: Loss - 46.7567241191864, Accuracy - 0.6422136422136422


100%|██████████| 19/19 [00:02<00:00,  6.86it/s]


Val results: Loss - 12.687081813812256, Accuracy - 0.6174957118353345


100%|██████████| 73/73 [00:41<00:00,  1.74it/s]


Epoch #4: Loss - 46.61616778373718, Accuracy - 0.6422136422136422


100%|██████████| 19/19 [00:02<00:00,  6.84it/s]


Val results: Loss - 12.621210634708405, Accuracy - 0.6174957118353345


100%|██████████| 73/73 [00:37<00:00,  1.92it/s]


Epoch #5: Loss - 46.18419596552849, Accuracy - 0.6495066495066495


100%|██████████| 19/19 [00:02<00:00,  6.65it/s]


Val results: Loss - 12.762828469276428, Accuracy - 0.6072041166380789


100%|██████████| 73/73 [00:42<00:00,  1.71it/s]


Epoch #6: Loss - 45.74505829811096, Accuracy - 0.6623766623766624


100%|██████████| 19/19 [00:02<00:00,  6.99it/s]


Val results: Loss - 12.63919734954834, Accuracy - 0.6123499142367067


100%|██████████| 73/73 [00:40<00:00,  1.79it/s]


Epoch #7: Loss - 45.36144983768463, Accuracy - 0.6735306735306735


100%|██████████| 19/19 [00:02<00:00,  6.80it/s]


Val results: Loss - 12.951039254665375, Accuracy - 0.6037735849056604


100%|██████████| 73/73 [00:40<00:00,  1.82it/s]


Epoch #8: Loss - 44.668718576431274, Accuracy - 0.6876876876876877


100%|██████████| 19/19 [00:02<00:00,  6.63it/s]


Val results: Loss - 12.677857667207718, Accuracy - 0.5883361921097771


100%|██████████| 73/73 [00:41<00:00,  1.76it/s]


Epoch #9: Loss - 44.78965723514557, Accuracy - 0.6851136851136851


100%|██████████| 19/19 [00:02<00:00,  6.68it/s]


Val results: Loss - 13.183808386325836, Accuracy - 0.58147512864494
Fold #4:


100%|██████████| 73/73 [00:41<00:00,  1.76it/s]


Epoch #0: Loss - 48.33587747812271, Accuracy - 0.6350771869639794


100%|██████████| 19/19 [00:02<00:00,  6.81it/s]


Val results: Loss - 12.401379525661469, Accuracy - 0.6391752577319587


100%|██████████| 73/73 [00:46<00:00,  1.59it/s]


Epoch #1: Loss - 47.42316693067551, Accuracy - 0.6367924528301887


100%|██████████| 19/19 [00:02<00:00,  6.59it/s]


Val results: Loss - 12.338246583938599, Accuracy - 0.6391752577319587


100%|██████████| 73/73 [00:39<00:00,  1.87it/s]


Epoch #2: Loss - 47.25936049222946, Accuracy - 0.6367924528301887


100%|██████████| 19/19 [00:02<00:00,  6.81it/s]


Val results: Loss - 12.439183354377747, Accuracy - 0.6391752577319587


100%|██████████| 73/73 [00:39<00:00,  1.86it/s]


Epoch #3: Loss - 47.109014213085175, Accuracy - 0.6367924528301887


100%|██████████| 19/19 [00:02<00:00,  6.86it/s]


Val results: Loss - 12.515631794929504, Accuracy - 0.6391752577319587


100%|██████████| 73/73 [00:42<00:00,  1.71it/s]


Epoch #4: Loss - 46.753248035907745, Accuracy - 0.6367924528301887


100%|██████████| 19/19 [00:02<00:00,  6.55it/s]


Val results: Loss - 12.468236267566681, Accuracy - 0.6391752577319587


100%|██████████| 73/73 [00:39<00:00,  1.87it/s]


Epoch #5: Loss - 46.63458603620529, Accuracy - 0.6367924528301887


100%|██████████| 19/19 [00:02<00:00,  6.76it/s]


Val results: Loss - 12.51987510919571, Accuracy - 0.6391752577319587


100%|██████████| 73/73 [00:39<00:00,  1.83it/s]


Epoch #6: Loss - 46.28244370222092, Accuracy - 0.6513722126929674


100%|██████████| 19/19 [00:02<00:00,  6.83it/s]


Val results: Loss - 12.696490466594696, Accuracy - 0.6323024054982818


100%|██████████| 73/73 [00:35<00:00,  2.03it/s]


Epoch #7: Loss - 45.98974308371544, Accuracy - 0.6680960548885078


100%|██████████| 19/19 [00:02<00:00,  6.76it/s]


Val results: Loss - 12.516859471797943, Accuracy - 0.6099656357388317


100%|██████████| 73/73 [00:43<00:00,  1.67it/s]


Epoch #8: Loss - 45.87077909708023, Accuracy - 0.6642367066895368


100%|██████████| 19/19 [00:02<00:00,  6.77it/s]


Val results: Loss - 12.76969462633133, Accuracy - 0.5876288659793815


100%|██████████| 73/73 [00:45<00:00,  1.62it/s]


Epoch #9: Loss - 45.11893633008003, Accuracy - 0.6818181818181818


100%|██████████| 19/19 [00:02<00:00,  6.83it/s]

Val results: Loss - 12.714339286088943, Accuracy - 0.627147766323024





({'loss': [48.167065346240996,
   47.467736983299254,
   47.16361492872238,
   46.92811236381531,
   46.66587210893631,
   46.333344316482545,
   45.97677105665207,
   45.50536856651306,
   44.98234070539475,
   44.698971676826474],
  'accuracy': [0.6366679470453056,
   0.6372684004759477,
   0.6372684004759477,
   0.6379548011623484,
   0.6390702022777495,
   0.6461916093991567,
   0.6562289684931195,
   0.6718431492016397,
   0.6777636862542524,
   0.6852267852267853]},
 {'loss': [12.46051584482193,
   12.406445133686066,
   12.392440855503082,
   12.466987419128419,
   12.426322937011719,
   12.46087704896927,
   12.533035516738892,
   12.565141606330872,
   12.724233502149582,
   12.73786764740944],
  'accuracy': [0.6372690138105427,
   0.6372690138105427,
   0.6372690138105427,
   0.636582907464059,
   0.6376120669837846,
   0.6307510035189475,
   0.6266320076862774,
   0.6187341220019686,
   0.6026029601598557,
   0.6156525378272121]})