## GRU

Попытка сделать монофонический выход из сетки. Аналогично LSTM

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np

Сделаем также пользовательский импорт

In [11]:
from decode_patterns import data_conversion

Загружаем датасет

In [12]:
# import dataset
drum, bass = data_conversion.make_lstm_dataset(height=16, limit=30000, patterns_file="decode_patterns/patterns.pairs.tsv", mono=True)


# define shuffling of dataset
def shuffle(A, B, p=0.8):
    # take 80% to training, other to testing
    L = len(A)
    idx = np.arange(L) < p*L
    np.random.shuffle(idx)
    yield A[idx]
    yield B[idx]
    yield A[np.logical_not(idx)]
    yield B[np.logical_not(idx)]
    
    
# we can select here a validation set
drum, bass, drum_validation, bass_validation = shuffle(drum, bass)
    
# and we can shuffle train and test set like this:
# drum_train, bass_train, drum_test, bass_test = shuffle(drum, bass)

In [13]:
bass_validation[10]

array([ 4,  1,  4,  4,  4,  1,  0,  0,  0, 14,  0,  0,  0,  0,  0,  0],
      dtype=int64)

In [14]:
# попробуем определить модель LSTM как конечный автомат
class DrumNBassGRU(nn.Module):
    def __init__(self):
        super(DrumNBassGRU, self).__init__()
        # one input neuron, one output neuron, one layer in LSTM block
        self.input_size = 14
        self.hidden_size = 34
        self.layer_count = 1
        self.lstm = nn.GRU(self.input_size, self.hidden_size, self.layer_count)
        self.embed_layer = nn.Linear(self.hidden_size, 1)
        self.sigm = nn.Sigmoid()
    
    def forward(self, input):
        # пусть в input у нас приходит вектор размерности (64, 32, 14)
        # то есть 64 отсчёта, тридцать два примера (минибатч), 14 значение в каждом (барабанная партия)
        output, _ = self.lstm(input)
        output = self.sigm(self.embed_layer(output))*37
        return output

In [15]:
# часть обучения
dnb_gru = DrumNBassGRU()

criterion = nn.MSELoss()

# оценим также и разнообразие мелодии по её.. дисперсии?)
# def melody_variety(melody):
#     return 1/(1 + (melody.sum(axis=2) > 1).int())
    
# criterion = nn.NLLLoss() # -- этот товарищ требует, чтобы LSTM выдавал классы,
# criterion = nn.CrossEntropyLoss() # и этот тоже
# (числа от 0 до C-1), но как всё-таки его заставить это делать?...
# optimizer = optim.SGD(dnb_lstm.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.Adam(dnb_gru.parameters(), lr=0.001)

Найденные баги и их решения:

https://stackoverflow.com/questions/56741087/how-to-fix-runtimeerror-expected-object-of-scalar-type-float-but-got-scalar-typ

https://stackoverflow.com/questions/49206550/pytorch-error-multi-target-not-supported-in-crossentropyloss/49209628

https://stackoverflow.com/questions/56243672/expected-target-size-50-88-got-torch-size50-288-88

In [16]:
epoch_count = 500
batch_size = 128
shuffle_every_epoch = True
    
if shuffle_every_epoch:
    print(f"shuffle_every_epoch is on")
else:
    print(f"shuffle_every_epoch is off")
    # shuffle train and test set:
    drum_train, bass_train, drum_test, bass_test = shuffle(drum, bass)
    drum_train = torch.tensor(drum_train, dtype=torch.float)
    bass_train = torch.tensor(bass_train, dtype=torch.float)
    drum_test = torch.tensor(drum_test, dtype=torch.float)
    drum_test = torch.tensor(drum_test, dtype=torch.float)
        
for epoch in range(epoch_count):  # loop over the dataset multiple times
    print(f"Epoch #{epoch}")
    if shuffle_every_epoch:
        # shuffle train and test set:
        drum_train, bass_train, drum_test, bass_test = shuffle(drum, bass)
        drum_train = torch.tensor(drum_train, dtype=torch.float)
        bass_train = torch.tensor(bass_train, dtype=torch.float)
        drum_test = torch.tensor(drum_test, dtype=torch.float)
        bass_test = torch.tensor(bass_test, dtype=torch.float)
        
    examples_count = drum_train.size()[0]
    examples_id = 0
    
    running_loss = 0.0
    runnint_count = 0
    batch_id = 0
    while examples_id < examples_count:
        batch_drum_train = drum_train[examples_id:examples_id + batch_size,:,:].transpose(0,1)
        batch_bass_train = bass_train[examples_id:examples_id + batch_size,].transpose(0,1)
        # transpose нужен для обмена размерности батча и размерности шагов
#         print(f"batch_drum_train:{batch_drum_train.size()}, batch_bass_train:{batch_bass_train.size()}")

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        bass_outputs = dnb_gru(batch_drum_train).squeeze()
#         bass_outputs = bass_outputs.reshape(bass_outputs.size()[0], -1)
#         batch_bass_train = batch_bass_train.reshape(batch_bass_train.size()[0], -1)
#         print(f"bass_outputs:{bass_outputs.size()} batch_bass_train: {batch_bass_train.size()}")
#         print(f"bass_outputs:{bass_outputs} batch_bass_train: {batch_bass_train}")
        
        # loss = criterion(bass_outputs, batch_bass_train.long())
        loss = criterion(bass_outputs, batch_bass_train)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        runnint_count += 1
        period = 5
        if batch_id % period == 0 or examples_id + batch_size >= examples_count:
            print('[%d, %5d] train loss: %.7f' %
                  (epoch + 1, batch_id + 1, running_loss / runnint_count))
            running_loss = 0.0
            runnint_count = 1
            
        # update batch info
        examples_id += batch_size
        batch_id += 1
        
    # here we can insert measure error on test set

#should check accuracy on validation set
print('Finished Training')

shuffle_every_epoch is on
Epoch #0
[1,     1] train loss: 168.5638733
[1,     6] train loss: 127.1132634
[1,    11] train loss: 101.4188716
[1,    16] train loss: 92.2979444
[1,    21] train loss: 71.6840566
[1,    26] train loss: 71.7772776
[1,    31] train loss: 67.8496424
[1,    36] train loss: 62.4713039
[1,    41] train loss: 53.1411381
[1,    46] train loss: 54.5914408
[1,    51] train loss: 55.8497251
[1,    56] train loss: 52.2756551
[1,    61] train loss: 56.6614424
[1,    66] train loss: 49.4436226
[1,    71] train loss: 55.7260723
[1,    76] train loss: 45.9068794
[1,    81] train loss: 53.6488190
[1,    86] train loss: 53.5256379
[1,    91] train loss: 50.2961597
[1,    96] train loss: 42.8001207
[1,   101] train loss: 41.6890640
[1,   106] train loss: 58.1738612
[1,   111] train loss: 50.8313313
[1,   116] train loss: 46.1922646
[1,   121] train loss: 52.9857845
[1,   126] train loss: 44.0465450
[1,   131] train loss: 42.9023240
[1,   136] train loss: 47.5815996
[1,   141]

[8,   106] train loss: 52.6692117
[8,   111] train loss: 48.7638054
[8,   116] train loss: 43.4371694
[8,   121] train loss: 52.5173804
[8,   126] train loss: 37.8901126
[8,   131] train loss: 40.1063528
[8,   136] train loss: 40.6526419
[8,   141] train loss: 44.2088273
[8,   146] train loss: 50.7108065
[8,   150] train loss: 49.4578819
Epoch #8
[9,     1] train loss: 41.1341629
[9,     6] train loss: 43.4163367
[9,    11] train loss: 41.2518616
[9,    16] train loss: 39.8513254
[9,    21] train loss: 40.7009061
[9,    26] train loss: 45.2932657
[9,    31] train loss: 46.6525574
[9,    36] train loss: 43.6890831
[9,    41] train loss: 42.1405659
[9,    46] train loss: 50.2783813
[9,    51] train loss: 48.2332948
[9,    56] train loss: 43.3110034
[9,    61] train loss: 48.8320325
[9,    66] train loss: 44.9844227
[9,    71] train loss: 46.5210292
[9,    76] train loss: 36.9104258
[9,    81] train loss: 44.9857203
[9,    86] train loss: 44.8374926
[9,    91] train loss: 46.6656195
[9,  

[23,   106] train loss: 51.2264290
[23,   111] train loss: 46.4646072
[23,   116] train loss: 42.8978189
[23,   121] train loss: 51.8405024
[23,   126] train loss: 35.3677591
[23,   131] train loss: 38.0064443
[23,   136] train loss: 39.2328014
[23,   141] train loss: 41.8497645
[23,   146] train loss: 48.2690935
[23,   150] train loss: 50.5425629
Epoch #23
[24,     1] train loss: 35.4852676
[24,     6] train loss: 43.1181952
[24,    11] train loss: 42.2005151
[24,    16] train loss: 39.0810483
[24,    21] train loss: 38.6086826
[24,    26] train loss: 43.1419601
[24,    31] train loss: 47.5860589
[24,    36] train loss: 45.0622311
[24,    41] train loss: 39.6489347
[24,    46] train loss: 48.4742444
[24,    51] train loss: 46.8980319
[24,    56] train loss: 44.7761580
[24,    61] train loss: 48.1600164
[24,    66] train loss: 42.3307317
[24,    71] train loss: 44.9482746
[24,    76] train loss: 35.2784243
[24,    81] train loss: 41.7341118
[24,    86] train loss: 40.7522119
[24,    91

[38,    96] train loss: 36.5464611
[38,   101] train loss: 42.9617545
[38,   106] train loss: 50.8736159
[38,   111] train loss: 44.6010329
[38,   116] train loss: 41.4568971
[38,   121] train loss: 50.9660225
[38,   126] train loss: 34.1654714
[38,   131] train loss: 39.3129972
[38,   136] train loss: 39.8412247
[38,   141] train loss: 41.7290306
[38,   146] train loss: 47.8474181
[38,   150] train loss: 50.2671761
Epoch #38
[39,     1] train loss: 36.6905899
[39,     6] train loss: 40.1473325
[39,    11] train loss: 40.7774684
[39,    16] train loss: 40.3445886
[39,    21] train loss: 36.5641155
[39,    26] train loss: 42.9654274
[39,    31] train loss: 40.9887740
[39,    36] train loss: 46.4860910
[39,    41] train loss: 41.1435254
[39,    46] train loss: 47.7815603
[39,    51] train loss: 49.4640903
[39,    56] train loss: 44.6908258
[39,    61] train loss: 48.7498474
[39,    66] train loss: 42.0077934
[39,    71] train loss: 42.1444956
[39,    76] train loss: 34.6411648
[39,    81

[53,    86] train loss: 40.2620201
[53,    91] train loss: 45.3875287
[53,    96] train loss: 37.9410807
[53,   101] train loss: 40.5827249
[53,   106] train loss: 48.2353268
[53,   111] train loss: 47.1033669
[53,   116] train loss: 42.4475505
[53,   121] train loss: 50.0535749
[53,   126] train loss: 34.2914480
[53,   131] train loss: 38.2386621
[53,   136] train loss: 41.0038096
[53,   141] train loss: 39.3134149
[53,   146] train loss: 47.6220163
[53,   150] train loss: 49.3088577
Epoch #53
[54,     1] train loss: 38.2416267
[54,     6] train loss: 42.3386854
[54,    11] train loss: 41.6874104
[54,    16] train loss: 39.6638597
[54,    21] train loss: 39.0793107
[54,    26] train loss: 45.3977604
[54,    31] train loss: 44.6543433
[54,    36] train loss: 43.6545010
[54,    41] train loss: 40.4382394
[54,    46] train loss: 46.2518304
[54,    51] train loss: 46.9591643
[54,    56] train loss: 43.2355982
[54,    61] train loss: 49.2183901
[54,    66] train loss: 42.4130395
[54,    71

[68,    76] train loss: 33.7451143
[68,    81] train loss: 39.9552046
[68,    86] train loss: 38.7349834
[68,    91] train loss: 44.2352320
[68,    96] train loss: 40.5419191
[68,   101] train loss: 37.6388197
[68,   106] train loss: 46.9577878
[68,   111] train loss: 47.0631186
[68,   116] train loss: 44.0947501
[68,   121] train loss: 46.7871520
[68,   126] train loss: 34.3590275
[68,   131] train loss: 37.8227615
[68,   136] train loss: 38.4868336
[68,   141] train loss: 40.8496463
[68,   146] train loss: 46.8761883
[68,   150] train loss: 49.5888100
Epoch #68
[69,     1] train loss: 40.6602554
[69,     6] train loss: 41.8390198
[69,    11] train loss: 41.3555101
[69,    16] train loss: 39.3194415
[69,    21] train loss: 38.9591058
[69,    26] train loss: 41.1886660
[69,    31] train loss: 48.7585716
[69,    36] train loss: 40.5413240
[69,    41] train loss: 38.5261358
[69,    46] train loss: 48.8913263
[69,    51] train loss: 47.4253044
[69,    56] train loss: 43.8978812
[69,    61

[83,    66] train loss: 41.8611787
[83,    71] train loss: 42.5110283
[83,    76] train loss: 33.9629993
[83,    81] train loss: 40.4825923
[83,    86] train loss: 38.1718896
[83,    91] train loss: 44.9022051
[83,    96] train loss: 37.9014867
[83,   101] train loss: 39.9290555
[83,   106] train loss: 48.6645851
[83,   111] train loss: 44.1894925
[83,   116] train loss: 41.4666176
[83,   121] train loss: 47.8520203
[83,   126] train loss: 33.4965839
[83,   131] train loss: 37.1212069
[83,   136] train loss: 39.7507350
[83,   141] train loss: 41.8157558
[83,   146] train loss: 42.4869127
[83,   150] train loss: 48.8605247
Epoch #83
[84,     1] train loss: 39.9018250
[84,     6] train loss: 41.1516825
[84,    11] train loss: 40.0103906
[84,    16] train loss: 38.0654424
[84,    21] train loss: 38.4403534
[84,    26] train loss: 44.8194714
[84,    31] train loss: 43.2967358
[84,    36] train loss: 43.0789248
[84,    41] train loss: 38.4665044
[84,    46] train loss: 46.2323030
[84,    51

[98,    56] train loss: 40.7532679
[98,    61] train loss: 46.8813508
[98,    66] train loss: 41.5663891
[98,    71] train loss: 40.7265549
[98,    76] train loss: 34.2116652
[98,    81] train loss: 39.8094501
[98,    86] train loss: 38.6724466
[98,    91] train loss: 43.9073041
[98,    96] train loss: 37.1450837
[98,   101] train loss: 39.4023959
[98,   106] train loss: 46.6619174
[98,   111] train loss: 43.9598885
[98,   116] train loss: 43.9738839
[98,   121] train loss: 47.0459512
[98,   126] train loss: 32.9388046
[98,   131] train loss: 36.8063157
[98,   136] train loss: 39.3635076
[98,   141] train loss: 40.4785252
[98,   146] train loss: 44.0332321
[98,   150] train loss: 49.1243248
Epoch #98
[99,     1] train loss: 35.2478867
[99,     6] train loss: 42.3118134
[99,    11] train loss: 39.9074116
[99,    16] train loss: 38.8581956
[99,    21] train loss: 38.0886122
[99,    26] train loss: 44.5829919
[99,    31] train loss: 45.5311673
[99,    36] train loss: 40.8780333
[99,    41

[112,   141] train loss: 39.1807493
[112,   146] train loss: 41.5680701
[112,   150] train loss: 49.0180740
Epoch #112
[113,     1] train loss: 38.9581566
[113,     6] train loss: 42.5771739
[113,    11] train loss: 41.1414897
[113,    16] train loss: 38.5655613
[113,    21] train loss: 38.7829819
[113,    26] train loss: 42.7836323
[113,    31] train loss: 47.8582757
[113,    36] train loss: 41.1165619
[113,    41] train loss: 37.3257529
[113,    46] train loss: 47.6180064
[113,    51] train loss: 43.2377027
[113,    56] train loss: 42.0634588
[113,    61] train loss: 47.5977713
[113,    66] train loss: 40.8988085
[113,    71] train loss: 46.0123177
[113,    76] train loss: 33.4892734
[113,    81] train loss: 39.9720167
[113,    86] train loss: 37.7639370
[113,    91] train loss: 43.6727854
[113,    96] train loss: 39.8915056
[113,   101] train loss: 40.0343806
[113,   106] train loss: 46.8197136
[113,   111] train loss: 43.4416739
[113,   116] train loss: 42.9995569
[113,   121] trai

[127,    66] train loss: 42.3571002
[127,    71] train loss: 41.8799324
[127,    76] train loss: 32.9553417
[127,    81] train loss: 38.5020733
[127,    86] train loss: 37.6383750
[127,    91] train loss: 43.7117697
[127,    96] train loss: 36.2269440
[127,   101] train loss: 40.7969888
[127,   106] train loss: 47.1977272
[127,   111] train loss: 44.3955453
[127,   116] train loss: 43.2365271
[127,   121] train loss: 45.6031494
[127,   126] train loss: 33.5589523
[127,   131] train loss: 36.3047848
[127,   136] train loss: 37.3412765
[127,   141] train loss: 37.8610878
[127,   146] train loss: 40.4477676
[127,   150] train loss: 48.3379288
Epoch #127
[128,     1] train loss: 38.5055084
[128,     6] train loss: 41.6947441
[128,    11] train loss: 39.7100016
[128,    16] train loss: 38.8882898
[128,    21] train loss: 35.6551558
[128,    26] train loss: 44.4690291
[128,    31] train loss: 42.5209039
[128,    36] train loss: 43.1037515
[128,    41] train loss: 38.6002076
[128,    46] trai

[141,   146] train loss: 39.9493771
[141,   150] train loss: 47.9250481
Epoch #141
[142,     1] train loss: 36.2177887
[142,     6] train loss: 40.4821078
[142,    11] train loss: 40.8519274
[142,    16] train loss: 36.9495538
[142,    21] train loss: 37.6327960
[142,    26] train loss: 43.7424335
[142,    31] train loss: 44.8711815
[142,    36] train loss: 40.5077070
[142,    41] train loss: 38.2346274
[142,    46] train loss: 46.4393756
[142,    51] train loss: 45.5261065
[142,    56] train loss: 41.2624842
[142,    61] train loss: 47.0685844
[142,    66] train loss: 41.7551047
[142,    71] train loss: 39.5309324
[142,    76] train loss: 33.6086877
[142,    81] train loss: 40.0126966
[142,    86] train loss: 38.0723616
[142,    91] train loss: 42.7957548
[142,    96] train loss: 37.7829119
[142,   101] train loss: 40.6441180
[142,   106] train loss: 47.9362907
[142,   111] train loss: 44.0944786
[142,   116] train loss: 41.7090263
[142,   121] train loss: 44.4371815
[142,   126] trai

[156,    71] train loss: 43.0223103
[156,    76] train loss: 32.1943960
[156,    81] train loss: 38.4340611
[156,    86] train loss: 36.5763321
[156,    91] train loss: 44.1933174
[156,    96] train loss: 36.8537693
[156,   101] train loss: 39.3020360
[156,   106] train loss: 44.8519516
[156,   111] train loss: 44.2187926
[156,   116] train loss: 40.6967882
[156,   121] train loss: 44.6629353
[156,   126] train loss: 32.6144199
[156,   131] train loss: 36.5472247
[156,   136] train loss: 38.5852992
[156,   141] train loss: 38.4802386
[156,   146] train loss: 39.0307372
[156,   150] train loss: 45.8980797
Epoch #156
[157,     1] train loss: 37.2738647
[157,     6] train loss: 41.0808360
[157,    11] train loss: 40.5098349
[157,    16] train loss: 39.7851499
[157,    21] train loss: 37.2976468
[157,    26] train loss: 43.5134462
[157,    31] train loss: 41.3442001
[157,    36] train loss: 42.4714572
[157,    41] train loss: 38.3300889
[157,    46] train loss: 47.1535962
[157,    51] trai

[170,   150] train loss: 46.1024254
Epoch #170
[171,     1] train loss: 38.8398666
[171,     6] train loss: 42.3891989
[171,    11] train loss: 40.2881190
[171,    16] train loss: 39.3442313
[171,    21] train loss: 37.5077712
[171,    26] train loss: 44.1960812
[171,    31] train loss: 43.9121437
[171,    36] train loss: 41.2550634
[171,    41] train loss: 38.7670333
[171,    46] train loss: 44.9876029
[171,    51] train loss: 43.7517236
[171,    56] train loss: 40.2386843
[171,    61] train loss: 45.6169103
[171,    66] train loss: 41.2210712
[171,    71] train loss: 42.3309905
[171,    76] train loss: 32.4704119
[171,    81] train loss: 38.3477739
[171,    86] train loss: 36.4541054
[171,    91] train loss: 42.8459670
[171,    96] train loss: 35.7204882
[171,   101] train loss: 38.2451639
[171,   106] train loss: 46.9133987
[171,   111] train loss: 42.9540354
[171,   116] train loss: 42.2556934
[171,   121] train loss: 43.2742513
[171,   126] train loss: 32.5617663
[171,   131] trai

[185,    76] train loss: 32.7410533
[185,    81] train loss: 37.3064022
[185,    86] train loss: 36.8628597
[185,    91] train loss: 43.4116001
[185,    96] train loss: 36.0185725
[185,   101] train loss: 39.2396202
[185,   106] train loss: 44.9387582
[185,   111] train loss: 43.4412562
[185,   116] train loss: 41.2660179
[185,   121] train loss: 43.6724790
[185,   126] train loss: 32.3832563
[185,   131] train loss: 34.9853284
[185,   136] train loss: 36.8532823
[185,   141] train loss: 38.9526863
[185,   146] train loss: 36.9870809
[185,   150] train loss: 45.4772232
Epoch #185
[186,     1] train loss: 36.5148010
[186,     6] train loss: 40.6510092
[186,    11] train loss: 38.6480338
[186,    16] train loss: 39.0302753
[186,    21] train loss: 37.4009775
[186,    26] train loss: 42.7796593
[186,    31] train loss: 41.1940517
[186,    36] train loss: 41.5257905
[186,    41] train loss: 38.9402205
[186,    46] train loss: 44.9845479
[186,    51] train loss: 45.1861083
[186,    56] trai

[200,     1] train loss: 38.9343758
[200,     6] train loss: 40.9645716
[200,    11] train loss: 37.2317524
[200,    16] train loss: 38.6288738
[200,    21] train loss: 35.4895871
[200,    26] train loss: 42.4863389
[200,    31] train loss: 40.3485902
[200,    36] train loss: 42.2596594
[200,    41] train loss: 39.0968113
[200,    46] train loss: 43.8184210
[200,    51] train loss: 43.8247490
[200,    56] train loss: 38.9850998
[200,    61] train loss: 45.7735055
[200,    66] train loss: 42.0016123
[200,    71] train loss: 40.1910143
[200,    76] train loss: 32.5018266
[200,    81] train loss: 38.0976594
[200,    86] train loss: 35.6234608
[200,    91] train loss: 42.0265783
[200,    96] train loss: 35.9172668
[200,   101] train loss: 37.6517852
[200,   106] train loss: 45.4031118
[200,   111] train loss: 42.7579613
[200,   116] train loss: 41.1074607
[200,   121] train loss: 43.2629350
[200,   126] train loss: 33.5009877
[200,   131] train loss: 34.4796130
[200,   136] train loss: 37.

[214,    81] train loss: 39.0896263
[214,    86] train loss: 36.0534000
[214,    91] train loss: 43.2053108
[214,    96] train loss: 37.0205854
[214,   101] train loss: 37.7721659
[214,   106] train loss: 46.1424446
[214,   111] train loss: 43.5333201
[214,   116] train loss: 40.2646128
[214,   121] train loss: 42.8133500
[214,   126] train loss: 32.2767223
[214,   131] train loss: 33.9024105
[214,   136] train loss: 36.5449667
[214,   141] train loss: 39.0834401
[214,   146] train loss: 33.5558955
[214,   150] train loss: 43.3019493
Epoch #214
[215,     1] train loss: 35.9291191
[215,     6] train loss: 39.8443292
[215,    11] train loss: 38.8463236
[215,    16] train loss: 37.9375998
[215,    21] train loss: 36.2624512
[215,    26] train loss: 41.5800165
[215,    31] train loss: 39.8901494
[215,    36] train loss: 43.5226402
[215,    41] train loss: 38.7308172
[215,    46] train loss: 42.3728797
[215,    51] train loss: 43.4506626
[215,    56] train loss: 39.3601303
[215,    61] trai

[229,     6] train loss: 39.5586732
[229,    11] train loss: 37.1287708
[229,    16] train loss: 37.3090649
[229,    21] train loss: 33.2497899
[229,    26] train loss: 43.5947997
[229,    31] train loss: 41.1806761
[229,    36] train loss: 40.3619626
[229,    41] train loss: 38.8013687
[229,    46] train loss: 44.5038185
[229,    51] train loss: 44.5435200
[229,    56] train loss: 38.5064017
[229,    61] train loss: 44.4749196
[229,    66] train loss: 41.1297455
[229,    71] train loss: 40.7794685
[229,    76] train loss: 30.8887011
[229,    81] train loss: 38.1693923
[229,    86] train loss: 36.1633358
[229,    91] train loss: 42.0926603
[229,    96] train loss: 34.8169670
[229,   101] train loss: 39.5617243
[229,   106] train loss: 44.6775455
[229,   111] train loss: 43.8855089
[229,   116] train loss: 39.8725503
[229,   121] train loss: 42.5099119
[229,   126] train loss: 32.2830242
[229,   131] train loss: 33.8285789
[229,   136] train loss: 36.5214688
[229,   141] train loss: 38.

[243,    86] train loss: 35.8235747
[243,    91] train loss: 43.9451313
[243,    96] train loss: 34.1746515
[243,   101] train loss: 38.2030436
[243,   106] train loss: 43.5912259
[243,   111] train loss: 41.6165059
[243,   116] train loss: 40.7753003
[243,   121] train loss: 41.6648788
[243,   126] train loss: 30.8937111
[243,   131] train loss: 31.3122794
[243,   136] train loss: 37.9687773
[243,   141] train loss: 37.4973405
[243,   146] train loss: 35.9391502
[243,   150] train loss: 45.4361679
Epoch #243
[244,     1] train loss: 35.9040565
[244,     6] train loss: 39.4921087
[244,    11] train loss: 37.9760399
[244,    16] train loss: 37.4580968
[244,    21] train loss: 35.2216892
[244,    26] train loss: 41.1861070
[244,    31] train loss: 43.3035247
[244,    36] train loss: 39.2159805
[244,    41] train loss: 39.1830082
[244,    46] train loss: 43.0929260
[244,    51] train loss: 42.3471127
[244,    56] train loss: 40.2810465
[244,    61] train loss: 44.7081439
[244,    66] trai

[258,    11] train loss: 37.8158251
[258,    16] train loss: 35.1689116
[258,    21] train loss: 37.0102199
[258,    26] train loss: 41.2923717
[258,    31] train loss: 42.7210789
[258,    36] train loss: 39.8992322
[258,    41] train loss: 38.0967871
[258,    46] train loss: 43.8815505
[258,    51] train loss: 42.9184163
[258,    56] train loss: 39.4391988
[258,    61] train loss: 44.8509776
[258,    66] train loss: 41.3978157
[258,    71] train loss: 41.8559767
[258,    76] train loss: 31.2239396
[258,    81] train loss: 38.0370938
[258,    86] train loss: 35.7596416
[258,    91] train loss: 43.3620612
[258,    96] train loss: 35.5062447
[258,   101] train loss: 38.9425030
[258,   106] train loss: 45.2692223
[258,   111] train loss: 42.1149457
[258,   116] train loss: 41.6561680
[258,   121] train loss: 41.0895640
[258,   126] train loss: 30.7932793
[258,   131] train loss: 34.2543068
[258,   136] train loss: 37.1754039
[258,   141] train loss: 38.6803945
[258,   146] train loss: 36.

[272,    91] train loss: 43.7785403
[272,    96] train loss: 35.0800482
[272,   101] train loss: 37.5133053
[272,   106] train loss: 42.4314067
[272,   111] train loss: 43.4086202
[272,   116] train loss: 40.6450882
[272,   121] train loss: 40.9107049
[272,   126] train loss: 29.4925957
[272,   131] train loss: 32.3722569
[272,   136] train loss: 38.1982218
[272,   141] train loss: 39.0889053
[272,   146] train loss: 35.1211907
[272,   150] train loss: 43.7316269
Epoch #272
[273,     1] train loss: 36.3623924
[273,     6] train loss: 38.5890970
[273,    11] train loss: 36.5108547
[273,    16] train loss: 37.3713061
[273,    21] train loss: 34.4150985
[273,    26] train loss: 42.2535318
[273,    31] train loss: 37.5890388
[273,    36] train loss: 41.9345589
[273,    41] train loss: 38.2499180
[273,    46] train loss: 41.5770315
[273,    51] train loss: 44.2475287
[273,    56] train loss: 39.5787563
[273,    61] train loss: 44.5192521
[273,    66] train loss: 41.4507402
[273,    71] trai

[287,    16] train loss: 37.6564064
[287,    21] train loss: 34.6628761
[287,    26] train loss: 41.0700811
[287,    31] train loss: 37.5573101
[287,    36] train loss: 41.5632795
[287,    41] train loss: 39.1090088
[287,    46] train loss: 42.8218765
[287,    51] train loss: 42.8069954
[287,    56] train loss: 38.9257838
[287,    61] train loss: 45.3658492
[287,    66] train loss: 39.4735705
[287,    71] train loss: 41.2999738
[287,    76] train loss: 31.3831453
[287,    81] train loss: 37.1927312
[287,    86] train loss: 35.0937964
[287,    91] train loss: 43.7939097
[287,    96] train loss: 33.5113522
[287,   101] train loss: 39.7459186
[287,   106] train loss: 44.1749725
[287,   111] train loss: 42.9127413
[287,   116] train loss: 39.1669922
[287,   121] train loss: 40.8360348
[287,   126] train loss: 31.0541617
[287,   131] train loss: 32.5454159
[287,   136] train loss: 36.5285333
[287,   141] train loss: 39.3251044
[287,   146] train loss: 33.8832722
[287,   150] train loss: 44.

[301,    96] train loss: 37.2066301
[301,   101] train loss: 36.7946761
[301,   106] train loss: 44.6902809
[301,   111] train loss: 42.8132623
[301,   116] train loss: 40.7678839
[301,   121] train loss: 40.0360203
[301,   126] train loss: 30.9330873
[301,   131] train loss: 31.9775384
[301,   136] train loss: 36.1091042
[301,   141] train loss: 39.0907771
[301,   146] train loss: 35.2810825
[301,   150] train loss: 43.3245682
Epoch #301
[302,     1] train loss: 35.2855988
[302,     6] train loss: 39.3795033
[302,    11] train loss: 38.3470815
[302,    16] train loss: 35.9530404
[302,    21] train loss: 35.8469378
[302,    26] train loss: 41.5201969
[302,    31] train loss: 43.0841618
[302,    36] train loss: 37.5997699
[302,    41] train loss: 37.4130770
[302,    46] train loss: 42.9807707
[302,    51] train loss: 41.2843870
[302,    56] train loss: 38.4426988
[302,    61] train loss: 43.8501269
[302,    66] train loss: 40.4496803
[302,    71] train loss: 42.4636497
[302,    76] trai

[316,    21] train loss: 35.5715138
[316,    26] train loss: 38.7979488
[316,    31] train loss: 44.0443548
[316,    36] train loss: 35.9332015
[316,    41] train loss: 36.2533455
[316,    46] train loss: 43.0549793
[316,    51] train loss: 41.0181567
[316,    56] train loss: 38.6630065
[316,    61] train loss: 45.2942429
[316,    66] train loss: 40.6879444
[316,    71] train loss: 42.0439469
[316,    76] train loss: 31.3014129
[316,    81] train loss: 36.2487164
[316,    86] train loss: 35.4683285
[316,    91] train loss: 42.4602292
[316,    96] train loss: 33.2220612
[316,   101] train loss: 36.7592525
[316,   106] train loss: 45.2801450
[316,   111] train loss: 41.4949977
[316,   116] train loss: 39.8966376
[316,   121] train loss: 40.1815014
[316,   126] train loss: 31.4240411
[316,   131] train loss: 31.3695218
[316,   136] train loss: 35.1778485
[316,   141] train loss: 39.2306973
[316,   146] train loss: 34.7906332
[316,   150] train loss: 45.4976974
Epoch #316
[317,     1] trai

[330,   101] train loss: 39.6373781
[330,   106] train loss: 44.7298323
[330,   111] train loss: 41.2954648
[330,   116] train loss: 39.8762315
[330,   121] train loss: 40.9662495
[330,   126] train loss: 30.0899954
[330,   131] train loss: 33.0412248
[330,   136] train loss: 37.6079915
[330,   141] train loss: 38.6340370
[330,   146] train loss: 34.1083043
[330,   150] train loss: 43.8766769
Epoch #330
[331,     1] train loss: 33.6419373
[331,     6] train loss: 37.4596157
[331,    11] train loss: 37.4455102
[331,    16] train loss: 36.1433150
[331,    21] train loss: 34.1525024
[331,    26] train loss: 41.6361065
[331,    31] train loss: 39.7588781
[331,    36] train loss: 40.2386373
[331,    41] train loss: 38.6169389
[331,    46] train loss: 41.8406232
[331,    51] train loss: 41.1892465
[331,    56] train loss: 37.0487804
[331,    61] train loss: 43.6507441
[331,    66] train loss: 40.9441528
[331,    71] train loss: 41.8293362
[331,    76] train loss: 30.7724492
[331,    81] trai

[345,    26] train loss: 39.6224670
[345,    31] train loss: 44.7572969
[345,    36] train loss: 34.2963270
[345,    41] train loss: 37.5382843
[345,    46] train loss: 43.4861323
[345,    51] train loss: 40.5004533
[345,    56] train loss: 36.9801426
[345,    61] train loss: 43.8397497
[345,    66] train loss: 40.7917767
[345,    71] train loss: 41.4824969
[345,    76] train loss: 30.0963736
[345,    81] train loss: 37.3119990
[345,    86] train loss: 36.9827137
[345,    91] train loss: 42.4365234
[345,    96] train loss: 33.4870167
[345,   101] train loss: 36.8652515
[345,   106] train loss: 46.5374527
[345,   111] train loss: 42.3091691
[345,   116] train loss: 40.3798803
[345,   121] train loss: 38.2332579
[345,   126] train loss: 33.9418745
[345,   131] train loss: 31.8637530
[345,   136] train loss: 35.8734220
[345,   141] train loss: 39.4414711
[345,   146] train loss: 33.9437339
[345,   150] train loss: 45.8225449
Epoch #345
[346,     1] train loss: 34.5780067
[346,     6] trai

[359,   106] train loss: 44.4317792
[359,   111] train loss: 42.4580638
[359,   116] train loss: 39.3773842
[359,   121] train loss: 40.0741450
[359,   126] train loss: 30.7514639
[359,   131] train loss: 31.8012651
[359,   136] train loss: 36.7365297
[359,   141] train loss: 39.4648997
[359,   146] train loss: 34.1931238
[359,   150] train loss: 42.8507851
Epoch #359
[360,     1] train loss: 35.2469788
[360,     6] train loss: 39.0997518
[360,    11] train loss: 36.4881846
[360,    16] train loss: 35.2569695
[360,    21] train loss: 34.8477573
[360,    26] train loss: 38.3842971
[360,    31] train loss: 44.4827442
[360,    36] train loss: 34.9085903
[360,    41] train loss: 35.7432982
[360,    46] train loss: 43.6072152
[360,    51] train loss: 40.9761721
[360,    56] train loss: 37.8271281
[360,    61] train loss: 45.0303357
[360,    66] train loss: 38.4580940
[360,    71] train loss: 41.8282744
[360,    76] train loss: 31.4509300
[360,    81] train loss: 37.7014198
[360,    86] trai

[374,    31] train loss: 40.3345153
[374,    36] train loss: 37.9092795
[374,    41] train loss: 38.0162366
[374,    46] train loss: 42.7723293
[374,    51] train loss: 40.1251055
[374,    56] train loss: 38.2852017
[374,    61] train loss: 43.0628789
[374,    66] train loss: 41.0602506
[374,    71] train loss: 41.3100300
[374,    76] train loss: 31.2122857
[374,    81] train loss: 37.1090663
[374,    86] train loss: 34.9927044
[374,    91] train loss: 41.8230197
[374,    96] train loss: 35.6446190
[374,   101] train loss: 36.5571400
[374,   106] train loss: 43.0541134
[374,   111] train loss: 41.5051912
[374,   116] train loss: 39.7292875
[374,   121] train loss: 39.2986889
[374,   126] train loss: 31.3195333
[374,   131] train loss: 32.8070017
[374,   136] train loss: 34.9685748
[374,   141] train loss: 38.4947968
[374,   146] train loss: 35.7155492
[374,   150] train loss: 42.4629425
Epoch #374
[375,     1] train loss: 36.4431419
[375,     6] train loss: 39.1782373
[375,    11] trai

[388,   111] train loss: 42.4081675
[388,   116] train loss: 40.3776595
[388,   121] train loss: 38.4234187
[388,   126] train loss: 29.5998866
[388,   131] train loss: 33.1687705
[388,   136] train loss: 36.4016908
[388,   141] train loss: 38.0244605
[388,   146] train loss: 35.6987791
[388,   150] train loss: 43.1097809
Epoch #388
[389,     1] train loss: 36.8354263
[389,     6] train loss: 36.0569541
[389,    11] train loss: 35.3661950
[389,    16] train loss: 35.5817076
[389,    21] train loss: 34.3506724
[389,    26] train loss: 41.1210289
[389,    31] train loss: 41.7382749
[389,    36] train loss: 37.4337457
[389,    41] train loss: 38.1118711
[389,    46] train loss: 41.4331512
[389,    51] train loss: 42.1469453
[389,    56] train loss: 37.8698959
[389,    61] train loss: 43.7617779
[389,    66] train loss: 41.0492407
[389,    71] train loss: 39.8665320
[389,    76] train loss: 32.5789057
[389,    81] train loss: 38.1495330
[389,    86] train loss: 35.3119825
[389,    91] trai

[403,    36] train loss: 39.4309769
[403,    41] train loss: 36.6978550
[403,    46] train loss: 42.0794551
[403,    51] train loss: 42.5555855
[403,    56] train loss: 37.2055620
[403,    61] train loss: 43.2732525
[403,    66] train loss: 40.9102580
[403,    71] train loss: 40.2051992
[403,    76] train loss: 30.8262196
[403,    81] train loss: 37.7354965
[403,    86] train loss: 35.8271379
[403,    91] train loss: 40.2101758
[403,    96] train loss: 34.0395075
[403,   101] train loss: 38.9251029
[403,   106] train loss: 42.9249694
[403,   111] train loss: 41.9970620
[403,   116] train loss: 38.4020303
[403,   121] train loss: 40.9117107
[403,   126] train loss: 29.8896179
[403,   131] train loss: 31.1265917
[403,   136] train loss: 38.9941394
[403,   141] train loss: 37.4765733
[403,   146] train loss: 34.1741931
[403,   150] train loss: 41.9220055
Epoch #403
[404,     1] train loss: 36.3291512
[404,     6] train loss: 38.7295621
[404,    11] train loss: 36.3030434
[404,    16] trai

[417,   116] train loss: 38.8012994
[417,   121] train loss: 37.8690805
[417,   126] train loss: 32.2728853
[417,   131] train loss: 32.1369979
[417,   136] train loss: 36.8279692
[417,   141] train loss: 39.1599331
[417,   146] train loss: 34.3049129
[417,   150] train loss: 43.4679100
Epoch #417
[418,     1] train loss: 35.5095634
[418,     6] train loss: 38.1586647
[418,    11] train loss: 36.1052424
[418,    16] train loss: 35.1455113
[418,    21] train loss: 33.9532798
[418,    26] train loss: 39.7847137
[418,    31] train loss: 41.9397163
[418,    36] train loss: 36.3434381
[418,    41] train loss: 35.5732931
[418,    46] train loss: 40.0642764
[418,    51] train loss: 39.6072801
[418,    56] train loss: 38.3864174
[418,    61] train loss: 43.5557880
[418,    66] train loss: 40.6348209
[418,    71] train loss: 41.2031282
[418,    76] train loss: 31.6791941
[418,    81] train loss: 37.8239816
[418,    86] train loss: 34.8935210
[418,    91] train loss: 41.7235966
[418,    96] trai

[432,    41] train loss: 36.9317525
[432,    46] train loss: 42.2093894
[432,    51] train loss: 40.8376236
[432,    56] train loss: 38.4401868
[432,    61] train loss: 44.1614984
[432,    66] train loss: 41.2225297
[432,    71] train loss: 40.3474871
[432,    76] train loss: 30.8870478
[432,    81] train loss: 35.7429682
[432,    86] train loss: 35.5225054
[432,    91] train loss: 41.5269121
[432,    96] train loss: 35.1441085
[432,   101] train loss: 37.0282024
[432,   106] train loss: 43.9918067
[432,   111] train loss: 40.3673623
[432,   116] train loss: 39.3617338
[432,   121] train loss: 39.2542737
[432,   126] train loss: 29.6174800
[432,   131] train loss: 32.5503642
[432,   136] train loss: 35.0555541
[432,   141] train loss: 40.0952937
[432,   146] train loss: 34.0745672
[432,   150] train loss: 44.7086044
Epoch #432
[433,     1] train loss: 35.8857536
[433,     6] train loss: 39.4669355
[433,    11] train loss: 35.7987372
[433,    16] train loss: 36.5640678
[433,    21] trai

[446,   121] train loss: 39.3071219
[446,   126] train loss: 30.4106661
[446,   131] train loss: 31.6134243
[446,   136] train loss: 36.2109114
[446,   141] train loss: 38.7421246
[446,   146] train loss: 34.7322187
[446,   150] train loss: 43.5034813
Epoch #446
[447,     1] train loss: 33.8091011
[447,     6] train loss: 37.9426804
[447,    11] train loss: 37.5222982
[447,    16] train loss: 36.0729001
[447,    21] train loss: 34.4708856
[447,    26] train loss: 38.3555876
[447,    31] train loss: 44.3442173
[447,    36] train loss: 35.7357445
[447,    41] train loss: 34.9086933
[447,    46] train loss: 43.2924525
[447,    51] train loss: 39.3010368
[447,    56] train loss: 37.2055175
[447,    61] train loss: 43.3443845
[447,    66] train loss: 38.1876481
[447,    71] train loss: 42.4111964
[447,    76] train loss: 30.7692626
[447,    81] train loss: 37.3696378
[447,    86] train loss: 33.8858649
[447,    91] train loss: 42.2604605
[447,    96] train loss: 34.5818698
[447,   101] trai

[461,    46] train loss: 42.6957410
[461,    51] train loss: 39.1680107
[461,    56] train loss: 37.0510400
[461,    61] train loss: 45.4004529
[461,    66] train loss: 37.5486520
[461,    71] train loss: 42.6013737
[461,    76] train loss: 30.9865452
[461,    81] train loss: 37.5879002
[461,    86] train loss: 35.1998177
[461,    91] train loss: 41.1738911
[461,    96] train loss: 36.1589069
[461,   101] train loss: 36.1628990
[461,   106] train loss: 43.0035788
[461,   111] train loss: 42.1344096
[461,   116] train loss: 38.0996494
[461,   121] train loss: 40.0794789
[461,   126] train loss: 29.0783720
[461,   131] train loss: 31.1670135
[461,   136] train loss: 37.0899426
[461,   141] train loss: 38.4446672
[461,   146] train loss: 33.7176275
[461,   150] train loss: 42.6386009
Epoch #461
[462,     1] train loss: 30.8396358
[462,     6] train loss: 38.2540023
[462,    11] train loss: 35.9642849
[462,    16] train loss: 34.7790292
[462,    21] train loss: 34.3341389
[462,    26] trai

[475,   126] train loss: 29.6385120
[475,   131] train loss: 31.7447348
[475,   136] train loss: 37.0218391
[475,   141] train loss: 38.4175485
[475,   146] train loss: 33.3684514
[475,   150] train loss: 43.1613777
Epoch #475
[476,     1] train loss: 34.5736427
[476,     6] train loss: 38.2276834
[476,    11] train loss: 36.2864456
[476,    16] train loss: 34.3301385
[476,    21] train loss: 34.2724133
[476,    26] train loss: 41.3712451
[476,    31] train loss: 40.2997653
[476,    36] train loss: 37.6926848
[476,    41] train loss: 38.3649502
[476,    46] train loss: 41.3980916
[476,    51] train loss: 40.9777813
[476,    56] train loss: 38.2048721
[476,    61] train loss: 42.5690943
[476,    66] train loss: 39.7435093
[476,    71] train loss: 40.4467392
[476,    76] train loss: 31.1646662
[476,    81] train loss: 35.9543521
[476,    86] train loss: 34.5877628
[476,    91] train loss: 41.1175893
[476,    96] train loss: 33.8936183
[476,   101] train loss: 35.1313521
[476,   106] trai

[490,    51] train loss: 41.0023136
[490,    56] train loss: 38.7683372
[490,    61] train loss: 44.3377406
[490,    66] train loss: 38.0881958
[490,    71] train loss: 39.2288268
[490,    76] train loss: 30.8484097
[490,    81] train loss: 36.7132848
[490,    86] train loss: 36.1060537
[490,    91] train loss: 41.5760167
[490,    96] train loss: 33.1836898
[490,   101] train loss: 37.2258142
[490,   106] train loss: 43.6945286
[490,   111] train loss: 41.7899412
[490,   116] train loss: 38.7544918
[490,   121] train loss: 38.7323742
[490,   126] train loss: 31.0454890
[490,   131] train loss: 31.8731772
[490,   136] train loss: 36.9730644
[490,   141] train loss: 37.2105172
[490,   146] train loss: 33.8548969
[490,   150] train loss: 42.9801971
Epoch #490
[491,     1] train loss: 32.4494591
[491,     6] train loss: 39.5604127
[491,    11] train loss: 35.1621628
[491,    16] train loss: 34.1306527
[491,    21] train loss: 33.5720421
[491,    26] train loss: 40.4526971
[491,    31] trai

In [18]:
batch_drum_train = drum_train[:,:,:].transpose(0,1)
batch_bass_train = bass_train[:,:].transpose(0,1)
with torch.no_grad():
    bass_outputs = dnb_gru(batch_drum_train)

In [19]:
result = bass_outputs.squeeze().int()
result

tensor([[ 5,  5,  5,  ...,  8,  8,  8],
        [ 3,  3,  3,  ...,  7,  7,  4],
        [ 3,  3,  3,  ...,  7,  7,  8],
        ...,
        [ 8,  8,  8,  ...,  6,  6,  8],
        [ 6,  6,  6,  ...,  2,  2,  3],
        [ 6,  6,  6,  ...,  3,  3, 10]], dtype=torch.int32)

Попробуем сохранить результаты работы сети. На anaconda нет mido, поэтому сохраняем результаты работы просто в массивчик npy... Однако, как альтернатива, его можно поставить чере pip в conda:
https://github.com/mido/mido/issues/198

In [26]:
import mido
from decode_patterns.data_conversion import build_track, DrumMelodyPair, NumpyImage, Converter


converter = Converter((16,50))

batch_drum = torch.cat((drum_train, drum_test, torch.tensor(drum_validation))).transpose(0,1)
batch_bass = torch.cat((bass_train.int(), bass_test.int(), torch.tensor(bass_validation).int())).transpose(0,1)
with torch.no_grad():
    bass_outputs = dnb_gru(batch_drum)
    bass_outputs = bass_outputs.squeeze().int()
    
    for i in range(bass_outputs.size()[1]):
        bass_seq = bass_outputs[:,i]
#         bass_seq = batch_bass[:,i]
#         print(f"bass_seq:{bass_seq.size()}")
        bass_output = []
        for bass_note in bass_seq:
            bass_row = np.eye(1, 36, bass_note - 1)[0]
            bass_output.append(bass_row)
        bass_output = torch.tensor(bass_output).int().squeeze()
#         print(f"bass_output:{bass_output.size()}")
        
#         print(f"batch_drum:{batch_drum[:,i,:].size()}, bass_output:{bass_output.size()}")
            
        img_dnb = torch.cat((batch_drum[:,i,:].int(),bass_output), axis=1)
#         print(f"img_dnb:{list(bass_output)}")
        numpy_pair = NumpyImage(np.array(img_dnb), 120, 1, 1, 36)
        pair = converter.convert_numpy_image_to_pair(numpy_pair)
#         print(f"pair.melody:{pair.melody}")
        mid = build_track(pair, tempo=pair.tempo)
        mid.save(f"midi/npy/sample{i+1}.mid")
#         np.save(f"midi/npy/drum{i+1}.npy", batch_drum[:,i,:].int())
#         np.save(f"midi/npy/bass{i+1}.npy", bass_outputs[:,i,:])