# Dataloading 01

In this notebook, we'll figure out how to use PyTorch's DataLoader class to load our massive files without reading the entirety of them into memory

In [75]:
import comet_ml
import dask.dataframe as dd
import pandas as pd 
import torch
import linecache 
import csv
import numpy as np
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import torch.nn.functional as F
import sys, os
from pathlib import Path
from sklearn.utils.class_weight import compute_class_weight

here = Path().cwd()

We'll first design a custom dataset to use with PyTorch's `DataLoader` class

In [100]:
class GeneExpressionData(Dataset):
    def __init__(self, filename, labelname):
        self._filename = filename
        self._labelname = pd.read_csv(labelname)
        self._total_data = 0
        
        with open(filename, "r") as f:
            self._total_data = len(f.readlines()) - 1
    
    def __getitem__(self, idx):
        line = linecache.getline(self._filename, idx + 2)
        csv_data = csv.reader([line])
        data = [x for x in csv_data][0]
        
        label = self._labelname.loc[idx, '# label']
        return torch.from_numpy(np.array([float(x) for x in data])).float(), label
    
    def __len__(self):
        return self._total_data
    
    def num_labels(self):
        return self._labelname['# label'].nunique()
    
    def num_features(self):
        return len(self.__getitem__(0)[0])

    def compute_class_weights(self):
        weights = compute_class_weight(
            class_weight='balanced', 
            classes=np.unique(self._labelname['# label'].values), 
            y=self._labelname['# label'].values
        )    

        weights = torch.from_numpy(weights)
        return weights.float().to('cuda')

Since PyTorch loss functions require classes in $[0, C]$, we'll first add $1$ to the labels and re-write it out so we can use it for training

In [101]:
def fix_labels(file):
    labels = pd.read_csv(file)
    labels['# label'] = labels['# label'].astype(int) + 1
    labels.to_csv('fixed_' + file.split('/')[-1], index=False)

fix_labels('../data/processed/meta_primary_labels.csv')

Great, we now continue as normal

In [102]:
t = GeneExpressionData(
    filename=os.path.join(here, '../data/processed/umap/primary_reduction_neighbors_100_components_3.csv'),
    labelname=os.path.join('../data/processed/meta_primary_labels.csv')
)
t.compute_class_weights()

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29]


tensor([ 23.5583,   1.7000,  23.2975,   4.5162,   4.2203,   6.5156,   0.6397,
          1.6949,   0.6687,   0.8051,   0.3609,   6.2326,   4.8905,   1.9614,
         17.3929,   0.1903,   3.3441,   2.1526,   1.8785,   0.7911,   0.3222,
          0.7065,   0.4820,   0.9035,   1.5354,   3.3459,   0.7220,   2.0976,
        170.6387,   0.5595], dtype=torch.float64)

Let's see how fast it takes to load a minibatch of data

In [68]:
%%time 

for i in range(0,100):
    print(t.__getitem__(i))

(tensor([6.3848, 2.4303, 7.7812]), 26)
(tensor([6.3639, 2.4563, 7.7301]), 26)
(tensor([6.3428, 2.4779, 7.7295]), 8)
(tensor([6.3877, 2.4754, 7.7310]), 15)
(tensor([6.4995, 2.5715, 7.8413]), 8)
(tensor([6.4212, 2.4951, 7.7277]), 8)
(tensor([6.3750, 2.5142, 7.6449]), 29)
(tensor([6.3622, 2.5209, 7.6304]), 8)
(tensor([6.3773, 2.5858, 7.6911]), 8)
(tensor([6.4225, 2.5520, 7.7651]), 15)
(tensor([6.3794, 2.5182, 7.7005]), 29)
(tensor([6.4494, 2.6395, 7.6515]), 8)
(tensor([6.2893, 2.7212, 7.7196]), 8)
(tensor([6.4438, 2.6235, 7.7850]), 15)
(tensor([6.3955, 2.4870, 7.7330]), 8)
(tensor([6.4113, 2.5882, 7.7068]), 29)
(tensor([6.3515, 2.5249, 7.6835]), 24)
(tensor([2.5176, 5.6257, 3.4814]), 14)
(tensor([6.4297, 2.5550, 7.7791]), 8)
(tensor([6.3944, 2.6301, 7.7066]), 8)
(tensor([6.3972, 2.5236, 7.7470]), 6)
(tensor([6.4212, 2.5425, 7.7602]), 8)
(tensor([6.4844, 2.5889, 7.7253]), 8)
(tensor([6.4816, 2.5496, 7.7555]), 16)
(tensor([6.3856, 2.6333, 7.7570]), 6)
(tensor([6.3817, 2.5582, 7.7225]), 8)
(

Before we train our model, we need to split our data into training and testing sets, in order to get an unbiased evaluation of our model's performance. Likely, we will initially overfit the training set since we provide no regularization.

In [14]:
train_size = int(0.8 * len(t))
test_size = len(t) - train_size

train, test = torch.utils.data.random_split(t, [train_size, test_size])

In [15]:
traindata = DataLoader(train, batch_size = 8, num_workers = 0)
valdata = DataLoader(test, batch_size = 8, num_workers = 0)

In [17]:
for X, y in traindata:
    print(X, y)

tensor([[ 3.4116,  4.3611,  3.6923],
        [ 1.9499,  3.7356,  6.1307],
        [ 1.0966,  0.4327,  5.5884],
        [ 2.6846,  7.8154,  6.6181],
        [ 4.5364,  6.9276,  2.7043],
        [ 0.7097, -0.3879,  5.3842],
        [ 1.8299,  2.5179,  1.5972],
        [ 2.5743,  7.7902,  6.7043]]) tensor([15, 12,  4,  9, 13,  4, 15,  9])
tensor([[ 4.2022,  8.7407,  1.4905],
        [ 6.6213,  7.5119,  0.7633],
        [ 2.6400,  8.4390,  6.2129],
        [-1.3735,  8.4140,  2.3972],
        [ 4.1918,  4.3778,  3.5796],
        [ 2.8546,  8.0918,  6.0737],
        [-2.7102,  3.1529,  5.9547],
        [-2.3790,  3.5089,  6.2377]]) tensor([ 8,  8,  9, 11, 15,  0,  3,  3])
tensor([[ 2.3997,  7.2758,  3.2588],
        [ 1.4293,  5.1844,  3.4736],
        [ 2.0637,  1.4185,  2.8869],
        [ 5.0089,  5.8322, -0.9349],
        [ 5.1753,  4.2075,  5.1987],
        [ 2.3370,  9.0350,  2.4659],
        [ 1.9502,  3.6692,  6.2672],
        [ 1.3264,  9.7038,  5.3532]]) tensor([10,  0, 15,  8, 12,

tensor([[-2.7478,  3.6829,  5.8313],
        [-2.5611,  3.0203,  6.4444],
        [ 4.9854,  5.8952,  4.5249],
        [-2.5121,  3.1054,  6.2594],
        [-2.5242,  3.4251,  5.8270],
        [ 2.1844,  1.2057,  2.7340],
        [ 1.4234,  8.3443, -0.7206],
        [-0.6267,  3.9891,  4.3014]]) tensor([ 3,  3,  8,  3,  3, 15,  9,  0])
tensor([[-0.8864,  1.6219,  5.7295],
        [ 2.6921,  3.7845,  6.5563],
        [ 3.5187,  7.6315,  3.3477],
        [ 2.3783,  9.0114,  2.2746],
        [-1.5731,  8.5866,  2.4652],
        [ 2.3053,  7.3353,  3.3078],
        [ 4.0524,  4.4479,  4.7053],
        [ 3.2431,  5.2917,  3.7889]]) tensor([ 4, 12, 14,  9, 11, 10, 12,  0])
tensor([[ 1.5196,  1.8262,  2.9626],
        [ 6.8126,  7.8132,  3.8017],
        [-2.9219,  4.6821,  5.3783],
        [ 4.1805,  8.8457,  1.4872],
        [-1.9178,  8.1225,  3.3194],
        [ 2.7356,  0.4884,  2.2067],
        [ 3.3459,  2.6750,  4.5240],
        [ 1.2887,  1.3725,  1.8575]]) tensor([15,  8,  3,  8,  0,

        [ 3.4790,  6.5520,  4.8415]]) tensor([ 4, 15, 11, 15, 10,  8,  8,  9])
tensor([[ 3.8245,  5.6319,  2.7255],
        [ 2.0514, 11.3805,  5.8315],
        [ 3.4684,  6.5693,  3.4868],
        [ 3.0782,  7.6227,  4.9318],
        [ 2.5396,  7.4892,  4.4655],
        [ 2.5059,  7.3475,  3.4376],
        [-4.0559,  4.9390,  1.6056],
        [ 0.2598,  4.6911,  4.8763]]) tensor([ 8,  9, 14,  8,  8, 10, 11, 12])
tensor([[ 4.0558,  2.3200,  4.8797],
        [ 3.7718,  6.9656,  3.7377],
        [-0.1403,  0.1820,  5.1755],
        [ 2.4506,  4.3804,  2.7011],
        [-0.3305,  8.3629,  4.8598],
        [ 2.1617,  2.5490,  1.5321],
        [ 3.9690,  4.0832,  3.1315],
        [ 3.8429,  4.2248,  5.1621]]) tensor([15, 14,  4,  0, 11, 15, 15, 12])
tensor([[ 6.5845,  7.5000,  3.7077],
        [ 3.8364,  6.8131,  3.3949],
        [ 0.9454,  0.2969,  7.2429],
        [ 2.8971,  4.1229,  2.1365],
        [ 2.7127,  8.4239,  6.2937],
        [ 0.5244,  3.3007,  5.3820],
        [ 1.2955,  1.54

        [ 3.5958,  1.5807,  2.6426]]) tensor([14, 15, 12,  0, 12, 14,  8, 15])
tensor([[ 0.7062,  3.6128,  5.3429],
        [ 0.8739, 11.5135,  6.0867],
        [ 0.8082,  3.7082,  6.4030],
        [-4.2530,  4.9736,  1.8340],
        [ 7.3218,  7.7414,  0.3060],
        [ 6.0422,  7.1128,  3.0705],
        [ 3.2799,  4.4006,  3.6468],
        [ 0.7650,  8.7596,  2.4584]]) tensor([12,  9, 12, 11,  8,  8, 15, 11])
tensor([[-1.9051,  8.1231,  3.3113],
        [ 0.0092,  4.5269,  4.7555],
        [ 3.0771,  4.5031,  3.5876],
        [ 1.1553, -0.3164,  6.8872],
        [ 7.2752,  0.8497,  7.3419],
        [ 3.6179,  2.6029,  5.0710],
        [-1.9818,  7.3484,  3.6077],
        [ 3.3180,  7.2062,  3.6986]]) tensor([11, 12, 15,  4,  1, 15, 11, 10])
tensor([[ 2.9067,  5.3198,  3.5641],
        [-2.9942,  4.9738,  2.2761],
        [ 2.3580,  8.9273,  2.2548],
        [-2.4545,  3.6642,  5.9809],
        [-2.9185,  4.9526,  1.9117],
        [ 3.4120,  1.9437,  2.8388],
        [ 2.7111,  8.37

        [ 3.8167,  7.4656,  2.6617]]) tensor([15, 11,  0, 10,  0, 15,  4, 14])
tensor([[8.0244, 0.0247, 6.6550],
        [4.2083, 7.5611, 2.7179],
        [3.1999, 7.2354, 3.2173],
        [1.1601, 8.9575, 1.5272],
        [2.5677, 0.0127, 6.3376],
        [2.4828, 6.8963, 3.6811],
        [1.0582, 7.4730, 0.7487],
        [3.2882, 2.0437, 2.8710]]) tensor([ 1, 14, 14,  9,  4,  9, 10, 15])
tensor([[ 2.8305,  3.4995,  6.3617],
        [ 1.6596,  3.0123,  6.9386],
        [ 3.2547,  2.8233,  4.6956],
        [ 2.8577,  3.0561,  1.2230],
        [ 1.2129,  4.5423,  5.3098],
        [ 2.3234,  9.1082,  2.5931],
        [ 1.9967,  7.8820,  2.9852],
        [-3.3201,  5.3073,  6.0503]]) tensor([12, 12, 15, 15,  0,  9, 10, 11])
tensor([[ 4.4417,  2.4624,  4.7021],
        [ 3.2335,  1.2618,  3.7689],
        [ 0.1633,  4.5981,  4.8420],
        [-3.2934,  4.9876,  1.5665],
        [ 3.5828,  1.1932,  2.3301],
        [ 2.9990,  6.4920,  4.2890],
        [ 3.2390,  8.0341,  6.9725],
        [ 

tensor([[ 1.3266,  9.5471,  5.2860],
        [ 2.6401,  3.4329,  1.1739],
        [ 4.2004,  4.2666,  3.4415],
        [ 1.4008,  7.6805,  2.5534],
        [-0.6374,  8.4775,  4.8871],
        [ 4.1626,  4.4051,  4.8030],
        [ 1.6061,  4.4044,  4.7821],
        [ 1.5359,  3.3335,  6.4666]]) tensor([ 9, 15, 15, 11, 11, 12,  0, 12])
tensor([[ 1.0370, 11.1788,  5.1914],
        [ 3.3218,  3.0077,  5.0005],
        [ 2.0185,  2.5246,  1.5271],
        [ 0.8534,  3.8546,  6.2860],
        [ 2.7794,  6.5550,  4.0486],
        [ 4.3235,  4.2734,  5.7843],
        [ 0.3038,  3.3462,  5.5150],
        [-3.5098,  7.3942,  3.1553]]) tensor([ 9, 15, 15, 12,  9, 12, 12, 11])
tensor([[ 2.6640,  1.9442,  3.8723],
        [ 2.5807,  1.8901,  4.0087],
        [ 2.9577,  4.9107,  2.9804],
        [ 0.0574,  4.4728,  4.8235],
        [ 0.6910, 11.5839,  5.7743],
        [-0.6303,  8.4998,  3.8179],
        [ 3.8653,  8.4630,  0.0120],
        [ 2.7389,  7.2551,  3.6306]]) tensor([15, 15,  0, 12,  9,

tensor([[ 1.5292,  2.8682,  6.3162],
        [ 2.2198,  8.0763,  2.2670],
        [-3.1175,  4.9387,  5.4530],
        [-2.8611,  3.4442,  6.2557],
        [ 1.5926, -0.2972,  5.8095],
        [ 3.8489,  2.5272,  5.1077],
        [-3.6092,  7.0364,  2.4153],
        [ 1.9537,  4.9223,  5.0804]]) tensor([12,  9,  0,  3,  4, 15, 11, 12])
tensor([[ 3.4979,  7.7637,  5.0386],
        [-1.8871,  8.0808,  3.3659],
        [-0.2168,  8.8276,  5.3499],
        [ 3.4915,  6.7265,  2.7564],
        [ 1.5081,  2.6155,  1.8256],
        [ 4.3918,  6.9429, -1.1637],
        [ 1.9107,  8.0356,  2.5623],
        [-0.3841,  0.1530,  4.8508]]) tensor([ 8, 11, 11, 14, 15,  8, 10,  4])
tensor([[ 2.0598e+00,  7.9524e+00,  2.7718e+00],
        [-7.4916e-03,  4.4350e+00,  4.9705e+00],
        [ 1.6639e+00,  2.6480e+00,  5.7155e+00],
        [ 9.2576e-01,  1.1460e+01,  6.0063e+00],
        [-2.6530e+00,  7.0106e+00,  4.0591e+00],
        [-4.1781e+00,  5.0375e+00,  1.8033e+00],
        [ 3.2926e+00,  5.2111e

tensor([[ 5.1539,  6.0254,  4.7860],
        [ 0.1082, -0.0649,  5.5408],
        [ 3.4577,  6.4821,  3.5354],
        [ 3.5211,  7.4562,  3.6271],
        [ 2.6977,  2.8644,  1.2890],
        [ 1.9349,  7.8029,  3.0218],
        [ 7.9675,  0.8140,  7.3292],
        [ 1.1601,  1.5722,  2.6612]]) tensor([ 8,  4,  0, 14, 15, 10,  1, 15])
tensor([[-2.8337,  3.1766,  6.3335],
        [ 8.4582,  0.4953,  6.9599],
        [ 4.5363,  6.5913,  2.7570],
        [ 1.1028,  3.7671,  6.7309],
        [ 1.6809,  2.6301,  1.9067],
        [ 1.8853,  3.6519,  5.9604],
        [ 7.9318,  0.8069,  7.3125],
        [-2.5222,  3.2874,  6.2234]]) tensor([ 3,  1,  0, 12, 15, 12,  1,  3])
tensor([[ 4.1676,  7.0247,  3.8071],
        [ 4.6281,  8.7919,  4.2526],
        [ 2.2700,  4.2344,  2.0945],
        [ 1.3455,  2.6601,  1.8956],
        [-2.6534,  2.5901,  6.3243],
        [ 3.7930,  6.9488,  3.8299],
        [-4.4529,  4.9755,  2.1921],
        [ 3.8594,  5.6576,  3.9028]]) tensor([14,  8, 15, 15,  4,

tensor([[ 1.6482,  7.2763,  1.8412],
        [-3.5530,  7.2173,  2.4954],
        [-1.1992,  9.4259,  4.2588],
        [ 4.1000,  3.8710,  4.9501],
        [ 2.8852,  4.7440,  5.3478],
        [ 3.5164,  6.9426,  6.1391],
        [ 0.2650,  4.3137,  4.6867],
        [-3.3527,  4.8001,  2.8998]]) tensor([ 0, 11, 11, 12, 12,  9, 12, 11])
tensor([[ 3.9974,  8.6875,  0.7510],
        [ 2.9523,  5.3662,  3.5969],
        [ 3.8214,  4.3342,  4.9793],
        [ 1.6241,  4.9641,  5.2953],
        [ 2.4517, 11.0086,  4.6290],
        [ 0.9292,  7.5199,  0.6383],
        [ 4.0759,  5.6390,  2.9921],
        [ 1.7897,  9.5165,  2.2474]]) tensor([ 8,  8, 12, 12,  9, 10,  8,  9])
tensor([[-2.6947,  3.0490,  6.3543],
        [-0.4919,  8.3791,  4.9215],
        [ 2.8317,  7.3147,  3.9820],
        [ 2.7430,  7.9243,  6.5465],
        [ 1.3423,  9.0528,  4.9865],
        [-4.0733,  5.0083,  6.0193],
        [-3.6285,  5.1710,  6.2925],
        [ 2.7483,  4.1284,  2.1047]]) tensor([ 3, 11,  0,  9,  9,

        [ 3.4080,  7.7227,  5.0074]]) tensor([11, 14,  9,  8, 15, 10, 15,  8])
tensor([[ 3.0251,  5.2724,  3.7231],
        [ 4.1468,  4.4440,  3.7477],
        [-4.2372,  6.1032,  2.2923],
        [-2.8292,  2.9768,  5.7531],
        [-3.4580,  5.0147,  1.4653],
        [-3.1572,  4.9250,  1.5781],
        [-3.0337,  4.7461,  5.3121],
        [ 0.0156,  3.9157,  6.2963]]) tensor([ 8, 15, 11,  3,  0, 11,  3, 12])
tensor([[ 0.3130,  3.1692,  5.6942],
        [ 4.4524,  9.1321,  3.0129],
        [ 3.5471,  1.1841,  3.7813],
        [ 3.2884,  7.1685,  3.7668],
        [-2.4813,  3.6024,  6.0448],
        [-1.3099,  2.0111,  6.3038],
        [ 3.1506,  4.5540,  3.4586],
        [ 4.8672,  3.0252,  4.3109]]) tensor([12,  8, 15, 10,  3,  4,  0, 12])
tensor([[ 1.9660,  2.7384,  5.6331],
        [ 1.9588,  4.3627,  5.6041],
        [ 3.1987,  3.8939,  2.6020],
        [-4.1649,  6.2556,  2.2830],
        [ 3.2727,  2.0320,  3.0425],
        [ 2.3913,  7.7531,  6.7921],
        [ 0.2257,  9.31

tensor([[ 4.8247,  6.1623, -1.0668],
        [ 2.9929,  1.5095,  4.0126],
        [ 4.4563,  3.5478,  4.7498],
        [ 2.4373,  7.2560,  3.9551],
        [ 0.4922,  3.4157,  5.4838],
        [-2.7027,  2.9938,  6.1041],
        [ 6.2830,  7.2783,  3.3744],
        [ 3.1498,  5.2647,  3.8065]]) tensor([ 8, 15, 12,  8, 12,  0,  8,  8])
tensor([[ 7.7460,  0.7952,  7.3782],
        [ 4.7503,  4.1612,  5.6790],
        [ 2.4849,  4.3404,  2.7032],
        [-3.8499,  4.9143,  1.5013],
        [ 3.5467,  1.1363,  2.2751],
        [ 1.4091,  1.7841,  2.7674],
        [ 7.3573,  5.4168, -0.5010],
        [ 2.6296,  9.3774,  4.8648]]) tensor([ 1, 12,  0, 11, 15,  0,  8,  9])
tensor([[ 6.2248,  4.9671, -0.2923],
        [ 7.1595,  7.7760,  0.3855],
        [ 3.9131,  7.1287,  2.7868],
        [ 3.0915,  6.5299,  4.4561],
        [-2.5948,  3.5327,  5.7227],
        [ 1.9933,  0.5870,  6.6149],
        [ 1.4056,  4.9829,  5.2907],
        [ 8.1166,  5.9947,  2.3697]]) tensor([ 8,  8,  0,  9,  3,

tensor([[ 4.8956,  2.9882,  4.2684],
        [-1.6706,  2.2168,  6.0366],
        [ 1.9608,  7.7651,  3.0703],
        [ 2.7666,  3.0627,  1.2031],
        [-1.0504,  9.4537,  3.9606],
        [ 0.6206,  9.0057,  2.6120],
        [ 3.6295,  1.2549,  2.3610],
        [ 1.3903,  8.2475,  4.1860]]) tensor([12,  4, 10, 15, 11, 11,  0,  9])
tensor([[ 3.2561,  6.3627,  3.3923],
        [ 1.2335, -0.4762,  6.3416],
        [ 3.3583,  1.9963,  2.9195],
        [ 3.4078,  6.5523,  3.5356],
        [-3.9983,  4.7203,  3.9654],
        [ 1.8488,  1.8741,  3.1011],
        [ 2.5441, -0.0353,  6.4336],
        [ 3.2697,  0.7999,  1.9750]]) tensor([14,  0, 15, 14, 11,  0,  4, 15])
tensor([[ 3.5942,  7.4164,  3.2599],
        [-2.5306,  2.9349,  6.3258],
        [ 4.5585,  7.0383,  2.6509],
        [ 1.7836,  7.7116,  3.2227],
        [-2.3990,  3.2365,  5.7650],
        [ 6.2924,  8.2136,  3.5743],
        [ 0.9037,  8.3183,  2.3967],
        [ 4.4152,  3.6376,  4.7051]]) tensor([14,  3, 13,  9,  3,

tensor([[-0.6403,  8.9147,  3.7676],
        [ 2.5463,  8.8617,  5.2433],
        [ 3.4605,  6.6304,  2.8760],
        [-0.9418,  1.6412,  5.7800],
        [ 1.9919,  7.7472,  2.8559],
        [ 1.5673,  2.6197,  5.8221],
        [ 2.7379,  8.3785,  6.3545],
        [-0.6433,  3.9626,  5.0868]]) tensor([11,  9, 14,  4,  9, 12,  9, 12])
tensor([[ 1.6767, -0.2987,  5.7892],
        [ 2.6248,  6.6443,  3.8051],
        [ 5.1964,  4.2318,  4.9141],
        [ 3.4932,  1.7920,  2.6776],
        [ 2.7680,  1.8341,  3.9559],
        [ 4.7237,  6.3347, -1.1307],
        [ 4.2262,  7.1604,  3.6685],
        [-3.8507,  4.9666,  1.4683]]) tensor([ 4,  9, 12,  0, 15,  8, 14, 11])
tensor([[ 6.4141,  6.7957, -0.0441],
        [ 7.2347,  0.7845,  7.2756],
        [-1.1398,  8.0946,  2.3920],
        [ 5.0160,  4.1896,  5.4606],
        [ 7.0164,  5.9095, -0.8410],
        [ 4.1428,  4.3900,  3.6329],
        [ 1.5227,  4.9954,  5.4352],
        [-2.7454,  3.5838,  5.7252]]) tensor([ 8,  1, 11, 12,  8,

tensor([[ 7.9566,  0.0144,  6.6711],
        [ 7.1539,  0.7226,  7.5165],
        [-4.0356,  4.9495,  1.6306],
        [ 7.0291,  0.5684,  7.1728],
        [ 2.2523, 11.4976,  5.3597],
        [-2.7865,  3.2428,  6.2213],
        [ 2.3704, 11.2430,  4.7962],
        [ 3.5708,  2.8174,  5.1547]]) tensor([ 1,  1, 11,  1,  9,  3,  9,  0])
tensor([[-1.8592,  8.4443,  2.7664],
        [ 6.0062,  7.6633,  4.4313],
        [ 8.2450,  6.3485,  1.5269],
        [ 8.0049,  7.2256,  0.4764],
        [ 0.9472,  2.0554,  1.5066],
        [ 0.3414,  9.3062,  3.9858],
        [ 4.8595,  3.0496,  4.3671],
        [ 1.4778,  4.9405,  5.2382]]) tensor([11,  8,  8,  8, 15, 11, 12, 12])
tensor([[ 7.2462,  0.8039,  7.3856],
        [ 3.0744,  3.8328,  2.5597],
        [ 1.3285,  7.4327,  0.0790],
        [-4.3437,  5.0508,  2.1277],
        [ 1.2492,  8.8133,  1.3253],
        [ 3.2795,  2.1165,  3.0956],
        [ 2.2809,  4.2075,  2.0913],
        [ 0.7377,  3.7138,  5.4619]]) tensor([ 1, 15, 10, 11,  9,

tensor([[-2.9298,  4.9315,  2.0611],
        [ 3.7447,  1.6709,  4.0331],
        [ 2.9801,  1.4213,  4.0366],
        [-3.0853,  4.9058,  2.4620],
        [ 1.1702,  0.5531,  5.6548],
        [-3.2229,  4.9391,  1.5455],
        [ 2.9874,  1.5426,  4.0127],
        [ 1.3391,  9.1850,  5.0766]]) tensor([11, 15, 15, 11,  4, 11, 15,  0])
tensor([[ 2.3276,  0.9886,  2.6814],
        [ 1.2572, 10.7959,  5.2870],
        [ 2.6036,  8.0010,  6.9440],
        [ 1.7374,  7.7596,  2.8190],
        [ 2.1357,  7.8446,  2.9477],
        [ 1.1834,  1.4993,  2.4121],
        [ 2.6705, 10.0827,  4.5907],
        [-0.5489,  8.1963,  3.7808]]) tensor([15,  9,  0, 10,  0, 15,  9, 11])
tensor([[ 3.4358,  1.8157,  2.8612],
        [-1.8188,  8.4750,  2.7224],
        [ 2.0715,  8.6257,  3.6217],
        [ 3.5681,  2.3784,  4.3798],
        [-0.1812,  0.2810,  5.1591],
        [ 3.0170,  6.2622,  2.5697],
        [-3.4238,  7.3254,  2.8316],
        [ 2.4442, 11.0302,  4.6407]]) tensor([15, 11,  0,  0,  4,

tensor([[-3.6421,  5.0138,  1.5412],
        [ 2.7565,  8.3188,  6.3389],
        [ 1.7795,  2.6731,  1.8046],
        [-0.1502,  0.2828,  5.0863],
        [ 3.2279,  3.0184,  4.8793],
        [-2.3336,  2.3710,  6.5793],
        [ 3.9163,  4.3251,  4.8352],
        [-2.8128,  5.1461,  1.9108]]) tensor([11,  9, 15,  4, 15,  4, 12,  0])
tensor([[ 3.4123,  7.0236,  6.5539],
        [ 0.0610,  4.3144,  4.7312],
        [ 2.0012,  3.7850,  6.0863],
        [ 1.8159, 11.3057,  6.0786],
        [ 3.0620,  3.8823,  2.4851],
        [ 3.2321,  6.6532,  4.2673],
        [ 1.2985,  1.7663,  2.8148],
        [ 3.4922,  2.4693,  4.3652]]) tensor([ 9, 12, 12,  9, 15,  0, 15, 15])
tensor([[ 3.6967,  2.5444,  5.0374],
        [ 2.7649,  8.0196,  6.4944],
        [ 3.8687,  4.2967,  4.9562],
        [ 0.2372,  9.3144,  4.4444],
        [-3.1462,  4.9306,  1.5767],
        [ 0.9792, 11.1876,  5.1212],
        [-1.0875,  8.1154,  2.2948],
        [ 1.2596,  1.3855,  1.9904]]) tensor([15,  9, 12, 11, 11,

tensor([[1.2254, 9.2296, 1.7254],
        [3.6468, 4.5312, 5.7220],
        [3.3720, 1.1279, 3.6560],
        [2.5941, 2.6772, 1.4716],
        [0.6172, 4.2430, 4.0427],
        [1.6556, 2.7187, 5.6064],
        [7.1697, 0.7070, 7.5134],
        [1.5910, 7.9537, 2.5463]]) tensor([ 9, 12, 15, 15,  0, 12,  1, 10])
tensor([[ 1.6067,  7.5098,  0.5340],
        [-2.6011,  3.3546,  6.1762],
        [ 3.1836,  1.2549,  3.9066],
        [ 0.6781, 11.5652,  5.7395],
        [ 5.2183,  4.1932,  5.3057],
        [-0.0681,  0.2124,  5.2463],
        [ 6.3749,  7.3344,  3.4952],
        [ 3.6022,  7.5952,  3.3469]]) tensor([ 0,  3, 15,  9, 12,  4,  8, 14])
tensor([[ 7.0780,  0.5713,  7.3699],
        [ 2.0326,  3.8085,  6.2124],
        [ 3.0958,  1.4512,  3.9779],
        [-2.5798,  2.8680,  6.2962],
        [ 4.0854,  7.5592,  2.7093],
        [ 4.1436,  7.6037,  2.7256],
        [ 8.4464,  0.5472,  6.9981],
        [-0.4785,  4.5543,  4.1896]]) tensor([ 1, 12, 15,  3, 14, 14,  1, 12])
tensor([[ 

        [ 1.9470,  3.7512,  5.9986]]) tensor([15, 12,  1,  8, 12,  0, 15, 12])
tensor([[ 0.0952,  4.2343,  4.7986],
        [ 6.0625,  4.9647, -0.3491],
        [ 0.7669, 11.4881,  5.4203],
        [ 1.4647,  2.4639,  5.7208],
        [ 2.8624,  0.5318,  2.1328],
        [ 1.2056,  1.4257,  2.2348],
        [-3.0518,  4.8429,  5.3113],
        [-0.6118,  1.4717,  5.7546]]) tensor([12,  8,  9, 12, 15, 15,  0,  4])
tensor([[-4.4279,  5.1549,  2.2710],
        [ 3.4220,  0.9210,  2.0264],
        [ 4.4323,  2.3901,  4.6762],
        [-0.5803,  8.0810,  3.7487],
        [-2.5067,  3.0263,  6.1692],
        [-0.6566,  3.6842,  5.4910],
        [ 4.8237,  4.2266,  4.8412],
        [ 1.9252,  4.0950,  5.3849]]) tensor([11, 15, 15, 11,  3,  0, 12, 12])
tensor([[ 2.7008,  0.6521,  2.2846],
        [-4.3956,  5.0361,  2.1540],
        [ 6.3507,  7.3550,  3.4640],
        [ 7.1205,  0.4764,  7.3635],
        [ 3.4977,  4.5666,  5.6487],
        [ 1.4291,  8.3635,  4.3578],
        [-2.2497,  3.17

tensor([[1.3682, 8.6653, 0.6731],
        [3.7169, 1.6986, 3.9937],
        [3.5096, 7.2175, 3.7248],
        [4.0975, 7.5568, 2.7744],
        [2.5320, 5.3182, 3.0714],
        [8.2129, 6.2198, 1.7504],
        [1.1522, 7.7972, 2.4130],
        [3.3385, 6.6117, 4.6947]]) tensor([ 9, 15, 14, 14,  0,  8, 11,  0])
tensor([[ 3.7729,  7.3195,  2.6852],
        [ 1.3816,  7.7152,  2.3999],
        [ 2.7365,  2.0162,  3.8370],
        [-2.5766,  3.2616,  5.7822],
        [-1.7456,  2.4560,  6.2381],
        [ 4.1415,  8.0838,  4.9618],
        [-0.1798,  1.0533,  6.6912],
        [ 2.8518,  1.7874,  4.0327]]) tensor([14, 11, 15,  3,  4,  8,  4, 15])
tensor([[ 7.3092,  5.5567, -0.7475],
        [ 2.9437,  6.5406,  4.2347],
        [ 3.1434,  4.6680,  5.5993],
        [ 1.2287,  3.6485,  6.8959],
        [ 6.7894,  8.1375,  3.6622],
        [ 0.8938,  8.4139,  2.4423],
        [ 4.0471,  4.4211,  3.8190],
        [ 1.3713,  9.3355,  1.8231]]) tensor([ 8,  9, 12, 12,  8, 11, 15,  9])
tensor([[ 

tensor([[-0.4074,  4.5808,  4.5381],
        [ 3.6408,  6.6026,  3.4468],
        [-2.6692,  7.0926,  4.0555],
        [-3.5128,  5.1925,  6.1826],
        [ 4.1446,  7.1138,  3.7226],
        [ 1.0340,  7.4007,  0.5660],
        [ 0.7693,  3.8172,  5.5000],
        [ 8.2346,  0.7545,  7.2183]]) tensor([12, 14, 11, 11,  0, 10, 12,  1])
tensor([[ 3.8663,  3.9617,  2.9485],
        [ 1.8779,  4.6105,  5.6896],
        [ 2.5576,  6.6869,  3.7498],
        [ 0.7771, 11.4929,  5.3738],
        [ 3.0165,  2.0701,  3.5379],
        [ 8.1773,  6.1959,  1.8577],
        [ 7.0605,  0.5391,  7.3380],
        [ 3.5245,  1.7546,  2.7026]]) tensor([15, 12,  9,  9, 15,  8,  1, 15])
tensor([[ 3.6147,  6.7390,  5.2558],
        [-2.7767,  3.6836,  5.7640],
        [ 4.0937,  4.1073,  3.1959],
        [ 3.5128,  1.2980,  3.7164],
        [ 1.4825, 11.3327,  6.2138],
        [ 6.9843,  7.7022,  0.5570],
        [-4.2943,  4.8710,  5.5788],
        [-4.1810,  4.9305,  5.7413]]) tensor([ 9,  3, 15, 15,  9,

tensor([[2.6144, 3.4373, 1.0875],
        [0.7599, 8.8196, 2.4663],
        [3.5738, 3.8738, 2.7809],
        [8.2053, 0.1122, 6.6838],
        [3.1795, 3.7830, 2.6481],
        [2.6087, 9.3376, 4.9265],
        [1.6596, 2.6052, 1.8654],
        [3.5479, 6.8942, 6.1200]]) tensor([15, 11, 15,  1, 15,  9, 15,  9])
tensor([[ 1.2427,  2.0886,  1.7636],
        [ 3.7553,  7.3067,  2.7304],
        [ 3.5474,  7.3967,  3.7942],
        [-4.2489,  4.8900,  5.5937],
        [ 4.1170,  5.9055,  4.0861],
        [ 8.2407,  6.3496,  1.4991],
        [-0.3314,  8.1584,  3.6941],
        [-1.8834,  8.3781,  2.8328]]) tensor([15, 14,  0, 11,  8,  8,  0, 11])
tensor([[ 3.0498,  2.1332,  3.3992],
        [ 4.5457,  7.9796,  2.7505],
        [-0.1140,  9.0841,  5.2245],
        [ 2.6741,  1.9476,  3.9563],
        [-0.2006,  3.7450,  4.6292],
        [-0.0200,  0.1216,  5.3395],
        [ 4.4856,  8.4191,  4.9098],
        [-1.7275,  2.4588,  6.2817]]) tensor([15, 13, 11, 15, 12,  4,  8,  4])
tensor([[ 

        [ 3.3988,  1.9791,  2.8517]]) tensor([11, 11, 11, 14, 15,  8, 11, 15])
tensor([[ 1.3310,  8.7043,  0.8728],
        [-0.5318,  0.0281,  4.7678],
        [ 2.3618,  8.5218,  2.0781],
        [ 3.6587,  1.6279,  2.5835],
        [-0.9154,  9.2898,  3.8298],
        [ 0.6662,  3.6085,  5.4011],
        [ 0.8914,  0.3858,  7.2654],
        [ 3.8387,  6.8935,  3.2669]]) tensor([ 9,  4,  9, 15, 11, 12,  4, 14])
tensor([[-0.7326,  8.4798,  4.9335],
        [-2.5762,  6.4727,  2.9386],
        [ 1.5070,  1.8001,  2.9898],
        [ 1.0279,  1.9991,  1.4987],
        [ 2.6289, 10.0864,  4.5554],
        [ 1.8234,  2.7583,  5.5941],
        [-2.4442,  6.4339,  2.6629],
        [ 2.3852,  3.6380,  1.2854]]) tensor([11,  0, 15, 15,  9, 12,  0, 15])
tensor([[-4.1096e+00,  6.3651e+00,  2.2675e+00],
        [-1.2252e+00,  8.2870e+00,  2.4091e+00],
        [ 3.6256e+00,  7.6428e+00,  3.3538e+00],
        [ 1.4028e+00,  8.6052e+00, -6.3126e-03],
        [ 4.9050e+00,  5.8528e+00,  4.4300e+00],


        [ 3.4621,  7.8604,  7.2300]]) tensor([ 8, 14, 13, 11,  4, 15,  9,  9])
tensor([[ 2.2175,  2.8598,  5.7251],
        [ 3.3857,  7.1047,  6.7031],
        [ 6.6525,  6.4429, -0.5268],
        [ 4.2210,  7.1302,  3.6693],
        [-2.8116,  4.0606,  5.5665],
        [ 2.6770,  3.2260,  6.1737],
        [ 3.8854,  5.6516,  2.7156],
        [ 2.3606, 11.2817,  4.8736]]) tensor([12,  9,  8, 14,  3, 12,  8,  9])
tensor([[ 4.6202,  8.6897,  4.8046],
        [ 2.7266,  6.8431,  3.7701],
        [ 2.4445, 10.9814,  4.6325],
        [ 3.9157,  6.9669,  3.3348],
        [ 0.1863,  1.0812,  7.0441],
        [-3.5830,  5.1882,  6.2699],
        [-3.5046,  7.3663,  3.4375],
        [-2.4106,  7.0364,  3.9315]]) tensor([ 8,  0,  9, 14,  4, 11, 11, 11])
tensor([[ 3.9690,  8.7255,  0.6852],
        [ 6.6684,  8.2523,  3.5661],
        [ 6.0381,  7.0392,  2.9777],
        [ 2.1079,  7.9117,  2.8206],
        [ 1.5373,  1.8304,  2.8692],
        [ 6.8065,  8.0463,  3.7126],
        [-0.0440,  3.64

tensor([[-0.1762,  3.7915,  4.5008],
        [ 1.2324,  1.5513,  1.4248],
        [ 0.0973,  0.0575,  5.5079],
        [ 0.9038,  0.4056,  7.2445],
        [ 2.0958,  0.4616,  6.6723],
        [-2.4675,  3.3332,  6.0714],
        [ 2.5596,  0.0357,  6.4684],
        [ 0.4793,  7.7232,  3.2241]]) tensor([12, 15,  4,  4,  4,  3,  4,  0])
tensor([[-2.6656,  3.3082,  6.1829],
        [-0.5543,  4.3922,  4.5107],
        [ 1.1479,  1.7491,  1.4069],
        [-0.6096,  7.8419,  2.7139],
        [ 1.7532,  3.1159,  6.7129],
        [ 1.1636, -0.3698,  6.6671],
        [-4.2190,  4.8011,  4.9475],
        [ 4.1765,  6.9941,  3.7559]]) tensor([ 3,  0, 15, 11, 12,  4, 11, 14])
tensor([[2.1324, 4.1115, 2.0635],
        [3.5746, 7.3584, 3.4634],
        [5.2308, 4.1094, 4.7319],
        [2.2505, 0.5099, 6.6701],
        [1.7629, 4.8766, 5.0200],
        [4.1584, 8.0129, 4.7367],
        [1.2507, 1.3404, 2.1377],
        [0.5377, 3.2843, 5.3345]]) tensor([15, 14, 12,  4, 12,  0, 15, 12])
tensor([[ 

tensor([[3.6719, 7.3468, 3.9094],
        [1.5988, 1.8483, 3.0539],
        [3.2353, 0.7740, 1.9915],
        [3.6472, 6.8144, 5.8679],
        [1.1202, 2.1212, 1.5658],
        [2.3584, 2.9140, 1.6978],
        [3.1802, 3.8211, 2.7328],
        [1.4968, 2.6266, 1.8735]]) tensor([ 0, 15, 15,  9, 15,  0, 15, 15])
tensor([[4.9265, 2.9023, 4.3173],
        [7.9715, 5.8339, 3.5299],
        [3.9033, 4.3393, 3.7714],
        [3.3129, 6.3767, 3.4430],
        [2.2233, 8.0566, 2.3394],
        [3.4679, 6.5338, 4.3316],
        [3.6006, 6.8543, 6.0014],
        [2.6093, 3.8497, 6.4920]]) tensor([12,  8, 15, 14,  9,  0,  9, 12])
tensor([[-0.3950,  0.9824,  6.3871],
        [-2.8685,  2.9802,  5.8135],
        [-4.4457,  5.4605,  2.3327],
        [ 1.5768,  2.7167,  5.8951],
        [ 7.4632,  7.7064,  0.2377],
        [ 5.8530,  6.7416,  2.9462],
        [ 1.9879,  3.9400,  5.4166],
        [-2.6699,  3.5149,  6.1092]]) tensor([ 4,  3, 11, 12,  8,  8, 12,  3])
tensor([[ 0.9773,  0.1922,  7.2159

        [ 3.3146,  2.7122,  4.5553]]) tensor([11,  8, 14, 15, 12, 11,  9, 15])
tensor([[-0.6097,  4.2468,  4.9887],
        [-0.8500,  9.1969,  3.7170],
        [-1.0991,  9.4763,  4.0040],
        [ 3.3459,  2.0858,  2.9855],
        [ 3.0674,  5.6434,  2.7920],
        [ 2.3631,  3.9281,  6.4179],
        [-3.5785,  7.1567,  2.4277],
        [-4.0898,  6.4078,  2.2572]]) tensor([12, 11, 11, 15,  8, 12, 11, 11])
tensor([[ 1.7911,  4.8121,  5.1181],
        [ 1.5870,  4.9389,  5.3227],
        [-2.6779,  2.9739,  6.1088],
        [ 2.4114,  8.2396,  7.1303],
        [ 2.6826,  7.8676,  6.6093],
        [ 2.7283,  8.3860,  6.3454],
        [-3.0509,  5.6461,  3.7097],
        [-4.3894,  4.9494,  2.0850]]) tensor([12, 12,  3,  9,  9,  9,  0, 11])
tensor([[ 2.8408,  5.4506,  2.4989],
        [ 2.0596, 11.3493,  5.8026],
        [ 1.9117,  1.1422,  3.1401],
        [ 1.1537,  7.7992,  2.3678],
        [ 0.4051,  9.3042,  3.8147],
        [ 0.9864,  2.0278,  1.5146],
        [-3.2771,  4.84

        [3.6369, 1.4264, 2.5181]]) tensor([15,  8,  8,  8, 11, 12, 14, 15])
tensor([[ 2.7656,  0.4685,  2.1633],
        [ 0.2071,  4.7341,  4.6523],
        [ 1.3523,  1.6189,  2.4592],
        [ 0.5657,  9.1458,  2.7614],
        [ 0.1928,  9.3114,  4.5068],
        [ 7.1955,  0.6488,  7.3378],
        [ 3.4822,  7.3443,  3.4617],
        [ 1.3297, 10.4149,  5.4694]]) tensor([15, 12, 15, 11, 11,  1, 14,  9])
tensor([[ 0.2220,  3.2910,  5.8761],
        [ 1.5639,  2.6368,  5.8368],
        [ 2.0979,  0.5569,  6.6174],
        [ 8.0939,  5.9517,  2.4993],
        [ 2.2992, 11.4442,  5.1192],
        [ 1.6671,  4.8407,  5.5875],
        [-1.2491,  9.2765,  4.6441],
        [ 2.4288,  0.7011,  2.4863]]) tensor([12, 12,  4,  8,  9, 12, 11, 15])
tensor([[ 0.7458,  6.4563,  4.2185],
        [-0.3891,  4.2803,  5.9465],
        [ 7.4942,  6.2749,  4.4888],
        [ 3.9800,  7.0069,  3.7381],
        [ 3.1873,  0.6537,  1.9222],
        [ 7.6784,  6.1263,  4.5388],
        [-3.4893,  7.3169,

        [2.7345, 3.4698, 6.2647]]) tensor([ 8, 10,  0,  8, 15, 15, 15, 12])
tensor([[ 6.4186,  8.1432,  3.6030],
        [ 6.9708,  6.4562,  4.0414],
        [ 1.5050,  7.5883,  2.6690],
        [ 2.2928, 11.4472,  5.1003],
        [ 2.7941,  3.2824,  1.1666],
        [ 7.6720,  6.1235,  4.5214],
        [ 4.8976,  4.2584,  4.7634],
        [ 3.6603,  1.2438,  2.3669]]) tensor([ 8,  8,  0,  9,  0,  8, 12, 15])
tensor([[-0.5046,  7.7953,  3.5389],
        [ 2.5251,  8.6788,  5.5173],
        [ 2.1367,  2.9425,  5.6103],
        [ 3.5268,  7.7574,  5.0218],
        [ 2.2751,  4.9878,  2.8346],
        [ 3.2763,  0.9087,  2.1055],
        [-2.6748,  3.5037,  5.9924],
        [-0.8203,  9.1920,  3.7827]]) tensor([11,  9, 12,  8,  0,  0,  3, 11])
tensor([[ 0.8767,  8.3859,  2.3979],
        [ 2.5509,  4.7849,  5.3023],
        [ 3.8119,  1.3007,  4.0675],
        [ 8.0817,  0.7939,  7.2894],
        [-2.0454,  2.3804,  6.5870],
        [ 3.4684,  4.5694,  5.7380],
        [ 7.1748,  0.6723,

tensor([[-2.4913,  3.0824,  6.1805],
        [ 1.2245, -0.4780,  6.4891],
        [ 3.0024,  2.0828,  3.4633],
        [ 3.2196,  3.8018,  2.6881],
        [ 3.8760,  4.3427,  3.8263],
        [ 1.9152,  4.5460,  5.7692],
        [-0.0611,  3.5331,  6.2921],
        [-3.2463,  7.3773,  3.2861]]) tensor([ 3,  4, 15,  0, 15, 12, 12,  0])
tensor([[ 2.3918,  4.7626,  5.2681],
        [ 5.6970,  6.5629,  5.0572],
        [-0.2328,  4.3920,  6.0997],
        [-1.8764,  8.3730,  2.8357],
        [ 1.7020,  2.6739,  5.6220],
        [ 1.5804,  2.6692,  5.7615],
        [ 6.3143,  8.2325,  3.5373],
        [ 4.3404,  7.3699,  2.8099]]) tensor([12,  8, 12, 11, 12, 12,  8, 14])
tensor([[ 1.5471,  2.5915,  1.8997],
        [ 1.3090,  1.6244,  1.4363],
        [ 2.2871,  3.8058,  6.5412],
        [ 1.4146,  2.7180,  1.9313],
        [-1.1810,  9.4721,  4.2142],
        [ 4.4732,  3.5617,  4.7734],
        [ 0.0231,  4.0761,  6.3042],
        [ 3.1349,  1.3500,  3.9141]]) tensor([15, 15, 12, 15, 11,

tensor([[-0.3776,  4.1379,  5.3473],
        [ 3.3826,  2.8119,  4.9437],
        [-2.4685,  2.9439,  5.8138],
        [ 2.1397,  3.7946,  1.8817],
        [-2.5663,  3.3349,  5.8622],
        [-3.4320,  7.2767,  3.7122],
        [-2.4160,  3.2670,  5.9403],
        [ 1.8707,  0.6334,  6.5306]]) tensor([12,  0,  3, 15,  3, 11,  3,  4])
tensor([[ 5.9669,  7.0743,  4.9436],
        [ 1.1430, -0.2764,  6.9407],
        [ 1.8270,  7.7909,  3.0903],
        [ 2.9139,  4.9981,  2.8978],
        [ 2.0805,  2.9041,  5.5094],
        [ 1.9866,  1.2724,  2.8721],
        [ 2.5411,  3.5957,  1.1516],
        [ 7.6366,  7.6112,  0.2369]]) tensor([ 8,  0,  9, 14, 12, 15, 15,  8])
tensor([[ 6.4181,  8.2110,  3.5246],
        [ 3.6713,  1.3445,  2.4504],
        [ 0.9171,  8.2224,  2.4107],
        [-3.9129,  6.6602,  2.2647],
        [ 8.2858,  0.6924,  7.1560],
        [ 0.1507,  4.0427,  4.8297],
        [ 1.4161,  7.2746,  0.6831],
        [ 2.6067,  6.7467,  3.7537]]) tensor([ 8, 15, 11,  0,  1,

tensor([[ 3.6712,  1.2912,  3.8968],
        [ 1.8723,  2.5342,  1.7114],
        [ 1.2677,  8.7260,  1.0603],
        [-0.6023,  8.2094,  3.8097],
        [-3.7646,  5.1163,  6.2939],
        [ 2.9865,  7.5917,  4.9524],
        [ 4.0008,  5.5932,  2.8747],
        [ 4.4292,  9.0369,  2.8894]]) tensor([15, 15,  9, 11, 11,  8,  8,  8])
tensor([[ 6.0230,  7.2412,  4.8781],
        [-3.9811,  5.0484,  6.1008],
        [-2.5646,  3.4077,  5.6293],
        [ 2.3790,  8.8580,  2.1215],
        [ 8.0610,  5.9341,  2.5255],
        [ 2.0762, -0.2759,  5.8199],
        [ 4.0746,  4.1184,  3.2438],
        [ 1.2029,  9.1294,  1.6790]]) tensor([ 8, 11,  3,  9,  8,  4, 15,  9])
tensor([[ 2.9362,  5.7205,  2.5485],
        [-0.4957,  8.5681,  4.5862],
        [ 8.1864,  6.8472,  0.8552],
        [ 0.1870,  4.3784,  4.7019],
        [-3.4347,  7.1394,  2.5463],
        [-3.8368,  5.0969,  6.2928],
        [ 2.1758,  2.5367,  1.5526],
        [ 8.2058,  6.2047,  1.7831]]) tensor([14,  0,  8, 12, 11,

tensor([[1.3832, 2.6479, 1.8762],
        [7.7209, 0.1283, 6.7995],
        [0.2261, 3.1168, 5.7510],
        [2.4851, 8.5882, 5.6648],
        [5.9380, 6.8496, 2.9099],
        [3.7094, 4.2347, 4.5498],
        [1.1765, 1.5154, 2.4986],
        [7.1935, 0.8355, 7.1676]]) tensor([15,  1, 12,  9,  8,  0, 15,  1])
tensor([[ 4.0719,  7.4592,  2.7063],
        [ 4.3860,  4.2852,  4.8285],
        [ 0.5890,  9.0930,  2.7045],
        [ 4.6172,  3.4178,  4.6444],
        [ 0.9863,  2.1965,  1.5530],
        [-4.3152,  5.5239,  2.2978],
        [ 3.9832,  4.0606,  3.1080],
        [ 3.3948,  6.4970,  3.1460]]) tensor([14, 12, 11, 12, 15, 11, 15, 14])
tensor([[ 3.7989,  4.3112,  3.7862],
        [ 1.9857,  4.5136,  5.6806],
        [-2.5925,  3.4652,  6.1106],
        [ 6.0245,  7.1583,  3.1598],
        [ 7.6819,  7.5817,  0.2376],
        [ 7.3461,  5.4107, -0.1055],
        [ 5.5593,  5.1864, -0.5942],
        [ 3.9547,  4.0645,  5.0210]]) tensor([15, 12,  3,  8,  8,  8,  8, 12])
tensor([[-

        [-3.9433,  4.7182,  3.8354]]) tensor([15,  9, 15,  3, 10,  8, 11, 11])
tensor([[ 1.3355, 11.3412,  6.2222],
        [ 6.2817,  8.1767,  3.5805],
        [ 1.9468,  7.9369,  2.8992],
        [ 4.1980,  7.9846,  4.7251],
        [-2.6645,  3.2711,  6.2070],
        [ 1.1576,  1.5038,  2.3468],
        [ 8.2761,  0.1364,  6.6992],
        [ 8.2286,  0.7652,  7.2366]]) tensor([ 9,  8, 10,  0,  3,  0,  1,  1])
tensor([[ 2.4822,  3.5943,  1.1953],
        [ 2.4469, 10.9864,  4.6211],
        [ 6.6232,  6.4728,  3.6004],
        [ 4.0239,  4.0351,  3.0663],
        [ 2.2026,  9.2489,  2.6561],
        [ 0.1956,  1.0587,  7.0483],
        [ 1.8217,  0.6740,  6.4923],
        [ 1.3707,  1.5086,  1.5122]]) tensor([15,  9,  8, 15,  9,  4,  4, 15])
tensor([[6.7798, 8.1752, 3.6382],
        [1.8911, 3.7216, 5.4316],
        [3.8326, 5.6074, 3.8853],
        [2.1212, 3.6318, 2.1883],
        [4.5411, 8.9572, 3.7999],
        [5.1274, 4.2335, 5.3301],
        [3.5325, 7.6010, 3.4170],
       

tensor([[ 2.2716, 11.5155,  5.3664],
        [-0.3709,  1.1380,  5.8765],
        [ 3.7620,  6.9717,  3.8103],
        [-0.6476,  3.2171,  4.4998],
        [ 3.8781,  7.5818,  2.6483],
        [-3.6639,  6.9628,  2.3578],
        [ 3.4569,  6.5836,  4.8829],
        [ 2.1760,  3.7001,  1.6524]]) tensor([ 9,  4, 14,  0, 14,  0,  9, 15])
tensor([[ 3.6526,  4.5085,  5.7027],
        [ 4.5278,  9.1270,  3.4899],
        [ 2.1165,  0.5903,  6.6481],
        [-0.0384,  1.0886,  6.8529],
        [ 7.2376,  0.7638,  7.3593],
        [ 2.3723, 11.1873,  4.8637],
        [ 4.1637,  8.1684,  2.7223],
        [ 4.5201,  6.7947,  2.8059]]) tensor([12,  8,  4,  4,  1,  9,  0, 13])
tensor([[ 2.3928,  3.9776,  6.4159],
        [ 3.9607,  7.0118,  3.9170],
        [ 0.8998, 11.3563,  5.2459],
        [ 4.5543,  8.2771,  4.4630],
        [ 5.9670,  4.9340, -0.4582],
        [ 6.3783,  8.2132,  3.5434],
        [ 2.5843,  5.6980,  2.7429],
        [ 3.5861,  7.4785,  3.2967]]) tensor([12, 14,  9, 13,  8,

tensor([[ 2.5445,  0.6294,  2.3826],
        [ 1.2844, 10.7012,  5.3344],
        [ 1.3993,  1.5165,  1.9290],
        [ 7.1385,  7.7598,  0.4103],
        [ 0.8188,  0.5390,  7.2618],
        [ 7.9236,  5.8149,  3.8816],
        [ 3.4883,  6.6267,  2.8786],
        [ 4.0824,  4.4933,  4.7115]]) tensor([15,  9, 15,  8,  4,  8, 14, 12])
tensor([[-1.7609,  2.4390,  6.0841],
        [ 0.4316,  0.9398,  7.1928],
        [ 4.4344,  9.0986,  3.4357],
        [ 7.8445,  5.9675,  4.4510],
        [ 2.4949,  7.1395,  3.7424],
        [ 3.4725,  1.0961,  2.2379],
        [ 1.2543,  4.9284,  5.5804],
        [ 2.6769,  1.6635,  4.1926]]) tensor([ 0,  4,  8,  8,  9, 15,  0, 15])
tensor([[ 2.7264,  1.8468,  3.9180],
        [ 2.5392,  7.4895,  4.5020],
        [ 5.8903,  4.9605, -0.4434],
        [-1.1093,  9.4217,  4.1511],
        [ 2.6195,  3.8193,  6.4393],
        [ 2.7242,  3.3830,  1.0285],
        [ 7.9575,  5.8426,  3.4565],
        [-2.6161,  3.0681,  5.7468]]) tensor([15,  8,  8, 11, 12,

tensor([[-3.5035,  7.3270,  3.5478],
        [-1.1969,  9.3051,  4.5326],
        [-1.7276,  8.6681,  2.5307],
        [ 0.1578,  3.7772,  4.6119],
        [ 3.7647,  1.6821,  4.0715],
        [-2.3068,  3.0557,  5.7199],
        [-4.4588,  5.4067,  2.3270],
        [ 3.5670,  7.6655,  3.5000]]) tensor([11, 11, 11, 12, 15,  3, 11, 14])
tensor([[ 2.0198, 11.3185,  5.8807],
        [-0.5603,  1.4769,  5.7373],
        [-0.6628,  2.5994,  4.7560],
        [ 1.7952,  1.9187,  2.8990],
        [ 2.4726,  4.2631,  2.7416],
        [ 3.5105,  2.2252,  4.0735],
        [ 2.3614,  8.7300,  2.0587],
        [ 2.4682,  3.9059,  6.4991]]) tensor([ 9,  4,  0, 15,  0, 15,  9, 12])
tensor([[ 8.4409,  0.5191,  6.9850],
        [ 3.3912,  6.4034,  3.2419],
        [ 3.3516,  6.4578,  3.0929],
        [-3.3645,  4.9568,  1.4183],
        [ 2.7047,  3.4038,  1.0887],
        [ 3.8484,  3.9717,  2.9461],
        [ 1.4425,  8.5754,  0.1135],
        [ 7.1115,  5.3231,  0.1468]]) tensor([ 1, 14, 14, 11, 15,

tensor([[ 3.3327,  2.9768,  5.0195],
        [ 2.9678,  0.5468,  2.0817],
        [-0.5827,  4.0754,  5.6193],
        [ 3.3081,  6.4231,  3.3209],
        [ 2.6868,  9.8796,  4.6525],
        [ 7.8357,  0.0788,  6.7374],
        [ 1.0954, 11.0630,  5.1720],
        [-0.2416,  4.1286,  6.3720]]) tensor([15, 15, 12, 14,  9,  1,  9, 12])
tensor([[ 3.5198,  5.4269,  3.7876],
        [ 3.0027,  4.9263,  2.9963],
        [-3.0138,  4.9035,  2.3378],
        [-2.7144,  3.2612,  6.0969],
        [ 1.4029,  8.6344,  0.2751],
        [ 2.3802,  2.9077,  5.9034],
        [-4.1207,  5.0118,  6.0405],
        [ 7.2380,  0.8496,  7.1986]]) tensor([ 8,  0, 11,  3,  9, 12, 11,  1])
tensor([[ 6.5731,  7.4632,  3.6526],
        [ 2.5320,  2.0369,  3.8133],
        [-3.3638,  7.3770,  3.2535],
        [ 0.1610,  4.6189,  4.8294],
        [-0.0135,  3.6687,  6.2616],
        [ 1.4749,  2.6236,  1.8794],
        [-2.6550,  3.6819,  6.0335],
        [ 4.4989,  8.4899,  4.9103]]) tensor([ 8,  0, 11, 12, 12,

        [ 7.2609,  0.8007,  7.1934]]) tensor([ 9, 11,  4, 15,  9,  2, 15,  1])
tensor([[ 7.1364,  0.6010,  7.4558],
        [ 3.7553,  1.7913,  4.0878],
        [ 3.3586,  6.4693,  3.4516],
        [-0.3302,  4.6304,  4.5339],
        [ 1.3186, -0.3371,  6.0729],
        [ 1.5519, 11.3073,  6.2002],
        [-0.6102,  8.2437,  3.8228],
        [-0.9752,  1.2272,  4.9587]]) tensor([ 1, 15, 14, 12,  4,  9, 11,  0])
tensor([[ 3.1698,  2.1497,  3.2772],
        [-1.2085,  9.2089,  4.6706],
        [ 3.1268,  2.9261,  4.6306],
        [ 1.3724,  8.6782,  0.6802],
        [-0.5214,  8.4034,  4.8883],
        [-0.4457,  1.1182,  5.9549],
        [ 3.5426,  6.6177,  5.0045],
        [ 2.0935,  9.3796,  2.6136]]) tensor([15, 11, 15,  9, 11,  4,  9,  9])
tensor([[ 0.1375,  3.8787,  4.4872],
        [ 0.2846,  9.3056,  4.1452],
        [ 2.2905,  2.9293,  5.7805],
        [ 1.1590,  8.9902,  1.4902],
        [ 4.6049,  8.9269,  4.2971],
        [ 3.8685,  4.3648,  4.9279],
        [-0.0749,  1.10

        [ 1.1805,  2.6361,  1.8557]]) tensor([ 4, 10, 12, 12, 14, 15,  1, 15])
tensor([[-2.5353,  2.8108,  6.3470],
        [ 2.4724,  8.4626,  5.8964],
        [-0.1646,  1.0538,  6.7501],
        [-0.1147,  3.6358,  6.3380],
        [ 3.1144,  1.3144,  3.9208],
        [ 7.1097,  0.5892,  7.2826],
        [ 0.7231, 11.5694,  5.8546],
        [ 3.6579,  1.2401,  2.3557]]) tensor([ 3,  9,  4, 12, 15,  1,  9, 15])
tensor([[ 1.9556, -0.2795,  5.7304],
        [ 1.3131,  1.3981,  1.9431],
        [ 5.4189,  5.2384, -0.5416],
        [ 0.9896,  0.2053,  7.2085],
        [-1.2347,  9.4626,  4.2815],
        [ 2.2107,  4.9019,  5.1172],
        [ 7.1483,  6.4237,  4.2348],
        [ 4.9769,  4.1482,  5.5472]]) tensor([ 4, 15,  0,  4, 11, 12,  8, 12])
tensor([[ 3.6331,  1.1119,  2.2250],
        [ 0.1850,  3.9471,  4.4763],
        [ 2.4884,  7.3372,  3.3271],
        [-2.5083,  3.5781,  6.0013],
        [ 6.5425,  6.6006, -0.3116],
        [ 3.4716,  6.6928,  2.8155],
        [ 3.4313,  8.06

tensor([[ 1.9093,  2.5834,  5.6666],
        [ 7.8765,  5.8700,  3.4937],
        [-4.2903,  4.8479,  5.4177],
        [ 2.9523,  3.9035,  2.3652],
        [ 2.5183,  3.5753,  1.1938],
        [-2.4869,  3.7080,  6.1473],
        [ 2.2429,  4.1934,  2.1219],
        [ 0.0358,  9.2091,  4.9336]]) tensor([ 0,  8, 11, 15, 15,  3, 15, 11])
tensor([[ 2.8636,  1.4688,  4.1641],
        [ 0.1332,  4.0146,  4.7615],
        [-2.5885,  2.5990,  6.4192],
        [ 1.1164,  1.7966,  2.5521],
        [ 6.3771,  6.9679,  0.2721],
        [ 3.3643,  4.3538,  3.7663],
        [ 7.2068,  7.7479,  0.3697],
        [ 4.2881,  8.8900,  1.8946]]) tensor([ 0, 12,  4, 15,  8, 15,  8,  8])
tensor([[ 0.4414,  9.2607,  3.5528],
        [ 6.3489,  6.4629,  3.1878],
        [ 7.9829,  5.7814,  4.2152],
        [ 3.0809,  7.1308,  3.8242],
        [ 3.0532,  4.5043,  3.5482],
        [-0.2906,  1.0181,  6.5561],
        [ 4.1884,  7.0984,  3.7801],
        [ 4.8545,  3.0282,  4.3415]]) tensor([11,  8,  8, 10, 15,

tensor([[ 2.8825,  1.6328,  4.0282],
        [-0.8249,  0.1269,  4.8802],
        [ 3.1031,  7.2854,  3.1292],
        [ 1.0542,  7.5327,  0.2591],
        [ 4.5612,  5.6962,  3.5929],
        [-1.4985,  8.4561,  2.4107],
        [ 4.3848,  8.8651,  2.6158],
        [ 2.3120,  9.1424,  2.5274]]) tensor([15,  0, 14, 10,  8, 11,  0,  9])
tensor([[ 0.4712,  0.9075,  7.2136],
        [-2.9826,  4.9868,  1.7706],
        [ 2.9320,  6.7240,  2.7630],
        [ 3.8940,  7.4981,  2.7348],
        [ 3.0420,  5.3096,  3.7057],
        [ 3.1193,  0.6116,  1.9537],
        [-2.9905,  4.8025,  5.4823],
        [ 3.0599,  7.0519,  3.7401]]) tensor([ 4, 11, 14, 14,  8, 15,  0,  0])
tensor([[ 6.7572,  7.5737,  0.7548],
        [ 1.3047,  1.7237,  1.4944],
        [ 2.7583,  3.4745,  6.4016],
        [ 0.3007,  9.3344,  4.1765],
        [ 1.0415,  0.1133,  7.1628],
        [ 0.1667, -0.1127,  5.6120],
        [ 0.9685, 11.4415,  6.1371],
        [-0.6148,  7.7269,  2.9676]]) tensor([ 8, 15, 12, 11,  4,

tensor([[ 3.1784,  0.6636,  1.9017],
        [ 3.4591,  7.4902,  3.3216],
        [ 3.1315,  3.7912,  2.5840],
        [-3.3784,  4.9612,  1.4530],
        [ 3.4283,  2.5111,  4.3183],
        [ 4.1721,  7.5096,  2.7216],
        [ 3.2740,  2.1173,  3.1517],
        [ 2.9001,  4.0141,  2.1604]]) tensor([15, 14, 15, 11, 15, 14, 15, 15])
tensor([[-2.9237,  4.6748,  5.3239],
        [-3.4586,  4.9740,  1.4440],
        [ 4.1134,  7.5610,  2.7240],
        [ 2.6823,  9.9697,  4.5901],
        [ 3.6267,  1.6165,  2.5776],
        [-0.7330,  8.9847,  3.7554],
        [ 5.3397,  6.2020,  4.8730],
        [ 3.3746,  6.5659,  4.7988]]) tensor([ 3, 11, 14,  9, 15, 11,  8,  9])
tensor([[ 3.8231,  6.8377,  3.4010],
        [ 3.2122,  3.0040,  4.8118],
        [-2.7439,  3.1009,  6.5664],
        [ 4.0553,  7.5493,  2.7288],
        [ 1.5328,  2.6832,  6.3531],
        [ 5.9915,  7.8205,  4.0976],
        [-1.2373,  8.1802,  2.4101],
        [ 3.5208,  1.7639,  2.7496]]) tensor([14, 15,  3, 14, 12,

tensor([[-3.2887e-01,  9.6771e-01,  6.4920e+00],
        [ 1.4817e+00,  1.7045e+00,  2.9377e+00],
        [ 3.3208e+00,  2.9351e+00,  4.9680e+00],
        [ 7.4992e-01,  8.7197e+00,  2.4691e+00],
        [-6.5624e-03,  3.5675e+00,  6.2610e+00],
        [-2.8338e+00,  4.2007e+00,  5.5591e+00],
        [-3.3753e+00,  7.2362e+00,  2.6789e+00],
        [ 3.9656e+00,  7.8706e+00, -8.0403e-01]]) tensor([ 4, 15,  0, 11, 12,  0, 11,  8])
tensor([[ 5.4123,  6.2803,  4.9433],
        [ 2.7339,  0.6980,  2.3027],
        [-3.2557,  4.9803,  2.7987],
        [ 4.0397,  7.0002,  3.8526],
        [ 1.9619,  2.7929,  5.6200],
        [ 1.3454,  9.9827,  5.4488],
        [ 8.0211,  5.8771,  3.0394],
        [-0.0522,  9.1688,  5.0732]]) tensor([ 8, 15, 11, 14, 12,  9,  8, 11])
tensor([[ 1.3886,  3.9593,  5.5309],
        [-4.3113,  4.7935,  5.1686],
        [ 2.2144,  4.1765,  2.1142],
        [ 2.7329,  2.0948,  3.7314],
        [ 2.5426,  8.5797,  5.5589],
        [ 2.6224,  4.1822,  2.0851],
      

        [ 2.8073,  1.7565,  3.9822]]) tensor([15, 12, 15,  0,  4, 11, 15,  0])
tensor([[ 0.8765,  2.2816,  1.4813],
        [ 4.9312,  2.7435,  4.3650],
        [ 4.8972,  4.1949,  4.7748],
        [ 3.6752,  5.6434,  2.5945],
        [-3.0050,  4.6146,  5.3359],
        [ 6.3261,  6.4861,  3.1689],
        [ 1.8344,  0.7018,  6.5195],
        [ 6.6601,  6.4007, -0.4639]]) tensor([15, 12, 12,  8,  3,  8,  4,  8])
tensor([[-4.1809,  4.8193,  4.7561],
        [ 4.3734,  6.7004,  3.4458],
        [ 5.8397,  6.7769,  5.0730],
        [ 8.0031,  7.1633,  0.5494],
        [ 2.0530,  0.5308,  6.6313],
        [ 2.3454,  2.9285,  5.8195],
        [ 4.9657,  4.1574,  4.7806],
        [ 2.4094,  9.0018,  2.1883]]) tensor([11,  2,  8,  8,  4, 12, 12,  9])
tensor([[ 2.5373, 10.7335,  4.5557],
        [ 1.2970,  8.7017,  0.9476],
        [ 4.1093,  7.0130,  3.8395],
        [ 2.2940,  8.3310,  2.0967],
        [ 2.0488,  4.3223,  5.5539],
        [-1.1409,  1.5447,  6.0800],
        [ 4.2129,  4.19

tensor([[ 2.6125,  8.3990,  6.1909],
        [ 8.4014,  0.2662,  6.7835],
        [-2.6895,  3.0605,  6.5884],
        [ 3.4260,  1.9082,  2.9073],
        [ 2.0953,  2.9077,  1.8837],
        [ 3.6043,  6.6420,  5.1639],
        [ 0.9340,  8.3666,  2.4207],
        [ 3.2905,  1.1814,  3.7171]]) tensor([ 9,  1,  3, 15,  0,  9, 11, 15])
tensor([[ 1.0597,  7.9258,  2.5567],
        [ 1.7208,  3.0071,  6.7559],
        [ 2.8323,  4.7408,  5.4081],
        [ 7.3802,  5.5034, -0.6207],
        [ 7.6633,  0.1833,  6.8358],
        [ 4.3976,  8.9355,  2.4218],
        [ 1.4188,  2.0744,  3.0697],
        [ 3.8904,  4.4206,  4.8463]]) tensor([11, 12, 12,  8,  1,  8, 15, 12])
tensor([[2.3448, 3.5834, 1.4506],
        [3.5632, 2.6517, 4.9731],
        [3.1484, 5.2351, 4.1776],
        [1.2457, 0.6249, 5.8345],
        [4.0562, 4.0847, 3.1489],
        [2.7481, 3.3042, 1.1536],
        [4.5922, 9.0417, 3.8960],
        [4.5684, 9.0640, 3.7869]]) tensor([15, 15,  0,  4, 15, 15,  8,  8])
tensor([[2

tensor([[ 1.1429,  9.0788,  1.6031],
        [ 7.0587,  5.8437, -0.8444],
        [-2.6418,  2.8417,  6.2252],
        [ 7.3983,  0.7758,  7.2814],
        [ 1.4860,  7.7839,  2.2031],
        [ 1.2395,  4.8917,  5.5105],
        [-0.2788,  1.0189,  6.5891],
        [ 4.0790,  4.0817,  3.1380]]) tensor([ 9,  8,  3,  1,  0, 12,  4, 15])
tensor([[ 3.6451,  2.1173,  4.0275],
        [ 3.4231,  1.8757,  2.9451],
        [ 3.5741,  1.5206,  2.6384],
        [ 1.8690,  3.7387,  5.4728],
        [-2.2195,  3.3087,  5.8497],
        [ 2.4586, -0.2110,  6.0737],
        [ 3.1017,  8.4074,  7.2596],
        [ 1.4878,  7.8158,  2.2902]]) tensor([15, 15,  0, 12,  3,  4,  9, 11])
tensor([[ 5.9845,  7.0849,  4.9611],
        [-0.1963,  1.0667,  6.6749],
        [ 3.2318,  6.3734,  3.1964],
        [ 6.4358,  6.7947, -0.0322],
        [-2.5450,  3.2835,  5.9633],
        [ 3.7413,  1.8931,  4.1386],
        [ 7.6496,  6.1371,  4.5281],
        [ 1.3304,  8.8524,  4.8301]]) tensor([ 8,  4, 14,  8,  0,

tensor([[ 1.3381,  0.6516,  5.9666],
        [ 1.8924, 11.3552,  6.0294],
        [ 4.1698,  4.4624,  3.8384],
        [-0.1982,  0.2194,  5.0595],
        [ 2.7681,  7.9158,  6.5783],
        [ 4.5892,  7.0664,  2.5178],
        [-2.4449,  3.3947,  5.9422],
        [ 0.2976,  9.2337,  4.1971]]) tensor([ 4,  9, 15,  4,  9, 13,  3, 11])
tensor([[ 2.4622,  4.2053,  2.1639],
        [ 1.6086,  2.8339,  6.4010],
        [ 1.3122,  1.3797,  2.0200],
        [ 5.0709,  5.7860, -0.9557],
        [ 1.9423,  2.7911,  5.6400],
        [ 1.7761,  7.4761,  2.6079],
        [ 2.0509,  1.4630,  2.8855],
        [ 2.7555,  3.5304,  6.3808]]) tensor([15, 12, 15,  8, 12, 10, 15, 12])
tensor([[ 3.4439,  2.4604,  4.2737],
        [ 5.9921,  7.2734,  4.7965],
        [ 3.8390,  6.9256,  3.3503],
        [-0.2623,  4.5198,  5.6453],
        [ 3.2647,  2.0752,  3.1269],
        [-2.5210,  3.1179,  6.0391],
        [ 4.1713,  7.0315,  3.7677],
        [ 3.2988,  2.0694,  3.0780]]) tensor([15,  8, 14, 12, 15,

tensor([[ 3.2441,  5.7410,  2.5737],
        [ 1.8538,  2.5185,  2.0536],
        [ 0.9569,  8.0337,  2.3437],
        [ 1.3372,  9.3881,  5.2024],
        [ 4.3730,  8.2834,  4.9185],
        [ 6.9014,  5.2066,  0.2102],
        [ 2.9787,  6.1115,  2.5071],
        [-0.5820,  7.6989,  3.1183]]) tensor([ 8,  0, 11,  9,  8,  8, 14, 11])
tensor([[ 1.4381,  2.6151,  1.9206],
        [-0.3071,  8.0713,  3.6665],
        [ 2.4846,  7.4557,  4.3531],
        [ 1.0830,  0.3886,  5.5129],
        [ 2.0298,  3.7832,  6.1748],
        [ 6.0273,  4.9636, -0.3529],
        [ 0.2007,  3.9053,  4.7083],
        [ 1.5686,  4.8802,  5.8403]]) tensor([15,  0,  8,  4, 12,  8, 12, 12])
tensor([[ 7.2454,  0.8412,  7.2263],
        [-0.1066,  9.0417,  5.1660],
        [-0.8350,  1.6228,  5.6725],
        [ 7.3014,  5.5565, -0.6617],
        [ 1.8779,  3.7454,  5.4681],
        [ 3.1541,  3.0611,  4.7294],
        [ 1.2043,  7.7215,  2.4701],
        [ 5.9903,  6.7868,  2.8987]]) tensor([ 1, 11,  4,  8, 12,

tensor([[ 3.5303e+00,  1.3826e+00,  3.7538e+00],
        [ 9.1664e-01, -4.2101e-04,  5.2980e+00],
        [ 1.9061e+00,  2.5664e+00,  1.6244e+00],
        [ 6.5329e+00,  6.6420e+00, -2.6293e-01],
        [ 2.0072e+00,  7.7971e+00,  2.7411e+00],
        [ 1.4075e+00,  8.6382e+00,  2.5609e-01],
        [ 3.1930e+00,  1.4192e+00,  3.7895e+00],
        [ 2.4561e+00,  7.8923e-01,  2.4942e+00]]) tensor([15,  4, 15,  8, 10,  9, 15, 15])
tensor([[ 8.1510e+00,  6.0736e+00,  2.0945e+00],
        [ 1.0109e-01, -8.4057e-03,  5.5110e+00],
        [ 2.7257e+00,  4.1831e+00,  2.0362e+00],
        [-1.9261e+00,  8.0960e+00,  3.3564e+00],
        [ 2.2726e+00,  8.6175e-01,  2.6234e+00],
        [ 4.5599e-01,  9.2935e+00,  3.3465e+00],
        [ 2.4657e+00,  4.1982e+00,  2.0659e+00],
        [ 2.6383e+00,  7.1480e-01,  2.3760e+00]]) tensor([ 8,  4, 15, 11, 15, 11, 15, 15])
tensor([[-2.2078,  3.0957,  5.7837],
        [-2.7602,  3.7014,  5.6765],
        [ 7.9565,  5.8405,  3.5254],
        [-1.2864,  8.

tensor([[ 2.5651,  9.0399,  5.1581],
        [ 2.1362,  2.8919,  5.6254],
        [ 2.7838,  8.0334,  6.5776],
        [-0.7226,  4.2696,  4.4794],
        [ 5.0052,  5.8901, -0.9816],
        [-3.3717,  4.8694,  2.9009],
        [-1.2383,  8.2307,  2.4282],
        [ 1.7633,  9.4875,  2.2725]]) tensor([ 9, 12,  9,  0,  8, 11, 11,  9])
tensor([[ 1.3331,  2.6343,  1.8545],
        [ 0.9815,  0.1418,  5.3842],
        [ 3.1993,  0.7175,  1.9900],
        [ 1.8706,  7.8605,  2.4560],
        [ 2.4885, -0.0951,  6.3173],
        [-3.5552,  7.3163,  3.5754],
        [ 3.0843,  0.6628,  2.1138],
        [ 4.0473,  4.3948,  3.7927]]) tensor([15,  4, 15, 10,  4, 11, 15, 15])
tensor([[-0.5586,  7.7373,  3.4127],
        [ 1.9441,  3.8797,  5.3782],
        [ 2.9767,  1.4017,  3.9547],
        [-3.3848,  7.1618,  3.8933],
        [ 3.3355,  7.5919,  3.4534],
        [ 2.0606, -0.3048,  5.8090],
        [ 2.3482,  8.9881,  2.3989],
        [ 3.3677,  8.2291,  7.2666]]) tensor([11, 12, 15, 11, 14,

tensor([[ 2.3602,  8.1887,  7.0976],
        [ 3.8372,  6.9145,  3.3269],
        [ 3.0183,  2.1319,  3.4039],
        [ 6.8300,  8.0389,  3.7290],
        [ 1.3417,  8.6559,  4.6428],
        [ 1.3822,  2.5805,  1.9114],
        [-0.1609,  1.0886,  6.7111],
        [ 2.5964,  1.8702,  4.0222]]) tensor([ 9, 14, 15,  8,  9, 15,  4,  0])
tensor([[ 2.2783,  8.4049,  2.0817],
        [ 1.5760,  2.6844,  6.5047],
        [ 3.4614,  7.7326,  5.0026],
        [ 7.2425,  0.8095,  7.3161],
        [-0.7075,  1.6062,  5.6538],
        [ 3.7713,  6.7486,  3.4478],
        [ 0.0159,  4.4319,  4.7623],
        [ 4.4380,  8.8566,  2.5958]]) tensor([ 9, 12,  8,  1,  4, 14, 12,  8])
tensor([[ 2.6600, 10.2012,  4.5799],
        [-2.4887,  3.6003,  6.1230],
        [ 1.7631,  7.8016,  2.5291],
        [ 1.7050,  3.1443,  6.9163],
        [ 1.1347, -0.3166,  6.9070],
        [ 0.2298,  9.3019,  4.4063],
        [ 5.9967,  7.1909,  4.9322],
        [ 0.9683, 11.2898,  5.2112]]) tensor([ 9,  0, 10, 12,  4,

tensor([[ 2.1090,  8.2948,  4.2265],
        [ 4.5218,  4.3132,  4.7893],
        [ 3.5507,  2.7950,  5.1428],
        [ 7.9921,  0.8438,  7.2782],
        [ 2.2580,  8.2779,  2.1762],
        [ 0.6807,  3.6513,  5.4615],
        [ 1.3521,  7.7286,  2.4471],
        [ 4.6902,  6.3908, -1.0874]]) tensor([ 0, 12, 15,  1,  9, 12, 11,  8])
tensor([[ 2.9881,  3.9506,  2.3071],
        [-2.5501,  3.4537,  5.9681],
        [ 1.4534,  7.9905,  3.8599],
        [-4.4085,  5.5485,  2.3421],
        [-0.5185,  4.1375,  4.2596],
        [-1.2683,  9.1266,  4.8197],
        [-2.6849,  3.3111,  6.1779],
        [ 8.3657,  0.6516,  7.1131]]) tensor([15,  3,  9, 11,  0, 11,  3,  1])
tensor([[-3.4996,  7.1303,  2.4948],
        [ 1.6281,  2.7817,  6.4913],
        [-0.5231,  8.3052,  3.7991],
        [ 2.3657,  8.7454,  2.1338],
        [-4.3386,  5.1955,  2.1718],
        [ 0.7393, 11.5677,  5.9223],
        [-1.0791,  9.4295,  4.0651],
        [ 6.5268,  7.4158,  3.6414]]) tensor([11, 12, 11,  9, 11,

tensor([[ 1.4145,  8.1294,  4.0448],
        [ 0.9620,  7.5329,  0.8326],
        [-1.7635,  8.4209,  2.7477],
        [ 1.7564,  2.5883,  1.6352],
        [ 3.6350,  5.4496,  3.7896],
        [-0.0777,  9.0947,  5.1600],
        [ 6.5484,  6.6188, -0.3217],
        [ 2.0179,  3.8392,  6.1266]]) tensor([ 9, 10, 11, 15,  8, 11,  8, 12])
tensor([[ 4.1867,  6.2808,  2.6120],
        [ 6.6454,  6.4672, -0.5227],
        [ 3.6121,  2.0826,  4.0110],
        [ 7.3718,  0.4091,  7.0634],
        [-1.9737,  8.2768,  3.0628],
        [ 2.7715,  2.8383,  1.3742],
        [ 2.9920,  1.4522,  4.0212],
        [-0.5726,  8.2589,  3.8084]]) tensor([ 0,  8, 15,  1, 11, 15,  0, 11])
tensor([[1.7554, 7.8971, 2.4661],
        [5.5500, 6.3582, 5.0497],
        [3.5533, 6.5371, 3.5614],
        [1.6807, 2.8745, 6.7871],
        [2.0101, 4.0268, 5.3227],
        [7.6272, 6.1659, 4.5268],
        [7.3875, 0.4603, 7.1486],
        [1.2259, 1.5087, 2.3727]]) tensor([10,  8, 14, 12, 12,  0,  1, 15])
tensor([[ 

tensor([[-2.9560,  4.9494,  1.8480],
        [ 1.2991, -0.4661,  6.1442],
        [ 1.4609,  4.9284,  5.4714],
        [-1.2496,  9.3305,  4.5205],
        [ 4.5901,  2.2931,  4.7661],
        [ 4.4255,  4.3384,  4.7951],
        [-0.1398,  4.6289,  4.5716],
        [-4.2249,  4.9320,  5.7149]]) tensor([ 0,  4, 12, 11, 15, 12, 12, 11])
tensor([[ 6.9364,  5.2376,  0.1637],
        [ 4.5732,  4.2224,  5.7540],
        [ 1.5390,  4.9644,  5.2539],
        [-1.6965,  1.7397,  6.2086],
        [ 3.3603,  3.8272,  2.6657],
        [ 1.6232,  4.9312,  5.2519],
        [ 1.7482,  3.0272,  6.8840],
        [ 1.4796,  8.0404, -0.4910]]) tensor([ 8, 12, 12,  4, 15, 12, 12, 10])
tensor([[-1.0062,  1.5618,  6.5427],
        [ 3.0665,  3.9011,  2.4652],
        [-1.3695,  1.9378,  6.3304],
        [ 4.0039,  8.6824,  0.7300],
        [ 3.8081,  4.3505,  3.8867],
        [ 3.6863,  6.7112,  5.3516],
        [ 3.4496,  2.5177,  4.8997],
        [ 4.0433,  4.4112,  5.8431]]) tensor([ 4, 15,  4,  8, 15,

        [ 1.3230,  9.2523,  5.1025]]) tensor([ 8, 10, 15,  8,  0, 11,  9,  9])
tensor([[ 3.8957,  4.4302,  5.7576],
        [ 3.0108,  5.3769,  3.6002],
        [ 2.5476,  7.4696,  4.5762],
        [ 3.2779,  2.7031,  4.4316],
        [ 3.1548,  5.2387,  3.7978],
        [ 2.0653,  2.8876,  5.6642],
        [ 3.4716,  3.8806,  2.7486],
        [-3.0951,  5.1975,  5.4336]]) tensor([12,  8,  8, 15,  8, 12, 15,  0])
tensor([[1.2235, 9.1599, 1.6805],
        [1.1371, 7.7532, 2.4957],
        [2.0287, 7.9113, 2.9746],
        [3.7107, 1.3026, 3.9941],
        [3.6208, 7.0632, 3.7347],
        [2.3371, 0.5092, 2.5377],
        [2.5725, 3.1189, 6.0587],
        [4.9588, 2.4496, 4.5882]]) tensor([ 9, 11, 10, 15, 14,  0, 12, 15])
tensor([[ 6.7494,  5.1196,  0.1870],
        [-0.9802,  1.5115,  6.5579],
        [ 0.9862,  0.2409,  7.2309],
        [ 8.1697,  6.1871,  1.7335],
        [ 0.7245,  8.8355,  2.8128],
        [-0.1537,  4.5992,  4.4046],
        [ 2.8295,  2.1205,  3.5362],
        [ 

tensor([[ 2.1480,  1.3233,  2.8180],
        [ 7.4745,  0.3179,  6.9796],
        [-0.5524,  7.7817,  3.4431],
        [-3.5923,  7.1292,  2.4026],
        [ 3.1650,  3.0509,  5.1240],
        [ 0.9364,  2.0148,  1.4479],
        [ 1.3667,  9.3623,  1.8919],
        [ 2.5661,  4.2170,  2.0973]]) tensor([15,  1, 11, 11,  0, 15,  9, 15])
tensor([[ 3.3872,  6.4582,  3.5128],
        [-0.8574,  3.6890,  5.1556],
        [ 2.3194,  3.9062,  6.3335],
        [ 2.5448,  1.9034,  4.0339],
        [ 2.3684,  9.0643,  2.4422],
        [-0.4866,  1.3943,  5.7963],
        [ 1.5937,  7.8880,  2.3224],
        [ 0.7218,  3.8428,  5.8043]]) tensor([14, 12, 12, 15,  9,  4, 10, 12])
tensor([[ 1.1509,  4.7928,  5.3018],
        [ 7.1347,  0.6145,  7.3782],
        [ 4.3469,  9.1300,  3.1447],
        [-1.8727,  7.6122,  3.5849],
        [-2.0049,  2.5723,  6.1611],
        [ 2.6589,  6.5597,  3.9000],
        [ 4.0377,  4.4218,  4.8251],
        [ 0.6259, -0.4379,  5.4959]]) tensor([12,  1,  8, 11,  0,

tensor([[ 4.3153,  8.8879,  1.9450],
        [-2.8336,  3.1545,  6.3582],
        [ 8.1980,  6.2639,  1.6881],
        [ 4.7984,  3.1156,  4.3035],
        [ 1.9325,  3.7450,  6.0057],
        [ 0.0246,  3.8963,  6.3158],
        [ 7.8875,  0.0460,  6.7052],
        [ 1.9532,  3.7690,  6.0735]]) tensor([ 8,  3,  8, 12, 12, 12,  1, 12])
tensor([[ 0.7717,  3.8631,  5.8985],
        [ 4.2448,  2.3271,  4.7945],
        [ 4.3356,  3.6554,  4.8817],
        [ 3.2925,  1.3500,  3.7133],
        [-2.9849,  4.7102,  5.3703],
        [ 2.2912,  2.8733,  5.8896],
        [-0.4994,  7.7560,  3.1292],
        [-0.8285,  1.1381,  4.8362]]) tensor([12, 15, 12, 15,  0, 12, 11,  0])
tensor([[ 2.3734,  8.7863,  2.1226],
        [-1.8996,  8.2403,  3.1394],
        [-3.4256,  7.4029,  3.1163],
        [ 0.7818,  3.8799,  5.9448],
        [ 7.1899,  0.6935,  7.3696],
        [ 0.5509,  8.0273,  3.4188],
        [ 4.0660,  7.6444, -0.9111],
        [ 3.0278,  3.8616,  2.6039]]) tensor([ 9, 11, 11, 12,  1,

tensor([[-0.5090,  4.0191,  5.4572],
        [ 1.7669,  7.4347,  2.6217],
        [-1.7645,  7.5924,  3.4848],
        [-3.3705,  4.8086,  2.9233],
        [ 2.8377,  6.5633,  4.0745],
        [ 2.5396,  2.7048,  1.5494],
        [ 7.3463,  0.4292,  7.0957],
        [ 3.4365,  3.8699,  2.7269]]) tensor([12, 10,  0, 11,  9, 15,  1,  0])
tensor([[ 3.5837,  1.5970,  2.5998],
        [ 1.0000,  7.4578,  0.4058],
        [ 0.3238, -0.2254,  5.6730],
        [ 8.2033,  0.7629,  7.2366],
        [ 7.9577,  5.7913,  4.0576],
        [ 2.5587,  3.1777,  5.9705],
        [-3.3347,  7.1627,  3.8973],
        [ 2.6231,  4.1320,  2.7420]]) tensor([ 0, 10,  4,  1,  8, 12, 11,  0])
tensor([[ 7.4870,  0.7893,  7.3671],
        [ 7.1986,  0.7061,  7.4273],
        [-0.2658,  4.1034,  6.3165],
        [ 2.2392,  7.9594,  6.9483],
        [ 2.1199,  7.8926,  2.6665],
        [ 2.2526,  3.8901,  6.5048],
        [ 1.3029,  4.8675,  5.2410],
        [ 4.7492,  4.1518,  5.7011]]) tensor([ 1,  1, 12,  9,  9,

        [ 2.9951,  4.7186,  5.4617]]) tensor([12,  9,  8,  9, 12, 12,  9, 12])
tensor([[ 1.0655,  1.9738,  1.4944],
        [-0.5886,  8.0827,  3.7745],
        [-3.4667,  7.2234,  3.7567],
        [ 4.0357,  7.5811,  2.7130],
        [ 1.2966,  8.7062,  0.9700],
        [-0.1503,  0.2095,  5.1567],
        [-2.5938,  3.5618,  5.6814],
        [ 6.7434,  7.7813,  3.7898]]) tensor([15, 11, 11, 14,  9,  4,  3,  8])
tensor([[ 4.1158,  8.7819,  1.2778],
        [ 7.2369,  0.7375,  7.1287],
        [ 1.7498,  2.7730,  5.5539],
        [-0.9013,  1.6431,  6.6412],
        [ 3.6493,  4.3231,  3.8363],
        [ 1.1284,  9.0170,  1.5639],
        [ 3.1801,  3.0171,  4.8499],
        [-2.6994,  3.1630,  6.5463]]) tensor([ 8,  1, 12,  4, 15,  9, 15,  3])
tensor([[-1.8743,  7.5591,  3.5392],
        [ 3.6605,  2.6206,  5.1198],
        [ 1.5240,  2.8381,  5.8955],
        [ 1.4577,  0.6791,  6.0748],
        [ 3.4489,  7.1510,  4.1457],
        [ 6.6244,  7.5336,  3.7569],
        [ 1.9552,  3.76

tensor([[ 7.4014,  5.4640, -0.4814],
        [-1.5565,  8.5588,  2.4888],
        [-0.2626,  0.9821,  6.5638],
        [ 2.2007,  4.2858,  5.0693],
        [ 1.4364,  2.5586,  5.9446],
        [ 2.1743,  2.4728,  1.4513],
        [ 2.6592,  7.4829,  4.6149],
        [ 3.7645,  7.8897,  5.0079]]) tensor([ 8, 11,  4,  0, 12,  0,  8,  8])
tensor([[6.3877, 6.9094, 0.1815],
        [2.0236, 7.8193, 2.7113],
        [2.4968, 2.3601, 3.9498],
        [7.2096, 0.7664, 7.2825],
        [3.6201, 6.6620, 5.0107],
        [1.4075, 2.6435, 1.8274],
        [7.0994, 0.6113, 7.4141],
        [3.3323, 8.0312, 7.0291]]) tensor([ 8,  9,  0,  1,  0, 15,  1,  0])
tensor([[ 1.8462,  2.7412,  5.6066],
        [-2.6855,  3.3687,  6.2191],
        [ 8.2408,  6.4895,  1.3099],
        [ 3.9692,  7.8649, -0.6765],
        [-4.0904,  4.9819,  1.6597],
        [ 8.3121,  0.1542,  6.6997],
        [ 0.1022,  4.5063,  4.7271],
        [ 3.7684,  4.0138,  2.9767]]) tensor([12,  3,  8,  8, 11,  1, 12, 15])
tensor([[ 

tensor([[ 4.6008,  6.5497, -1.1489],
        [ 3.8530,  5.6528,  2.7621],
        [ 1.5094,  7.6737, -0.2355],
        [ 3.5347,  1.6925,  2.7051],
        [ 2.6739,  9.9562,  4.6438],
        [-0.8264,  7.8443,  2.5896],
        [ 3.7223,  5.6815,  2.5639],
        [-0.6328,  8.5905,  3.8101]]) tensor([ 8,  8, 10, 15,  9, 11,  8, 11])
tensor([[ 6.0259,  5.0051, -0.3222],
        [ 3.6462,  6.7486,  5.5069],
        [ 3.9488,  7.4829,  2.7593],
        [ 1.7001, -0.2966,  5.7564],
        [ 2.3299,  3.8387,  6.5671],
        [ 0.8753,  0.4216,  7.2555],
        [ 0.0250,  4.1968,  4.7977],
        [ 2.4625,  3.5305,  1.1899]]) tensor([ 8,  9, 14,  4, 12,  4, 12, 15])
tensor([[ 4.0707,  3.9200,  4.9496],
        [ 8.0522,  0.7974,  7.2880],
        [-1.1132,  9.4328,  4.1210],
        [ 2.4325, 11.0451,  4.6436],
        [ 3.3587,  6.5608,  4.7479],
        [ 0.5073,  9.2743,  3.0362],
        [ 3.8473,  5.6662,  3.8395],
        [ 3.1032,  6.7686,  3.8745]]) tensor([12,  1, 11,  9,  9,

        [ 7.9324,  0.0303,  6.6863]]) tensor([ 0,  3,  8,  0, 12,  9,  1,  1])
tensor([[ 8.1844,  0.7871,  7.2229],
        [-2.4296,  3.2569,  5.8125],
        [-4.3882,  5.6744,  2.3220],
        [ 1.8939,  3.6960,  5.9220],
        [ 7.6964,  0.1488,  6.8110],
        [ 3.6207,  4.5276,  5.7368],
        [ 1.3435,  2.0294,  3.0168],
        [ 4.3054,  8.8266,  1.9392]]) tensor([ 1,  3, 11, 12,  1, 12,  0,  8])
tensor([[ 4.7317,  6.7409,  2.7324],
        [ 2.5296,  8.7952,  5.4101],
        [ 3.2041,  3.0229,  4.8957],
        [ 0.5442,  9.1171,  3.3307],
        [ 6.1712,  7.2818,  3.1873],
        [ 3.0971,  2.9847,  4.5017],
        [-4.1358,  4.8341,  4.7653],
        [ 1.4954,  7.6923,  2.3508]]) tensor([ 0,  9, 15, 11,  8, 15, 11,  0])
tensor([[ 4.2565,  2.3289,  4.8428],
        [ 6.0247,  7.5401,  4.5667],
        [ 2.0452,  0.7087,  6.5387],
        [ 5.2240,  4.2409,  4.9204],
        [ 3.4926,  2.2729,  4.0371],
        [ 7.8212,  0.8143,  7.3590],
        [ 6.5587,  4.98

tensor([[ 7.1835,  5.6883, -0.7929],
        [ 7.3152,  0.7347,  7.3913],
        [-1.9668,  8.2043,  3.1750],
        [ 1.7957,  3.6140,  5.6928],
        [-0.5435,  7.7222,  3.3719],
        [ 3.4611,  7.7222,  7.1509],
        [ 1.2618,  4.8964,  5.4157],
        [ 3.2255,  3.8252,  2.6494]]) tensor([ 8,  1, 11, 12, 11,  9,  0, 15])
tensor([[ 2.8182,  4.9008,  2.8344],
        [ 3.0992,  7.2157,  3.7554],
        [-2.9621,  4.9807,  2.1403],
        [ 4.3506,  7.3908,  2.7452],
        [ 6.7511,  5.1447,  0.2473],
        [ 8.2198,  6.5387,  1.2727],
        [-0.9563,  9.3666,  3.9138],
        [ 2.6955,  2.0771,  3.8553]]) tensor([ 0, 10, 11, 14,  8,  8, 11, 15])
tensor([[ 3.5927,  1.3084,  3.8079],
        [ 2.7911,  8.0644,  6.4676],
        [ 0.9556,  0.2650,  7.2832],
        [ 1.6115,  3.4210,  6.9915],
        [-2.6160,  3.5312,  5.7653],
        [-4.0923,  4.9942,  1.6693],
        [ 3.2156,  0.7389,  1.9803],
        [-0.3313,  4.4772,  4.9364]]) tensor([15,  9,  4, 12,  3,

tensor([[ 2.4380,  0.2159,  6.6465],
        [-2.3474,  3.1558,  5.9187],
        [ 3.0461,  5.3834,  3.5888],
        [ 7.1105,  0.5954,  7.2713],
        [-3.9619,  5.0553,  6.2658],
        [ 3.8203,  3.9594,  2.9529],
        [ 3.3599,  6.4151,  2.7692],
        [ 0.7833,  3.7494,  5.4907]]) tensor([ 4,  3,  8,  1, 11, 15,  0, 12])
tensor([[ 4.2096,  4.1900,  3.3266],
        [ 4.2214,  2.3741,  4.6241],
        [ 1.5107, -0.3086,  5.8938],
        [ 1.4229,  1.8291,  2.9539],
        [ 2.5841,  2.7318,  1.5170],
        [-3.9412,  4.7329,  3.8526],
        [ 5.9547,  5.0096, -0.3259],
        [ 4.1697,  7.0950,  3.6019]]) tensor([15,  0,  4, 15, 15, 11,  8, 14])
tensor([[ 1.7218, -0.2148,  5.7425],
        [ 2.6280,  3.2553,  6.0614],
        [ 2.6533,  8.4163,  6.2332],
        [ 4.1962,  7.1386,  3.4939],
        [-0.5246,  7.7039,  3.0970],
        [ 3.5763,  7.4536,  3.8440],
        [-2.5563,  3.2646,  5.8130],
        [ 2.9845,  5.4482,  3.4817]]) tensor([ 4, 12,  9, 14, 11,

tensor([[ 3.1187,  2.1343,  3.3295],
        [ 4.6309,  6.4248,  2.7458],
        [-1.6738,  1.6644,  6.2899],
        [ 2.6267, 10.3850,  4.5486],
        [-2.8049,  3.9528,  5.5960],
        [ 2.1144,  4.0874,  2.0481],
        [ 3.7482,  5.5500,  3.8103],
        [ 0.9994,  2.4866,  1.6350]]) tensor([ 0,  0,  4,  9,  3, 15,  8, 15])
tensor([[-3.6082,  4.7616,  3.2874],
        [ 1.7349,  3.0246,  6.7924],
        [ 0.6328,  3.5819,  5.5510],
        [ 3.3557,  7.2697,  6.9298],
        [ 3.0304,  1.4135,  4.0541],
        [ 1.3589,  9.3996,  5.1958],
        [ 4.3246,  4.3367,  4.8415],
        [-0.3861,  8.3428,  4.7610]]) tensor([11, 12, 12,  9, 15,  9, 12, 11])
tensor([[ 0.5995,  9.0829,  2.7035],
        [ 3.1254,  2.0821,  3.2131],
        [-4.2664,  5.0276,  1.9472],
        [ 0.9401,  3.9270,  6.2495],
        [ 4.1061,  7.4929,  2.7543],
        [ 3.9420,  2.4442,  4.6611],
        [-1.9282,  7.4159,  3.5813],
        [ 2.2170,  3.8596,  6.4640]]) tensor([11,  0, 11, 12, 14,

        [ 2.0138,  4.1641,  5.5049]]) tensor([15,  4,  8, 10, 12, 12, 14, 12])
tensor([[ 2.1245, 11.3935,  5.7229],
        [ 3.1247,  7.5749,  6.5381],
        [ 8.1100,  6.0774,  2.0747],
        [ 2.0579,  7.7244,  3.0694],
        [ 3.8332,  1.7114,  4.2035],
        [ 1.3391, 10.4182,  5.4629],
        [-0.3527,  0.9821,  6.4491],
        [ 3.6635,  6.8025,  5.7333]]) tensor([ 9,  0,  8, 10, 15,  9,  4,  9])
tensor([[ 3.0931,  6.5276,  4.4640],
        [ 3.2868,  3.0101,  4.8938],
        [ 1.5009,  4.9579,  5.2196],
        [ 8.4625,  0.4239,  6.8999],
        [ 1.1299, -0.3818,  6.6522],
        [ 8.3193,  0.6827,  7.1402],
        [-0.2479,  4.3990,  6.0733],
        [ 1.2964,  1.4494,  2.3740]]) tensor([ 9, 15, 12,  1,  4,  1, 12, 15])
tensor([[ 3.0406,  6.6803,  2.7324],
        [-1.9251,  8.0824,  3.3537],
        [-0.2038,  3.9491,  4.3074],
        [ 1.6551,  2.9554,  6.8291],
        [ 1.7233,  3.1240,  6.7850],
        [ 1.0121,  2.2176,  1.5922],
        [ 1.7800,  9.49

tensor([[ 2.9491,  6.6084,  2.7602],
        [ 2.3633,  2.9992,  5.7993],
        [ 0.6535,  3.4858,  5.3699],
        [ 6.0924,  8.0491,  3.8021],
        [-3.3110,  7.1282,  3.9723],
        [ 7.9344, -0.0192,  6.6515],
        [ 5.5851,  5.1498, -0.6070],
        [ 1.5950,  7.8188,  4.0261]]) tensor([ 0,  0, 12,  8, 11,  1,  8,  0])
tensor([[0.1593, 9.2642, 4.5582],
        [3.8089, 2.4580, 5.0788],
        [3.2592, 6.3739, 3.1336],
        [2.6504, 8.3902, 6.2095],
        [3.5262, 5.4062, 3.7914],
        [3.1208, 6.2659, 2.5969],
        [0.8980, 3.8019, 6.4969],
        [1.4107, 2.5577, 6.1788]]) tensor([11,  0, 14,  9,  8, 14, 12, 12])
tensor([[7.7201, 6.0787, 4.5136],
        [1.7162, 7.7790, 3.3037],
        [8.0521, 5.8659, 3.0361],
        [2.1761, 8.0021, 2.4274],
        [0.0742, 9.2428, 4.8130],
        [2.8440, 7.5616, 4.8680],
        [4.3327, 7.3334, 2.8337],
        [3.5308, 7.6011, 3.4678]]) tensor([ 8,  9,  8,  0, 11,  8,  0, 14])
tensor([[ 3.3575,  7.0857,  6.6754

tensor([[-0.5252,  7.8236,  3.5260],
        [ 3.7884,  1.8640,  4.1160],
        [-4.3305,  5.8170,  2.3829],
        [ 3.2112,  2.9696,  4.8835],
        [ 0.6485, -0.3692,  5.4862],
        [ 3.6728,  1.4642,  3.9647],
        [ 3.5559,  7.1004,  2.6547],
        [-2.6501,  2.9673,  5.7020]]) tensor([11,  0, 11, 15,  4, 15, 14,  3])
tensor([[ 2.0325,  7.8421,  2.7801],
        [ 4.0051,  3.8198,  4.9952],
        [ 3.3123,  4.3514,  3.7149],
        [ 4.0912,  6.3999,  2.7309],
        [-2.5350,  3.0218,  5.7316],
        [ 3.3348,  6.4000,  3.2583],
        [ 1.7179,  4.8149,  5.7765],
        [-2.6492,  3.0691,  6.1536]]) tensor([10, 12, 15,  0,  3, 14, 12,  0])
tensor([[-0.4614,  1.1870,  5.9430],
        [ 4.1995,  7.3093, -1.0536],
        [ 3.9963,  4.4013,  3.8804],
        [ 0.9056, 11.5293,  6.0661],
        [-0.4357,  1.0314,  6.2805],
        [ 7.0135,  6.4549,  4.0838],
        [ 2.7676,  2.9453,  1.2643],
        [ 0.9259, 11.4832,  6.0983]]) tensor([ 4,  8, 15,  9,  4,

tensor([[ 8.0309,  0.0204,  6.6500],
        [ 3.2596,  8.3192,  7.2628],
        [ 7.6847,  6.0922,  4.5320],
        [ 2.0763,  1.3874,  2.8195],
        [ 7.4095,  0.3810,  7.0548],
        [ 4.1903,  6.3923,  3.8438],
        [ 1.1158, 11.0729,  5.2070],
        [ 0.9553,  2.3523,  1.5655]]) tensor([ 1,  9,  8, 15,  1,  8,  9, 15])
tensor([[-0.2407,  4.5862,  4.7806],
        [ 3.2625,  7.2070,  3.7126],
        [ 0.7754,  4.3631,  5.0026],
        [ 2.5699,  8.8283,  5.2811],
        [ 2.1874,  1.0257,  2.7344],
        [-2.6269,  3.2736,  5.8877],
        [-3.4647,  4.7886,  3.0326],
        [ 2.3270,  6.8741,  3.8760]]) tensor([12, 10,  6,  9, 15,  3, 11,  0])
tensor([[ 0.9226,  0.3906,  7.2685],
        [ 1.3257,  1.4850,  2.0580],
        [ 1.4749,  7.7003,  2.2251],
        [ 7.0794,  0.5421,  7.3680],
        [-2.7056,  3.6681,  5.6738],
        [ 0.4545,  0.9043,  7.2122],
        [ 2.7079,  3.7534,  6.5335],
        [-2.5158,  3.2316,  5.9110]]) tensor([ 4, 15,  0,  1,  3,

tensor([[ 2.0736,  4.1021,  2.0786],
        [ 4.1472,  8.8417,  1.3785],
        [-0.5702,  7.6716,  3.0687],
        [ 2.0217,  3.8024,  6.1897],
        [ 1.6796,  7.9705,  2.6585],
        [ 2.8873,  8.4488,  7.2571],
        [ 1.3884,  8.6456,  0.5322],
        [ 6.7473,  6.3137, -0.6643]]) tensor([15,  8, 11, 12, 10,  9,  9,  8])
tensor([[ 3.1862,  7.2583,  3.7422],
        [ 0.7977,  3.7803,  6.2993],
        [ 2.0710, 11.3829,  5.8505],
        [ 1.9115,  3.8682,  5.4030],
        [ 0.2445,  4.5571,  4.8486],
        [ 4.1665,  7.1117,  2.8723],
        [ 2.0181,  4.0729,  5.4520],
        [-0.2054,  4.5490,  5.9166]]) tensor([ 0, 12,  9, 12, 12,  0, 12, 12])
tensor([[ 2.5514,  7.4602,  4.3965],
        [-0.2774,  0.9419,  6.2487],
        [-3.4970,  7.2315,  3.7805],
        [ 1.2290,  8.7598,  1.1818],
        [ 1.6921,  1.7296,  3.0939],
        [ 2.3892,  9.0369,  2.2377],
        [ 5.1720,  4.2752,  5.0162],
        [ 2.5355, 10.7346,  4.5597]]) tensor([ 8,  4, 11,  9, 15,

        [ 1.7073,  2.8982,  6.8163]]) tensor([ 9, 13, 15, 12, 12,  9, 12, 12])
tensor([[ 3.0720,  3.0590,  4.8637],
        [ 0.0225,  1.1064,  6.9175],
        [-2.4927,  3.0774,  5.7776],
        [ 2.0038,  2.5326,  1.6880],
        [ 1.2090,  1.5212,  2.5873],
        [ 7.7927,  0.0832,  6.7519],
        [ 0.8647,  2.4724,  1.6086],
        [ 0.0648,  7.9238,  3.3460]]) tensor([15,  4,  3, 15, 15,  1, 15,  0])
tensor([[ 1.4407,  1.8744,  2.7855],
        [ 3.1487,  6.5654,  4.1706],
        [ 2.7107,  2.8735,  1.4607],
        [ 4.4501,  2.4508,  4.7285],
        [ 2.2755, -0.1074,  5.9362],
        [ 4.1943,  7.0513,  3.7539],
        [-3.8041,  4.9287,  1.4740],
        [ 2.6561,  2.7247,  1.4046]]) tensor([15,  0, 15, 15,  4, 14, 11, 15])
tensor([[ 0.7823,  3.7717,  5.2752],
        [ 2.8583,  1.4811,  4.1899],
        [ 2.0156,  3.9472,  2.0163],
        [ 1.3926,  9.1914,  1.9836],
        [ 2.6453,  5.7634,  3.0937],
        [-0.1082,  4.5590,  5.8394],
        [-2.9602,  4.96

tensor([[ 0.8818,  3.8990,  6.2105],
        [ 0.2242,  2.9456,  5.6403],
        [-3.9527,  5.0715,  6.1486],
        [-1.8972,  7.8310,  3.5093],
        [ 3.4363,  7.6535,  7.1906],
        [-3.3277,  7.3463,  3.5388],
        [ 2.3884,  8.1621,  7.0535],
        [-0.6656,  4.0891,  5.8316]]) tensor([12, 12, 11, 11,  9, 11,  9,  0])
tensor([[ 4.0757,  7.5710,  2.7036],
        [ 1.6355, 11.3308,  6.1938],
        [ 2.3972, -0.2801,  5.9692],
        [ 1.4684,  7.5705, -0.1264],
        [-0.4862,  1.3457,  5.8148],
        [-4.1671,  4.7565,  4.5907],
        [ 8.0744,  5.9022,  2.7434],
        [ 2.5438,  2.0300,  3.8747]]) tensor([14,  0,  4, 10,  4, 11,  8, 15])
tensor([[ 0.8783,  8.4995,  2.4645],
        [ 1.6469,  4.8671,  5.6425],
        [ 5.1246,  4.2531,  5.2103],
        [ 3.7099,  3.9456,  2.8904],
        [-0.1088,  1.0951,  6.7856],
        [-0.8161,  3.7322,  5.1738],
        [-1.8693,  8.1043,  3.3191],
        [ 0.9887,  0.3452,  7.3110]]) tensor([11, 12, 12, 15,  4,

        [ 5.5161,  6.3413,  5.0261]]) tensor([ 1,  4, 11,  1, 15, 11, 15,  8])
tensor([[ 0.7595,  8.8211,  2.4963],
        [ 8.3904,  0.6176,  7.0762],
        [ 2.9382,  5.5802,  2.5811],
        [ 2.1581,  3.9068,  6.2749],
        [ 4.1457,  2.6111,  4.6803],
        [ 3.3757,  3.9160,  2.8039],
        [ 7.4122,  5.4522, -0.4951],
        [-0.6384,  8.5829,  3.8310]]) tensor([11,  1,  0, 12,  0, 15,  8, 11])
tensor([[-2.5720,  3.4567,  5.8349],
        [ 6.7409,  8.1840,  3.6177],
        [ 8.3307,  0.1815,  6.7201],
        [ 0.9908,  3.8919,  6.4798],
        [ 3.5747,  6.6531,  5.0325],
        [ 3.3313,  6.4291,  3.3809],
        [-2.3959,  6.9919,  3.9545],
        [ 1.1031,  3.8246,  6.6423]]) tensor([ 3,  8,  1, 12,  9, 14, 11, 12])
tensor([[ 2.4478,  3.5411,  1.4534],
        [ 3.9932,  4.3909,  3.8182],
        [ 4.0379,  4.0700,  3.0951],
        [ 0.9251, 11.4786,  6.0826],
        [-4.0332,  4.7599,  4.0852],
        [ 3.4894,  1.8449,  2.6903],
        [ 6.6838,  7.60

tensor([[ 4.2103,  7.1063,  3.8214],
        [ 0.1104,  3.4758,  6.1385],
        [ 0.9324,  8.2694,  2.3844],
        [ 2.0218,  3.7738,  1.8128],
        [ 1.3053,  4.9320,  5.6427],
        [ 3.9817,  6.3592,  3.3233],
        [-0.7308,  8.8895,  3.7784],
        [ 6.4137,  6.4907,  3.2615]]) tensor([14, 12, 11, 15, 12,  0, 11,  8])
tensor([[-0.8843,  1.8701,  6.6974],
        [-3.4716,  7.2145,  3.8222],
        [ 1.1246,  1.7979,  1.4458],
        [ 7.7522,  6.0528,  4.4958],
        [-1.4468,  8.1984,  3.0977],
        [ 4.1320,  6.4698,  3.3325],
        [ 6.0926,  7.1682,  3.0873],
        [ 1.4490,  7.7283, -0.3011]]) tensor([ 4, 11, 15,  8,  0,  0,  8, 10])
tensor([[ 1.5912,  7.4862,  2.7746],
        [-3.4503,  7.3653,  2.8557],
        [ 1.3255,  9.3133,  1.8628],
        [ 1.3743,  1.6404,  2.5543],
        [-0.1138,  8.8220,  5.2280],
        [ 4.7838,  6.7264,  2.5163],
        [-0.5367,  7.6647,  3.1124],
        [-1.1000,  9.4571,  3.9669]]) tensor([ 0, 11,  9, 15, 11,

tensor([[ 0.5648,  9.1655,  2.9227],
        [ 3.3080,  3.1237,  4.9816],
        [ 1.1622, -0.4340,  6.4524],
        [ 1.1967, -0.3728,  6.8424],
        [ 4.3764,  8.2473,  4.8461],
        [ 6.7576,  7.6009,  0.7075],
        [ 4.0060,  4.0656,  3.0940],
        [-2.5975,  3.1751,  5.9803]]) tensor([11,  0,  4,  4,  8,  8, 15,  3])
tensor([[-3.2108,  7.0624,  4.0654],
        [-2.6305,  2.7472,  6.2907],
        [ 4.1524,  7.4173,  2.1600],
        [ 5.4591,  5.2557, -0.6071],
        [ 2.5131,  8.7337,  5.4349],
        [ 2.5417,  8.3723,  5.9511],
        [ 3.6892,  1.9471,  4.1202],
        [ 7.7767,  7.4835,  0.2944]]) tensor([11,  3,  0,  8,  9,  9, 15,  8])
tensor([[-2.5863,  3.4891,  6.0704],
        [-0.1771,  8.9503,  5.2967],
        [-3.5956,  4.9694,  1.4400],
        [ 3.5477,  1.6084,  2.7537],
        [ 0.9269,  0.3565,  7.2733],
        [ 1.1978,  2.5984,  1.8779],
        [-2.7580,  3.0833,  5.7743],
        [ 4.6058,  8.9747,  4.1980]]) tensor([ 3, 11, 11, 15,  4,

tensor([[ 4.5958,  5.6925,  3.7941],
        [ 4.3896,  8.0356,  4.6401],
        [ 3.5878,  7.1113,  2.6324],
        [ 7.8585,  5.9373,  4.3919],
        [ 2.9131,  5.1803,  3.7875],
        [ 4.2045,  8.1244,  4.9045],
        [ 0.4751,  0.8855,  7.2020],
        [ 7.3522,  5.5044, -0.6476]]) tensor([ 8, 13, 14,  8,  8,  8,  4,  8])
tensor([[ 6.8663,  0.5240,  7.3058],
        [ 0.2951,  9.3064,  4.1676],
        [-1.0747,  2.1349,  6.3262],
        [ 1.9825,  4.1316,  5.5106],
        [ 2.5802,  8.4595,  6.1971],
        [-3.7685,  4.9352,  1.4580],
        [-0.2023,  8.8700,  5.3413],
        [ 2.5113, -0.0320,  6.4223]]) tensor([ 1, 11,  4, 12,  0, 11,  0,  4])
tensor([[ 0.9820, 11.2992,  5.2808],
        [ 6.3694,  7.2115,  0.7251],
        [-2.5975,  3.8513,  5.6280],
        [-2.6600,  2.5351,  6.4162],
        [ 2.2985,  8.4360,  2.1028],
        [ 0.5843,  9.1281,  2.8219],
        [ 3.2476,  7.2741,  3.7356],
        [ 2.4919,  8.5659,  5.6426]]) tensor([ 9,  8,  3,  4,  9,

tensor([[ 8.0923, -0.0272,  6.6026],
        [-3.5595,  5.2024,  6.2702],
        [ 3.0020,  6.6231,  4.0728],
        [ 1.9935,  3.7683,  6.1278],
        [ 8.4784,  0.5198,  6.9777],
        [ 1.7253,  8.1043,  2.5585],
        [ 2.7855,  3.4922,  6.2882],
        [-3.3675,  7.3683,  3.2784]]) tensor([ 1, 11,  0, 12,  1, 10, 12, 11])
tensor([[ 1.6565,  2.7867,  6.4858],
        [ 1.2513,  7.5896,  1.6590],
        [ 2.5424,  8.8554,  5.3359],
        [ 3.6901,  1.1931,  4.0556],
        [ 3.9967,  7.1680,  3.3662],
        [ 0.7877, -0.2820,  5.3249],
        [ 1.1222,  8.9005,  1.4593],
        [ 6.1437,  7.2191,  3.1595]]) tensor([12, 10,  9,  0, 14,  4,  9,  8])
tensor([[ 1.0849,  7.6728,  1.1441],
        [ 3.0499,  8.3877,  7.2251],
        [ 5.8723,  6.8357,  4.9755],
        [ 0.8026,  8.7447,  2.4720],
        [ 3.7073,  1.3586,  4.0065],
        [ 3.3290,  2.0411,  3.0011],
        [-0.4163,  8.3529,  3.8019],
        [-2.1345,  2.3643,  6.5871]]) tensor([10,  9,  8, 11, 15,

tensor([[ 3.8743,  3.9643,  2.9626],
        [ 1.7114,  3.7028,  5.6827],
        [ 1.3332,  8.6491,  4.6522],
        [ 2.5749,  3.0634,  6.0252],
        [ 4.2265,  7.2395,  3.4806],
        [-0.4485,  7.8895,  3.6141],
        [ 4.1469,  4.3706,  3.6250],
        [ 8.3931,  0.2734,  6.7848]]) tensor([15, 12,  9, 12, 14, 11, 15,  1])
tensor([[ 4.4641,  6.8052, -1.1860],
        [ 2.6684,  1.7358,  4.0881],
        [-3.4865,  4.8048,  3.0755],
        [ 3.4231,  7.5337,  7.1107],
        [ 3.4638,  2.8831,  5.1399],
        [ 0.1531,  3.3614,  6.0614],
        [ 6.1339,  8.0797,  3.7185],
        [ 2.6704, 10.0953,  4.5669]]) tensor([ 8, 15, 11,  9, 15, 12,  8,  9])
tensor([[-1.7600,  7.8546,  3.4003],
        [ 1.8870,  3.8395,  5.4983],
        [ 0.2679, -0.2340,  5.6531],
        [ 2.7563,  3.5637,  6.4627],
        [ 4.2948,  6.4033,  4.0623],
        [ 3.9603,  4.0817,  3.1439],
        [ 6.7002,  7.6233,  3.8093],
        [ 2.6088,  9.3232,  4.9298]]) tensor([11, 12,  4, 12,  0,

tensor([[ 2.1584,  2.6183,  1.6672],
        [ 6.7700,  7.6065,  0.7123],
        [ 1.6091,  7.8261,  2.6654],
        [ 4.4015,  8.7342,  2.5526],
        [-1.6638,  8.6622,  2.4548],
        [ 0.9837,  7.8937,  2.3104],
        [ 0.8674, -0.1341,  5.3292],
        [ 1.3339,  8.5622,  4.5593]]) tensor([15,  8, 10,  8, 11, 11,  4,  9])
tensor([[-1.6653,  2.2640,  6.0823],
        [ 1.9795,  4.4791,  5.6259],
        [ 2.6123,  9.2748,  4.9430],
        [-3.4908,  7.2052,  3.8233],
        [-0.2407,  0.1709,  5.0307],
        [ 2.8553,  4.8229,  3.2448],
        [ 2.8103,  8.2788,  6.4043],
        [ 2.6365,  7.4976,  4.6680]]) tensor([ 4, 12,  9, 11,  4, 15,  9,  8])
tensor([[ 1.4109,  1.5702,  2.3334],
        [ 2.8132,  8.1539,  6.4630],
        [-1.0546,  1.5017,  5.9917],
        [ 3.1895,  2.8541,  4.6597],
        [ 1.1491,  2.0305,  1.5992],
        [-0.9950,  8.9305,  4.7737],
        [-0.9539,  9.3700,  3.8221],
        [ 3.3742,  8.1746,  7.2397]]) tensor([15,  9,  4, 15, 15,

tensor([[-0.6851,  9.1778,  3.8895],
        [ 3.6400,  1.1998,  2.3139],
        [-0.8785,  0.0846,  4.8280],
        [ 3.3415,  6.4172,  3.0739],
        [ 3.4340,  7.3221,  3.8364],
        [ 1.9958,  3.7791,  6.1027],
        [ 0.8039,  8.7987,  2.6442],
        [ 6.3348,  4.9664, -0.2144]]) tensor([11, 15,  0,  0,  0, 12,  0,  8])
tensor([[-2.4798,  3.1549,  6.0358],
        [-2.6964,  7.1030,  4.0541],
        [ 0.1140,  4.0166,  4.4187],
        [-1.1896,  8.1458,  2.3839],
        [ 3.6307,  1.4997,  2.5318],
        [ 7.8418,  0.7926,  7.3572],
        [ 2.1176,  3.8622,  1.9156],
        [ 1.2984,  8.0392,  2.5527]]) tensor([ 3, 11, 12, 11, 15,  1,  0, 11])
tensor([[ 1.5961,  2.6574,  1.8006],
        [ 1.3460,  3.6313,  6.9883],
        [ 0.1235,  3.4424,  6.0773],
        [ 1.6348,  7.9582,  2.5199],
        [-3.3106,  5.4030,  5.8178],
        [-0.4376,  8.6821,  3.8132],
        [ 7.2274,  0.7547,  7.2844],
        [-0.8787,  1.8817,  6.6316]]) tensor([15, 12, 12, 10, 11,

tensor([[ 0.8568,  8.7309,  2.2951],
        [ 0.7513, 11.5235,  5.4680],
        [-1.9053,  8.2941,  3.0156],
        [-0.1055,  3.7953,  6.3389],
        [ 1.0735,  3.7578,  6.7405],
        [ 2.4937,  3.1231,  5.8455],
        [ 2.3073,  3.6682,  1.4467],
        [ 3.9732,  7.8149, -0.7833]]) tensor([11,  9, 11, 12, 12, 12, 15,  0])
tensor([[ 4.2489,  7.5765,  2.7087],
        [ 7.3959,  5.4125, -0.1827],
        [ 2.7847,  3.1116,  1.1642],
        [ 6.7226,  8.0709,  3.6756],
        [ 2.3914,  8.8433,  2.0683],
        [-3.8443,  5.1423,  1.6711],
        [ 2.8804,  5.6154,  2.5811],
        [ 3.2548,  7.1875,  3.8059]]) tensor([14,  8, 15,  8,  9,  0, 14,  0])
tensor([[ 3.7276,  1.4021,  3.9866],
        [-4.2376,  5.1920,  2.0765],
        [ 3.1756,  4.4240,  3.6641],
        [ 2.0723,  0.5509,  6.6407],
        [ 2.7195,  3.3711,  1.1304],
        [ 8.2030,  0.0951,  6.6805],
        [ 1.3123,  1.7499,  1.5319],
        [ 0.6802,  4.1894,  4.0048]]) tensor([ 0,  0, 15,  4, 15,

tensor([[ 3.5613,  7.6074,  3.3142],
        [ 2.6550,  9.6263,  4.7759],
        [ 4.2668,  4.3795,  3.5974],
        [ 1.2657,  1.5357,  2.7807],
        [ 2.2866,  2.5255,  1.4843],
        [ 2.0386,  2.5651,  1.6223],
        [ 2.9989,  1.7355,  3.8577],
        [-2.6614,  2.8449,  6.2388]]) tensor([14,  0, 15, 15, 15, 15,  0,  3])
tensor([[-2.3152,  3.3473,  6.0700],
        [ 4.4936,  8.4374,  4.9202],
        [ 2.0408,  4.4594,  5.6122],
        [ 6.8704,  6.4749,  3.8759],
        [-0.7520,  9.0294,  3.7630],
        [-1.9981,  2.3911,  6.5210],
        [ 3.8443,  6.9192,  3.3297],
        [ 1.2925,  0.6259,  5.8349]]) tensor([ 3,  8, 12,  8, 11,  4, 14,  4])
tensor([[ 1.3495,  7.5978,  2.5611],
        [ 6.7365,  6.3133, -0.6389],
        [-2.0229,  7.3206,  3.6651],
        [ 4.3077,  6.2405,  4.2659],
        [ 3.4545,  6.5498,  2.9131],
        [ 5.4652,  6.2723,  5.0315],
        [-4.1207,  6.2955,  2.2748],
        [ 1.2734,  1.8328,  1.5004]]) tensor([11,  8, 11,  8, 14,

tensor([[1.8103, 2.9602, 6.7954],
        [1.1012, 3.8772, 6.5356],
        [3.7658, 4.2710, 5.0077],
        [7.0971, 0.5903, 7.4518],
        [1.2122, 2.9200, 2.2859],
        [3.0851, 6.8818, 3.1944],
        [0.6949, 3.5943, 5.4260],
        [3.0147, 3.8665, 2.4612]]) tensor([12, 12, 12,  1,  0,  0, 12, 15])
tensor([[ 2.2185,  3.6372,  1.5535],
        [ 0.2137,  3.2580,  5.8994],
        [ 4.2222,  4.4155,  4.7739],
        [ 0.7924,  3.9050,  5.9072],
        [-1.2185,  9.3375,  4.5545],
        [ 4.6368,  8.7359,  4.7283],
        [ 3.5578,  6.8046,  2.7546],
        [ 1.6020,  4.9105,  5.6329]]) tensor([15, 12, 12, 12, 11,  0, 14, 12])
tensor([[ 3.3877,  2.5279,  4.2962],
        [ 1.2966,  7.2275,  0.6299],
        [-1.6667,  7.7613,  3.0307],
        [-0.1144,  8.6529,  5.2855],
        [ 1.5627,  4.9431,  5.3648],
        [-0.1843,  4.4530,  6.0018],
        [-2.4466,  3.4690,  5.9043],
        [ 3.7320,  6.7097,  3.4411]]) tensor([15,  0,  0, 11, 12, 12,  3, 14])
tensor([[-

tensor([[ 4.0429,  6.0202,  3.9666],
        [ 4.0293,  4.4018,  4.7767],
        [ 7.0871,  0.5555,  7.3695],
        [ 0.1249,  4.5540,  4.7996],
        [-3.8871,  5.0708,  6.1831],
        [-2.8374,  7.0383,  4.0898],
        [-1.9530,  8.1870,  3.2372],
        [-4.1392,  6.2789,  2.2803]]) tensor([ 8, 12,  1, 12, 11, 11, 11, 11])
tensor([[-2.6999,  3.6304,  5.7394],
        [-0.6606,  2.4261,  4.7520],
        [ 2.5978,  7.7588,  6.6633],
        [ 4.1151,  4.3453,  3.6280],
        [ 6.6978,  6.3789, -0.5934],
        [ 0.6203,  3.6531,  5.6684],
        [ 1.4122, -0.3844,  6.0035],
        [ 2.0315,  7.7330,  3.0689]]) tensor([ 3, 12,  9, 15,  8, 12,  4, 10])
tensor([[ 1.6627,  7.7911,  2.7899],
        [ 2.2746,  2.5589,  3.9278],
        [-4.3497,  4.9124,  1.9751],
        [ 2.1761,  3.8666,  6.3539],
        [-2.7144,  3.0874,  6.6012],
        [-2.2435,  3.3116,  6.0302],
        [-2.2488,  7.1006,  3.8156],
        [ 1.9971,  1.5251,  2.9178]]) tensor([10,  0, 11, 12,  3,

tensor([[0.1379, 3.4180, 6.0997],
        [3.5845, 7.7309, 4.8764],
        [0.9573, 7.5607, 0.7623],
        [8.1836, 6.6562, 1.1126],
        [0.5765, 3.4551, 5.3394],
        [7.4750, 6.2793, 4.5289],
        [7.2046, 0.7466, 7.1490],
        [3.4621, 7.6260, 7.2228]]) tensor([ 0,  0, 10,  8, 12,  8,  1,  9])
tensor([[ 4.8386,  4.1262,  5.6402],
        [-1.5759,  7.5545,  3.4170],
        [ 2.2840, 11.3219,  5.1488],
        [ 3.7567,  1.5350,  4.0683],
        [ 0.1719,  4.1968,  4.7569],
        [ 7.5170,  0.3018,  6.9592],
        [-3.4138,  5.3153,  5.9513],
        [ 5.1746,  4.1721,  4.8614]]) tensor([12,  0,  9, 15, 12,  1, 11, 12])
tensor([[ 6.7234,  7.5613,  0.7237],
        [-2.3984,  3.6017,  6.1920],
        [ 5.2503,  5.5246, -0.8474],
        [-1.1129,  8.1149,  2.4650],
        [-1.8991,  7.5177,  3.5594],
        [ 4.8914,  5.8627,  4.3494],
        [ 3.5570,  1.6318,  2.7109],
        [ 3.1990,  1.2631,  3.7898]]) tensor([ 8,  3,  8, 11, 11,  8, 15, 15])
tensor([[-

tensor([[-0.5034,  1.3562,  5.7865],
        [ 6.6953,  6.4636,  3.5766],
        [ 6.1285,  4.9360, -0.3647],
        [ 1.3802,  1.7594,  2.7945],
        [ 0.8568,  3.8968,  6.1172],
        [ 3.2014,  1.1721,  3.8250],
        [ 1.3073,  1.6968,  1.9341],
        [ 1.6212,  1.8517,  2.9488]]) tensor([ 4,  8,  8, 15, 12, 15,  0, 15])
tensor([[ 4.1809,  4.4161,  3.7074],
        [ 3.7186,  1.5644,  4.0132],
        [ 2.0271,  1.4304,  2.8930],
        [-2.6890,  3.4182,  5.8884],
        [ 1.3461,  1.6278,  1.4771],
        [-3.4603,  5.2708,  6.2710],
        [ 1.5244,  5.0167,  5.3005],
        [ 1.3231,  9.6751,  5.3445]]) tensor([15, 15, 15,  3, 15, 11, 12,  9])
tensor([[ 2.1016,  2.8494,  5.6889],
        [ 3.9486,  7.9679, -0.5151],
        [-0.1410,  8.9522,  5.2732],
        [ 2.7281,  4.7698,  5.3711],
        [ 3.0253,  4.6095,  3.4979],
        [ 3.1964,  1.2144,  3.8572],
        [ 2.0737,  7.8454,  3.0200],
        [ 4.0024,  4.4259,  5.7737]]) tensor([12, 13, 11, 12, 15,

tensor([[-0.4056,  0.9822,  6.3262],
        [ 3.4361,  7.7223,  5.0519],
        [-0.3386,  8.3753,  4.8282],
        [-2.9006,  4.3372,  5.7174],
        [ 0.6629,  8.9287,  2.6351],
        [ 2.8908,  4.1100,  2.1827],
        [ 3.5057,  1.8032,  2.6987],
        [ 3.4279,  7.7467,  7.1652]]) tensor([ 4,  8, 11,  3, 11, 15, 15,  9])
tensor([[-3.4381,  7.3438,  2.9865],
        [ 6.7194,  8.2262,  3.5712],
        [ 1.2180,  1.6823,  2.1192],
        [-1.9095,  7.9776,  3.4721],
        [ 5.9363,  6.8269,  2.8984],
        [ 2.4597,  3.0380,  5.8927],
        [ 2.3008, 11.4478,  5.1041],
        [-0.4418,  1.0514,  6.0491]]) tensor([11,  8,  0, 11,  8, 12,  9,  4])
tensor([[ 2.7383,  4.9689,  2.7573],
        [-1.7543,  1.7612,  6.1520],
        [ 1.0551,  2.4855,  1.6327],
        [-0.6524,  8.4837,  3.8285],
        [ 3.1205,  7.1921,  3.6810],
        [ 2.6173,  1.9775,  3.9231],
        [ 0.9088,  3.7852,  6.4967],
        [ 7.7204,  6.0891,  4.5216]]) tensor([14,  4, 15, 11, 10,

tensor([[-4.2678,  4.8486,  5.3790],
        [ 4.8398,  4.2108,  4.7539],
        [-0.2009,  8.7382,  5.3677],
        [-0.3456,  4.4446,  4.8456],
        [ 6.8689,  6.4790,  3.8786],
        [ 1.5544,  2.6818,  5.7839],
        [ 0.0521,  9.2261,  4.8274],
        [-2.7821,  3.7824,  5.6115]]) tensor([11, 12, 11, 12,  8, 12, 11,  3])
tensor([[-2.4877,  3.7396,  6.1292],
        [ 2.1788,  1.3115,  2.7951],
        [ 1.7957,  7.9248,  2.5053],
        [ 1.0068,  7.8471,  2.4459],
        [-3.4206,  4.9202,  3.1454],
        [ 1.7569,  3.0067,  6.8892],
        [ 5.9865,  7.0658,  5.0002],
        [-2.5318,  3.2266,  5.9365]]) tensor([ 3, 15, 10, 11,  0, 12,  8,  3])
tensor([[ 0.7963, 11.4779,  5.3363],
        [-2.8115,  3.1550,  6.2192],
        [ 3.3741,  1.4373,  3.6485],
        [ 5.6382,  6.5192,  5.0585],
        [ 3.3229,  7.2075,  6.8914],
        [-2.1822,  3.3700,  5.9678],
        [ 7.9419,  0.8051,  7.3330],
        [ 3.4278,  1.2157,  3.6478]]) tensor([ 9,  3,  0,  8,  9,

tensor([[-0.7753,  9.0522,  3.7319],
        [-0.2993,  1.0095,  6.5668],
        [-3.4309,  7.2440,  3.7791],
        [ 4.8595,  3.0695,  4.2853],
        [ 4.5794,  9.0477,  3.8832],
        [ 0.9203,  3.7185,  6.6308],
        [ 1.0715,  7.8302,  2.4011],
        [-1.7194,  8.6774,  2.5449]]) tensor([11,  4, 11, 12,  8, 12, 11, 11])
tensor([[ 2.4486,  3.6057,  1.1274],
        [ 0.1088,  4.5438,  4.8784],
        [ 4.9139,  2.8374,  4.2969],
        [-3.6582,  5.1556,  6.2920],
        [-1.6466,  8.6313,  2.5602],
        [ 5.0933,  4.1558,  5.4571],
        [ 1.9689,  4.0900,  5.5763],
        [-0.0360,  4.0454,  4.7404]]) tensor([15, 12, 12, 11, 11,  0, 12, 12])
tensor([[ 1.4681,  3.5624,  6.9563],
        [ 0.8352, 11.5274,  5.9853],
        [ 1.5603,  2.7729,  6.3631],
        [ 2.3113,  8.4643,  2.0605],
        [-2.9555,  4.8984,  2.2113],
        [-1.6517,  8.0099,  3.3333],
        [ 0.3540,  9.2726,  4.0357],
        [ 1.4074,  8.0686, -0.6747]]) tensor([12,  9, 12,  9, 11,

        [ 3.5289,  6.7647,  2.7798]]) tensor([14, 12,  8,  8,  4,  1,  8, 14])
tensor([[ 1.1974,  1.8583,  1.5243],
        [-1.0739,  1.5283,  5.9988],
        [ 2.7347,  2.0905,  3.6399],
        [ 4.4195,  4.2221,  5.8639],
        [ 3.3739,  7.0799,  6.6716],
        [ 1.6930,  3.2483,  6.9683],
        [ 1.3309,  1.6274,  1.5002],
        [-2.8597,  3.5316,  5.9285]]) tensor([15,  4, 15, 12,  9, 12, 15,  0])
tensor([[ 0.7901,  3.9810,  5.8631],
        [ 2.7750,  2.0416,  3.8244],
        [ 2.3836,  0.3135,  6.6771],
        [ 3.6674,  7.4889,  3.2138],
        [ 2.6755,  2.8045,  1.3333],
        [ 7.1224,  0.5922,  7.4050],
        [ 0.0307,  3.3968,  6.2175],
        [-2.6019,  2.9728,  6.4060]]) tensor([12,  0,  4, 14, 15,  1, 12,  3])
tensor([[ 1.5925,  2.6158,  1.8610],
        [ 6.5297,  6.6517, -0.2940],
        [ 7.0952,  5.3123,  0.1688],
        [ 8.2116,  0.5887,  7.0483],
        [-0.1592,  4.0009,  6.3633],
        [-3.4564,  7.3296,  2.7495],
        [ 7.9412,  0.82

tensor([[-2.6168,  3.5495,  5.7414],
        [-0.2867,  4.6106,  4.5287],
        [ 2.5358,  7.4762,  4.4212],
        [ 2.5060,  8.7310,  5.3637],
        [ 4.2818,  7.5413,  2.7098],
        [ 3.5503,  6.9557,  2.6696],
        [-2.6052,  2.8215,  6.2527],
        [ 0.4670,  9.2794,  3.3891]]) tensor([ 3, 12,  8,  9, 14, 14,  3, 11])
tensor([[ 1.9243,  7.5080,  3.0408],
        [ 4.0912,  4.2909,  3.5491],
        [ 1.4864,  0.7120,  6.0805],
        [ 8.2116,  6.4040,  1.4544],
        [ 3.9932,  8.7090,  0.6773],
        [ 2.6244, 10.3134,  4.5692],
        [ 1.3619,  3.5828,  6.9508],
        [ 1.4863,  9.4668,  1.9632]]) tensor([10, 15,  4,  8,  8,  9, 12,  9])
tensor([[-2.6571,  3.6654,  5.6486],
        [ 4.2129,  4.2153,  3.3506],
        [ 3.2101,  2.8579,  4.6979],
        [ 4.2888,  4.3640,  4.7375],
        [ 3.3263,  4.4097,  3.6565],
        [ 6.8291,  7.9002,  3.7835],
        [-3.0948,  4.9234,  1.6215],
        [-4.4158,  5.4295,  2.3260]]) tensor([ 0, 15, 15, 12, 15,

tensor([[ 3.6759e+00,  6.6284e+00,  3.4904e+00],
        [ 8.0278e+00,  5.8577e+00,  3.0918e+00],
        [ 6.3356e+00,  7.0573e+00,  4.4382e-01],
        [ 1.6549e+00,  7.3215e-01,  6.2947e+00],
        [-1.0628e+00,  1.4265e+00,  6.3414e+00],
        [-8.5233e-03,  9.1893e+00,  4.9875e+00],
        [ 3.4831e+00,  1.8102e+00,  2.7012e+00],
        [ 3.6547e+00,  7.3409e+00,  3.8670e+00]]) tensor([14,  8,  8,  4,  4, 11,  0,  5])
tensor([[-2.7472e+00,  3.6053e+00,  5.6268e+00],
        [ 1.1979e+00,  5.0089e+00,  5.6727e+00],
        [ 8.6889e-01,  4.4240e+00,  4.9226e+00],
        [ 6.5827e+00,  4.9665e+00,  5.2966e-03],
        [ 2.0969e+00,  8.7125e+00,  3.5446e+00],
        [-1.0332e+00,  1.4741e+00,  6.4913e+00],
        [ 1.2509e+00,  2.6722e+00,  1.8367e+00],
        [-4.1787e+00,  4.9323e+00,  5.6986e+00]]) tensor([ 3, 12,  0,  8,  0,  4, 15, 11])
tensor([[ 2.6934,  1.9817,  4.0767],
        [ 5.5782,  6.4127,  5.0229],
        [ 4.0172,  5.8265,  4.0782],
        [ 3.8923,  4.

tensor([[ 4.9275,  4.1779,  4.7594],
        [ 1.2215, 10.8776,  5.2317],
        [-0.4301,  3.0782,  4.7549],
        [-3.8639,  4.7414,  3.7577],
        [ 3.5576,  7.4430,  3.3960],
        [ 2.6536,  2.0898,  3.6672],
        [ 2.6913,  1.9161,  3.9458],
        [ 3.7659,  4.4785,  5.8390]]) tensor([12,  9, 12, 11, 14, 15, 15, 12])
tensor([[ 2.6183,  8.7051,  5.3909],
        [ 2.3285,  8.4616,  2.0439],
        [ 0.7483,  3.8166,  5.4844],
        [ 3.6109,  1.4836,  2.5665],
        [ 0.5612,  0.4190,  5.9018],
        [ 2.4605, 10.9523,  4.6353],
        [-3.4623,  7.0871,  2.6429],
        [ 1.9833,  1.5453,  2.8775]]) tensor([ 0,  9, 12, 15,  0,  9, 11, 15])
tensor([[ 2.1774,  2.5964,  1.6517],
        [ 0.8174, -0.2100,  5.3450],
        [ 3.6344,  6.5865,  3.4789],
        [ 4.1337,  4.4150,  3.6802],
        [ 2.6018,  8.8763,  5.1789],
        [-0.9860,  1.5335,  6.5730],
        [ 3.1408,  5.6672,  2.7354],
        [ 3.1053,  6.5298,  4.4137]]) tensor([15,  4, 14, 15,  9,

For clarity, let's define a data generation method that simply returns the train and test split from our gene expression dataset

In [16]:
def data_gen():
    t = GeneExpressionData(
        filename=os.path.join(here, '../data/processed/umap/primary_reduction_neighbors_100_components_3.csv'),
        labelname=os.path.join(here, 'fixed_primary_labels_neighbors_50_components_50_clust_size_100.csv')
    )

    train_size = int(0.8 * len(t))
    test_size = len(t) - train_size

    train, test = torch.utils.data.random_split(t, [train_size, test_size])
    
    return train, test

Now that we've defined our `DataLoader`, let's test it when training a simple Neural Network

## Using PyTorch Lightning

PyTorch lightning seems nicer than Ignite, especially for GPU training. Let's test it out 

In [17]:
from torchmetrics import Accuracy, ConfusionMatrix
from sklearn.utils.class_weight import compute_class_weight

class GeneClassifier(pl.LightningModule):
    def __init__(self, N_features, N_labels, weights, config):
        """
        Initialize the gene classifier neural network

        Parameters:
        N_features: Number of features in the inpute matrix 
        N_labels: Number of classes 
        """
        
        self.train_data, self.test_data = data_gen()
        
        # Params for optimizer 
        self.lr = config['lr']
        self.momentum = config['momentum']
        self.weight_decay = config['weight_decay']
        
        super(GeneClassifier, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(N_features, 512),
            nn.ReLU(),
            nn.Linear(512, N_labels),
        )
        
        self.accuracy = Accuracy()
        self.confusion = ConfusionMatrix(N_labels)
        self.weights = weights

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(), 
            lr=self.lr, 
            momentum=self.momentum,
            weight_decay=self.weight_decay
            
        )
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y, weight=self.weights)
        acc = self.accuracy(y_hat.softmax(dim=-1), y)
        matrix = self.confusion(y_hat.softmax(dim=-1), y)
        
        self.log("train_loss", loss, on_step=False, on_epoch=True, logger=True)
        self.log("train_accuracy", acc, on_step=False, on_epoch=True, logger=True)
        self.log("train_confusion_mat", matrix, on_step=False, on_epoch=True, logger=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        val_loss = F.cross_entropy(y_hat, y, weight=self.weights)
        acc = self.accuracy(y_hat.softmax(dim=-1), y)
        matrix = self.confusion(y_hat.softmax(dim=-1), y)
        
        self.log("val_loss", val_loss, on_step=False, on_epoch=True, logger=True)
        self.log("val_accuracy", acc, on_step=False, on_epoch=True, logger=True)
        self.log("val_confusion_mat", matrix, on_step=False, on_epoch=True, logger=True)
        return val_loss
    
    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=8, num_workers=0)

    def val_dataloader(self):
        return DataLoader(self.test_data, batch_size=8, num_workers=0)


In [18]:
from sklearn.utils.class_weight import compute_class_weight

def class_weights(label_df):
    label_df = pd.read_csv(label_df)
    
    weights = compute_class_weight(
        class_weight='balanced', 
        classes=np.unique(label_df), 
        y=label_df.values.reshape(-1)
    )

    weights = torch.from_numpy(weights)
    return weights.float()

weights = class_weights('fixed_primary_labels_neighbors_50_components_50_clust_size_100.csv')

Now let's set up RayTune

In [19]:
class UploadCallback(pl.callbacks.Callback):
    def __init__(self, path, WIDTH, LAYERS) -> None:
        super().__init__()
        self.path = path 
        self.width = WIDTH
        self.layers = LAYERS

    def on_train_epoch_end(self, trainer, pl_module):
        epoch = trainer.current_epoch
        trainer.save_checkpoint(f'checkpoints/checkpoint-{epoch}-width-{self.width}-layers-{self.layers}.ckpt')
        print(os.listdir('checkpoints'))
        print ('Uploading file...')

uploadcallback = UploadCallback('checkpoints', 10, 10)

Now let's set up RayTune

In [20]:
from ray.tune.integration.pytorch_lightning import TuneReportCallback
import ray.tune as tune 

raytunecallback = TuneReportCallback(
    {
        "loss": "val_loss", 
        "mean_accuracy": "val_accuracy"
    },
    on="validation_end"
)

In [23]:
def train_with_tune(config, max_epochs):
    model = GeneClassifier(t.num_features(), t.num_labels(), weights, config)
    
    trainer = pl.Trainer(
        max_epochs=max_epochs, 
        callbacks=[
            raytunecallback,
            uploadcallback,
        ]
    )
    
    trainer.fit(model)
    
def model_search(num_samples=10, max_epochs=10):
    config = {
        "lr" : tune.loguniform(1e-4, 1e-1),
        "momentum" : tune.loguniform(0.1, 0.8),
        "weight_decay" : tune.uniform(1e-4, 1e-1)
    }

    scheduler = ASHAScheduler(
        max_t=max_epochs,
        grace_period=1,
        reduction_factor=2)

    reporter = CLIReporter(
        parameter_columns = ["lr", "momentum", "weight_decay"],
        metric_columns=["loss", "mean_accuracy", "training_iteration"])

    train_fn_with_parameters = tune.with_parameters(train_with_tune, max_epochs=max_epochs)
                                            
    resources_per_trial = {"cpu": 1}

    analysis = tune.run(train_fn_with_parameters,
        resources_per_trial=resources_per_trial,
        metric="loss",
        mode="min",
        config=config,
        num_samples=num_samples,
        scheduler=scheduler,
        progress_reporter=reporter,
        name="model_search"
    )

    print("Best hyperparameters found were: ", analysis.best_config)

model_search()

== Status ==
Current time: 2021-12-15 15:56:45 (running for 00:00:00.16)
Memory usage on this node: 14.7/16.0 GiB: ***LOW MEMORY*** less than 10% of the memory on this node is available for use. This can cause unexpected crashes. Consider reducing the memory used by your application or reducing the Ray object store size by setting `object_store_memory` when calling `ray.init`.
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 8.000: None | Iter 4.000: None | Iter 2.000: None | Iter 1.000: None
Resources requested: 0/10 CPUs, 0/0 GPUs, 0.0/8.53 GiB heap, 0.0/4.26 GiB objects
Result logdir: /Users/julian/ray_results/model_search
Number of trials: 10/10 (10 PENDING)
+-----------------------------+----------+-------+-------------+------------+----------------+
| Trial name                  | status   | loc   |          lr |   momentum |   weight_decay |
|-----------------------------+----------+-------+-------------+------------+----------------|
| train_with_tune_a8cac_00000 | PENDING  | 

[2m[36m(ImplicitFunc pid=41022)[0m GPU available: False, used: False
[2m[36m(ImplicitFunc pid=41022)[0m TPU available: False, using: 0 TPU cores
[2m[36m(ImplicitFunc pid=41022)[0m IPU available: False, using: 0 IPUs
[2m[36m(ImplicitFunc pid=41022)[0m   rank_zero_deprecation(
[2m[36m(ImplicitFunc pid=41022)[0m 
[2m[36m(ImplicitFunc pid=41022)[0m   | Name              | Type            | Params
[2m[36m(ImplicitFunc pid=41022)[0m ------------------------------------------------------
[2m[36m(ImplicitFunc pid=41022)[0m 0 | flatten           | Flatten         | 0     
[2m[36m(ImplicitFunc pid=41022)[0m 1 | linear_relu_stack | Sequential      | 10.3 K
[2m[36m(ImplicitFunc pid=41022)[0m 2 | accuracy          | Accuracy        | 0     
[2m[36m(ImplicitFunc pid=41022)[0m 3 | confusion         | ConfusionMatrix | 0     
[2m[36m(ImplicitFunc pid=41022)[0m ------------------------------------------------------
[2m[36m(ImplicitFunc pid=41022)[0m 10.3 K    Trai

[2m[36m(ImplicitFunc pid=41017)[0m   rank_zero_warn(
[2m[36m(ImplicitFunc pid=41017)[0m 2021-12-15 15:56:48,815	ERROR function_runner.py:268 -- Runner Thread raised error.
[2m[36m(ImplicitFunc pid=41017)[0m Traceback (most recent call last):
[2m[36m(ImplicitFunc pid=41017)[0m   File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/ray/tune/function_runner.py", line 262, in run
[2m[36m(ImplicitFunc pid=41017)[0m     self._entrypoint()
[2m[36m(ImplicitFunc pid=41017)[0m   File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/ray/tune/function_runner.py", line 330, in entrypoint
[2m[36m(ImplicitFunc pid=41017)[0m     return self._trainable_func(self.config, self._status_reporter,
[2m[36m(ImplicitFunc pid=41017)[0m   File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/ray/util/tracing/tracing_helper.py", line 451, in _resume_span
[2m[36m(ImplicitFunc pid=41017)[0m     return met

[2m[36m(ImplicitFunc pid=41022)[0m   File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 139, in on_run_end
[2m[36m(ImplicitFunc pid=41022)[0m     self._on_evaluation_end()
[2m[36m(ImplicitFunc pid=41022)[0m   File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 201, in _on_evaluation_end
[2m[36m(ImplicitFunc pid=41022)[0m     self.trainer.call_hook("on_validation_end", *args, **kwargs)
[2m[36m(ImplicitFunc pid=41022)[0m   File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1490, in call_hook
[2m[36m(ImplicitFunc pid=41022)[0m     callback_fx(*args, **kwargs)
[2m[36m(ImplicitFunc pid=41022)[0m   File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/pytorch_lightning/trainer/callback_hook

2021-12-15 15:56:48,994	ERROR trial_runner.py:958 -- Trial train_with_tune_a8cac_00007: Error processing event.
Traceback (most recent call last):
  File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/ray/tune/trial_runner.py", line 924, in _process_trial
    results = self.trial_executor.fetch_result(trial)
  File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/ray/tune/ray_trial_executor.py", line 787, in fetch_result
    result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)
  File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/ray/_private/client_mode_hook.py", line 105, in wrapper
    return func(*args, **kwargs)
  File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/ray/worker.py", line 1713, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(TuneError): [36mray::ImplicitFunc.train_buffered()[39m (pid=41015, ip=127.0.0.1, repr=<ray.

2021-12-15 15:56:49,006	ERROR trial_runner.py:958 -- Trial train_with_tune_a8cac_00006: Error processing event.
Traceback (most recent call last):
  File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/ray/tune/trial_runner.py", line 924, in _process_trial
    results = self.trial_executor.fetch_result(trial)
  File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/ray/tune/ray_trial_executor.py", line 787, in fetch_result
    result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)
  File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/ray/_private/client_mode_hook.py", line 105, in wrapper
    return func(*args, **kwargs)
  File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/ray/worker.py", line 1713, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(TuneError): [36mray::ImplicitFunc.train_buffered()[39m (pid=41016, ip=127.0.0.1, repr=<ray.

2021-12-15 15:56:49,027	ERROR trial_runner.py:958 -- Trial train_with_tune_a8cac_00004: Error processing event.
Traceback (most recent call last):
  File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/ray/tune/trial_runner.py", line 924, in _process_trial
    results = self.trial_executor.fetch_result(trial)
  File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/ray/tune/ray_trial_executor.py", line 787, in fetch_result
    result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)
  File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/ray/_private/client_mode_hook.py", line 105, in wrapper
    return func(*args, **kwargs)
  File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/ray/worker.py", line 1713, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(TuneError): [36mray::ImplicitFunc.train_buffered()[39m (pid=41021, ip=127.0.0.1, repr=<ray.

[2m[36m(ImplicitFunc pid=41018)[0m Validation sanity check: 0it [00:00, ?it/s]Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]
[2m[36m(ImplicitFunc pid=41016)[0m Validation sanity check: 0it [00:00, ?it/s]Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]
[2m[36m(ImplicitFunc pid=41022)[0m Validation sanity check: 0it [00:00, ?it/s]Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]
[2m[36m(ImplicitFunc pid=41015)[0m Validation sanity check: 0it [00:00, ?it/s]Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]Validation sanity check:  50%|█████     | 1/2 [00:00<00:00,  7.96it/s]
[2m[36m(ImplicitFunc pid=41017)[0m Validation sanity check: 0it [00:00, ?it/s]Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]Validation sanity check:  50%|█████     | 1/2 [00:00<00:00,  7.46it/s]
[2m[36m(ImplicitFunc pid=41014)[0m Validation sanity check: 0it [00:00, ?it/s]Validation sanity check:   0%|          | 0/2

2021-12-15 15:56:49,043	ERROR trial_runner.py:958 -- Trial train_with_tune_a8cac_00002: Error processing event.
Traceback (most recent call last):
  File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/ray/tune/trial_runner.py", line 924, in _process_trial
    results = self.trial_executor.fetch_result(trial)
  File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/ray/tune/ray_trial_executor.py", line 787, in fetch_result
    result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)
  File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/ray/_private/client_mode_hook.py", line 105, in wrapper
    return func(*args, **kwargs)
  File "/Users/julian/miniconda3/envs/base-data-science/lib/python3.9/site-packages/ray/worker.py", line 1713, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(TuneError): [36mray::ImplicitFunc.train_buffered()[39m (pid=41020, ip=127.0.0.1, repr=<ray.



Result for train_with_tune_a8cac_00002:
  date: 2021-12-15_15-56-48
  experiment_id: bbd8789f6724464d95bc856f64c3f828
  hostname: MacBook-Pro.local
  node_ip: 127.0.0.1
  pid: 41020
  timestamp: 1639612608
  trial_id: a8cac_00002
  
Result for train_with_tune_a8cac_00000:
  date: 2021-12-15_15-56-48
  experiment_id: 4ad9589641f84c6589b47615d88292b6
  hostname: MacBook-Pro.local
  node_ip: 127.0.0.1
  pid: 41019
  timestamp: 1639612608
  trial_id: a8cac_00000
  
== Status ==
Current time: 2021-12-15 15:56:49 (running for 00:00:03.43)
Memory usage on this node: 14.5/16.0 GiB: ***LOW MEMORY*** less than 10% of the memory on this node is available for use. This can cause unexpected crashes. Consider reducing the memory used by your application or reducing the Ray object store size by setting `object_store_memory` when calling `ray.init`.
Using AsyncHyperBand: num_stopped=0
Bracket: Iter 8.000: None | Iter 4.000: None | Iter 2.000: None | Iter 1.000: None
Resources requested: 0/10 CPUs, 0/0

TuneError: ('Trials did not complete', [train_with_tune_a8cac_00000, train_with_tune_a8cac_00001, train_with_tune_a8cac_00002, train_with_tune_a8cac_00003, train_with_tune_a8cac_00004, train_with_tune_a8cac_00005, train_with_tune_a8cac_00006, train_with_tune_a8cac_00007, train_with_tune_a8cac_00008, train_with_tune_a8cac_00009])

In [None]:
t[0]

In [None]:
t = torch.from_numpy(a)
l = torch.from_numpy(l)
t.softmax(dim=-1)

In [None]:
t.softmax(dim=-1)