## Prepare your dataset

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import tqdm

def load_mnist_data(root_path='./data', batch_size=4):
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5))]
    )

    trainset = torchvision.datasets.MNIST(root=root_path, train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

    testset = torchvision.datasets.MNIST(root=root_path, train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

    return trainloader, testloader

## Building your neural network

In [2]:
import numpy as np
from typing import Any, Callable, Tuple

##################################
# For matrices or arbitrary size #
##################################
class MyWeightTensor:
    def __init__(self, shape: Tuple or int, init_weight_fn: Callable = np.random.randn, init_weights: 'MyWeightTensor' or np.ndarray or int or float = None):
        assert isinstance(shape, tuple) or isinstance(shape, int) or isinstance(shape, float), f'Allowed shapes: tuple, int, float, got: {type(shape)}'
        self.shape = shape

        if init_weights is not None:
            if isinstance(init_weights, MyWeightTensor):
                self.values = init_weights.values
            else:
                if isinstance(shape, tuple):
                    assert isinstance(init_weights, np.ndarray)
                else:
                    assert isinstance(init_weights, int) or isinstance(init_weights, float)
                
                self.values = init_weights
        else:
            if isinstance(shape, int):
                self.shape = (self.shape,)
                self.values = init_weight_fn(shape)
            else:
                self.values = init_weight_fn(*shape)
    
    @property
    def T(self) -> 'MyWeightTensor':
        _T = self.values.T
        return MyWeightTensor(shape=_T.shape, init_weights=_T)
    
    def __add__(self, other) -> 'MyWeightTensor':
        if isinstance(other, MyWeightTensor):
            other = other.values
        else:
            assert isinstance(other, np.ndarray) or isinstance(other, int) or isinstance(other, float)
        
        return MyWeightTensor(shape=self.values.shape, init_weights=self.values + other)

    def __mul__(self, other) -> 'MyWeightTensor':
        if isinstance(other, MyWeightTensor):
            other = other.values
        else:
            assert isinstance(other, np.ndarray) or isinstance(other, int) or isinstance(other, float)
        
        _dot = np.dot(self.values, other)

        return MyWeightTensor(shape=_dot.shape, init_weights=_dot)


###############################
# For creating a linear layer #
###############################
class MyLinearLayer:
    def __init__(self, in_features: int, out_features: int, init_weight_fn: Callable = np.random.randn) -> None:
        self.in_features = in_features
        self.out_features = out_features

        self.weights = MyWeightTensor(shape=(out_features, in_features), init_weight_fn=init_weight_fn)
        self.bias = MyWeightTensor(shape=out_features, init_weight_fn=init_weight_fn)

        self.latest_input = None
        self.latest_output = None

    def __call__(self, tensor: np.ndarray or MyWeightTensor) -> MyWeightTensor:
        self.latest_input = tensor

        bs = -1
        if len(tensor.shape) == 2:
            # batch size included
            bs = tensor.shape[0]
            _w = self.weights * tensor.T
        else:
            _w = self.weights * tensor
        
        _bias = self.bias.values
        if bs != -1:
            _bias = np.tile(_bias, bs).reshape(bs, -1)
        
        self.latest_output = (_w + _bias.T).T

        return MyWeightTensor(shape=self.latest_output.shape, init_weights=self.latest_output)
    
    def derivative(self) -> float:
        assert self.latest_output is not None, 'Cannot calculate grad without a single forward pass.'
        # Linear activation derivation
        return np.ones(shape=self.latest_output.shape)

In [3]:
####################################
# Creating a custom neural network #
####################################

def xavier_normal_init(*shape) -> np.ndarray:
    assert len(shape) <= 2, 'Can only init max 2d tensors'
    fan_in = shape[0]
    if len(shape) == 1:
        fan_out = fan_in
    else:
        fan_out = shape[1]
    gain = 1.0

    std = gain * np.sqrt(2.0 / (fan_in + fan_out))
    return np.random.normal(loc=0.0, scale=std, size=shape)


class MyNeuralNetwork:
    def __init__(self) -> None:
        # init_weight_fn = lambda *shape: np.random.randn(*shape) / 10
        init_weight_fn = lambda *shape: xavier_normal_init(*shape)
        self.layers = [
            MyLinearLayer(in_features=784, out_features=32, init_weight_fn=init_weight_fn),
            MyLinearLayer(in_features=32, out_features=32, init_weight_fn=init_weight_fn),
            MyLinearLayer(in_features=32, out_features=10, init_weight_fn=init_weight_fn)
        ]
    
    def __call__(self, tensor: np.ndarray) -> Any:
        x = tensor
        for layer in self.layers:
            x = layer(x)
        
        return x

## Implement your loss function

In [4]:
import torch.nn.functional as F


def softmax(input: np.ndarray) -> np.ndarray:
    _softmax = np.asarray([np.exp(_in) /np.sum(np.exp(_in), axis=0) for _in in input])

    return _softmax


class CrossEntropyLoss:
    def __init__(self) -> None:
        pass

    def __call__(self, predictions: MyWeightTensor or np.ndarray, targets: MyWeightTensor or np.ndarray) -> np.ndarray:
        """
        Computes cross entropy between targets snd predictions.    
        Returns: List of cross entropy losses (batch-wise)
        """
        if isinstance(predictions, MyWeightTensor):
            predictions = predictions.values
        
        if isinstance(targets, MyWeightTensor):
            targets = targets.values

        assert predictions.shape[0] == targets.shape[0]
        if len(targets.shape) == 2:
            targets = targets.reshape(-1)
        predictions = torch.as_tensor(predictions)
        targets = torch.as_tensor(targets)

        loss = np.array([F.cross_entropy(pred, t).item() for pred, t in zip(predictions, targets)])

        return loss
    
    def derivative(self) -> Callable:
        # y_hat is the prediction
        # y is the target value
        def _derivative(y_hat: MyWeightTensor or np.ndarray, y: MyWeightTensor or np.ndarray) -> np.ndarray:
            if isinstance(y_hat, MyWeightTensor):
                y_hat = y_hat.values
            
            if isinstance(y, MyWeightTensor):
                y = y.values

            _y = np.zeros(shape=y_hat.shape)
            np.put_along_axis(_y, y, 1, axis=-1)

            y_hat = softmax(y_hat)

            return y_hat - _y
        
        return _derivative

## Implement the training loop

In [5]:
def train(model: MyNeuralNetwork, batch_size: int, learning_rate: float, loss_fn: Callable, epochs: int = 10):
    train_loader, _ = load_mnist_data(batch_size=batch_size)

    for epoch in range(epochs):
        running_loss = 0.0
        running_accuracy = []
        for imgs, targets in tqdm.tqdm(train_loader, desc=f'Training iteration {epoch + 1}'):

            # for custom model
            imgs = imgs.numpy()
            targets = targets.numpy()

            if len(targets.shape) == 1:
                targets = targets.reshape(-1, 1)

            imgs = imgs.reshape(-1, 28 * 28)

            imgs = MyWeightTensor(shape=imgs.shape, init_weights=imgs)

            outputs = model(imgs).values

            loss = loss_fn(outputs, targets)
            print(loss)
            avg_loss = np.mean(loss)

            # print statistics
            running_loss += avg_loss

            # Calculate the Accuracy (how many of all samples are correctly classified?)
            max_outputs = np.argmax(outputs, axis=1)
            accuracy = (max_outputs == targets.flatten()).mean()
            running_accuracy.append(accuracy)

            #########################
            # Start backpropagation #
            #########################

            # Your code for backpropagation!


            #######################
            # End backpropagation #
            #######################

        print(f'Epoch {epoch + 1} finished with loss: {running_loss / len(train_loader):.3f} and accuracy: {torch.tensor(running_accuracy).mean():.3f}')

In [6]:
#############################
# Execute the training loop #
#############################
model = MyNeuralNetwork()
batch_size = 4
learning_rate = 0.001
epochs = 10
loss_fn = CrossEntropyLoss()

train(
    model=model,
    batch_size=batch_size,
    learning_rate=learning_rate,
    epochs=epochs,
    loss_fn=loss_fn
)

Training iteration 1:   1%|          | 107/15000 [00:00<00:37, 395.90it/s]

[2.21251621 2.62722012 4.08997173 4.16832853]
[3.75657256 0.74680998 5.84860971 9.24597745]
[8.95500462 5.16231383 3.50316857 4.13994585]
[4.44727052 9.21176537 5.53796637 4.56394185]
[3.89086537 2.32854035 0.99426957 7.06007515]
[0.33984084 8.95206481 2.05298541 3.54978461]
[0.12355097 2.79234607 2.34648956 3.27133829]
[4.21024965 5.49897468 5.10306077 5.92708084]
[5.05114888 3.92546176 2.74594077 7.17290702]
[3.21881353 2.37229205 3.29993042 3.42745593]
[0.85105006 1.5115467  9.71950615 4.0915773 ]
[4.74296926 3.9385491  4.71441908 4.18360347]
[7.20895865 4.75326218 2.66396456 4.85465451]
[8.93353747 3.64126646 6.07202875 4.19425316]
[2.28270079 9.95008363 5.60091488 4.05527051]
[5.01494175 3.68537247 4.93773897 4.38240209]
[4.5369737  6.04379049 3.41607592 3.98014592]
[1.69473025 5.95532904 3.8600434  2.30631495]
[2.08673323 4.43486478 9.15400642 3.24576811]
[6.70341074 4.52607969 5.98445369 7.05138687]
[3.97631878 3.1543811  5.00954738 3.15892566]
[2.92299136 4.66982648 5.17468497 

Training iteration 1:   1%|▏         | 218/15000 [00:00<00:30, 488.96it/s]

[2.26910432 4.07815604 2.563597   3.90034096]
[3.26947245 5.37272111 2.76111586 3.75586252]
[1.55698005 4.89375169 4.45107289 7.41604967]
[3.79804076 3.64155269 2.65599313 5.32627492]
[4.09316159 3.05420396 4.67171884 5.83847909]
[3.88010811 4.8709957  1.71572268 8.61132363]
[4.38492691 5.80100374 7.53653508 2.32972129]
[4.33599947 3.20366044 2.61545847 3.84427427]
[8.25351881 3.5541327  2.51731571 7.25655753]
[1.9032581  6.05507677 3.95437675 3.50238837]
[3.74277014 3.59863881 4.86661413 8.62556604]
[6.07263428 1.58301331 5.93911111 1.43221284]
[8.76192872 4.62529767 4.14396295 9.02668804]
[8.21846204 6.38440598 6.4394113  8.77006912]
[7.1520097  8.68653875 3.27211747 8.77738814]
[2.14952056 7.21071614 2.680151   4.08013269]
[1.28510974 2.34792663 0.93202935 3.79232315]
[4.71251804 4.70144274 0.89447577 5.94415825]
[3.57484359 5.34461987 2.99631321 0.23208882]
[4.12492184 3.88534578 4.83823093 3.12140147]
[5.44143333 4.2934927  3.61165706 3.87591066]
[3.73574595 5.4345527  5.65807546 

Training iteration 1:   2%|▏         | 268/15000 [00:00<00:38, 385.16it/s]

[3.3139425  4.68558679 4.00066335 3.8495636 ]
[3.69957522 1.74666421 6.89617518 5.01241512]
[6.37528486 7.13962644 6.4812986  8.94888796]
[4.28392868 3.79202236 3.83964352 4.53459109]
[4.6027903  4.59129675 2.44761072 1.72888246]
[2.37751882 4.96738991 4.30219247 1.48612567]
[4.68836215 6.15466656 4.76796689 4.02859964]
[3.13745799 4.03328412 9.61788615 2.95750419]
[4.52208984 8.24303932 7.5732103  4.18947323]
[4.54552204 8.65330454 6.35747746 2.92932659]
[8.07286967 5.46412008 5.01884364 5.20800962]
[5.01025873 3.3131886  1.6391706  3.85320005]
[6.04427261 5.27712811 4.31605936 5.95474073]
[4.5123319  1.30101369 5.26664121 4.2988693 ]
[4.83984252 9.19609219 0.97059964 4.83961328]
[2.69170077 4.57046021 2.64472487 4.36090915]
[7.67208192 5.50348518 7.63170312 5.23726214]
[3.09619275 2.91009848 1.1759193  3.10324138]
[7.50999712 1.47845425 2.14015052 2.21854121]
[10.39404181  2.28930482  4.49721893  4.50251747]
[3.56332941 1.83551999 3.19470185 2.76322241]
[4.13050465 3.78101323 7.05326

Training iteration 1:   3%|▎         | 376/15000 [00:00<00:32, 456.14it/s]

[6.05393125 3.28190286 4.52471041 3.45498098]
[7.29200184 5.49795501 5.13908563 2.19999275]
[4.4634497  1.70749637 1.03958174 3.7636406 ]
[9.36265515 2.6508095  8.66258726 7.98340882]
[3.13268579 3.87874287 7.71152973 4.86574739]
[5.42721217 4.67484422 0.5647835  0.66886088]
[3.24503561 2.12682859 1.49526433 4.25914242]
[7.3755665  4.56662371 2.63054735 8.1321871 ]
[7.56296717 3.17526733 6.81743432 4.76215086]
[2.39561647 3.87285443 3.8680285  3.23430281]
[6.54670702 9.28430092 4.380093   0.45855429]
[6.50328876 2.40762353 2.90475574 6.75453146]
[2.15941508 4.81583467 9.27337988 6.1477233 ]
[0.6512545  3.82470376 5.17840771 5.81624135]
[8.7575608  3.82204877 2.58998473 5.18405316]
[1.36936903 4.13432653 4.38487781 3.64731989]
[4.71128022 1.17535454 3.21823003 3.30700487]
[8.02585738 4.70266871 6.90435282 4.41894329]
[7.9129341  3.90555816 8.25778059 4.39766697]
[2.78932391 0.54324064 4.65590855 1.61593991]
[3.74206296 7.55406382 3.93656034 9.0616965 ]
[ 6.68422736 10.44882573  3.237889

Training iteration 1:   3%|▎         | 486/15000 [00:01<00:29, 500.03it/s]

[9.47484849 5.687478   0.81648282 4.82611055]
[5.45717755 5.06429869 4.5296368  4.47051497]
[4.7652929  0.79955814 5.82880363 4.32720338]
[5.11689167 4.37883156 8.63446537 4.57556225]
[0.99982883 2.71919411 3.2978776  7.25239568]
[1.57404821 2.34396683 7.31029931 3.25927064]
[8.73736104 1.50279831 2.89383774 1.70441307]
[7.98902068 2.66957674 2.1363218  0.98320944]
[7.56941758 3.54756755 5.96195012 3.75628589]
[4.00562924 2.02752175 4.00287204 2.8317077 ]
[3.23331837 2.60833159 2.44302352 4.91885652]
[4.47580886 2.95946556 7.57153126 4.95548968]
[9.00530195 0.8273893  1.09568578 3.3190475 ]
[1.41634108 4.3184055  3.59924212 4.71763938]
[4.29506794 0.22023116 0.90555566 3.98858819]
[2.30021246 5.65744697 1.84989374 3.59823247]
[6.18671057 4.06203954 1.08094052 3.21341116]
[2.41915287 4.23558641 2.45756124 1.95291981]
[7.54475019 0.52244259 2.36663656 4.66654769]
[5.0678104  3.64462089 3.48005074 6.01321007]
[5.25110695 5.09164845 2.77904159 5.77737457]
[4.06806236 4.65133046 4.42985656 

Training iteration 1:   4%|▍         | 593/15000 [00:01<00:28, 513.93it/s]

[3.29650109 5.39476016 6.18638385 4.6611573 ]
[5.00449767 7.34887595 8.33064216 6.52897424]
[4.17454247 6.06154717 4.31878669 0.8708592 ]
[4.86854156 2.99272874 6.67452606 0.23614306]
[6.39416957 7.34020803 3.27374049 7.68111519]
[3.3852873  3.48277488 5.73569545 7.61328391]
[3.68132417 2.39737132 0.96137484 3.61691397]
[1.87185501 3.6440367  2.05265407 1.87544676]
[7.79495044 4.76872938 3.72733902 8.65198761]
[3.85383333 1.86836791 7.0463123  0.20574337]
[4.67419183 3.87576903 5.42494081 4.97049049]
[1.2743193  5.67312504 4.29340017 3.88328929]
[4.67100967 1.91253964 8.0086866  5.38545967]
[2.92640339 6.89349495 4.50728151 1.07735497]
[7.73119324 4.11773726 4.27361433 4.29124653]
[5.61426879 8.67974131 4.54850213 7.48391821]
[4.21164128 2.49529365 4.83336982 2.73507532]
[8.38390347 1.15629317 1.59985494 6.617323  ]
[9.51579733 4.36267876 4.91009889 1.73429707]
[4.33394807 8.37128623 2.94483032 2.61294138]
[3.45987418 1.06684801 5.26684413 2.28418579]
[5.24227836 4.26584608 5.57761075 

Training iteration 1:   5%|▍         | 703/15000 [00:01<00:26, 530.37it/s]

[3.24836562 3.63126188 0.82983428 1.75687728]
[3.72742058 6.96575969 4.30939581 7.65996715]
[4.47118736 3.58613289 1.21533072 3.97504491]
[4.2289262  0.75359523 9.81080584 4.84145473]
[0.97917475 1.38155516 5.10372816 2.46968531]
[3.95103611 8.43636647 7.04785807 5.64959621]
[4.75583231 6.71129161 3.18445679 8.00897155]
[2.07504094 0.57178166 3.87412496 0.85629383]
[4.11649414 3.43731649 3.83680885 2.13889982]
[3.45625151 1.69724913 2.5357567  4.03959618]
[2.4754308  4.23306575 4.96532891 2.78131807]
[7.5337309  6.48735558 4.81039575 4.94454927]
[4.14994481 6.81932736 0.72146045 7.64633097]
[3.49065595 3.6830904  2.73010791 3.3425475 ]
[2.39616453 2.77357251 5.10950874 2.7693398 ]
[4.15734606 4.33851813 3.7303762  3.48843316]
[5.20098022 6.12958408 3.36024718 0.44659037]
[7.34021108 1.93683329 3.64733097 6.74954041]
[1.30929735 3.20057567 8.84700196 1.80342487]
[8.5672673  2.30958535 8.62100377 3.22329287]
[4.78966139 5.014947   4.52061403 4.21662224]
[2.63398745 5.49067255 5.22515389 

Training iteration 1:   5%|▌         | 817/15000 [00:01<00:26, 544.35it/s]

[3.9385535  5.19226668 0.61027421 4.17381064]
[3.29925873 3.27136318 4.85917197 3.36398971]
[3.18060143 2.87546156 2.92724524 7.03203326]
[3.92062281 2.98567614 3.61243167 3.67047123]
[5.20499044 3.61019506 5.98938967 1.05825692]
[4.51192396 4.38684533 7.21234103 3.21375718]
[5.7023222  2.08395514 1.7311274  4.57244106]
[4.00291334 2.52513635 1.5997498  3.43226349]
[5.49915794 3.43264563 7.12013058 4.81743044]
[5.78172984 2.26655723 2.99724958 3.28995935]
[5.30563014 3.75452293 3.49009999 4.2060421 ]
[ 5.83600188 10.08720016  4.86155085  5.11959921]
[3.69206111 5.78197504 6.80593411 4.00052641]
[5.41187541 0.2142362  6.47735217 1.85564611]
[2.39766532 8.11521336 4.16435385 2.53499   ]
[3.91314071 6.1169789  8.84913131 8.41980248]
[4.62846149 6.8055009  7.09579725 3.82901734]
[5.52819263 3.75754675 4.0587133  1.67562907]
[1.07071305 4.15998742 3.00581465 5.19812552]
[7.34586135 0.12106106 2.52623125 8.02811781]
[7.99910619 6.20119589 2.53921744 5.75092482]
[0.46493848 4.08819263 1.03277

Training iteration 1:   6%|▌         | 931/15000 [00:01<00:25, 547.73it/s]

[3.1934279  6.04631262 0.73115298 7.82479558]
[5.02044644 8.57372704 5.51429084 3.16485129]
[3.77622561 4.81123753 2.43035183 5.69330315]
[4.55392738 1.00832269 4.46571668 1.91889836]
[9.2838475  3.97052636 5.73008542 3.63870305]
[3.05408014 4.44298381 3.06264631 7.83205088]
[2.9232287  3.83961054 2.39879848 4.72978359]
[2.99161513 6.91127448 4.96971863 4.03034703]
[3.30877053 2.28786467 4.05924243 1.45123257]
[2.02037463 5.76110686 7.16382198 4.13412149]
[3.83000869 4.20280115 7.13080177 4.26751994]
[4.11450407 7.65962868 5.3345135  4.06782458]
[7.11231839 3.93162129 0.92839695 3.23636134]
[7.61844683 8.31849807 1.69379448 8.16102513]
[2.87208102 4.20047345 2.8610345  3.73367596]
[2.59244066 7.69132228 1.33232143 0.52055053]
[2.96424068 7.24341945 7.41487564 3.12776891]
[3.57552686 5.20479736 4.40092251 5.13510911]
[7.79821802 4.94437163 8.92064997 3.85017329]
[6.66540824 5.22662565 9.34903579 3.27399018]
[7.77377975 8.54324693 2.85036268 3.82071891]
[4.43362642 9.17144658 3.72999076 

Training iteration 1:   7%|▋         | 1043/15000 [00:02<00:25, 549.03it/s]

[1.38737917 0.17940722 5.50986774 6.38387959]
[3.72074847 7.24809063 3.57523194 3.1550629 ]
[4.05052669 1.55645896 2.53251786 4.45235636]
[5.33763121 0.70550807 3.98045864 2.66625647]
[4.93575146 3.98700105 2.0002759  7.39033321]
[2.37416317 0.40138782 1.82373898 7.53666474]
[5.48785644 2.546984   8.23610493 3.50711279]
[5.51943144 4.48156081 7.33724627 3.41386906]
[2.32677405 2.42285892 6.53015631 4.17521191]
[3.93642797 3.4666104  2.1112223  2.91768276]
[4.94425254 4.5083805  5.39098591 3.78804031]
[9.10595149 4.79384804 5.77686892 4.30407517]
[4.35524059 5.29688936 1.8070308  6.83147124]
[1.1604636  9.40603922 1.02881117 1.64269152]
[8.89200562 0.89037712 6.51274904 0.31981852]
[5.44767709 4.55872633 3.67113831 2.14212084]
[4.80548661 3.80948762 2.12616257 4.81251645]
[3.73669774 3.06491451 5.49011547 3.79643557]
[5.69016029 5.04511508 5.48539442 3.50129832]
[2.14770688 5.87134267 2.37134863 2.78146962]
[7.84465416 2.45840968 9.61735278 7.59870803]
[2.23208146 4.26689254 1.43194071 

Training iteration 1:   8%|▊         | 1161/15000 [00:02<00:24, 562.36it/s]

[2.4329474  1.61767134 7.56747933 7.24466792]
[2.93751707 4.50900861 1.96102571 4.18795175]
[1.54108961 2.26617036 1.45379075 1.8906833 ]
[3.30737226 1.1964928  2.82636998 3.57354061]
[0.98690491 4.85982341 5.02562793 4.13980028]
[3.47440496 2.04724639 4.29794771 1.32008869]
[0.16159523 3.99591611 2.83313808 3.1464442 ]
[3.77589696 7.01803901 5.87289992 4.14753313]
[6.07779475 8.74258587 4.83384618 5.18159241]
[3.38300363 3.48520008 3.38144769 4.62734452]
[4.80429192 1.35124838 6.94109593 4.34550576]
[7.88086906 0.11599687 6.32640947 0.29077753]
[3.22012665 5.43644134 1.72549709 2.00398482]
[2.32698753 4.16736498 2.7716084  4.0415995 ]
[3.75749398 7.77976151 4.57869209 3.10278456]
[1.15071035 3.93272447 1.79978316 5.83108345]
[0.60140452 9.18339099 3.51963047 1.67638206]
[1.66263321 4.28306848 5.23763526 9.1330234 ]
[5.79334704 3.84776541 0.95304516 4.09296626]
[5.20623633 2.18646586 4.76230495 4.19104202]
[5.04210805 4.62734963 0.82943197 2.32683295]
[2.06852212 4.95357941 3.58036292 

Training iteration 1:   9%|▊         | 1277/15000 [00:02<00:24, 564.38it/s]

[3.56510844 5.27964905 5.62877022 4.03049091]
[3.80763383 3.78117491 2.00017678 2.7408976 ]
[4.55265659 1.78339683 6.33052609 2.8693337 ]
[3.26373646 3.51251499 8.38914297 3.14852476]
[8.2663416  5.72312285 8.17485498 0.26390557]
[4.50542138 0.56511513 3.11587846 5.4266737 ]
[6.15355772 0.95425973 2.71648739 3.91366806]
[6.50050769 1.32516833 2.09475797 8.36288741]
[7.72265421 4.57017419 3.82442347 2.68365026]
[3.29799274 3.93837938 1.85648679 5.24490178]
[4.9402026  3.09526566 1.04452074 4.19220458]
[4.25122346 6.41150913 3.20283083 4.39417998]
[3.6644873  3.62740071 8.79395955 0.97903981]
[4.09154027 3.97733075 4.1820729  1.61955983]
[3.90046924 6.53261744 3.86465953 3.33282306]
[4.10605823 4.17803398 5.08804936 3.73267027]
[0.79285538 2.5669779  2.51699909 5.38487785]
[2.41258483 6.43878534 3.66906837 4.6472426 ]
[3.31416253 3.17353322 5.54499005 6.3985122 ]
[3.93479406 2.85633027 6.50933914 3.40874336]
[2.41857946 3.25850424 0.58390997 3.92269809]
[3.27677097 2.24454654 3.9978886  

Training iteration 1:   9%|▉         | 1396/15000 [00:02<00:23, 570.05it/s]

[0.70933564 3.33169368 7.15130616 3.99668977]
[4.35452314 3.55867111 3.36203013 2.29784443]
[4.37764246 4.47701816 3.67136581 3.7592878 ]
[1.30611147 6.70503752 5.45858659 4.34581809]
[3.84849309 5.38943223 0.39701125 1.53597905]
[0.73338343 4.77325581 1.24367302 4.72373853]
[4.07387522 6.21206101 2.17204669 9.22281219]
[2.04271136 7.49293679 7.8305863  2.3905288 ]
[1.08895998 2.82849488 3.46088408 2.60723229]
[4.66044145 7.85518243 4.57265957 0.7126975 ]
[8.68802678 5.21215785 5.33189297 1.03353615]
[3.63772677 3.97560194 5.54897027 4.39281287]
[2.88541342 8.72636997 5.76642498 1.016535  ]
[3.40614316 3.95518546 6.56629527 2.97736664]
[3.63341504 2.11904634 2.02259888 4.97949385]
[3.52397125 4.50103502 5.03024145 0.28499072]
[6.97998058 1.24191444 2.99889784 8.09937627]
[3.66495318 4.41839033 3.21747497 4.30881195]
[2.91372078 1.60910614 5.2659122  4.66847458]
[4.26001235 4.59203507 6.6833914  2.71628841]
[3.78649804 9.3214437  4.61026495 0.4014023 ]
[4.19835012 4.56890806 5.86053058 

Training iteration 1:  10%|▉         | 1454/15000 [00:02<00:23, 570.47it/s]

[5.12636232 4.08673241 3.01037346 3.15489144]
[3.84372688 2.8709469  3.47701154 4.7798888 ]
[2.76192876 7.64434418 6.43373509 6.43976539]
[1.17654313 1.99618873 4.4339907  6.00880222]
[ 4.08946501  6.59823955 10.04730783  4.57690172]
[7.75024315 5.17298588 8.51701096 5.63533857]
[3.95447291 2.93164683 3.18773287 4.11181758]
[2.90855683 3.11255563 4.44841554 3.49336262]
[4.81310215 1.80849773 7.12386089 1.34688963]
[7.9042068  4.43305154 6.38327347 5.09182498]
[4.7213599  4.45023661 2.49687094 4.83405998]
[1.93945948 4.83816757 3.70987273 2.01769169]
[5.45878168 2.28931475 4.52343262 2.01888542]
[2.6968496  3.07656297 4.05243771 1.37784633]
[3.75036374 4.278986   2.96557254 0.89163062]
[7.98754378 2.69207451 2.58250298 3.44990211]
[0.55315949 9.22400086 3.39987529 0.64878503]
[4.26756975 4.46465628 4.61069477 3.06642958]
[8.38807272 5.50860147 4.15647236 4.63165663]
[2.51601678 7.83891944 1.46127225 4.51841767]
[3.52335751 4.81822287 7.06381954 5.56157364]
[6.75469978 2.44867741 5.06771

Training iteration 1:  10%|█         | 1568/15000 [00:03<00:24, 540.26it/s]

[2.05100072 2.75309306 1.11628055 4.25520089]
[3.86718402 3.12528835 3.39493858 4.52595907]
[4.30840978 6.14607173 3.40618286 3.4614646 ]
[0.26880472 5.51814601 4.41119473 2.60223825]
[4.09105157 4.58092987 3.17745895 4.37076443]
[3.85284505 2.78969028 4.51817919 2.19305231]
[6.72816149 5.55712702 4.58168079 7.61825714]
[6.20817796 7.60058906 0.53280139 3.99298145]
[7.11509055 3.80777824 4.50592678 6.32158244]
[3.95629962 5.61982587 3.48630036 7.70191084]
[0.6450108  3.92658331 2.26841273 5.57897069]
[1.59168711 2.03917724 7.32463005 3.94213373]
[3.1992904  6.16243487 1.59130833 5.6047745 ]
[4.98084854 4.23896712 3.46777358 2.19118525]
[5.11839527 0.9105001  5.46212764 0.30969865]
[5.15849689 8.26814003 5.28852022 4.32852163]
[6.90742249 0.48828676 4.65797324 7.91946599]
[10.19213359  7.62842129  0.50315333  2.25123495]
[4.24391722 2.73037495 4.00192774 6.55630495]
[4.13927321 0.73112085 4.39798715 4.07223041]
[4.8920368  3.46587947 9.03652027 4.96339646]
[5.18593694 4.6821675  4.84933

Training iteration 1:  11%|█         | 1679/15000 [00:03<00:24, 545.44it/s]

[5.92825675 6.39254854 8.09535934 3.32881102]
[8.03127974 2.04833572 9.70065854 1.95700464]
[4.20594809 5.35929316 5.38353902 4.71398381]
[6.64440485 2.92597747 0.27575175 1.14282941]
[3.26012874 4.16336718 4.70084218 5.01542279]
[9.42490285 0.9241601  4.76240687 7.98984642]
[6.15105191 3.64760962 2.14480185 1.57949472]
[4.08532774 2.10462808 2.82596547 4.5851207 ]
[4.79806002 0.14842298 7.91845694 1.91257204]
[5.24038168 4.75506765 2.65308853 2.39117202]
[3.27477841 3.83697073 3.11431069 1.34932739]
[4.18121529 0.82702104 3.58949399 1.33105907]
[0.41271391 2.65493642 2.24981725 4.05397809]
[4.86193844 4.0015303  0.21537085 3.57671747]
[4.22195997 4.67007746 5.66536979 3.88287345]
[4.46913642 8.35617146 5.73946007 5.25670169]
[4.19965703 2.95960758 5.51313248 3.78263142]
[5.59915353 3.17408191 3.26580643 6.46495497]
[4.15463478 4.34209871 7.82753577 2.25987487]
[3.16779706 2.81314551 5.04758857 1.74289127]
[4.18417997 8.82230068 2.89467926 5.37275092]
[5.5435664  1.29327293 5.33640969 

Training iteration 1:  12%|█▏        | 1798/15000 [00:03<00:23, 565.89it/s]

[2.10230686 4.17832607 4.566599   5.16619896]
[3.52140473 8.26394525 0.4088267  5.13972462]
[3.9901964  2.74285047 2.2254836  2.99731982]
[2.20521226 1.3327018  2.22000033 4.5574543 ]
[3.82330824 2.60497871 4.66793127 2.58649082]
[5.39867107 2.29642344 1.05293949 3.16804815]
[6.04548878 8.3710477  3.14575257 4.0871408 ]
[4.80358393 5.16951541 4.06433086 6.63034077]
[6.14070843 5.48159438 4.85905513 4.75797221]
[4.07193325 6.80062319 3.21937388 4.233742  ]
[3.44491841 3.01353061 2.92614909 4.08534401]
[2.7907189  0.99302612 3.42565365 7.58510864]
[7.22112715 3.51581195 3.80691845 3.9379878 ]
[0.57774953 5.29855234 3.84239371 3.93813504]
[2.36281344 1.05257188 3.32485295 4.43834701]
[4.41364415 4.95373685 4.41407361 1.12191689]
[2.04212082 6.96158383 2.29888847 3.32767875]
[4.07168932 0.9340173  4.94958198 1.49647824]
[3.46063697 4.30942646 1.99116378 2.28770238]
[1.30166785 5.35361097 2.64248001 4.37900898]
[2.7445006  2.89121969 1.17244488 3.7542872 ]
[3.74258251 4.48523303 2.23347809 

Training iteration 1:  13%|█▎        | 1915/15000 [00:03<00:23, 560.09it/s]

[6.86097975 3.91339027 2.62062787 3.37914717]
[4.50308406 4.26924618 5.50188259 1.04613752]
[7.12857172 1.96834546 5.24593012 4.68955783]
[4.82406856 4.5483493  3.72622446 0.46470657]
[3.83755451 5.49937905 4.10058321 4.97810965]
[5.12297468 4.41613885 3.64424259 7.78553738]
[4.59352593 2.08903692 5.70014198 3.85670806]
[ 2.90318664  4.08463645 10.81742198  5.50157136]
[4.14538854 3.89051743 8.67319237 4.55481174]
[0.58284924 5.02533416 1.1692976  5.83626842]
[3.97469739 5.99482999 5.95029911 2.95295783]
[7.56863294 4.5538467  6.84729045 7.93499308]
[4.46031083 8.43675746 5.11405435 4.48218717]
[3.07417293 2.96512632 1.97229885 1.74112276]
[8.48929125 1.84756974 4.28089552 9.66648838]
[5.78955353 4.58106376 3.74062477 8.27240684]
[3.41916117 5.7564116  1.77408339 0.95578405]
[8.77605974 4.52821367 0.43522888 4.38829189]
[4.36905888 2.14504156 3.21780782 1.68551355]
[2.59873268 1.6349955  9.11061808 5.47230408]
[1.41548112 4.34536703 1.35227864 3.26709015]
[7.61810568 3.25731178 5.28085

Training iteration 1:  14%|█▎        | 2030/15000 [00:03<00:23, 558.04it/s]

[1.18145671 8.44400778 4.22892993 5.88543354]
[4.32745578 3.37838304 0.47300831 3.9625816 ]
[4.96325223 2.97464255 3.03472353 6.47046763]
[4.07294064 7.13946052 7.22313955 8.30887597]
[4.06003731 4.89684068 7.0656927  1.53268632]
[5.5807977  7.23431676 4.86973944 0.31144577]
[2.67225303 6.12896014 2.57282131 5.36177825]
[2.64445773 4.50048189 2.1607105  0.50646755]
[4.90567527 4.5378055  1.53257646 5.09604274]
[4.3222666  2.46068281 3.10091505 4.30631581]
[2.63351117 5.97737212 8.01550359 8.93124392]
[5.03669276 1.12978016 5.96124702 4.85644271]
[2.67756556 3.86402509 8.62767019 5.83006016]
[1.4491192  3.97939262 9.68128254 7.18027719]
[8.22304018 6.19894395 7.23131005 2.56284689]
[2.0584445  1.03022619 3.372358   3.92238448]
[4.94000594 2.20244988 2.40182923 2.94583187]
[4.18227769 1.60355914 1.74646348 3.5417724 ]
[5.29372845 4.28432344 2.56470106 8.06700942]
[5.21538096 4.71811372 7.23987799 4.3412239 ]
[3.59894244 3.67791096 1.79554948 0.91325464]
[3.22149363 4.70674657 3.15081475 

Training iteration 1:  14%|█▍        | 2144/15000 [00:04<00:23, 556.59it/s]

[2.77488858 5.8929041  4.61347619 5.34011555]
[3.79091074 2.85598535 4.41556386 4.59667262]
[3.65592658 1.2229065  8.08717007 2.42722995]
[7.41784858 8.69556984 0.44194533 7.71811984]
[3.50466095 2.84725958 3.24106747 3.16157608]
[2.75611717 1.7738527  0.45453612 3.66361645]
[3.03547204 3.58077767 7.93565067 5.8564675 ]
[5.00762771 3.15315918 4.01933816 3.52975912]
[1.81035188 3.87155804 3.20066989 0.6142667 ]
[4.60019919 3.01664195 8.60470845 6.18749987]
[6.52815197 4.61766322 8.09626669 7.37290611]
[3.12639909 6.84814433 0.58764474 5.04318226]
[5.56807675 8.10195397 5.12681572 1.70597688]
[5.94004075 3.23106202 2.90202064 3.57382784]
[3.80554814 4.59162574 6.4897377  4.28490892]
[1.7783471  4.21812213 3.48241524 5.58033018]
[6.98919525 3.10006978 0.06389094 4.2459455 ]
[4.12675805 7.38210616 3.39451464 6.3286554 ]
[3.20752363 3.8562918  8.43117261 5.61527489]
[8.74412004 8.77072911 2.47881072 3.64447468]
[2.15848698 8.10164248 3.58387172 9.30587182]
[9.04226638 4.67054202 4.54585566 

Training iteration 1:  15%|█▌        | 2260/15000 [00:04<00:22, 563.25it/s]

[9.6980901  6.11687519 5.46722652 4.50859043]
[2.92723423 3.48055549 3.45626881 3.54353585]
[2.8933511  3.69176142 2.17989299 9.0453128 ]
[2.6096289  3.64735391 8.45432889 2.69457389]
[1.64923324 3.29416988 1.95513262 3.69925387]
[3.88565392 2.56364559 3.20065412 2.4996425 ]
[2.8286932  3.88358283 6.61834044 1.61469036]
[5.19964324 5.69882301 0.55174127 1.72344787]
[3.78519926 0.89203239 8.57123824 1.32303129]
[1.99643827 0.86013305 5.48416389 1.60935879]
[1.66896727 7.14634878 3.86497544 1.92924826]
[4.06596913 0.82451971 8.4278735  4.53827777]
[11.12161102  4.36127855  8.56380493  2.62945735]
[7.86490833 6.15001545 1.04623519 1.79948948]
[7.23030009 4.76815812 1.29914253 5.49060964]
[2.52665644 4.76757543 4.90346866 5.41355435]
[1.49280012 3.60791068 3.29302015 5.09768902]
[5.05520125 8.48810165 8.95878363 4.49120326]
[7.50997178 5.95450173 8.78448258 0.95701577]
[1.28510562 4.63518054 3.98615206 4.01832579]
[5.40379672 4.72460475 3.52687725 7.85139144]
[2.75200095 2.78399406 4.60567

Training iteration 1:  16%|█▌        | 2375/15000 [00:04<00:22, 556.27it/s]

[1.8530379  0.74416412 4.34117145 5.61020336]
[3.16981988 0.22969982 2.90646639 9.72563873]
[0.72516211 2.891008   3.87323003 3.09530453]
[1.47010296 3.41457834 6.37275241 6.09655785]
[1.56927634 5.20818766 4.68550976 7.83526047]
[1.49309502 4.45955075 3.68665302 4.77936487]
[2.2338974  6.71683185 5.71576461 3.41294447]
[4.80104672 3.45600421 3.76877085 8.25893261]
[4.62643634 5.30335799 3.50956934 3.692296  ]
[8.46226861 9.65682335 4.76642186 9.7483547 ]
[4.49472309 3.33987013 7.62498374 5.47009111]
[4.82803696 8.05515866 3.71915773 3.28651203]
[2.39471041 3.60281445 4.59556031 4.31388346]
[8.67193897 6.17062694 0.70184399 4.11662131]
[6.39517247 8.77130912 4.70026795 5.22199807]
[3.06280999 1.47873211 9.23009453 3.47961714]
[2.06594366 4.18366242 4.42288831 1.3691451 ]
[0.24153328 5.54002382 0.20823705 1.46098162]
[9.22933825 3.64889129 4.60335716 3.27821796]
[4.32984033 3.3506792  4.20398643 2.66412271]
[3.99003433 2.321478   7.58315123 5.12436711]
[2.9780665  2.07625184 7.41725407 

Training iteration 1:  17%|█▋        | 2487/15000 [00:04<00:22, 552.33it/s]

[5.19329429 3.60513905 4.70088199 2.62440301]
[7.01800997 4.50501599 5.74287338 3.38908242]
[4.66542113 7.66691539 2.30787519 2.05559486]
[1.88602686 5.95420847 3.02524654 4.18187238]
[4.72365153 5.23153473 2.53537501 7.30059078]
[3.93724323 4.65119129 1.43443628 8.67765523]
[0.88000756 5.6648378  3.22192755 8.43848196]
[2.98118862 2.05288619 3.55454493 3.27567218]
[3.10669645 7.28893791 2.57340778 0.7010645 ]
[8.01954084 3.78167546 5.06916419 7.83429087]
[4.66278919 3.97162126 3.16026257 7.26005616]
[3.69569252 4.84136655 5.87251794 5.25191061]
[3.648812   3.4443452  4.95727398 5.31859775]
[3.34144865 1.77919088 8.62904112 8.10070104]
[1.33119756 5.24890036 3.66962161 6.0584593 ]
[3.16137216 2.59130872 3.35312714 1.60677625]
[3.91314117 4.30146733 1.31065379 9.14444244]
[0.77619567 1.39417425 2.10344803 5.75777979]
[2.9093227  4.30533428 4.33616658 2.04699151]
[2.42021796 2.32802332 8.38830771 4.42425454]
[1.7136748  2.33729189 6.47006665 1.76664624]
[3.82050585 8.18192314 4.68423103 

Training iteration 1:  17%|█▋        | 2603/15000 [00:04<00:22, 563.14it/s]

[0.11761024 2.75890909 4.14363704 5.03240573]
[7.70608721 3.62511715 5.31105139 3.02922897]
[9.84199806 3.5066139  9.46252239 4.33639013]
[2.3163824  7.44907636 2.19750843 3.37201189]
[2.48412049 2.79270507 5.07856522 9.28665896]
[1.87144237 0.50636993 1.36081825 2.98379483]
[7.48100516 4.08072955 5.25689709 1.81914531]
[0.35850554 7.46173452 8.53551406 4.40977122]
[8.33842371 5.38232442 4.28924182 3.17139564]
[3.63921133 3.90518951 3.84478242 1.10152819]
[4.55944054 7.9778919  5.38983265 0.14807559]
[3.53359482 1.47013488 2.49824794 7.43475317]
[5.91460424 4.23209896 4.10702657 7.16817836]
[0.30469309 2.874828   2.56241376 4.95660097]
[2.55428051 3.1601409  5.15730599 2.4922512 ]
[3.26886772 4.0765023  5.2727065  3.37275906]
[8.09722646 3.70456127 0.52605685 3.04772076]
[3.83144695 0.34738144 3.75401514 3.80958745]
[1.19899536 8.8621914  2.78029089 4.42623738]
[3.19752302 2.77130425 5.78633226 7.55511236]
[6.18645532 7.84039808 4.98716642 2.88829079]
[5.138581   1.14307296 8.11026368 

Training iteration 1:  18%|█▊        | 2716/15000 [00:05<00:22, 552.58it/s]

[4.30115508 4.15736992 2.27862657 8.72072853]
[7.1078144  3.92487085 4.23441371 1.65144986]
[5.57803509 3.72167839 0.88913455 1.95706794]
[0.89316385 4.283183   3.25403113 2.03022547]
[4.21286329 9.01108842 3.05787245 5.60213891]
[2.77960488 4.7930557  3.61495602 1.07162025]
[4.11714507 2.68909405 8.31548139 9.30307586]
[4.80641059 6.28518813 6.76248756 4.73164383]
[1.90276495 4.70121359 2.94663342 1.56936451]
[3.02197068 1.56099128 7.10817099 4.98719666]
[4.40482229 1.99112862 4.26540084 3.97657573]
[6.86799081 2.22575912 4.72777758 2.87927985]
[3.38071123 9.22641649 6.38895414 2.26741601]
[1.49194324 1.44626177 2.32724568 1.97914843]
[5.96807221 5.25102686 6.23981439 2.68522233]
[0.33860697 6.31968947 4.71528352 2.92949904]
[1.32344998 5.23043537 3.6587626  8.51585746]
[2.83858645 4.94361466 9.97514564 4.18437642]
[0.16839711 7.86732351 2.05184355 3.60985364]
[4.55022679 1.6424562  3.68217114 3.11344986]
[0.98909487 2.11779362 3.58800635 1.96088329]
[6.34648272 2.37430344 5.1562836  

Training iteration 1:  19%|█▉        | 2829/15000 [00:05<00:21, 557.19it/s]

[2.66564284 7.39581525 2.4096775  2.16181444]
[3.72702426 2.07558522 4.50597264 5.19557135]
[3.49703628 3.91840353 6.05094402 3.86019079]
[2.659509   3.07470665 4.3405511  4.34811885]
[3.48804856 1.01109751 2.14941658 7.7526345 ]
[5.46306869 4.78117476 3.81106788 4.7123642 ]
[9.48324265 2.15209575 0.34151605 3.27278738]
[2.06504098 2.27363008 3.51923925 3.1297423 ]
[4.90301677 8.60392076 8.5776334  1.17069225]
[5.95162747 0.84445257 3.66658835 3.29013204]
[5.06453647 2.4819355  8.07574001 4.88159047]
[3.83280051 7.16223527 2.42871787 5.08308848]
[0.17547611 3.04570341 1.75908995 7.62298893]
[4.3988236  1.99495681 3.11961674 6.10785808]
[4.87709657 4.28217807 4.63941393 9.67189401]
[2.06355731 9.89387649 3.9319345  4.94488326]
[7.59046076 2.21354096 3.02197514 1.91730249]
[8.31138223 7.25119088 4.03189062 3.93416342]
[4.91488434 4.23016118 5.06612064 2.68426695]
[1.6320144  8.27562527 3.70584895 3.84163849]
[7.60548467 3.40644215 8.54297426 4.84445787]
[1.67935693 8.5473832  3.88152465 

Training iteration 1:  20%|█▉        | 2943/15000 [00:05<00:21, 557.39it/s]

[3.88280848 5.20408951 7.62961454 3.52928768]
[3.53374417 5.78158403 3.53667307 6.87608251]
[4.19201945 2.63093939 4.4230765  9.28860912]
[6.63990318 3.8188101  8.13047119 6.60264749]
[6.61079096 4.74959346 1.51159292 5.37737346]
[6.00750485 3.98913313 3.90478918 3.21272934]
[4.97584802 5.89482294 3.51550399 4.15252656]
[4.47203618 1.89774046 7.68934477 3.86357328]
[3.80570932 8.75100857 4.07955971 2.180545  ]
[4.23826405 5.08801532 3.87763174 5.4385926 ]
[2.79365034 3.85593108 3.52108038 0.8431629 ]
[2.79039922 7.53839215 2.48444053 5.93199594]
[4.94839993 3.18943821 7.93546942 1.84918012]
[8.04086083 8.70388228 4.61412121 1.94040412]
[4.48600181 2.67840839 6.92665654 6.80783101]
[3.71862937 5.27297638 4.0371744  5.53433772]
[3.88581819 8.02105731 5.66675317 3.79851261]
[3.98712913 1.60638756 3.09929518 0.62523534]
[2.65318375 4.55444179 3.80710778 4.96984855]
[7.28857059 3.30617325 4.13408869 2.74166028]
[2.53742679 4.34837958 1.11492512 2.22334566]
[7.02448679 5.36914672 5.28574472 

Training iteration 1:  20%|██        | 3057/15000 [00:05<00:21, 543.50it/s]

[3.92665927 4.55485626 2.69943598 4.53327119]
[4.61523337 4.52056152 5.7946005  4.39081812]
[5.30520802 0.9233127  4.58968969 8.36876379]
[3.28207804 1.24952673 2.37355272 3.15356175]
[4.91654699 7.92449874 4.92349483 4.94827598]
[3.64347714 3.18221871 3.69353187 4.68216188]
[4.27648053 8.50356775 5.06073853 4.2715515 ]
[4.03528578 2.01292241 3.82743106 3.63018246]
[2.60774666 4.87834644 4.07241088 1.01672694]
[5.02579169 6.25349807 7.59378384 1.54809537]
[3.07926514 2.78163579 0.59523528 4.88703138]
[5.73954062 1.33982472 3.33237056 6.88749094]
[8.79034657 4.45862151 3.08145347 9.05463912]
[8.24905002 2.74873322 8.06642428 4.64498076]
[3.75339304 7.22524847 2.80565182 5.25943772]
[7.07749054 6.71212756 2.92837765 2.99175525]
[5.60596169 2.22331907 8.10978128 3.8894538 ]
[1.68812701 0.79496591 5.05689608 8.95028642]
[5.69849496 5.29828275 5.55201622 9.26470512]
[3.87917961 5.78902888 5.24813841 3.49880111]
[1.0482388  5.16192544 5.77524877 5.06481208]
[3.24675519 2.77286959 2.89091919 

Training iteration 1:  21%|██        | 3168/15000 [00:05<00:21, 543.06it/s]

[7.98905503 0.62681987 4.63169997 2.79730155]
[4.13314005 3.55985683 3.68246163 5.4463017 ]
[1.40934434 4.62622644 4.58128908 8.37516104]
[4.4700921  5.43205715 2.56624088 5.31439648]
[3.99647284 5.56862036 2.13411263 4.30745421]
[1.62919838 5.06086521 2.64753768 4.95421732]
[3.47361335 1.72618661 5.47928454 3.92395163]
[3.87491251 3.51490993 0.4996723  4.12452237]
[3.88141048 5.91282822 4.25121546 9.07364451]
[7.8059091  2.11395582 3.59598085 8.29447277]
[0.63250904 3.89849688 4.34447461 7.35148053]
[5.10620683 2.54055899 4.01116897 6.42789022]
[4.77329655 2.13882767 1.95797316 5.58518748]
[2.6469706  0.97139547 8.9131016  5.59919241]
[2.31605544 4.60498752 7.25058311 8.03293721]
[3.8300681  5.82622572 1.78180447 5.40399069]
[4.54524172 2.66372088 6.74278486 3.38449416]
[4.42458287 0.41808745 3.71326482 3.60417247]
[7.51886667 0.91171207 5.66755602 4.21010029]
[3.93498265 3.72592855 7.24375861 8.96033445]
[3.05083652 3.84597786 2.57167695 4.78324576]
[3.98765264 4.17172459 2.90710717 

Training iteration 1:  22%|██▏       | 3279/15000 [00:06<00:21, 547.92it/s]

[8.48508355 3.12999514 4.62122381 0.84129993]
[2.29118627 3.48810554 6.88031677 3.28502123]
[8.24628263 4.19078007 2.99925932 8.45864816]
[2.42103146 8.39943404 4.07446422 4.94821672]
[4.05613774 4.25685795 5.47014561 3.30214272]
[1.09642995 5.84264654 3.99211324 3.74339359]
[5.95537319 2.82223313 6.98272503 5.16019981]
[3.11209082 4.44267115 0.04933355 2.14496226]
[5.07956292 7.30940889 3.50636763 0.8499178 ]
[3.37886291 7.8883697  2.41290187 1.86266939]
[3.75094351 2.17186974 7.10009861 5.10912296]
[4.28406783 3.10594379 7.30887168 7.94516348]
[1.99113962 4.59058757 6.17691802 1.58015517]
[8.41630665 4.65339658 1.577569   2.78594993]
[5.22874492 3.40417883 4.98659514 3.86590953]
[1.49923601 5.99293978 8.80146648 4.35075395]
[5.22867043 2.64469491 4.42623826 3.64191574]
[1.5878628  4.38337153 3.51924468 4.82326081]
[7.26059379 2.42153743 4.3617092  1.52284014]
[4.16783838 7.93799444 8.32658711 3.9473849 ]
[3.32423593 5.2485079  9.14667143 3.5698741 ]
[2.00428826 6.22245919 6.06995782 

Training iteration 1:  23%|██▎       | 3396/15000 [00:06<00:20, 564.82it/s]

[6.68608485 1.58781634 2.23775911 4.24940271]
[4.97161492 2.53081098 1.52196741 6.30935487]
[3.81842927 4.98960003 3.80602679 5.18506173]
[3.30428318 5.68905269 5.16278699 4.05898174]
[8.30725922 1.90663779 8.98513917 7.81795105]
[1.21152077 2.24457771 4.05863587 2.20948755]
[2.01526097 5.49943412 7.83919713 1.38667324]
[1.18381828 1.51057247 2.07215685 6.19066129]
[1.06082419 4.39740247 3.60650432 2.89637374]
[8.78924487 1.7480939  4.20389393 7.75150242]
[3.47326687 2.8355084  5.55933892 6.14778047]
[3.92437378 9.06000215 3.41233568 4.80637185]
[4.60071339 1.30462013 5.12898439 0.97854582]
[3.45499308 2.75540323 4.13348038 3.41121786]
[2.3574422  5.75284168 4.66582498 3.35746284]
[3.3161384  3.39123185 2.03120257 3.84499378]
[3.0809447  5.62883883 3.05891409 4.60849226]
[3.60615106 0.63188469 4.16205724 3.52320578]
[7.18843039 3.71388356 2.64334639 1.57759799]
[2.89127082 3.72439122 7.52118547 2.70208554]
[5.26304412 3.28199349 1.81484554 3.64985193]
[2.35577867 2.99282246 2.27868618 

Training iteration 1:  23%|██▎       | 3510/15000 [00:06<00:20, 556.07it/s]

[4.41507253 2.42215754 5.01464674 7.13948591]
[7.06772113 4.47290738 2.68691923 3.58302841]
[7.80729421 7.14706478 1.98938894 0.16077207]
[3.94699416 5.45203953 3.37272716 1.35445834]
[8.2884849  5.54951475 0.49848065 5.40278463]
[4.01903423 5.56686438 3.94647171 6.33795902]
[1.35626469 7.13508379 3.44638422 9.14162605]
[3.5695903  2.6758116  2.86155763 9.17014014]
[1.9692067  8.9589238  3.27591018 4.53254502]
[8.87131249 6.3385215  3.82177339 1.73090873]
[4.39195626 5.726367   5.65253185 4.63498303]
[2.34194259 4.43341307 4.79639751 3.69735473]
[0.16272046 2.40258032 4.67943162 6.51166702]
[7.24641005 3.06253045 5.0624649  0.25747572]
[7.53470775 0.13475951 4.86385985 2.64922827]
[1.35565293 3.93278087 4.81905851 3.34535344]
[7.92194098 0.47965197 4.00798231 4.35547142]
[6.10187706 0.3524869  4.50575573 8.04978435]
[7.78200847 4.22547197 5.86611617 2.80495536]
[4.7950549  3.44342506 4.16154998 5.51479083]
[4.54279091 5.80448611 5.41184416 4.9221105 ]
[3.47856994 4.37695567 1.5488045  

Training iteration 1:  24%|██▍       | 3580/15000 [00:06<00:21, 533.80it/s]


[4.73547931 5.45702755 4.94186187 2.9311208 ]
[6.63777632 8.28396956 2.22856349 4.09215197]
[10.30482271  4.93193517  3.88564936  7.79688743]
[3.4351642  4.74684503 0.2259261  4.53043769]
[2.60663856 2.84742334 3.24149638 0.58788315]
[4.84554685 6.89132218 4.54685508 4.9600875 ]
[7.82849022 0.0681875  2.22044134 1.42418136]
[5.29536751 8.11960468 3.73088005 2.01884039]
[4.83762695 3.10545201 5.1328328  7.83125698]
[0.30728727 4.90426386 4.33455115 3.60816826]
[4.96071813 4.04671089 5.76946656 6.7154566 ]
[8.34517362 4.74012812 2.81596426 0.44480438]
[1.91566731 7.58986454 3.9742142  2.5942749 ]
[4.71791072 7.45447026 2.69547999 5.42725181]
[2.72884379 3.31291266 5.50094469 1.93693967]
[2.69717565 6.58967348 5.23742153 1.9165484 ]
[4.53533286 3.03925727 0.05449929 2.05135375]
[5.0296304  3.59639098 6.46566403 4.75225858]
[7.44537784 1.56497581 3.22206139 4.62951573]
[5.0976964  4.39120504 5.28929285 1.77930625]
[6.65857933 4.26235051 4.22618318 3.58842874]
[3.49170367 4.75373938 3.98453

KeyboardInterrupt: 