In [1]:
import numpy as np
import random
from tqdm import tqdm_notebook as tqdm

import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn as nn
import torch.optim as optim

import reber
 
# word_to_ix = {'hello': 0, 'world': 1}
# embeds = nn.Embedding(2, 2)
# embeds.weight.data = torch.eye(2)
# hello_idx = torch.LongTensor([0,1,1])
# hello_idx = Variable(hello_idx)
# hello_embed = embeds(hello_idx)
# print(hello_embed)

In [2]:
word2idx = {'<PAD>':0, 'B':1, 'T':2, 'S':3, 'X':4, 'P':5, 'V':6, 'E': 7}

In [3]:
# valid reber strings
train_good = reber.get_n_embedded_examples(10000, minLength=10)
# decode one-hot (+1 for reserving 0 for padding)
train_good = [np.argmax(train_good[i][0], axis=1)+1 for i in range(10000)]

test_good = reber.get_n_embedded_examples(2000, minLength=10)
# decode one-hot (+1 for reserving 0 for padding)
test_good = [np.argmax(test_good[i][0], axis=1)+1 for i in range(2000)]

In [4]:
# invalid reber strings
train_bad = [np.array(list(map(word2idx.get, reber.embedded_reber_bad()))) for i in range(10000)]
test_bad = [np.array(list(map(word2idx.get, reber.embedded_reber_bad()))) for i in range(2000)]

In [5]:
train_X = train_good + train_bad
train_Y = [0]*10000 + [1]*2000

test_X = test_good + test_bad
test_Y = [0]*10000 + [1]*2000

In [6]:
# make custom datasets
class ReberDataset(Dataset):

    def __init__(self, X, Y):
        self.X = X
        self.Y = Y
    
    def __len__(self):
        return len(self.Y)
    
    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

In [7]:
# Function to pad and transpose data (to be used in Dataloader)
def collate_fn(data):
    data.sort(key=lambda x: len(x[0]), reverse=True)
    lens = [len(sent) for sent, label in data]
    labels = []
    padded_sents = torch.zeros(len(data), max(lens)).long()
    for i, (sent, label) in enumerate(data):
        padded_sents[i,:lens[i]] = torch.LongTensor(sent)
        labels.append(label)
    
    padded_sents = padded_sents.transpose(0,1)
    return padded_sents, torch.tensor(labels).long(), lens

In [8]:
train_ds = ReberDataset(train_X, train_Y)
test_ds = ReberDataset(test_X, test_Y)

batch_size = 128
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_dl = DataLoader(train_ds, batch_size=batch_size, collate_fn=collate_fn)

In [13]:
class GRUClassifier(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, n_layers, vocab_size):
        super(GRUClassifier, self).__init__()
        self.hidden_dim = hidden_dim
        self.embedding_dim = embedding_dim
        self.n_layers = n_layers
        self.vocab_size = vocab_size
        
        self.emb = nn.Embedding(vocab_size, embedding_dim)
        self.gru = nn.GRU(embedding_dim, hidden_dim, n_layers)
        self.out = nn.Linear(hidden_dim, 1)
        self.out_act = nn.Sigmoid()

    def init_hidden(self, batch_size):
        return torch.zeros((self.n_layers, batch_size, self.hidden_dim), requires_grad=True).to(device)
        

    def forward(self, seq, lengths):
        self.h = self.init_hidden(seq.size(1))
        print(self.emb(seq).size())
        print(self.emb(seq)[:,0,:])
        embs = pack_padded_sequence(self.emb(seq), lengths)
        gru_out, self.h = self.gru(embs, self.h)
        gru_out, lengths = pad_packed_sequence(gru_out)
        y = self.out(gru_out[-1])
        y = self.out_act(y)
        
        return y

In [14]:
def epoch_train(model, trainloader, optimizer, criterion):
    model.train()
 
    for x_train, y_train, lens in trainloader:
        x_train, y_train = x_train.to(device), y_train.to(device)        
        y_pred = model(x_train, lens)
        
        optimizer.zero_grad()
        loss = criterion(y_pred.squeeze(), y_train.float())
        loss.backward()
        # print(next(model.parameters()))
        nn.utils.clip_grad_value_(model.parameters(), 20)
        optimizer.step()


def epoch_eval(model, testloader):
    model.eval()
    
    correct = 0
    total = 0
    for x_test, y_test, lens in testloader:
        x_test, y_test = x_test.to(device), y_test.to(device)

        y_pred = model(x_test, lens)
        predicted = y_pred.data > 0.3
        total += y_test.size(0)
        correct += (predicted.squeeze() == y_test).sum().item()
        # print('yyy', y_pred.data[0], correct / total)
    
    return correct / total

In [15]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# (vocab_size + 1) is because of pad added to the vocab
model = GRUClassifier(embedding_dim = 7+1, hidden_dim = 10, n_layers = 2, vocab_size = 7+1)
model.to(device)

# specify loss function
criterion = nn.BCELoss()

# specify optimizer
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.95)

# number of epochs
n_epoch = int(300 / len(train_dl))

In [16]:
# model training
for i in range(10):
    acc_test_epochs = []
    epoch_train(model, train_dl, optimizer, criterion)
    acc_test = epoch_eval(model, test_dl)
    acc_train = epoch_eval(model, train_dl)
    print('Train Accuracy: {:.4f}  Test Accuracy: {:.4f}'.format(acc_train, acc_test))
    acc_test_epochs.append(acc_test)

torch.Size([30, 128, 8])
tensor([[ 1.7112,  1.4345,  0.0194,  0.0368,  1.6483,  0.1744,  0.9491, -1.1289],
        [-0.6134,  0.8814,  0.8256, -0.9608, -1.3140,  0.1769, -0.2799,  2.6394],
        [ 1.7112,  1.4345,  0.0194,  0.0368,  1.6483,  0.1744,  0.9491, -1.1289],
        [-0.3921,  2.1876, -1.0027, -0.7868, -1.4162, -0.3498,  0.6517, -0.1904],
        [-0.6134,  0.8814,  0.8256, -0.9608, -1.3140,  0.1769, -0.2799,  2.6394],
        [-0.6134,  0.8814,  0.8256, -0.9608, -1.3140,  0.1769, -0.2799,  2.6394],
        [-1.7368, -1.0895, -0.4092, -0.8719,  0.0866,  1.0801,  0.1282,  0.2546],
        [-0.3921,  2.1876, -1.0027, -0.7868, -1.4162, -0.3498,  0.6517, -0.1904],
        [ 0.5946,  1.4185,  0.3303, -0.0079,  1.3151,  0.0145,  0.3767, -0.4789],
        [-1.7368, -1.0895, -0.4092, -0.8719,  0.0866,  1.0801,  0.1282,  0.2546],
        [-0.3921,  2.1876, -1.0027, -0.7868, -1.4162, -0.3498,  0.6517, -0.1904],
        [ 0.5946,  1.4185,  0.3303, -0.0079,  1.3151,  0.0145,  0.3767, -

tensor([[ 1.7104,  1.4337,  0.0186,  0.0360,  1.6475,  0.1736,  0.9483, -1.1281],
        [-0.3913,  2.1868, -1.0019, -0.7860, -1.4154, -0.3490,  0.6509, -0.1896],
        [ 1.7104,  1.4337,  0.0186,  0.0360,  1.6475,  0.1736,  0.9483, -1.1281],
        [-0.3913,  2.1868, -1.0019, -0.7860, -1.4154, -0.3490,  0.6509, -0.1896],
        [-0.6126,  0.8806,  0.8248, -0.9600, -1.3132,  0.1761, -0.2791,  2.6386],
        [-1.7360, -1.0887, -0.4084, -0.8711,  0.0858,  1.0793,  0.1274,  0.2538],
        [-0.3913,  2.1868, -1.0019, -0.7860, -1.4154, -0.3490,  0.6509, -0.1896],
        [ 0.5938,  1.4177,  0.3295, -0.0071,  1.3143,  0.0137,  0.3759, -0.4781],
        [-1.7360, -1.0887, -0.4084, -0.8711,  0.0858,  1.0793,  0.1274,  0.2538],
        [-0.3913,  2.1868, -1.0019, -0.7860, -1.4154, -0.3490,  0.6509, -0.1896],
        [ 0.5938,  1.4177,  0.3295, -0.0071,  1.3143,  0.0137,  0.3759, -0.4781],
        [-0.6126,  0.8806,  0.8248, -0.9600, -1.3132,  0.1761, -0.2791,  2.6386],
        [-0.6126

tensor([[ 1.7096,  1.4329,  0.0178,  0.0352,  1.6467,  0.1728,  0.9475, -1.1273],
        [-0.6118,  0.8798,  0.8240, -0.9592, -1.3124,  0.1753, -0.2783,  2.6378],
        [ 1.7096,  1.4329,  0.0178,  0.0352,  1.6467,  0.1728,  0.9475, -1.1273],
        [-0.3905,  2.1860, -1.0011, -0.7852, -1.4146, -0.3482,  0.6501, -0.1888],
        [-0.6118,  0.8798,  0.8240, -0.9592, -1.3124,  0.1753, -0.2783,  2.6378],
        [-0.6118,  0.8798,  0.8240, -0.9592, -1.3124,  0.1753, -0.2783,  2.6378],
        [-1.7352, -1.0879, -0.4076, -0.8703,  0.0850,  1.0785,  0.1266,  0.2530],
        [-0.3905,  2.1860, -1.0011, -0.7852, -1.4146, -0.3482,  0.6501, -0.1888],
        [ 0.5930,  1.4169,  0.3287, -0.0063,  1.3135,  0.0129,  0.3751, -0.4773],
        [-1.7352, -1.0879, -0.4076, -0.8703,  0.0850,  1.0785,  0.1266,  0.2530],
        [-0.3905,  2.1860, -1.0011, -0.7852, -1.4146, -0.3482,  0.6501, -0.1888],
        [ 0.5930,  1.4169,  0.3287, -0.0063,  1.3135,  0.0129,  0.3751, -0.4773],
        [-0.6118

torch.Size([28, 128, 8])
tensor([[ 1.7087,  1.4320,  0.0169,  0.0344,  1.6458,  0.1719,  0.9466, -1.1264],
        [-0.3896,  2.1851, -1.0002, -0.7843, -1.4137, -0.3473,  0.6493, -0.1879],
        [ 1.7087,  1.4320,  0.0169,  0.0344,  1.6458,  0.1719,  0.9466, -1.1264],
        [-0.3896,  2.1851, -1.0002, -0.7843, -1.4137, -0.3473,  0.6493, -0.1879],
        [-0.6109,  0.8789,  0.8231, -0.9583, -1.3115,  0.1744, -0.2774,  2.6369],
        [-0.6109,  0.8789,  0.8231, -0.9583, -1.3115,  0.1744, -0.2774,  2.6369],
        [-0.6109,  0.8789,  0.8231, -0.9583, -1.3115,  0.1744, -0.2774,  2.6369],
        [-0.6109,  0.8789,  0.8231, -0.9583, -1.3115,  0.1744, -0.2774,  2.6369],
        [-0.6109,  0.8789,  0.8231, -0.9583, -1.3115,  0.1744, -0.2774,  2.6369],
        [-1.7343, -1.0870, -0.4067, -0.8694,  0.0841,  1.0776,  0.1257,  0.2521],
        [-0.3896,  2.1851, -1.0002, -0.7843, -1.4137, -0.3473,  0.6493, -0.1879],
        [ 0.5921,  1.4160,  0.3278, -0.0054,  1.3126,  0.0120,  0.3742, -

torch.Size([24, 128, 8])
tensor([[ 1.7078,  1.4311,  0.0161,  0.0335,  1.6449,  0.1710,  0.9457, -1.1255],
        [-0.3887,  2.1842, -0.9993, -0.7834, -1.4128, -0.3464,  0.6484, -0.1870],
        [ 1.7078,  1.4311,  0.0161,  0.0335,  1.6449,  0.1710,  0.9457, -1.1255],
        [-0.6100,  0.8780,  0.8222, -0.9574, -1.3106,  0.1735, -0.2765,  2.6360],
        [-0.3006, -1.1189,  0.3567, -1.4395, -0.0265, -2.5237,  1.2924, -1.8625],
        [ 0.5912,  1.4151,  0.3269, -0.0046,  1.3117,  0.0112,  0.3733, -0.4755],
        [ 0.5912,  1.4151,  0.3269, -0.0046,  1.3117,  0.0112,  0.3733, -0.4755],
        [-0.6100,  0.8780,  0.8222, -0.9574, -1.3106,  0.1735, -0.2765,  2.6360],
        [-1.7334, -1.0861, -0.4058, -0.8685,  0.0832,  1.0767,  0.1248,  0.2512],
        [-0.3887,  2.1842, -0.9993, -0.7834, -1.4128, -0.3464,  0.6484, -0.1870],
        [ 0.5912,  1.4151,  0.3269, -0.0046,  1.3117,  0.0112,  0.3733, -0.4755],
        [-1.7334, -1.0861, -0.4058, -0.8685,  0.0832,  1.0767,  0.1248,  

torch.Size([33, 128, 8])
tensor([[ 1.7069,  1.4302,  0.0152,  0.0326,  1.6440,  0.1701,  0.9448, -1.1246],
        [-0.3878,  2.1833, -0.9984, -0.7825, -1.4119, -0.3455,  0.6475, -0.1862],
        [ 1.7069,  1.4302,  0.0152,  0.0326,  1.6440,  0.1701,  0.9448, -1.1246],
        [-0.6091,  0.8771,  0.8213, -0.9565, -1.3097,  0.1726, -0.2756,  2.6351],
        [ 0.5903,  1.4142,  0.3260, -0.0039,  1.3108,  0.0103,  0.3724, -0.4746],
        [ 0.5903,  1.4142,  0.3260, -0.0039,  1.3108,  0.0103,  0.3724, -0.4746],
        [-0.6091,  0.8771,  0.8213, -0.9565, -1.3097,  0.1726, -0.2756,  2.6351],
        [-0.6091,  0.8771,  0.8213, -0.9565, -1.3097,  0.1726, -0.2756,  2.6351],
        [-1.7325, -1.0852, -0.4049, -0.8676,  0.0823,  1.0758,  0.1239,  0.2503],
        [-0.3878,  2.1833, -0.9984, -0.7825, -1.4119, -0.3455,  0.6475, -0.1862],
        [ 0.5903,  1.4142,  0.3260, -0.0039,  1.3108,  0.0103,  0.3724, -0.4746],
        [-0.6091,  0.8771,  0.8213, -0.9565, -1.3097,  0.1726, -0.2756,  

tensor([[ 1.7061e+00,  1.4294e+00,  1.4487e-02,  3.1835e-02,  1.6432e+00,
          1.6931e-01,  9.4403e-01, -1.1238e+00],
        [-3.8697e-01,  2.1825e+00, -9.9757e-01, -7.8174e-01, -1.4111e+00,
         -3.4470e-01,  6.4665e-01, -1.8536e-01],
        [ 1.7061e+00,  1.4294e+00,  1.4487e-02,  3.1835e-02,  1.6432e+00,
          1.6931e-01,  9.4403e-01, -1.1238e+00],
        [-3.8697e-01,  2.1825e+00, -9.9757e-01, -7.8174e-01, -1.4111e+00,
         -3.4470e-01,  6.4665e-01, -1.8536e-01],
        [-1.7317e+00, -1.0844e+00, -4.0406e-01, -8.6683e-01,  8.1543e-02,
          1.0750e+00,  1.2311e-01,  2.4953e-01],
        [-3.8697e-01,  2.1825e+00, -9.9757e-01, -7.8174e-01, -1.4111e+00,
         -3.4470e-01,  6.4665e-01, -1.8536e-01],
        [ 5.8953e-01,  1.4134e+00,  3.2519e-01, -3.2844e-03,  1.3100e+00,
          9.6295e-03,  3.7156e-01, -4.7379e-01],
        [-1.7317e+00, -1.0844e+00, -4.0406e-01, -8.6683e-01,  8.1543e-02,
          1.0750e+00,  1.2311e-01,  2.4953e-01],
        [-3.8697

tensor([[ 1.7053e+00,  1.4286e+00,  1.3773e-02,  3.1077e-02,  1.6424e+00,
          1.6852e-01,  9.4323e-01, -1.1230e+00],
        [-6.0747e-01,  8.7555e-01,  8.1966e-01, -9.5491e-01, -1.3081e+00,
          1.7100e-01, -2.7398e-01,  2.6335e+00],
        [ 1.7053e+00,  1.4286e+00,  1.3773e-02,  3.1077e-02,  1.6424e+00,
          1.6852e-01,  9.4323e-01, -1.1230e+00],
        [-3.8618e-01,  2.1817e+00, -9.9677e-01, -7.8094e-01, -1.4103e+00,
         -3.4391e-01,  6.4586e-01, -1.8457e-01],
        [-6.0747e-01,  8.7555e-01,  8.1966e-01, -9.5491e-01, -1.3081e+00,
          1.7100e-01, -2.7398e-01,  2.6335e+00],
        [-6.0747e-01,  8.7555e-01,  8.1966e-01, -9.5491e-01, -1.3081e+00,
          1.7100e-01, -2.7398e-01,  2.6335e+00],
        [-6.0747e-01,  8.7555e-01,  8.1966e-01, -9.5491e-01, -1.3081e+00,
          1.7100e-01, -2.7398e-01,  2.6335e+00],
        [-1.7309e+00, -1.0836e+00, -4.0327e-01, -8.6603e-01,  8.0761e-02,
          1.0742e+00,  1.2232e-01,  2.4874e-01],
        [-3.8618

tensor([[ 1.7046e+00,  1.4279e+00,  1.3164e-02,  3.0422e-02,  1.6417e+00,
          1.6782e-01,  9.4253e-01, -1.1223e+00],
        [-3.8548e-01,  2.1810e+00, -9.9607e-01, -7.8024e-01, -1.4096e+00,
         -3.4321e-01,  6.4516e-01, -1.8387e-01],
        [ 1.7046e+00,  1.4279e+00,  1.3164e-02,  3.0422e-02,  1.6417e+00,
          1.6782e-01,  9.4253e-01, -1.1223e+00],
        [-6.0677e-01,  8.7485e-01,  8.1896e-01, -9.5421e-01, -1.3074e+00,
          1.7031e-01, -2.7329e-01,  2.6328e+00],
        [-2.9742e-01, -1.1157e+00,  3.5348e-01, -1.4363e+00, -2.3501e-02,
         -2.5205e+00,  1.2892e+00, -1.8593e+00],
        [-2.9742e-01, -1.1157e+00,  3.5348e-01, -1.4363e+00, -2.3501e-02,
         -2.5205e+00,  1.2892e+00, -1.8593e+00],
        [ 5.8803e-01,  1.4119e+00,  3.2370e-01, -2.3162e-03,  1.3085e+00,
          8.3780e-03,  3.7007e-01, -4.7230e-01],
        [ 5.8803e-01,  1.4119e+00,  3.2370e-01, -2.3162e-03,  1.3085e+00,
          8.3780e-03,  3.7007e-01, -4.7230e-01],
        [-6.0677

tensor([[ 1.7038e+00,  1.4271e+00,  1.2486e-02,  2.9682e-02,  1.6409e+00,
          1.6704e-01,  9.4174e-01, -1.1215e+00],
        [-6.0597e-01,  8.7405e-01,  8.1817e-01, -9.5341e-01, -1.3066e+00,
          1.6952e-01, -2.7249e-01,  2.6320e+00],
        [ 1.7038e+00,  1.4271e+00,  1.2486e-02,  2.9682e-02,  1.6409e+00,
          1.6704e-01,  9.4174e-01, -1.1215e+00],
        [-6.0597e-01,  8.7405e-01,  8.1817e-01, -9.5341e-01, -1.3066e+00,
          1.6952e-01, -2.7249e-01,  2.6320e+00],
        [-2.9663e-01, -1.1149e+00,  3.5268e-01, -1.4355e+00, -2.2776e-02,
         -2.5197e+00,  1.2884e+00, -1.8585e+00],
        [ 5.8724e-01,  1.4111e+00,  3.2290e-01, -1.8857e-03,  1.3077e+00,
          7.7507e-03,  3.6927e-01, -4.7150e-01],
        [ 5.8724e-01,  1.4111e+00,  3.2290e-01, -1.8857e-03,  1.3077e+00,
          7.7507e-03,  3.6927e-01, -4.7150e-01],
        [-6.0597e-01,  8.7405e-01,  8.1817e-01, -9.5341e-01, -1.3066e+00,
          1.6952e-01, -2.7249e-01,  2.6320e+00],
        [-6.0597

tensor([[ 1.7032e+00,  1.4264e+00,  1.1910e-02,  2.9042e-02,  1.6402e+00,
          1.6635e-01,  9.4104e-01, -1.1208e+00],
        [-6.0528e-01,  8.7335e-01,  8.1747e-01, -9.5271e-01, -1.3059e+00,
          1.6883e-01, -2.7180e-01,  2.6314e+00],
        [ 1.7032e+00,  1.4264e+00,  1.1910e-02,  2.9042e-02,  1.6402e+00,
          1.6635e-01,  9.4104e-01, -1.1208e+00],
        [-6.0528e-01,  8.7335e-01,  8.1747e-01, -9.5271e-01, -1.3059e+00,
          1.6883e-01, -2.7180e-01,  2.6314e+00],
        [-2.9594e-01, -1.1142e+00,  3.5199e-01, -1.4348e+00, -2.2152e-02,
         -2.5190e+00,  1.2877e+00, -1.8578e+00],
        [-2.9594e-01, -1.1142e+00,  3.5199e-01, -1.4348e+00, -2.2152e-02,
         -2.5190e+00,  1.2877e+00, -1.8578e+00],
        [ 5.8654e-01,  1.4104e+00,  3.2221e-01, -1.5571e-03,  1.3070e+00,
          7.2253e-03,  3.6858e-01, -4.7081e-01],
        [ 5.8654e-01,  1.4104e+00,  3.2221e-01, -1.5571e-03,  1.3070e+00,
          7.2253e-03,  3.6858e-01, -4.7081e-01],
        [-1.7287

torch.Size([46, 128, 8])
tensor([[ 1.7024e+00,  1.4256e+00,  1.1271e-02,  2.8320e-02,  1.6394e+00,
          1.6556e-01,  9.4024e-01, -1.1200e+00],
        [-3.8320e-01,  2.1787e+00, -9.9378e-01, -7.7795e-01, -1.4073e+00,
         -3.4093e-01,  6.4287e-01, -1.8161e-01],
        [ 1.7024e+00,  1.4256e+00,  1.1271e-02,  2.8320e-02,  1.6394e+00,
          1.6556e-01,  9.4024e-01, -1.1200e+00],
        [ 5.8574e-01,  1.4096e+00,  3.2142e-01, -1.2324e-03,  1.3062e+00,
          6.6529e-03,  3.6779e-01, -4.7001e-01],
        [-2.9515e-01, -1.1134e+00,  3.5120e-01, -1.4340e+00, -2.1450e-02,
         -2.5182e+00,  1.2869e+00, -1.8570e+00],
        [-2.9515e-01, -1.1134e+00,  3.5120e-01, -1.4340e+00, -2.1450e-02,
         -2.5182e+00,  1.2869e+00, -1.8570e+00],
        [-2.9515e-01, -1.1134e+00,  3.5120e-01, -1.4340e+00, -2.1450e-02,
         -2.5182e+00,  1.2869e+00, -1.8570e+00],
        [-3.8320e-01,  2.1787e+00, -9.9378e-01, -7.7795e-01, -1.4073e+00,
         -3.4093e-01,  6.4287e-01, -1.81

tensor([[ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-3.8270e-01,  2.1782e+00, -9.9328e-01, -7.7745e-01, -1.4068e+00,
         -3.4043e-01,  6.4237e-01, -1.8112e-01],
        [-1.7274e+00, -1.0802e+00, -3.9979e-01, -8.6254e-01,  7.7380e-02,
          1.0707e+00,  1.1890e-01,  2.4528e-01],
        [-3.8270e-01,  2.1782e+00, -9.9328e-01, -7.7745e-01, -1.4068e+00,
         -3.4043e-01,  6.4237e-01, -1.8112e-01],
        [ 5.8525e-01,  1.4091e+00,  3.2092e-01, -1.0557e-03,  1.3057e+00,
          6.3102e-03,  3.6729e-01, -4.6952e-01],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-6.0398

tensor([[ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-3.8270e-01,  2.1782e+00, -9.9328e-01, -7.7745e-01, -1.4068e+00,
         -3.4043e-01,  6.4237e-01, -1.8112e-01],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-1.7274e+00, -1.0802e+00, -3.9979e-01, -8.6254e-01,  7.7380e-02,
          1.0707e+00,  1.1890e-01,  2.4528e-01],
        [-3.8270

torch.Size([30, 128, 8])
tensor([[ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-2.9465e-01, -1.1129e+00,  3.5070e-01, -1.4335e+00, -2.1018e-02,
         -2.5177e+00,  1.2864e+00, -1.8565e+00],
        [-2.9465e-01, -1.1129e+00,  3.5070e-01, -1.4335e+00, -2.1018e-02,
         -2.5177e+00,  1.2864e+00, -1.8565e+00],
        [-2.9465e-01, -1.1129e+00,  3.5070e-01, -1.4335e+00, -2.1018e-02,
         -2.5177e+00,  1.2864e+00, -1.8565e+00],
        [-2.9465e-01, -1.1129e+00,  3.5070e-01, -1.4335e+00, -2.1018e-02,
         -2.5177e+00,  1.2864e+00, -1.85

torch.Size([28, 128, 8])
tensor([[ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-3.8270e-01,  2.1782e+00, -9.9328e-01, -7.7745e-01, -1.4068e+00,
         -3.4043e-01,  6.4237e-01, -1.8112e-01],
        [ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-2.9465e-01, -1.1129e+00,  3.5070e-01, -1.4335e+00, -2.1018e-02,
         -2.5177e+00,  1.2864e+00, -1.8565e+00],
        [ 5.8525e-01,  1.4091e+00,  3.2092e-01, -1.0557e-03,  1.3057e+00,
          6.3102e-03,  3.6729e-01, -4.6952e-01],
        [ 5.8525e-01,  1.4091e+00,  3.2092e-01, -1.0557e-03,  1.3057e+00,
          6.3102e-03,  3.6729e-01, -4.6952e-01],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.63

tensor([[ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [ 5.8525e-01,  1.4091e+00,  3.2092e-01, -1.0557e-03,  1.3057e+00,
          6.3102e-03,  3.6729e-01, -4.6952e-01],
        [ 5.8525e-01,  1.4091e+00,  3.2092e-01, -1.0557e-03,  1.3057e+00,
          6.3102e-03,  3.6729e-01, -4.6952e-01],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-6.0398

tensor([[ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-3.8270e-01,  2.1782e+00, -9.9328e-01, -7.7745e-01, -1.4068e+00,
         -3.4043e-01,  6.4237e-01, -1.8112e-01],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-1.7274e+00, -1.0802e+00, -3.9979e-01, -8.6254e-01,  7.7380e-02,
          1.0707e+00,  1.1890e-01,  2.4528e-01],
        [-3.8270e-01,  2.1782e+00, -9.9328e-01, -7.7745e-01, -1.4068e+00,
         -3.4043e-01,  6.4237e-01, -1.8112e-01],
        [ 5.8525e-01,  1.4091e+00,  3.2092e-01, -1.0557e-03,  1.3057e+00,
          6.3102e-03,  3.6729e-01, -4.6952e-01],
        [-6.0398

tensor([[ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-3.8270e-01,  2.1782e+00, -9.9328e-01, -7.7745e-01, -1.4068e+00,
         -3.4043e-01,  6.4237e-01, -1.8112e-01],
        [ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [ 5.8525e-01,  1.4091e+00,  3.2092e-01, -1.0557e-03,  1.3057e+00,
          6.3102e-03,  3.6729e-01, -4.6952e-01],
        [-3.8270e-01,  2.1782e+00, -9.9328e-01, -7.7745e-01, -1.4068e+00,
         -3.4043e-01,  6.4237e-01, -1.8112e-01],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-1.7274e+00, -1.0802e+00, -3.9979e-01, -8.6254e-01,  7.7380e-02,
          1.0707e+00,  1.1890e-01,  2.4528e-01],
        [-3.8270

tensor([[ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-3.8270e-01,  2.1782e+00, -9.9328e-01, -7.7745e-01, -1.4068e+00,
         -3.4043e-01,  6.4237e-01, -1.8112e-01],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-1.7274e+00, -1.0802e+00, -3.9979e-01, -8.6254e-01,  7.7380e-02,
          1.0707e+00,  1.1890e-01,  2.4528e-01],
        [-3.8270

torch.Size([28, 128, 8])
tensor([[ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-3.8270e-01,  2.1782e+00, -9.9328e-01, -7.7745e-01, -1.4068e+00,
         -3.4043e-01,  6.4237e-01, -1.8112e-01],
        [ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-2.9465e-01, -1.1129e+00,  3.5070e-01, -1.4335e+00, -2.1018e-02,
         -2.5177e+00,  1.2864e+00, -1.8565e+00],
        [-2.9465e-01, -1.1129e+00,  3.5070e-01, -1.4335e+00, -2.1018e-02,
         -2.5177e+00,  1.2864e+00, -1.8565e+00],
        [ 5.8525e-01,  1.4091e+00,  3.2092e-01, -1.0557e-03,  1.3057e+00,
          6.3102e-03,  3.6729e-01, -4.6952e-01],
        [ 5.8525e-01,  1.4091e+00,  3.2092e-01, -1.0557e-03,  1.3057e+00,
          6.3102e-03,  3.6729e-01, -4.69

tensor([[ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [ 5.8525e-01,  1.4091e+00,  3.2092e-01, -1.0557e-03,  1.3057e+00,
          6.3102e-03,  3.6729e-01, -4.6952e-01],
        [ 5.8525e-01,  1.4091e+00,  3.2092e-01, -1.0557e-03,  1.3057e+00,
          6.3102e-03,  3.6729e-01, -4.6952e-01],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-1.7274

torch.Size([32, 128, 8])
tensor([[ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-3.8270e-01,  2.1782e+00, -9.9328e-01, -7.7745e-01, -1.4068e+00,
         -3.4043e-01,  6.4237e-01, -1.8112e-01],
        [ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-2.9465e-01, -1.1129e+00,  3.5070e-01, -1.4335e+00, -2.1018e-02,
         -2.5177e+00,  1.2864e+00, -1.8565e+00],
        [-2.9465e-01, -1.1129e+00,  3.5070e-01, -1.4335e+00, -2.1018e-02,
         -2.5177e+00,  1.2864e+00, -1.8565e+00],
        [-2.9465e-01, -1.1129e+00,  3.5070e-01, -1.4335e+00, -2.1018e-02,
         -2.5177e+00,  1.2864e+00, -1.8565e+00],
        [-2.9465e-01, -1.1129e+00,  3.5070e-01, -1.4335e+00, -2.1018e-02,
         -2.5177e+00,  1.2864e+00, -1.85

tensor([[ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-2.9465e-01, -1.1129e+00,  3.5070e-01, -1.4335e+00, -2.1018e-02,
         -2.5177e+00,  1.2864e+00, -1.8565e+00],
        [-2.9465e-01, -1.1129e+00,  3.5070e-01, -1.4335e+00, -2.1018e-02,
         -2.5177e+00,  1.2864e+00, -1.8565e+00],
        [ 5.8525e-01,  1.4091e+00,  3.2092e-01, -1.0557e-03,  1.3057e+00,
          6.3102e-03,  3.6729e-01, -4.6952e-01],
        [ 5.8525e-01,  1.4091e+00,  3.2092e-01, -1.0557e-03,  1.3057e+00,
          6.3102e-03,  3.6729e-01, -4.6952e-01],
        [-1.7274

tensor([[ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-3.8270e-01,  2.1782e+00, -9.9328e-01, -7.7745e-01, -1.4068e+00,
         -3.4043e-01,  6.4237e-01, -1.8112e-01],
        [ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-2.9465e-01, -1.1129e+00,  3.5070e-01, -1.4335e+00, -2.1018e-02,
         -2.5177e+00,  1.2864e+00, -1.8565e+00],
        [-2.9465e-01, -1.1129e+00,  3.5070e-01, -1.4335e+00, -2.1018e-02,
         -2.5177e+00,  1.2864e+00, -1.8565e+00],
        [-2.9465e-01, -1.1129e+00,  3.5070e-01, -1.4335e+00, -2.1018e-02,
         -2.5177e+00,  1.2864e+00, -1.8565e+00],
        [ 5.8525e-01,  1.4091e+00,  3.2092e-01, -1.0557e-03,  1.3057e+00,
          6.3102e-03,  3.6729e-01, -4.6952e-01],
        [ 5.8525

tensor([[ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-3.8270e-01,  2.1782e+00, -9.9328e-01, -7.7745e-01, -1.4068e+00,
         -3.4043e-01,  6.4237e-01, -1.8112e-01],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-1.7274e+00, -1.0802e+00, -3.9979e-01, -8.6254e-01,  7.7380e-02,
          1.0707e+00,  1.1890e-01,  2.4528e-01],
        [-3.8270e-01,  2.1782e+00, -9.9328e-01, -7.7745e-01, -1.4068e+00,
         -3.4043e-01,  6.4237e-01, -1.8112e-01],
        [ 5.8525e-01,  1.4091e+00,  3.2092e-01, -1.0557e-03,  1.3057e+00,
          6.3102e-03,  3.6729e-01, -4.6952e-01],
        [-6.0398

tensor([[ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [ 1.7019e+00,  1.4251e+00,  1.0882e-02,  2.7874e-02,  1.6389e+00,
          1.6507e-01,  9.3974e-01, -1.1195e+00],
        [ 5.8525e-01,  1.4091e+00,  3.2092e-01, -1.0557e-03,  1.3057e+00,
          6.3102e-03,  3.6729e-01, -4.6952e-01],
        [-3.8270e-01,  2.1782e+00, -9.9328e-01, -7.7745e-01, -1.4068e+00,
         -3.4043e-01,  6.4237e-01, -1.8112e-01],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-6.0398e-01,  8.7206e-01,  8.1617e-01, -9.5142e-01, -1.3046e+00,
          1.6756e-01, -2.7052e-01,  2.6301e+00],
        [-1.7274

torch.Size([32, 128, 8])
tensor([[ 1.7015e+00,  1.4247e+00,  1.0576e-02,  2.7520e-02,  1.6385e+00,
          1.6468e-01,  9.3934e-01, -1.1191e+00],
        [-6.0359e-01,  8.7166e-01,  8.1578e-01, -9.5102e-01, -1.3042e+00,
          1.6717e-01, -2.7012e-01,  2.6297e+00],
        [ 1.7015e+00,  1.4247e+00,  1.0576e-02,  2.7520e-02,  1.6385e+00,
          1.6468e-01,  9.3934e-01, -1.1191e+00],
        [-3.8231e-01,  2.1778e+00, -9.9288e-01, -7.7705e-01, -1.4064e+00,
         -3.4004e-01,  6.4197e-01, -1.8073e-01],
        [-6.0359e-01,  8.7166e-01,  8.1578e-01, -9.5102e-01, -1.3042e+00,
          1.6717e-01, -2.7012e-01,  2.6297e+00],
        [-6.0359e-01,  8.7166e-01,  8.1578e-01, -9.5102e-01, -1.3042e+00,
          1.6717e-01, -2.7012e-01,  2.6297e+00],
        [-6.0359e-01,  8.7166e-01,  8.1578e-01, -9.5102e-01, -1.3042e+00,
          1.6717e-01, -2.7012e-01,  2.6297e+00],
        [-6.0359e-01,  8.7166e-01,  8.1578e-01, -9.5102e-01, -1.3042e+00,
          1.6717e-01, -2.7012e-01,  2.62

KeyboardInterrupt: 

In [13]:
train_X

[array([1, 2, 1, 2, 4, 4, 2, 2, 2, 6, 5, 3, 7, 2]),
 array([1, 5, 1, 2, 4, 4, 2, 6, 5, 4, 2, 2, 2, 2, 6, 6, 7, 5]),
 array([1, 2, 1, 5, 2, 2, 6, 5, 4, 6, 5, 3, 7, 2]),
 array([1, 5, 1, 5, 2, 2, 6, 5, 4, 2, 2, 6, 5, 4, 6, 6, 7, 5]),
 array([1, 2, 1, 2, 3, 4, 4, 2, 2, 6, 5, 4, 2, 2, 2, 2, 6, 5, 3, 7, 2]),
 array([1, 5, 1, 5, 6, 5, 4, 6, 5, 4, 2, 2, 6, 6, 7, 5]),
 array([1, 5, 1, 2, 3, 3, 4, 4, 6, 5, 4, 6, 6, 7, 5]),
 array([1, 2, 1, 2, 4, 4, 6, 5, 4, 2, 6, 6, 7, 2]),
 array([1, 5, 1, 2, 3, 4, 4, 6, 5, 4, 2, 6, 5, 3, 7, 5]),
 array([1, 5, 1, 5, 2, 6, 5, 4, 2, 2, 2, 2, 6, 6, 7, 5]),
 array([1, 5, 1, 2, 3, 3, 4, 4, 2, 6, 5, 3, 7, 5]),
 array([1, 2, 1, 5, 2, 2, 2, 2, 6, 5, 4, 6, 6, 7, 2]),
 array([1, 2, 1, 2, 3, 4, 4, 6, 5, 4, 6, 6, 7, 2]),
 array([1, 5, 1, 5, 6, 5, 4, 2, 6, 5, 4, 6, 5, 3, 7, 5]),
 array([1, 5, 1, 5, 2, 6, 5, 4, 6, 5, 4, 6, 6, 7, 5]),
 array([1, 5, 1, 5, 2, 6, 5, 4, 6, 5, 4, 2, 2, 2, 6, 6, 7, 5]),
 array([1, 5, 1, 2, 3, 3, 4, 4, 2, 2, 2, 2, 2, 2, 6, 5, 4, 2, 6, 6, 7, 5]),
 a