## LSTM на оригинальном датасете

Попытка сделать монофонический выход из сетки

In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np

Сделаем также пользовательский импорт

In [41]:
from decode_patterns import data_conversion

Загружаем датасет

In [60]:
# import dataset
drum, bass = data_conversion.make_lstm_dataset(height=16, limit=20000, patterns_file="decode_patterns/patterns.pairs.tsv", mono=True)


# define shuffling of dataset
def shuffle(A, B, p=0.8):
    # take 80% to training, other to testing
    L = len(A)
    idx = np.arange(L) < p*L
    np.random.shuffle(idx)
    yield A[idx]
    yield B[idx]
    yield A[np.logical_not(idx)]
    yield B[np.logical_not(idx)]
    
    
# we can select here a validation set
drum, bass, drum_validation, bass_validation = shuffle(drum, bass)
    
# and we can shuffle train and test set like this:
# drum_train, bass_train, drum_test, bass_test = shuffle(drum, bass)

In [62]:
bass_validation[9]

array([1, 1, 8, 1, 8, 1, 1, 1, 8, 1, 8, 1, 8, 1, 1, 6], dtype=int64)

Модель определим в самом простом варианте, который только можно себе представить -- как в примере с конечным автоматом

In [63]:
# попробуем определить модель LSTM как конечный автомат
class DrumNBassLSTM(nn.Module):
    def __init__(self):
        super(DrumNBassLSTM, self).__init__()
        # one input neuron, one output neuron, one layer in LSTM block
        self.input_size = 14
        self.hidden_size = 34
        self.layer_count = 1
        self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.layer_count)
        self.embed_layer = nn.Linear(self.hidden_size, 1)
        self.sigm = nn.Sigmoid()
    
    def forward(self, input):
        # пусть в input у нас приходит вектор размерности (64, 32, 14)
        # то есть 64 отсчёта, тридцать два примера (минибатч), 14 значение в каждом (барабанная партия)
        output, _ = self.lstm(input)
        output = self.sigm(self.embed_layer(output))*37
        return output

In [64]:
# часть обучения
dnb_lstm = DrumNBassLSTM()

criterion = nn.MSELoss()

# оценим также и разнообразие мелодии по её.. дисперсии?)
# def melody_variety(melody):
#     return 1/(1 + (melody.sum(axis=2) > 1).int())
    
# criterion = nn.NLLLoss() # -- этот товарищ требует, чтобы LSTM выдавал классы,
# criterion = nn.CrossEntropyLoss() # и этот тоже
# (числа от 0 до C-1), но как всё-таки его заставить это делать?...
# optimizer = optim.SGD(dnb_lstm.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.Adam(dnb_lstm.parameters(), lr=0.001)

Найденные баги и их решения:

https://stackoverflow.com/questions/56741087/how-to-fix-runtimeerror-expected-object-of-scalar-type-float-but-got-scalar-typ

https://stackoverflow.com/questions/49206550/pytorch-error-multi-target-not-supported-in-crossentropyloss/49209628

https://stackoverflow.com/questions/56243672/expected-target-size-50-88-got-torch-size50-288-88

In [65]:
epoch_count = 500
batch_size = 128
shuffle_every_epoch = True
    
if shuffle_every_epoch:
    print(f"shuffle_every_epoch is on")
else:
    print(f"shuffle_every_epoch is off")
    # shuffle train and test set:
    drum_train, bass_train, drum_test, bass_test = shuffle(drum, bass)
    drum_train = torch.tensor(drum_train, dtype=torch.float)
    bass_train = torch.tensor(bass_train, dtype=torch.float)
    drum_test = torch.tensor(drum_test, dtype=torch.float)
    drum_test = torch.tensor(drum_test, dtype=torch.float)
        
for epoch in range(epoch_count):  # loop over the dataset multiple times
    print(f"Epoch #{epoch}")
    if shuffle_every_epoch:
        # shuffle train and test set:
        drum_train, bass_train, drum_test, bass_test = shuffle(drum, bass)
        drum_train = torch.tensor(drum_train, dtype=torch.float)
        bass_train = torch.tensor(bass_train, dtype=torch.float)
        drum_test = torch.tensor(drum_test, dtype=torch.float)
        bass_test = torch.tensor(bass_test, dtype=torch.float)
        
    examples_count = drum_train.size()[0]
    examples_id = 0
    
    running_loss = 0.0
    runnint_count = 0
    batch_id = 0
    while examples_id < examples_count:
        batch_drum_train = drum_train[examples_id:examples_id + batch_size,:,:].transpose(0,1)
        batch_bass_train = bass_train[examples_id:examples_id + batch_size,].transpose(0,1)
        # transpose нужен для обмена размерности батча и размерности шагов
#         print(f"batch_drum_train:{batch_drum_train.size()}, batch_bass_train:{batch_bass_train.size()}")

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        bass_outputs = dnb_lstm(batch_drum_train).squeeze()
#         bass_outputs = bass_outputs.reshape(bass_outputs.size()[0], -1)
#         batch_bass_train = batch_bass_train.reshape(batch_bass_train.size()[0], -1)
#         print(f"bass_outputs:{bass_outputs.size()} batch_bass_train: {batch_bass_train.size()}")
#         print(f"bass_outputs:{bass_outputs} batch_bass_train: {batch_bass_train}")
        
        # loss = criterion(bass_outputs, batch_bass_train.long())
        loss = criterion(bass_outputs, batch_bass_train)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        runnint_count += 1
        period = 5
        if batch_id % period == 0 or examples_id + batch_size >= examples_count:
            print('[%d, %5d] train loss: %.7f' %
                  (epoch + 1, batch_id + 1, running_loss / runnint_count))
            running_loss = 0.0
            runnint_count = 1
            
        # update batch info
        examples_id += batch_size
        batch_id += 1
        
    # here we can insert measure error on test set

#should check accuracy on validation set
print('Finished Training')

shuffle_every_epoch is on
Epoch #0
[1,     1] train loss: 176.1044464
[1,     6] train loss: 134.1047796
[1,    11] train loss: 112.3866145
[1,    16] train loss: 102.9875336
[1,    21] train loss: 83.1650174
[1,    26] train loss: 82.2173220
[1,    31] train loss: 80.3045235
[1,    36] train loss: 65.7691758
[1,    41] train loss: 57.2387377
[1,    46] train loss: 61.9482002
[1,    51] train loss: 59.1013546
[1,    56] train loss: 57.0269629
[1,    61] train loss: 61.0347366
[1,    66] train loss: 53.9613469
[1,    71] train loss: 57.5882696
[1,    76] train loss: 47.0909729
[1,    81] train loss: 56.3308563
[1,    86] train loss: 54.4532382
[1,    91] train loss: 54.6777617
[1,    96] train loss: 44.8776766
[1,   100] train loss: 40.2034042
Epoch #1
[2,     1] train loss: 53.6293182
[2,     6] train loss: 50.2689387
[2,    11] train loss: 48.4459623
[2,    16] train loss: 43.8645871
[2,    21] train loss: 47.1056913
[2,    26] train loss: 51.5786514
[2,    31] train loss: 56.1698812


[12,    21] train loss: 44.3402627
[12,    26] train loss: 44.6274217
[12,    31] train loss: 52.3281593
[12,    36] train loss: 41.8031629
[12,    41] train loss: 40.1266454
[12,    46] train loss: 51.8167197
[12,    51] train loss: 45.1618252
[12,    56] train loss: 46.3268846
[12,    61] train loss: 50.2223568
[12,    66] train loss: 43.5352039
[12,    71] train loss: 47.3148524
[12,    76] train loss: 37.2679571
[12,    81] train loss: 44.4215654
[12,    86] train loss: 41.8869864
[12,    91] train loss: 49.1308041
[12,    96] train loss: 38.1630936
[12,   100] train loss: 35.3949478
Epoch #12
[13,     1] train loss: 38.9863167
[13,     6] train loss: 42.0444285
[13,    11] train loss: 42.1489474
[13,    16] train loss: 40.1826706
[13,    21] train loss: 44.5478681
[13,    26] train loss: 42.2498932
[13,    31] train loss: 52.8089472
[13,    36] train loss: 40.4510848
[13,    41] train loss: 40.9620304
[13,    46] train loss: 51.5613575
[13,    51] train loss: 46.7151445
[13,    56

[23,    16] train loss: 38.1869322
[23,    21] train loss: 42.8718580
[23,    26] train loss: 43.8489615
[23,    31] train loss: 50.9373951
[23,    36] train loss: 41.1116994
[23,    41] train loss: 42.8950933
[23,    46] train loss: 50.4388205
[23,    51] train loss: 45.7246030
[23,    56] train loss: 44.6139431
[23,    61] train loss: 50.7990557
[23,    66] train loss: 43.2851473
[23,    71] train loss: 46.9392719
[23,    76] train loss: 36.2552357
[23,    81] train loss: 43.0326042
[23,    86] train loss: 40.8723571
[23,    91] train loss: 46.4214846
[23,    96] train loss: 38.0721359
[23,   100] train loss: 36.5688362
Epoch #23
[24,     1] train loss: 43.9016647
[24,     6] train loss: 42.0432288
[24,    11] train loss: 43.7694295
[24,    16] train loss: 39.1246662
[24,    21] train loss: 43.4920769
[24,    26] train loss: 44.4759998
[24,    31] train loss: 52.0731907
[24,    36] train loss: 40.9975675
[24,    41] train loss: 41.9906851
[24,    46] train loss: 50.6954937
[24,    51

[34,    11] train loss: 43.9502741
[34,    16] train loss: 39.6971773
[34,    21] train loss: 42.4660403
[34,    26] train loss: 43.0779858
[34,    31] train loss: 52.2226969
[34,    36] train loss: 40.2930298
[34,    41] train loss: 40.5722446
[34,    46] train loss: 52.1829103
[34,    51] train loss: 45.3144951
[34,    56] train loss: 45.1415869
[34,    61] train loss: 50.6950518
[34,    66] train loss: 42.7962551
[34,    71] train loss: 45.5707397
[34,    76] train loss: 34.8556226
[34,    81] train loss: 41.7032026
[34,    86] train loss: 39.7733237
[34,    91] train loss: 48.1709652
[34,    96] train loss: 36.1493905
[34,   100] train loss: 39.8361885
Epoch #34
[35,     1] train loss: 39.0134926
[35,     6] train loss: 39.6016401
[35,    11] train loss: 44.6101634
[35,    16] train loss: 39.0009193
[35,    21] train loss: 43.7380714
[35,    26] train loss: 42.8293540
[35,    31] train loss: 50.6206264
[35,    36] train loss: 39.5892251
[35,    41] train loss: 41.7027639
[35,    46

[45,     6] train loss: 41.9851418
[45,    11] train loss: 40.6379973
[45,    16] train loss: 38.2061844
[45,    21] train loss: 41.9210072
[45,    26] train loss: 44.1896063
[45,    31] train loss: 50.2721513
[45,    36] train loss: 41.7191614
[45,    41] train loss: 39.9165026
[45,    46] train loss: 50.4596526
[45,    51] train loss: 46.6064885
[45,    56] train loss: 47.9014289
[45,    61] train loss: 46.6831347
[45,    66] train loss: 39.9768457
[45,    71] train loss: 46.1108354
[45,    76] train loss: 35.2851394
[45,    81] train loss: 40.9162273
[45,    86] train loss: 39.7056160
[45,    91] train loss: 46.4463196
[45,    96] train loss: 37.2881845
[45,   100] train loss: 36.4625526
Epoch #45
[46,     1] train loss: 38.9828072
[46,     6] train loss: 41.8731276
[46,    11] train loss: 41.0472266
[46,    16] train loss: 39.5732721
[46,    21] train loss: 40.8256493
[46,    26] train loss: 43.5568937
[46,    31] train loss: 52.6444677
[46,    36] train loss: 41.0559298
[46,    41

[56,     1] train loss: 42.9066277
[56,     6] train loss: 43.1551298
[56,    11] train loss: 41.4979070
[56,    16] train loss: 37.0571353
[56,    21] train loss: 42.7663816
[56,    26] train loss: 46.2480570
[56,    31] train loss: 49.2697035
[56,    36] train loss: 43.3494161
[56,    41] train loss: 39.2005723
[56,    46] train loss: 49.6573118
[56,    51] train loss: 45.9260934
[56,    56] train loss: 45.7636528
[56,    61] train loss: 49.1933174
[56,    66] train loss: 41.1513777
[56,    71] train loss: 44.4687532
[56,    76] train loss: 35.3537747
[56,    81] train loss: 39.4744835
[56,    86] train loss: 38.0869408
[56,    91] train loss: 46.2888571
[56,    96] train loss: 39.7582277
[56,   100] train loss: 35.5189507
Epoch #56
[57,     1] train loss: 43.5114098
[57,     6] train loss: 40.5647062
[57,    11] train loss: 41.6561311
[57,    16] train loss: 37.6310507
[57,    21] train loss: 43.2455330
[57,    26] train loss: 45.0368137
[57,    31] train loss: 51.1679986
[57,    36

Epoch #66
[67,     1] train loss: 41.4937401
[67,     6] train loss: 41.4327793
[67,    11] train loss: 41.1768862
[67,    16] train loss: 38.7879772
[67,    21] train loss: 41.8574435
[67,    26] train loss: 43.0030778
[67,    31] train loss: 49.8233744
[67,    36] train loss: 38.5170638
[67,    41] train loss: 38.9431426
[67,    46] train loss: 49.3199736
[67,    51] train loss: 45.0231895
[67,    56] train loss: 47.7084961
[67,    61] train loss: 48.3343410
[67,    66] train loss: 41.0276801
[67,    71] train loss: 46.3763707
[67,    76] train loss: 33.4655590
[67,    81] train loss: 40.5979360
[67,    86] train loss: 39.2823900
[67,    91] train loss: 45.5475159
[67,    96] train loss: 38.4848601
[67,   100] train loss: 34.8379494
Epoch #67
[68,     1] train loss: 39.3452530
[68,     6] train loss: 40.1594919
[68,    11] train loss: 42.3732891
[68,    16] train loss: 39.1617565
[68,    21] train loss: 42.8513991
[68,    26] train loss: 43.3066889
[68,    31] train loss: 48.7933623


[77,   100] train loss: 35.2654572
Epoch #77
[78,     1] train loss: 42.3344727
[78,     6] train loss: 40.8760516
[78,    11] train loss: 41.8974775
[78,    16] train loss: 38.2485275
[78,    21] train loss: 41.1062190
[78,    26] train loss: 44.6449649
[78,    31] train loss: 49.4310964
[78,    36] train loss: 38.3166879
[78,    41] train loss: 39.6710663
[78,    46] train loss: 52.5446154
[78,    51] train loss: 45.8879916
[78,    56] train loss: 45.2730033
[78,    61] train loss: 47.8187567
[78,    66] train loss: 39.6974916
[78,    71] train loss: 44.6648528
[78,    76] train loss: 34.0370086
[78,    81] train loss: 40.6195456
[78,    86] train loss: 38.4775550
[78,    91] train loss: 45.8612677
[78,    96] train loss: 36.1914635
[78,   100] train loss: 35.5698212
Epoch #78
[79,     1] train loss: 42.8130341
[79,     6] train loss: 41.1931661
[79,    11] train loss: 41.9415811
[79,    16] train loss: 39.0150293
[79,    21] train loss: 41.2571166
[79,    26] train loss: 45.3307063


[88,    96] train loss: 37.4100399
[88,   100] train loss: 34.4025131
Epoch #88
[89,     1] train loss: 41.6836243
[89,     6] train loss: 41.9950619
[89,    11] train loss: 42.2760461
[89,    16] train loss: 38.0970484
[89,    21] train loss: 40.8862702
[89,    26] train loss: 43.2764028
[89,    31] train loss: 48.9513105
[89,    36] train loss: 39.2870299
[89,    41] train loss: 38.6631298
[89,    46] train loss: 51.1149279
[89,    51] train loss: 44.6563476
[89,    56] train loss: 45.9069188
[89,    61] train loss: 47.6679700
[89,    66] train loss: 40.7064432
[89,    71] train loss: 44.8231723
[89,    76] train loss: 34.8522209
[89,    81] train loss: 39.4084015
[89,    86] train loss: 37.3542290
[89,    91] train loss: 45.3989016
[89,    96] train loss: 35.8252786
[89,   100] train loss: 37.3445976
Epoch #89
[90,     1] train loss: 41.6261864
[90,     6] train loss: 41.8981355
[90,    11] train loss: 42.1669769
[90,    16] train loss: 37.6499233
[90,    21] train loss: 41.4373182


[99,    91] train loss: 45.5543480
[99,    96] train loss: 35.8877023
[99,   100] train loss: 35.1190048
Epoch #99
[100,     1] train loss: 42.2210426
[100,     6] train loss: 41.8581340
[100,    11] train loss: 41.0079473
[100,    16] train loss: 38.8873666
[100,    21] train loss: 41.4655660
[100,    26] train loss: 42.8364391
[100,    31] train loss: 48.0355612
[100,    36] train loss: 41.8046665
[100,    41] train loss: 39.0326780
[100,    46] train loss: 49.3428815
[100,    51] train loss: 44.0430304
[100,    56] train loss: 45.4120674
[100,    61] train loss: 47.4614188
[100,    66] train loss: 40.2410361
[100,    71] train loss: 44.3745022
[100,    76] train loss: 32.2832286
[100,    81] train loss: 38.5512899
[100,    86] train loss: 37.7186941
[100,    91] train loss: 46.0766385
[100,    96] train loss: 36.5698376
[100,   100] train loss: 36.0768944
Epoch #100
[101,     1] train loss: 41.3603973
[101,     6] train loss: 40.6609039
[101,    11] train loss: 40.7518406
[101,    1

[110,    56] train loss: 43.6019948
[110,    61] train loss: 47.8671265
[110,    66] train loss: 40.9165090
[110,    71] train loss: 43.7565390
[110,    76] train loss: 34.5763372
[110,    81] train loss: 39.0572084
[110,    86] train loss: 37.3502433
[110,    91] train loss: 45.0723108
[110,    96] train loss: 32.9908892
[110,   100] train loss: 37.6339523
Epoch #110
[111,     1] train loss: 42.3751755
[111,     6] train loss: 41.3689626
[111,    11] train loss: 40.7782421
[111,    16] train loss: 37.4917266
[111,    21] train loss: 41.5808283
[111,    26] train loss: 43.7987836
[111,    31] train loss: 50.0053132
[111,    36] train loss: 39.2023252
[111,    41] train loss: 38.7101123
[111,    46] train loss: 49.5157522
[111,    51] train loss: 43.6405745
[111,    56] train loss: 45.1774718
[111,    61] train loss: 50.0177301
[111,    66] train loss: 41.3259703
[111,    71] train loss: 43.1350791
[111,    76] train loss: 33.4683342
[111,    81] train loss: 38.9324570
[111,    86] trai

[121,    21] train loss: 41.7609297
[121,    26] train loss: 42.2045301
[121,    31] train loss: 48.7628644
[121,    36] train loss: 38.8007870
[121,    41] train loss: 39.5758851
[121,    46] train loss: 49.1202361
[121,    51] train loss: 44.4943469
[121,    56] train loss: 45.7024593
[121,    61] train loss: 46.4791698
[121,    66] train loss: 40.9991862
[121,    71] train loss: 46.0967509
[121,    76] train loss: 32.8548848
[121,    81] train loss: 39.4179808
[121,    86] train loss: 37.3485012
[121,    91] train loss: 45.9882533
[121,    96] train loss: 36.6560326
[121,   100] train loss: 34.0027786
Epoch #121
[122,     1] train loss: 43.2550011
[122,     6] train loss: 41.1049455
[122,    11] train loss: 40.8158913
[122,    16] train loss: 40.2309539
[122,    21] train loss: 39.0445283
[122,    26] train loss: 41.3452161
[122,    31] train loss: 50.9804395
[122,    36] train loss: 39.2206554
[122,    41] train loss: 38.2023055
[122,    46] train loss: 50.9967664
[122,    51] trai

[131,    91] train loss: 45.1123117
[131,    96] train loss: 38.6817640
[131,   100] train loss: 35.5247414
Epoch #131
[132,     1] train loss: 43.3290367
[132,     6] train loss: 41.1126499
[132,    11] train loss: 40.3384794
[132,    16] train loss: 39.1169866
[132,    21] train loss: 38.6581650
[132,    26] train loss: 44.1080805
[132,    31] train loss: 47.1195800
[132,    36] train loss: 42.1529045
[132,    41] train loss: 40.7849483
[132,    46] train loss: 47.5632979
[132,    51] train loss: 42.6454322
[132,    56] train loss: 45.8989646
[132,    61] train loss: 48.2054615
[132,    66] train loss: 39.7724816
[132,    71] train loss: 45.5617174
[132,    76] train loss: 32.3851010
[132,    81] train loss: 38.3219992
[132,    86] train loss: 36.8928369
[132,    91] train loss: 46.0283368
[132,    96] train loss: 35.4883372
[132,   100] train loss: 32.9263340
Epoch #132
[133,     1] train loss: 42.7277832
[133,     6] train loss: 39.7855053
[133,    11] train loss: 39.3834833
[133, 

[142,    56] train loss: 45.0250270
[142,    61] train loss: 46.7528019
[142,    66] train loss: 40.4597263
[142,    71] train loss: 42.1917496
[142,    76] train loss: 32.7124084
[142,    81] train loss: 39.3724613
[142,    86] train loss: 36.0049566
[142,    91] train loss: 43.8580195
[142,    96] train loss: 37.8298798
[142,   100] train loss: 34.0148106
Epoch #142
[143,     1] train loss: 43.2537308
[143,     6] train loss: 39.6155891
[143,    11] train loss: 41.7954534
[143,    16] train loss: 38.6021118
[143,    21] train loss: 41.0906188
[143,    26] train loss: 41.4223003
[143,    31] train loss: 49.3693523
[143,    36] train loss: 39.3006833
[143,    41] train loss: 39.6773364
[143,    46] train loss: 49.0559381
[143,    51] train loss: 44.4522896
[143,    56] train loss: 46.5971826
[143,    61] train loss: 45.5177511
[143,    66] train loss: 39.9156946
[143,    71] train loss: 44.8949273
[143,    76] train loss: 31.8477583
[143,    81] train loss: 39.6593806
[143,    86] trai

[153,    21] train loss: 41.5931231
[153,    26] train loss: 42.9042816
[153,    31] train loss: 49.1450837
[153,    36] train loss: 37.9878616
[153,    41] train loss: 40.1038691
[153,    46] train loss: 49.8726095
[153,    51] train loss: 42.3244276
[153,    56] train loss: 42.9412098
[153,    61] train loss: 47.8342222
[153,    66] train loss: 40.1865228
[153,    71] train loss: 42.4603259
[153,    76] train loss: 34.2166176
[153,    81] train loss: 38.9679769
[153,    86] train loss: 35.7573236
[153,    91] train loss: 44.9979509
[153,    96] train loss: 38.7162533
[153,   100] train loss: 34.8129204
Epoch #153
[154,     1] train loss: 39.9011841
[154,     6] train loss: 39.8459994
[154,    11] train loss: 40.1333847
[154,    16] train loss: 40.0499363
[154,    21] train loss: 41.3165887
[154,    26] train loss: 41.9891027
[154,    31] train loss: 50.7973690
[154,    36] train loss: 37.2520542
[154,    41] train loss: 39.3748055
[154,    46] train loss: 50.9728203
[154,    51] trai

[163,    91] train loss: 43.9666169
[163,    96] train loss: 38.1296806
[163,   100] train loss: 35.3281059
Epoch #163
[164,     1] train loss: 44.1833725
[164,     6] train loss: 40.0237878
[164,    11] train loss: 41.0788447
[164,    16] train loss: 37.9357885
[164,    21] train loss: 40.2368126
[164,    26] train loss: 44.4965954
[164,    31] train loss: 47.0142727
[164,    36] train loss: 40.9678084
[164,    41] train loss: 40.7511533
[164,    46] train loss: 47.7439505
[164,    51] train loss: 42.3839359
[164,    56] train loss: 42.7140052
[164,    61] train loss: 47.7476419
[164,    66] train loss: 41.1207790
[164,    71] train loss: 44.2105840
[164,    76] train loss: 34.1592731
[164,    81] train loss: 38.6519356
[164,    86] train loss: 36.1591288
[164,    91] train loss: 45.4047451
[164,    96] train loss: 36.6491121
[164,   100] train loss: 35.3580872
Epoch #164
[165,     1] train loss: 41.6657410
[165,     6] train loss: 40.0324796
[165,    11] train loss: 38.2350146
[165, 

[174,    56] train loss: 41.4893201
[174,    61] train loss: 47.7281068
[174,    66] train loss: 41.2920659
[174,    71] train loss: 42.8429432
[174,    76] train loss: 32.9738188
[174,    81] train loss: 37.1340822
[174,    86] train loss: 37.0428823
[174,    91] train loss: 43.8814042
[174,    96] train loss: 34.3542423
[174,   100] train loss: 34.4968491
Epoch #174
[175,     1] train loss: 45.0598106
[175,     6] train loss: 39.3443267
[175,    11] train loss: 39.1944091
[175,    16] train loss: 38.2678922
[175,    21] train loss: 39.7921244
[175,    26] train loss: 44.1517220
[175,    31] train loss: 47.9819845
[175,    36] train loss: 39.1721560
[175,    41] train loss: 39.7801367
[175,    46] train loss: 47.7730694
[175,    51] train loss: 42.8495922
[175,    56] train loss: 42.7677428
[175,    61] train loss: 47.4114208
[175,    66] train loss: 41.2938296
[175,    71] train loss: 42.6643867
[175,    76] train loss: 33.2100954
[175,    81] train loss: 38.9880759
[175,    86] trai

[185,    21] train loss: 40.9764004
[185,    26] train loss: 41.9080633
[185,    31] train loss: 50.9434210
[185,    36] train loss: 37.1501134
[185,    41] train loss: 39.1158759
[185,    46] train loss: 49.0794957
[185,    51] train loss: 43.8804747
[185,    56] train loss: 40.6227843
[185,    61] train loss: 47.5330906
[185,    66] train loss: 41.3926417
[185,    71] train loss: 41.0989723
[185,    76] train loss: 34.2012962
[185,    81] train loss: 38.6485793
[185,    86] train loss: 35.8106492
[185,    91] train loss: 42.9299844
[185,    96] train loss: 35.5907510
[185,   100] train loss: 33.5650715
Epoch #185
[186,     1] train loss: 40.4086990
[186,     6] train loss: 38.8357442
[186,    11] train loss: 39.6743895
[186,    16] train loss: 39.4262390
[186,    21] train loss: 40.3393269
[186,    26] train loss: 41.5399170
[186,    31] train loss: 47.9636695
[186,    36] train loss: 38.3578917
[186,    41] train loss: 38.2010091
[186,    46] train loss: 48.0942338
[186,    51] trai

[195,    91] train loss: 44.1959197
[195,    96] train loss: 34.8616632
[195,   100] train loss: 35.8726105
Epoch #195
[196,     1] train loss: 40.9184074
[196,     6] train loss: 40.6827386
[196,    11] train loss: 38.3923594
[196,    16] train loss: 38.9975096
[196,    21] train loss: 40.5688368
[196,    26] train loss: 41.4079717
[196,    31] train loss: 48.1461684
[196,    36] train loss: 37.8006051
[196,    41] train loss: 39.5889225
[196,    46] train loss: 46.9600442
[196,    51] train loss: 43.0572077
[196,    56] train loss: 41.1666927
[196,    61] train loss: 47.1146062
[196,    66] train loss: 42.8704929
[196,    71] train loss: 40.3519936
[196,    76] train loss: 32.9676062
[196,    81] train loss: 38.5889600
[196,    86] train loss: 36.2647870
[196,    91] train loss: 44.6227036
[196,    96] train loss: 36.5660992
[196,   100] train loss: 33.1496651
Epoch #196
[197,     1] train loss: 41.1932182
[197,     6] train loss: 39.2192008
[197,    11] train loss: 39.8894456
[197, 

[206,    56] train loss: 40.8898493
[206,    61] train loss: 47.0320841
[206,    66] train loss: 40.7218561
[206,    71] train loss: 41.5026881
[206,    76] train loss: 31.5127894
[206,    81] train loss: 38.8649648
[206,    86] train loss: 35.2192669
[206,    91] train loss: 43.5178712
[206,    96] train loss: 35.7915541
[206,   100] train loss: 33.1977169
Epoch #206
[207,     1] train loss: 40.3356934
[207,     6] train loss: 38.6620299
[207,    11] train loss: 39.1808542
[207,    16] train loss: 38.5709222
[207,    21] train loss: 39.2541650
[207,    26] train loss: 42.8873336
[207,    31] train loss: 46.2020887
[207,    36] train loss: 39.9541054
[207,    41] train loss: 39.2095877
[207,    46] train loss: 46.4920979
[207,    51] train loss: 42.3125922
[207,    56] train loss: 40.7483114
[207,    61] train loss: 46.2804667
[207,    66] train loss: 40.6117636
[207,    71] train loss: 41.7338511
[207,    76] train loss: 33.0471665
[207,    81] train loss: 38.0765368
[207,    86] trai

[217,    21] train loss: 41.0157960
[217,    26] train loss: 42.4441299
[217,    31] train loss: 48.4146818
[217,    36] train loss: 39.2802963
[217,    41] train loss: 38.2271379
[217,    46] train loss: 48.5548325
[217,    51] train loss: 42.2831707
[217,    56] train loss: 39.3064028
[217,    61] train loss: 45.4345595
[217,    66] train loss: 41.1073335
[217,    71] train loss: 41.9728260
[217,    76] train loss: 32.8075473
[217,    81] train loss: 38.7977905
[217,    86] train loss: 34.5252501
[217,    91] train loss: 43.8790995
[217,    96] train loss: 33.9766089
[217,   100] train loss: 33.1933559
Epoch #217
[218,     1] train loss: 41.0904808
[218,     6] train loss: 40.4393088
[218,    11] train loss: 37.9441497
[218,    16] train loss: 40.0497036
[218,    21] train loss: 39.7733320
[218,    26] train loss: 45.3221289
[218,    31] train loss: 44.7113450
[218,    36] train loss: 40.6804053
[218,    41] train loss: 38.2049650
[218,    46] train loss: 45.7285322
[218,    51] trai

[227,    91] train loss: 42.0051015
[227,    96] train loss: 36.8606129
[227,   100] train loss: 34.8725708
Epoch #227
[228,     1] train loss: 39.3529358
[228,     6] train loss: 39.0583471
[228,    11] train loss: 39.0416209
[228,    16] train loss: 38.7390099
[228,    21] train loss: 39.9981559
[228,    26] train loss: 42.7852472
[228,    31] train loss: 47.6709913
[228,    36] train loss: 36.0121180
[228,    41] train loss: 39.5488612
[228,    46] train loss: 47.4686877
[228,    51] train loss: 41.7203569
[228,    56] train loss: 38.9744778
[228,    61] train loss: 46.3303585
[228,    66] train loss: 41.7422091
[228,    71] train loss: 41.3434258
[228,    76] train loss: 32.8344482
[228,    81] train loss: 37.7163874
[228,    86] train loss: 35.6606986
[228,    91] train loss: 44.7377205
[228,    96] train loss: 34.7283497
[228,   100] train loss: 33.2406395
Epoch #228
[229,     1] train loss: 40.0825768
[229,     6] train loss: 38.7595170
[229,    11] train loss: 37.3030046
[229, 

[238,    56] train loss: 39.0857964
[238,    61] train loss: 47.1435719
[238,    66] train loss: 40.3007291
[238,    71] train loss: 41.6880347
[238,    76] train loss: 33.0103839
[238,    81] train loss: 36.8888353
[238,    86] train loss: 34.7399095
[238,    91] train loss: 43.8957825
[238,    96] train loss: 34.5537949
[238,   100] train loss: 33.3689041
Epoch #238
[239,     1] train loss: 39.5096893
[239,     6] train loss: 39.4580708
[239,    11] train loss: 37.5787773
[239,    16] train loss: 39.0311050
[239,    21] train loss: 40.1256809
[239,    26] train loss: 41.8593884
[239,    31] train loss: 46.8564231
[239,    36] train loss: 39.7333705
[239,    41] train loss: 36.6092866
[239,    46] train loss: 47.9218057
[239,    51] train loss: 41.4300248
[239,    56] train loss: 41.2020791
[239,    61] train loss: 45.4049390
[239,    66] train loss: 40.3771407
[239,    71] train loss: 42.2999624
[239,    76] train loss: 31.6171274
[239,    81] train loss: 37.5486469
[239,    86] trai

[249,    21] train loss: 39.2649180
[249,    26] train loss: 41.6047935
[249,    31] train loss: 47.1795177
[249,    36] train loss: 36.7048906
[249,    41] train loss: 38.0434551
[249,    46] train loss: 44.5115261
[249,    51] train loss: 41.5942949
[249,    56] train loss: 41.8800265
[249,    61] train loss: 43.1820297
[249,    66] train loss: 38.2722473
[249,    71] train loss: 43.2857291
[249,    76] train loss: 30.1393595
[249,    81] train loss: 37.2214578
[249,    86] train loss: 36.3786755
[249,    91] train loss: 44.0983957
[249,    96] train loss: 34.9992205
[249,   100] train loss: 33.6148830
Epoch #249
[250,     1] train loss: 35.4772797
[250,     6] train loss: 39.6880449
[250,    11] train loss: 37.5157553
[250,    16] train loss: 38.4597480
[250,    21] train loss: 38.5806869
[250,    26] train loss: 44.1429653
[250,    31] train loss: 44.0284545
[250,    36] train loss: 38.2576707
[250,    41] train loss: 39.9996802
[250,    46] train loss: 48.0937195
[250,    51] trai

[259,    91] train loss: 44.0104510
[259,    96] train loss: 33.8869623
[259,   100] train loss: 32.7691654
Epoch #259
[260,     1] train loss: 39.8484497
[260,     6] train loss: 39.7384078
[260,    11] train loss: 39.2868379
[260,    16] train loss: 38.2280528
[260,    21] train loss: 39.4139512
[260,    26] train loss: 40.4091600
[260,    31] train loss: 47.6912060
[260,    36] train loss: 38.0991573
[260,    41] train loss: 37.2789822
[260,    46] train loss: 47.9041265
[260,    51] train loss: 40.4348103
[260,    56] train loss: 38.8035399
[260,    61] train loss: 44.5187823
[260,    66] train loss: 41.6964931
[260,    71] train loss: 41.2111956
[260,    76] train loss: 31.3492826
[260,    81] train loss: 38.0071449
[260,    86] train loss: 34.4183661
[260,    91] train loss: 41.9835860
[260,    96] train loss: 35.9208139
[260,   100] train loss: 33.5006203
Epoch #260
[261,     1] train loss: 41.7875214
[261,     6] train loss: 39.6262385
[261,    11] train loss: 38.7196763
[261, 

[270,    56] train loss: 37.6724472
[270,    61] train loss: 44.4629498
[270,    66] train loss: 40.4482740
[270,    71] train loss: 40.6662738
[270,    76] train loss: 31.8008788
[270,    81] train loss: 37.1594238
[270,    86] train loss: 33.8311056
[270,    91] train loss: 44.0138086
[270,    96] train loss: 34.0213451
[270,   100] train loss: 33.5161316
Epoch #270
[271,     1] train loss: 37.8574829
[271,     6] train loss: 38.9069227
[271,    11] train loss: 38.2159316
[271,    16] train loss: 38.5367572
[271,    21] train loss: 40.3027945
[271,    26] train loss: 39.6949329
[271,    31] train loss: 48.2111308
[271,    36] train loss: 37.0982024
[271,    41] train loss: 37.7108059
[271,    46] train loss: 48.6774311
[271,    51] train loss: 40.1812868
[271,    56] train loss: 40.5077604
[271,    61] train loss: 43.7216600
[271,    66] train loss: 39.0013145
[271,    71] train loss: 43.5107403
[271,    76] train loss: 31.3223562
[271,    81] train loss: 37.2670549
[271,    86] trai

[281,    21] train loss: 38.6927916
[281,    26] train loss: 43.2332026
[281,    31] train loss: 38.8559316
[281,    36] train loss: 41.1329384
[281,    41] train loss: 40.9021327
[281,    46] train loss: 44.3757935
[281,    51] train loss: 42.0439428
[281,    56] train loss: 37.6326243
[281,    61] train loss: 43.6144460
[281,    66] train loss: 40.5167313
[281,    71] train loss: 39.4967683
[281,    76] train loss: 31.5728219
[281,    81] train loss: 36.4373646
[281,    86] train loss: 34.6696339
[281,    91] train loss: 42.6541335
[281,    96] train loss: 33.6926378
[281,   100] train loss: 34.8471294
Epoch #281
[282,     1] train loss: 38.2949562
[282,     6] train loss: 37.7475293
[282,    11] train loss: 37.3980503
[282,    16] train loss: 39.3771712
[282,    21] train loss: 37.1468601
[282,    26] train loss: 40.5071653
[282,    31] train loss: 47.9853458
[282,    36] train loss: 36.7806625
[282,    41] train loss: 37.9855531
[282,    46] train loss: 46.3417238
[282,    51] trai

[291,    91] train loss: 43.7195454
[291,    96] train loss: 35.1865470
[291,   100] train loss: 33.0840836
Epoch #291
[292,     1] train loss: 36.4977760
[292,     6] train loss: 36.5264600
[292,    11] train loss: 37.2683131
[292,    16] train loss: 39.4182409
[292,    21] train loss: 37.3733810
[292,    26] train loss: 41.6870778
[292,    31] train loss: 47.2021173
[292,    36] train loss: 37.0199490
[292,    41] train loss: 39.5386136
[292,    46] train loss: 46.9702390
[292,    51] train loss: 40.9074287
[292,    56] train loss: 37.3076445
[292,    61] train loss: 44.5835673
[292,    66] train loss: 40.6080170
[292,    71] train loss: 41.2069276
[292,    76] train loss: 31.4810712
[292,    81] train loss: 36.7273496
[292,    86] train loss: 34.8458754
[292,    91] train loss: 43.7720178
[292,    96] train loss: 34.0844269
[292,   100] train loss: 32.5692505
Epoch #292
[293,     1] train loss: 37.4731941
[293,     6] train loss: 38.1473675
[293,    11] train loss: 37.8015143
[293, 

[302,    56] train loss: 38.5459773
[302,    61] train loss: 44.5245603
[302,    66] train loss: 41.2552802
[302,    71] train loss: 40.7593161
[302,    76] train loss: 31.4498088
[302,    81] train loss: 37.0222422
[302,    86] train loss: 34.4805565
[302,    91] train loss: 43.3174426
[302,    96] train loss: 33.2622197
[302,   100] train loss: 32.3899185
Epoch #302
[303,     1] train loss: 35.9162292
[303,     6] train loss: 39.9600906
[303,    11] train loss: 37.9596678
[303,    16] train loss: 36.6752224
[303,    21] train loss: 38.3394216
[303,    26] train loss: 41.2426491
[303,    31] train loss: 47.1465785
[303,    36] train loss: 37.1027845
[303,    41] train loss: 38.4614194
[303,    46] train loss: 43.9585412
[303,    51] train loss: 40.6178468
[303,    56] train loss: 41.4318771
[303,    61] train loss: 41.9279804
[303,    66] train loss: 37.9740416
[303,    71] train loss: 43.9520785
[303,    76] train loss: 30.3831968
[303,    81] train loss: 35.9535675
[303,    86] trai

[313,    21] train loss: 39.8627758
[313,    26] train loss: 42.2328218
[313,    31] train loss: 45.5046107
[313,    36] train loss: 37.2295494
[313,    41] train loss: 38.8162581
[313,    46] train loss: 45.8474789
[313,    51] train loss: 40.0578918
[313,    56] train loss: 38.3576056
[313,    61] train loss: 44.2213262
[313,    66] train loss: 40.7354762
[313,    71] train loss: 41.8512071
[313,    76] train loss: 30.7670212
[313,    81] train loss: 36.1177247
[313,    86] train loss: 34.8955669
[313,    91] train loss: 44.0843175
[313,    96] train loss: 32.7604701
[313,   100] train loss: 33.5401443
Epoch #313
[314,     1] train loss: 37.8987885
[314,     6] train loss: 39.3007692
[314,    11] train loss: 39.5258567
[314,    16] train loss: 38.3480428
[314,    21] train loss: 37.5430158
[314,    26] train loss: 38.6197751
[314,    31] train loss: 46.7613767
[314,    36] train loss: 36.4986814
[314,    41] train loss: 37.6144377
[314,    46] train loss: 45.2714761
[314,    51] trai

[323,    91] train loss: 43.6977857
[323,    96] train loss: 33.8623676
[323,   100] train loss: 32.9503883
Epoch #323
[324,     1] train loss: 36.4090958
[324,     6] train loss: 37.5371882
[324,    11] train loss: 37.7289429
[324,    16] train loss: 37.7416414
[324,    21] train loss: 38.7881222
[324,    26] train loss: 41.3296185
[324,    31] train loss: 44.3359299
[324,    36] train loss: 38.3018964
[324,    41] train loss: 37.7407074
[324,    46] train loss: 44.9022331
[324,    51] train loss: 40.8416131
[324,    56] train loss: 38.6462472
[324,    61] train loss: 44.8172455
[324,    66] train loss: 41.2665958
[324,    71] train loss: 40.4347649
[324,    76] train loss: 31.4336770
[324,    81] train loss: 37.2858791
[324,    86] train loss: 33.5222174
[324,    91] train loss: 42.5941982
[324,    96] train loss: 33.8299573
[324,   100] train loss: 33.3024681
Epoch #324
[325,     1] train loss: 34.3183517
[325,     6] train loss: 38.5849965
[325,    11] train loss: 37.4226023
[325, 

[334,    56] train loss: 38.2253183
[334,    61] train loss: 44.0296510
[334,    66] train loss: 38.7512296
[334,    71] train loss: 40.7995796
[334,    76] train loss: 31.4454724
[334,    81] train loss: 35.8591563
[334,    86] train loss: 35.8219322
[334,    91] train loss: 41.0776577
[334,    96] train loss: 34.5299956
[334,   100] train loss: 30.5069229
Epoch #334
[335,     1] train loss: 36.9926949
[335,     6] train loss: 39.0916456
[335,    11] train loss: 37.9644171
[335,    16] train loss: 37.6262226
[335,    21] train loss: 37.5018368
[335,    26] train loss: 41.9092890
[335,    31] train loss: 46.9774876
[335,    36] train loss: 35.3901641
[335,    41] train loss: 38.8207887
[335,    46] train loss: 46.3021463
[335,    51] train loss: 40.0096490
[335,    56] train loss: 39.7088191
[335,    61] train loss: 44.0618108
[335,    66] train loss: 38.4591096
[335,    71] train loss: 40.9851685
[335,    76] train loss: 30.6739324
[335,    81] train loss: 36.5053883
[335,    86] trai

[345,    21] train loss: 38.8683484
[345,    26] train loss: 40.1923402
[345,    31] train loss: 43.9020863
[345,    36] train loss: 34.8539880
[345,    41] train loss: 37.4468899
[345,    46] train loss: 44.8348942
[345,    51] train loss: 39.6580578
[345,    56] train loss: 38.6986777
[345,    61] train loss: 44.8187218
[345,    66] train loss: 39.5069653
[345,    71] train loss: 41.0136534
[345,    76] train loss: 31.4198847
[345,    81] train loss: 35.4463355
[345,    86] train loss: 33.9722436
[345,    91] train loss: 43.1312574
[345,    96] train loss: 33.8831679
[345,   100] train loss: 32.9759529
Epoch #345
[346,     1] train loss: 36.7666893
[346,     6] train loss: 39.0861702
[346,    11] train loss: 36.7762877
[346,    16] train loss: 38.3374666
[346,    21] train loss: 37.5529156
[346,    26] train loss: 41.2487450
[346,    31] train loss: 44.9072037
[346,    36] train loss: 36.3471813
[346,    41] train loss: 36.8718828
[346,    46] train loss: 45.0296974
[346,    51] trai

KeyboardInterrupt: 

In [51]:
batch_drum_train = drum_train[:,:,:].transpose(0,1)
batch_bass_train = bass_train[:,:].transpose(0,1)
with torch.no_grad():
    bass_outputs = dnb_lstm(batch_drum_train)

In [59]:
result = bass_outputs.squeeze().int()
result

tensor([[ 6,  3,  3,  ...,  5,  5,  5],
        [ 3,  4,  4,  ...,  5,  5,  3],
        [ 4,  1,  1,  ...,  4,  4,  3],
        ...,
        [ 1,  7,  7,  ...,  5,  5,  1],
        [ 5,  1,  1,  ...,  2,  2,  4],
        [ 1, 10, 10,  ...,  4,  4,  1]], dtype=torch.int32)

Попробуем сохранить результаты работы сети. На anaconda нет mido, поэтому сохраняем результаты работы просто в массивчик npy... Однако, как альтернатива, его можно поставить чере pip в conda:
https://github.com/mido/mido/issues/198

In [66]:
import mido
from decode_patterns.data_conversion import build_track, DrumMelodyPair, Converter


converter = Converter((16,50))

batch_drum = torch.cat((drum_train, drum_test, torch.tensor(drum_validation))).transpose(0,1)
batch_bass = torch.cat((bass_train.int(), bass_test.int(), torch.tensor(bass_validation).int())).transpose(0,1)
with torch.no_grad():
    bass_outputs = dnb_lstm(batch_drum)
    bass_outputs = bass_outputs.squeeze().int()
    
    for i in range(bass_outputs.size()[1]):
        bass_seq = bass_outputs[:,i]
#         bass_seq = batch_bass[:,i]
#         print(f"bass_seq:{bass_seq.size()}")
        bass_output = []
        for bass_note in bass_seq:
            bass_row = np.eye(1, 36, bass_note - 1)[0]
            bass_output.append(bass_row)
        bass_output = torch.tensor(bass_output).int().squeeze()
#         print(f"bass_output:{bass_output.size()}")
        
#         print(f"batch_drum:{batch_drum[:,i,:].size()}, bass_output:{bass_output.size()}")
            
        img_dnb = torch.cat((batch_drum[:,i,:].int(),bass_output), axis=1)
#         print(f"img_dnb:{list(bass_output)}")
        pair = converter.convert_numpy_image_to_pair(np.array(img_dnb))
#         print(f"pair.melody:{pair.melody}")
        mid = build_track(pair, tempo=pair.tempo)
        mid.save(f"midi/npy/sample{i+1}.mid")
#         np.save(f"midi/npy/drum{i+1}.npy", batch_drum[:,i,:].int())
#         np.save(f"midi/npy/bass{i+1}.npy", bass_outputs[:,i,:])