In [1]:
import os

import numpy as np
import torch
import torchvision.datasets as dset
import torch.nn as nn
import torchvision.transforms as transforms

import pyro
import pyro.distributions as dist
import pyro.contrib.examples.util  # patches torchvision
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

In [2]:
#assert pyro.__version__.startswith('1.3.0')
pyro.enable_validation(True)
pyro.distributions.enable_validation(False)
pyro.set_rng_seed(0)
# Enable smoke test - run the notebook cells on CI.
#smoke_test = 'CI' in os.environ

In [3]:
def setup_data_loaders(batch_size=128, use_cuda=False):
    root = './data'
    download = True
    trans = transforms.ToTensor()
    train_set = dset.MNIST(root=root, train=True, transform=trans,
                           download=download)
    test_set = dset.MNIST(root=root, train=False, transform=trans)

    kwargs = {'num_workers': 1, 'pin_memory': use_cuda}
    train_loader = torch.utils.data.DataLoader(dataset=train_set,
        batch_size=batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(dataset=test_set,
        batch_size=batch_size, shuffle=False, **kwargs)
    return train_loader, test_loader

In [4]:
class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        # setup the two linear transformations used
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, 784)
        # setup the non-linearities
        self.softplus = nn.Softplus()
        self.sigmoid = nn.Sigmoid()

    def forward(self, z):
        # define the forward computation on the latent z
        # first compute the hidden units
        hidden = self.softplus(self.fc1(z))
        # return the parameter for the output Bernoulli
        # each is of size batch_size x 784
        loc_img = self.sigmoid(self.fc21(hidden))
        return loc_img

In [5]:
class Encoder(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super().__init__()
        # setup the three linear transformations used
        self.fc1 = nn.Linear(784, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)
        # setup the non-linearities
        self.softplus = nn.Softplus()

    def forward(self, x):
        # define the forward computation on the image x
        # first shape the mini-batch to have pixels in the rightmost dimension
        x = x.reshape(-1, 784)
        # then compute the hidden units
        hidden = self.softplus(self.fc1(x))
        # then return a mean vector and a (positive) square root covariance
        # each of size batch_size x z_dim
        z_loc = self.fc21(hidden)
        z_scale = torch.exp(self.fc22(hidden))
        return z_loc, z_scale

In [6]:
# define the model p(x|z)p(z)
def model(self, x):
    # register PyTorch module `decoder` with Pyro
    pyro.module("decoder", self.decoder)
    with pyro.plate("data", x.shape[0]):
        # setup hyperparameters for prior p(z)
        z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
        # z loc torch.Size([256, 50])
        z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
        # sample from prior (value will be sampled by guide when computing the ELBO)
        z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
        # we sample a Z from a (0, I) normal distribution
        # then we pass it though a nn
        # mu = nn(z)
        # then this mu is used in another dist
        # p(x|z) where z is samples
        # then we sample an x from this
        # the idea is, this nn function learns a distribution
        # that is, what would it be like to sample z from P(z|X)
        # 
        # z shape torch.Size([256, 50])
        # decode the latent code z
        loc_img = self.decoder.forward(z)
        #loc img torch.Size([256, 784])
        # score against actual images
        # bern shape Independent(Bernoulli(probs: torch.Size([256, 784])), 1)
        # 784 is the batch size
        # 256 is the image size
        pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784))

In [7]:
# define the guide (i.e. variational distribution) q(z|x)
def guide(self, x):
    # register PyTorch module `encoder` with Pyro
    pyro.module("encoder", self.encoder)
    with pyro.plate("data", x.shape[0]):
        # use the encoder to get the parameters used to define q(z|x)
        z_loc, z_scale = self.encoder.forward(x)
        # p(z,b) = q(b)mult(i=1 to i=N)q(zi|f(xi))
        
        # given an image, we output a distribution for z
        # then we sample a z. because the guide always gives the
        # approximate posterior, the variational inference
        # sample the latent code z
        pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))

In [8]:
class VAE(nn.Module):
    # by default our latent space is 50-dimensional
    # and we use 400 hidden units
    def __init__(self, z_dim=50, hidden_dim=400, use_cuda=False):
        super().__init__()
        # create the encoder and decoder networks
        self.encoder = Encoder(z_dim, hidden_dim)
        self.decoder = Decoder(z_dim, hidden_dim)

        if use_cuda:
            # calling cuda() here will put all the parameters of
            # the encoder and decoder networks into gpu memory
            self.cuda()
        self.use_cuda = use_cuda
        self.z_dim = z_dim

    # define the model p(x|z)p(z)
    
    
    def model(self, x):
        # register PyTorch module `decoder` with Pyro
        pyro.module("decoder", self.decoder)
        with pyro.plate("data", x.shape[0]):
            # setup hyperparameters for prior p(z)
            z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
            z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
            # sample from prior (value will be sampled by guide when computing the ELBO)
            z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
            # decode the latent code z
            loc_img = self.decoder.forward(z)
            # score against actual images
            # decoder is where the image goes 
            pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 784))

    # define the guide (i.e. variational distribution) q(z|x)
    def guide(self, x):
        # register PyTorch module `encoder` with Pyro
        pyro.module("encoder", self.encoder)
        with pyro.plate("data", x.shape[0]):
            # use the encoder to get the parameters used to define q(z|x)
            z_loc, z_scale = self.encoder.forward(x)
            # sample the latent code z
            pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))

    # define a helper function for reconstructing images
    def reconstruct_img(self, x):
        # encode image x
        z_loc, z_scale = self.encoder(x)
        # sample in latent space
        z = dist.Normal(z_loc, z_scale).sample()
        # decode the image (note we don't sample in image space)
        loc_img = self.decoder(z)
        return loc_img

In [9]:
vae = VAE()

optimizer = Adam({"lr": 1.0e-3})

svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())

def train(svi, train_loader, use_cuda=False):
    # initialize loss accumulator
    epoch_loss = 0.
    # do a training epoch over each mini-batch x returned
    # by the data loader
    for x, _ in train_loader:
        # if on GPU put mini-batch into CUDA memory
        if use_cuda:
            x = x.cuda()
        # do ELBO gradient and accumulate loss
        epoch_loss += svi.step(x)

    # return epoch loss
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = epoch_loss / normalizer_train
    return total_epoch_loss_train

def evaluate(svi, test_loader, use_cuda=False):
    # initialize loss accumulator
    test_loss = 0.
    # compute the loss over the entire test set
    for x, _ in test_loader:
        # if on GPU put mini-batch into CUDA memory
        if use_cuda:
            x = x.cuda()
        # compute ELBO estimate and accumulate loss
        test_loss += svi.evaluate_loss(x)
    normalizer_test = len(test_loader.dataset)
    total_epoch_loss_test = test_loss / normalizer_test
    return total_epoch_loss_test

In [10]:
LEARNING_RATE = 1.0e-3
USE_CUDA = False
smoke_test = False

# Run only for a single iteration for testing
NUM_EPOCHS = 100 if smoke_test else 100
TEST_FREQUENCY = 5
train_loader, test_loader = setup_data_loaders(batch_size=256, use_cuda=USE_CUDA)

# clear param store
pyro.clear_param_store()

# setup the VAE
vae = VAE(use_cuda=USE_CUDA)

# setup the optimizer
adam_args = {"lr": LEARNING_RATE}
optimizer = Adam(adam_args)

# setup the inference algorithm
svi = SVI(vae.model, vae.guide, optimizer, loss=Trace_ELBO())

train_elbo = []
test_elbo = []
# training loop
for epoch in range(NUM_EPOCHS):
    total_epoch_loss_train = train(svi, train_loader, use_cuda=USE_CUDA)
    train_elbo.append(-total_epoch_loss_train)
    print("[epoch %03d]  average training loss: %.4f" % (epoch, total_epoch_loss_train))

    if epoch % TEST_FREQUENCY == 0:
        # report test diagnostics
        total_epoch_loss_test = evaluate(svi, test_loader, use_cuda=USE_CUDA)
        test_elbo.append(-total_epoch_loss_test)
        print("[epoch %03d] average test loss: %.4f" % (epoch, total_epoch_loss_test))

[epoch 000]  average training loss: 191.0216
[epoch 000] average test loss: 156.0872
[epoch 001]  average training loss: 146.8141
[epoch 002]  average training loss: 133.2540
[epoch 003]  average training loss: 124.6775
[epoch 004]  average training loss: 119.5152
[epoch 005]  average training loss: 116.1240
[epoch 005] average test loss: 113.7908
[epoch 006]  average training loss: 113.7285
[epoch 007]  average training loss: 112.0445
[epoch 008]  average training loss: 110.7292
[epoch 009]  average training loss: 109.7455
[epoch 010]  average training loss: 108.9070
[epoch 010] average test loss: 107.7720
[epoch 011]  average training loss: 108.2513
[epoch 012]  average training loss: 107.6953
[epoch 013]  average training loss: 107.2849
[epoch 014]  average training loss: 106.8870
[epoch 015]  average training loss: 106.4983
[epoch 015] average test loss: 105.9786
[epoch 016]  average training loss: 106.1872
[epoch 017]  average training loss: 105.9363
[epoch 018]  average training 

In [16]:
from pyro.nn import PyroSample, PyroModule
from pyro.distributions import Normal, Categorical

class ClassifierBnn(PyroModule):
    
    def __init__(self, num_in = 100, num_hidden = 200, num_out = 10, prior_std = 1.):
        
        # call to father constructor
        super().__init__()
        
        # define prior
        prior = Normal(0, prior_std)
        
        # Define layers
        
        # linear layer 1
        self.linear_layer = PyroModule[torch.nn.Linear](num_in, num_hidden)
        
        # linear alyer parameters as random variables
        self.linear_layer.weights = PyroSample(prior.expand([num_hidden, num_in]).to_event(2))
        self.linear_layer.bias = PyroSample(prior.expand([num_hidden]).to_event(1))
        
        # linear layer 2
        # output dimension is 3 because of the number of classes
        self.output_layer = PyroModule[torch.nn.Linear](num_hidden, num_out)
        
        # linear alyer parameters as random variables
        self.output_layer.weights = PyroSample(prior.expand([num_out, num_hidden]).to_event(2))
        self.output_layer.bias = PyroSample(prior.expand([num_out]).to_event(1))
        
        # activation function
        #self.activation = torch.nn.functional.softmax()
        
    def forward(self, x, y = None):
            
        # latent variable
        z = self.linear_layer(x)
        z = self.output_layer(z)
        z = torch.nn.functional.log_softmax(z, dim=1)
        # likelihood
        with pyro.plate("data",size = x.shape[0], dim = -1):
            # I think this means each batch is independent            
            # z is the input to the distribution (categorical)
            obs = pyro.sample("obs", Categorical(logits = z), obs=y)
        # return latent variable
        return z

In [17]:
# validate NN

pyro.enable_validation(True)

model = ClassifierBnn()
x, y = next(iter(train_loader))
z_loc, z_scale = vae.encoder(x)
combined_z = torch.cat((z_loc, z_scale), 1)


print(pyro.poutine.trace(model).get_trace(combined_z, y).format_shapes())

         Trace Shapes:            
          Param Sites:            
   linear_layer.weight  10 100    
   output_layer.weight  10  10    
         Sample Sites:            
linear_layer.bias dist       | 200
                 value       | 200
output_layer.bias dist       |  10
                 value       |  10
             data dist       |    
                 value 256   |    
              obs dist 256   |    
                 value 256   |    


In [18]:
pyro.enable_validation(True)
pyro.clear_param_store()
model = ClassifierBnn(num_hidden = 10, prior_std = 1.)

# define guide
from pyro.infer.autoguide import AutoDiagonalNormal
guide = AutoDiagonalNormal(model, init_scale=1e-1)

# define SVI (model for training)
svi = pyro.infer.SVI(model,
                    guide,
                    optim=pyro.optim.ClippedAdam({'lr':1e-2}),
                    # Define conventional ELBO
                     loss=pyro.infer.Trace_ELBO())

In [19]:
from pyro.infer import Predictive
predictive = Predictive(model, guide=guide, num_samples=20)

def predict(x):
    # for a single image, output a mean and sd for category
    yhats = predictive(x)["obs"].double()
    # yhats[0] seems to be integers 0 to 9, len 256
    # prediction for one model, for all items in batch
    # 20, 256
    mean = torch.mean(yhats, axis=0)
    std = torch.std(yhats.float(), 0).numpy()
    # yhats outputs a batch size number of predictions for 20 models
    # yhats seem to be a dictionary of weights
    return mean, std

In [20]:
# Training

num_epochs = 1000

# Define number of epochs
epoch_loss = np.zeros(shape=(num_epochs,))

# training
for epoch in range(10):
    i = 0
    for x, y in train_loader:
        # batches of size 256 are being fed in 
        z_loc, z_scale = vae.encoder(x)
        combined_z = torch.cat((z_loc, z_scale), 1)
        loss = svi.step(combined_z, y)
        mean, std = predict(combined_z)
        accuracy_per_batch = torch.sum(torch.sub(mean,y))/len(y)
        print("loss", loss)
        print("mean", mean[0], "y is", y[0])
        print(accuracy_per_batch)

loss 637.1355237960815
mean tensor(4.2000, dtype=torch.float64) y is tensor(5)
tensor(0.3443, dtype=torch.float64)
loss 626.6906805038452
mean tensor(5.3000, dtype=torch.float64) y is tensor(0)
tensor(0.3221, dtype=torch.float64)
loss 607.4282379150391
mean tensor(4.6500, dtype=torch.float64) y is tensor(4)
tensor(-0.0379, dtype=torch.float64)
loss 588.279972076416
mean tensor(4.1500, dtype=torch.float64) y is tensor(1)
tensor(-0.1799, dtype=torch.float64)
loss 580.9760131835938
mean tensor(4.6000, dtype=torch.float64) y is tensor(0)
tensor(-0.1559, dtype=torch.float64)
loss 576.6481323242188
mean tensor(3.2000, dtype=torch.float64) y is tensor(3)
tensor(-0.3453, dtype=torch.float64)
loss 557.9143877029419
mean tensor(4.5000, dtype=torch.float64) y is tensor(3)
tensor(-0.2748, dtype=torch.float64)
loss 555.7110233306885
mean tensor(4.9500, dtype=torch.float64) y is tensor(4)
tensor(0.2373, dtype=torch.float64)
loss 536.5982828140259
mean tensor(5., dtype=torch.float64) y is tensor(3)
t

loss 144.5504961013794
mean tensor(3.1000, dtype=torch.float64) y is tensor(3)
tensor(0.1932, dtype=torch.float64)
loss 150.53664875030518
mean tensor(3.2500, dtype=torch.float64) y is tensor(3)
tensor(0.0191, dtype=torch.float64)
loss 126.39814686775208
mean tensor(1., dtype=torch.float64) y is tensor(1)
tensor(-0.0117, dtype=torch.float64)
loss 149.61079025268555
mean tensor(2.9500, dtype=torch.float64) y is tensor(3)
tensor(-0.0482, dtype=torch.float64)
loss 102.30500411987305
mean tensor(2.5000, dtype=torch.float64) y is tensor(3)
tensor(0.0664, dtype=torch.float64)
loss 127.4996976852417
mean tensor(4.4000, dtype=torch.float64) y is tensor(5)
tensor(0.0199, dtype=torch.float64)
loss 130.95302963256836
mean tensor(1., dtype=torch.float64) y is tensor(1)
tensor(-0.0143, dtype=torch.float64)
loss 140.86411476135254
mean tensor(7., dtype=torch.float64) y is tensor(7)
tensor(-0.0340, dtype=torch.float64)
loss 123.00174140930176
mean tensor(8., dtype=torch.float64) y is tensor(8)
tensor

loss 109.10559511184692
mean tensor(2.0500, dtype=torch.float64) y is tensor(2)
tensor(0.0248, dtype=torch.float64)
loss 123.87192487716675
mean tensor(1., dtype=torch.float64) y is tensor(1)
tensor(-0.0682, dtype=torch.float64)
loss 97.45253252983093
mean tensor(1.1000, dtype=torch.float64) y is tensor(1)
tensor(-0.1191, dtype=torch.float64)
loss 113.89771890640259
mean tensor(8., dtype=torch.float64) y is tensor(8)
tensor(0.0318, dtype=torch.float64)
loss 117.80908286571503
mean tensor(5.2500, dtype=torch.float64) y is tensor(5)
tensor(-0.0557, dtype=torch.float64)
loss 110.97034645080566
mean tensor(5.7500, dtype=torch.float64) y is tensor(6)
tensor(-0.0205, dtype=torch.float64)
loss 102.0224871635437
mean tensor(5., dtype=torch.float64) y is tensor(5)
tensor(-0.0574, dtype=torch.float64)
loss 120.07359743118286
mean tensor(9., dtype=torch.float64) y is tensor(9)
tensor(0.0018, dtype=torch.float64)
loss 129.5422134399414
mean tensor(8.6500, dtype=torch.float64) y is tensor(9)
tensor

loss 116.36570739746094
mean tensor(1.5500, dtype=torch.float64) y is tensor(0)
tensor(-0.1789, dtype=torch.float64)
loss 135.39495873451233
mean tensor(1., dtype=torch.float64) y is tensor(1)
tensor(-0.1494, dtype=torch.float64)
loss 99.11452007293701
mean tensor(8., dtype=torch.float64) y is tensor(9)
tensor(-0.0561, dtype=torch.float64)
loss 71.2914924621582
mean tensor(1., dtype=torch.float64) y is tensor(1)
tensor(0.0143, dtype=torch.float64)
loss 94.58977127075195
mean tensor(7., dtype=torch.float64) y is tensor(7)
tensor(-0.2006, dtype=torch.float64)
loss 133.71141815185547
mean tensor(5.3500, dtype=torch.float64) y is tensor(5)
tensor(-0.2326, dtype=torch.float64)
loss 100.2219467163086
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(-0.0359, dtype=torch.float64)
loss 96.87181681394577
mean tensor(7.2000, dtype=torch.float64) y is tensor(8)
tensor(-0.1805, dtype=torch.float64)
loss 105.25891733169556
mean tensor(5.4500, dtype=torch.float64) y is tensor(5)
tensor(-0.1

loss 120.09221839904785
mean tensor(4.1000, dtype=torch.float64) y is tensor(5)
tensor(-0.0578, dtype=torch.float64)
loss 106.0954122543335
mean tensor(7., dtype=torch.float64) y is tensor(7)
tensor(-0.1184, dtype=torch.float64)
loss 97.4101185798645
mean tensor(9., dtype=torch.float64) y is tensor(9)
tensor(-0.1453, dtype=torch.float64)
loss 95.63082027435303
mean tensor(8., dtype=torch.float64) y is tensor(8)
tensor(-0.0119, dtype=torch.float64)
loss 106.55313777923584
mean tensor(2.0500, dtype=torch.float64) y is tensor(2)
tensor(-0.0688, dtype=torch.float64)
loss 97.9421854019165
mean tensor(0.2000, dtype=torch.float64) y is tensor(0)
tensor(0.0965, dtype=torch.float64)
loss 89.7940902709961
mean tensor(7.8500, dtype=torch.float64) y is tensor(8)
tensor(-0.0293, dtype=torch.float64)
loss 116.46881294250488
mean tensor(9., dtype=torch.float64) y is tensor(9)
tensor(0.0814, dtype=torch.float64)
loss 100.68982172012329
mean tensor(7., dtype=torch.float64) y is tensor(7)
tensor(0.1453,

loss 84.09272718429565
mean tensor(6.1500, dtype=torch.float64) y is tensor(8)
tensor(0.0672, dtype=torch.float64)
loss 85.6529483795166
mean tensor(7.5000, dtype=torch.float64) y is tensor(8)
tensor(0.0842, dtype=torch.float64)
loss 92.2231216430664
mean tensor(1.1000, dtype=torch.float64) y is tensor(1)
tensor(-0.0498, dtype=torch.float64)
loss 79.70802974700928
mean tensor(8., dtype=torch.float64) y is tensor(8)
tensor(0.1354, dtype=torch.float64)
loss 70.49874114990234
mean tensor(5.3000, dtype=torch.float64) y is tensor(5)
tensor(0.0977, dtype=torch.float64)
loss 104.86023473739624
mean tensor(2., dtype=torch.float64) y is tensor(2)
tensor(0.2609, dtype=torch.float64)
loss 130.48406410217285
mean tensor(7.3000, dtype=torch.float64) y is tensor(7)
tensor(0.1385, dtype=torch.float64)
loss 124.89528369903564
mean tensor(5.1500, dtype=torch.float64) y is tensor(5)
tensor(-0.0605, dtype=torch.float64)
loss 136.1039843559265
mean tensor(4., dtype=torch.float64) y is tensor(4)
tensor(0.1

loss 68.08052825927734
mean tensor(4., dtype=torch.float64) y is tensor(4)
tensor(-0.0814, dtype=torch.float64)
loss 103.3351697921753
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(-0.0002, dtype=torch.float64)
loss 109.4793152809143
mean tensor(4.6500, dtype=torch.float64) y is tensor(5)
tensor(-0.0305, dtype=torch.float64)
loss 106.69689559936523
mean tensor(3.9500, dtype=torch.float64) y is tensor(6)
tensor(-0.0855, dtype=torch.float64)
loss 88.18901634216309
mean tensor(8.6000, dtype=torch.float64) y is tensor(9)
tensor(-0.0004, dtype=torch.float64)
loss 107.48770999908447
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(-0.1295, dtype=torch.float64)
loss 117.06831073760986
mean tensor(6.1000, dtype=torch.float64) y is tensor(6)
tensor(-0.0305, dtype=torch.float64)
loss 94.38990592956543
mean tensor(6.6500, dtype=torch.float64) y is tensor(7)
tensor(-0.0865, dtype=torch.float64)
loss 97.1047739982605
mean tensor(0.5000, dtype=torch.float64) y is tensor(0)
ten

loss 123.88900470733643
mean tensor(5., dtype=torch.float64) y is tensor(5)
tensor(0.0844, dtype=torch.float64)
loss 108.83456993103027
mean tensor(4.2000, dtype=torch.float64) y is tensor(5)
tensor(-0.0258, dtype=torch.float64)
loss 108.31730365753174
mean tensor(5.0500, dtype=torch.float64) y is tensor(6)
tensor(-0.0434, dtype=torch.float64)
loss 87.39028930664062
mean tensor(2.9500, dtype=torch.float64) y is tensor(3)
tensor(0.0330, dtype=torch.float64)
loss 116.7141466140747
mean tensor(0.4000, dtype=torch.float64) y is tensor(0)
tensor(-0.0377, dtype=torch.float64)
loss 91.54458689689636
mean tensor(8.7500, dtype=torch.float64) y is tensor(9)
tensor(0.0926, dtype=torch.float64)
loss 94.01378440856934
mean tensor(5.4500, dtype=torch.float64) y is tensor(7)
tensor(0.0061, dtype=torch.float64)
loss 74.52265548706055
mean tensor(1., dtype=torch.float64) y is tensor(1)
tensor(-0.0084, dtype=torch.float64)
loss 80.6755895614624
mean tensor(1., dtype=torch.float64) y is tensor(1)
tensor(

loss 83.49018287658691
mean tensor(4.7500, dtype=torch.float64) y is tensor(5)
tensor(-0.0111, dtype=torch.float64)
loss 126.47246742248535
mean tensor(1.1000, dtype=torch.float64) y is tensor(1)
tensor(-0.0938, dtype=torch.float64)
loss 86.64925765991211
mean tensor(4., dtype=torch.float64) y is tensor(4)
tensor(0.0314, dtype=torch.float64)
loss 95.0462999343872
mean tensor(8., dtype=torch.float64) y is tensor(8)
tensor(-0.1402, dtype=torch.float64)
loss 80.41373062133789
mean tensor(8., dtype=torch.float64) y is tensor(8)
tensor(-0.1508, dtype=torch.float64)
loss 90.74332523345947
mean tensor(3.4500, dtype=torch.float64) y is tensor(3)
tensor(-0.1562, dtype=torch.float64)
loss 88.05576992034912
mean tensor(5., dtype=torch.float64) y is tensor(5)
tensor(-0.0889, dtype=torch.float64)
loss 83.09930610656738
mean tensor(4.5000, dtype=torch.float64) y is tensor(5)
tensor(-0.0982, dtype=torch.float64)
loss 125.92725658416748
mean tensor(2., dtype=torch.float64) y is tensor(2)
tensor(-0.024

loss 112.69073152542114
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(-0.0203, dtype=torch.float64)
loss 99.81190800666809
mean tensor(5.3500, dtype=torch.float64) y is tensor(5)
tensor(0.1359, dtype=torch.float64)
loss 113.06295585632324
mean tensor(5.1500, dtype=torch.float64) y is tensor(5)
tensor(-0.0318, dtype=torch.float64)
loss 82.79591846466064
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(-0.1582, dtype=torch.float64)
loss 114.01801109313965
mean tensor(6., dtype=torch.float64) y is tensor(6)
tensor(-0.0762, dtype=torch.float64)
loss 92.36364698410034
mean tensor(4.2500, dtype=torch.float64) y is tensor(4)
tensor(-0.0629, dtype=torch.float64)
loss 105.5300645828247
mean tensor(8.7500, dtype=torch.float64) y is tensor(9)
tensor(-0.0457, dtype=torch.float64)
loss 97.91150331497192
mean tensor(8.2500, dtype=torch.float64) y is tensor(7)
tensor(-0.0264, dtype=torch.float64)
loss 97.30155038833618
mean tensor(4.1000, dtype=torch.float64) y is tensor(5)
ten

loss 93.49860048294067
mean tensor(5.1500, dtype=torch.float64) y is tensor(5)
tensor(0.0051, dtype=torch.float64)
loss 106.96262073516846
mean tensor(8., dtype=torch.float64) y is tensor(8)
tensor(-0.0049, dtype=torch.float64)
loss 106.55093908309937
mean tensor(4.1500, dtype=torch.float64) y is tensor(1)
tensor(0.0006, dtype=torch.float64)
loss 69.75099182128906
mean tensor(6., dtype=torch.float64) y is tensor(6)
tensor(0.0621, dtype=torch.float64)
loss 96.33941602706909
mean tensor(4.1000, dtype=torch.float64) y is tensor(4)
tensor(-0.0160, dtype=torch.float64)
loss 77.94640445709229
mean tensor(4.7500, dtype=torch.float64) y is tensor(3)
tensor(-0.0848, dtype=torch.float64)
loss 83.01144409179688
mean tensor(5.9500, dtype=torch.float64) y is tensor(4)
tensor(-0.0369, dtype=torch.float64)
loss 84.7215805053711
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(-0.1037, dtype=torch.float64)
loss 96.47309494018555
mean tensor(0.1500, dtype=torch.float64) y is tensor(0)
tensor(

loss 125.92034006118774
mean tensor(6., dtype=torch.float64) y is tensor(6)
tensor(0.0002, dtype=torch.float64)
loss 94.49663162231445
mean tensor(6.8500, dtype=torch.float64) y is tensor(7)
tensor(-0.0344, dtype=torch.float64)
loss 80.39834690093994
mean tensor(4., dtype=torch.float64) y is tensor(4)
tensor(0.0787, dtype=torch.float64)
loss 81.07506942749023
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(-0.0650, dtype=torch.float64)
loss 94.74077224731445
mean tensor(8.2500, dtype=torch.float64) y is tensor(9)
tensor(-0.0117, dtype=torch.float64)
loss 75.18901538848877
mean tensor(4.9000, dtype=torch.float64) y is tensor(5)
tensor(0.0492, dtype=torch.float64)
loss 80.55398178100586
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(-0.0340, dtype=torch.float64)
loss 71.19616889953613
mean tensor(0.3500, dtype=torch.float64) y is tensor(0)
tensor(0.0170, dtype=torch.float64)
loss 118.3784704208374
mean tensor(4.7500, dtype=torch.float64) y is tensor(5)
tensor(-0.03

loss 115.54618549346924
mean tensor(2., dtype=torch.float64) y is tensor(2)
tensor(0.0869, dtype=torch.float64)
loss 106.81514358520508
mean tensor(4.7500, dtype=torch.float64) y is tensor(5)
tensor(0.0383, dtype=torch.float64)
loss 108.37419128417969
mean tensor(7.7000, dtype=torch.float64) y is tensor(7)
tensor(0.0797, dtype=torch.float64)
loss 82.37391948699951
mean tensor(0.2500, dtype=torch.float64) y is tensor(0)
tensor(0.0416, dtype=torch.float64)
loss 70.6726942062378
mean tensor(2.4500, dtype=torch.float64) y is tensor(2)
tensor(0.0012, dtype=torch.float64)
loss 76.15513038635254
mean tensor(5.5000, dtype=torch.float64) y is tensor(6)
tensor(-0.1047, dtype=torch.float64)
loss 102.50265073776245
mean tensor(6., dtype=torch.float64) y is tensor(6)
tensor(0.0717, dtype=torch.float64)
loss 69.16819381713867
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(0.0578, dtype=torch.float64)
loss 86.75756788253784
mean tensor(4.4500, dtype=torch.float64) y is tensor(4)
tensor(-0

loss 70.27880001068115
mean tensor(6., dtype=torch.float64) y is tensor(6)
tensor(-0.0420, dtype=torch.float64)
loss 92.28617095947266
mean tensor(3., dtype=torch.float64) y is tensor(3)
tensor(-0.1113, dtype=torch.float64)
loss 82.50908279418945
mean tensor(1., dtype=torch.float64) y is tensor(1)
tensor(-0.1313, dtype=torch.float64)
loss 93.76692867279053
mean tensor(1., dtype=torch.float64) y is tensor(1)
tensor(-0.0787, dtype=torch.float64)
loss 86.86839962005615
mean tensor(9., dtype=torch.float64) y is tensor(9)
tensor(-0.0191, dtype=torch.float64)
loss 92.0901689529419
mean tensor(4.5500, dtype=torch.float64) y is tensor(4)
tensor(0.0777, dtype=torch.float64)
loss 66.24468517303467
mean tensor(9., dtype=torch.float64) y is tensor(9)
tensor(0.0170, dtype=torch.float64)
loss 87.1777458190918
mean tensor(4.5000, dtype=torch.float64) y is tensor(4)
tensor(0.1035, dtype=torch.float64)
loss 86.77987861633301
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(0.0605, dtype=torch

loss 99.81272506713867
mean tensor(2.0500, dtype=torch.float64) y is tensor(0)
tensor(0.1029, dtype=torch.float64)
loss 80.82464504241943
mean tensor(7., dtype=torch.float64) y is tensor(7)
tensor(-0.0609, dtype=torch.float64)
loss 87.17094707489014
mean tensor(3.6500, dtype=torch.float64) y is tensor(4)
tensor(-0.0033, dtype=torch.float64)
loss 110.28117179870605
mean tensor(1., dtype=torch.float64) y is tensor(1)
tensor(0.0498, dtype=torch.float64)
loss 89.5617322921753
mean tensor(3., dtype=torch.float64) y is tensor(3)
tensor(-0.0885, dtype=torch.float64)
loss 64.84646606445312
mean tensor(3.3000, dtype=torch.float64) y is tensor(3)
tensor(0.0439, dtype=torch.float64)
loss 98.25870895385742
mean tensor(1., dtype=torch.float64) y is tensor(1)
tensor(-0.0687, dtype=torch.float64)
loss 99.8612699508667
mean tensor(4., dtype=torch.float64) y is tensor(4)
tensor(0.0459, dtype=torch.float64)
loss 88.54671096801758
mean tensor(2., dtype=torch.float64) y is tensor(2)
tensor(0.0301, dtype=t

loss 69.23097229003906
mean tensor(8.7500, dtype=torch.float64) y is tensor(9)
tensor(-0.1418, dtype=torch.float64)
loss 96.41435146331787
mean tensor(7.7500, dtype=torch.float64) y is tensor(8)
tensor(-0.0674, dtype=torch.float64)
loss 94.11279582977295
mean tensor(4.3000, dtype=torch.float64) y is tensor(5)
tensor(-0.2180, dtype=torch.float64)
loss 80.0074234008789
mean tensor(1., dtype=torch.float64) y is tensor(1)
tensor(-0.0221, dtype=torch.float64)
loss 78.21426773071289
mean tensor(4., dtype=torch.float64) y is tensor(4)
tensor(-0.1166, dtype=torch.float64)
loss 115.66349935531616
mean tensor(0.3000, dtype=torch.float64) y is tensor(0)
tensor(0.0299, dtype=torch.float64)
loss 82.09143733978271
mean tensor(2., dtype=torch.float64) y is tensor(2)
tensor(0.0156, dtype=torch.float64)
loss 99.8141450881958
mean tensor(3.0500, dtype=torch.float64) y is tensor(3)
tensor(-0.0365, dtype=torch.float64)
loss 85.47571086883545
mean tensor(8., dtype=torch.float64) y is tensor(8)
tensor(0.103

loss 69.9306173324585
mean tensor(0.3500, dtype=torch.float64) y is tensor(0)
tensor(0.0467, dtype=torch.float64)
loss 39.769776344299316
mean tensor(4.9000, dtype=torch.float64) y is tensor(3)
tensor(0.1547, dtype=torch.float64)
loss 89.08648204803467
mean tensor(1., dtype=torch.float64) y is tensor(1)
tensor(0.1361, dtype=torch.float64)
loss 88.15431785583496
mean tensor(3., dtype=torch.float64) y is tensor(3)
tensor(-0.0047, dtype=torch.float64)
loss 96.17287111282349
mean tensor(3., dtype=torch.float64) y is tensor(3)
tensor(0.0926, dtype=torch.float64)
loss 110.7667350769043
mean tensor(4., dtype=torch.float64) y is tensor(4)
tensor(-0.0709, dtype=torch.float64)
loss 93.73112297058105
mean tensor(4.2000, dtype=torch.float64) y is tensor(2)
tensor(-0.0072, dtype=torch.float64)
loss 101.48041725158691
mean tensor(4., dtype=torch.float64) y is tensor(4)
tensor(-0.0035, dtype=torch.float64)
loss 96.81252002716064
mean tensor(6., dtype=torch.float64) y is tensor(6)
tensor(-0.0641, dtyp

loss 143.4969825744629
mean tensor(3.5500, dtype=torch.float64) y is tensor(3)
tensor(-0.0643, dtype=torch.float64)
loss 92.18595600128174
mean tensor(2., dtype=torch.float64) y is tensor(1)
tensor(-0.0576, dtype=torch.float64)
loss 84.73680305480957
mean tensor(7.5500, dtype=torch.float64) y is tensor(8)
tensor(-0.0660, dtype=torch.float64)
loss 113.95426082611084
mean tensor(6., dtype=torch.float64) y is tensor(6)
tensor(-0.0172, dtype=torch.float64)
loss 97.98661994934082
mean tensor(7.0500, dtype=torch.float64) y is tensor(9)
tensor(-0.1820, dtype=torch.float64)
loss 102.47793102264404
mean tensor(7., dtype=torch.float64) y is tensor(7)
tensor(-0.1854, dtype=torch.float64)
loss 92.79703521728516
mean tensor(7.6000, dtype=torch.float64) y is tensor(7)
tensor(-0.0998, dtype=torch.float64)
loss 117.1281270980835
mean tensor(3., dtype=torch.float64) y is tensor(3)
tensor(0.0400, dtype=torch.float64)
loss 97.06324195861816
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(0.037

loss 111.70051002502441
mean tensor(5., dtype=torch.float64) y is tensor(5)
tensor(-0.0018, dtype=torch.float64)
loss 78.1844596862793
mean tensor(5.1500, dtype=torch.float64) y is tensor(5)
tensor(0.0760, dtype=torch.float64)
loss 125.44043731689453
mean tensor(5.3000, dtype=torch.float64) y is tensor(7)
tensor(0.1082, dtype=torch.float64)
loss 94.12100601196289
mean tensor(9., dtype=torch.float64) y is tensor(9)
tensor(0.0357, dtype=torch.float64)
loss 83.11532402038574
mean tensor(5.1500, dtype=torch.float64) y is tensor(7)
tensor(0.0682, dtype=torch.float64)
loss 83.21380138397217
mean tensor(8.9000, dtype=torch.float64) y is tensor(9)
tensor(-0.1469, dtype=torch.float64)
loss 110.20901870727539
mean tensor(2., dtype=torch.float64) y is tensor(2)
tensor(0.1346, dtype=torch.float64)
loss 96.63115787506104
mean tensor(8., dtype=torch.float64) y is tensor(8)
tensor(-0.0525, dtype=torch.float64)
loss 100.81395626068115
mean tensor(7., dtype=torch.float64) y is tensor(7)
tensor(-0.0176,

loss 73.5935411453247
mean tensor(1., dtype=torch.float64) y is tensor(1)
tensor(0.0848, dtype=torch.float64)
loss 82.20454120635986
mean tensor(4.5000, dtype=torch.float64) y is tensor(4)
tensor(0.1494, dtype=torch.float64)
loss 69.36073875427246
mean tensor(4., dtype=torch.float64) y is tensor(4)
tensor(0.1930, dtype=torch.float64)
loss 68.85163688659668
mean tensor(0.1500, dtype=torch.float64) y is tensor(0)
tensor(0.1602, dtype=torch.float64)
loss 94.96114826202393
mean tensor(7., dtype=torch.float64) y is tensor(7)
tensor(0.1842, dtype=torch.float64)
loss 66.16502094268799
mean tensor(6., dtype=torch.float64) y is tensor(6)
tensor(0.1107, dtype=torch.float64)
loss 101.19660186767578
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(0.2064, dtype=torch.float64)
loss 92.40336656570435
mean tensor(4.8000, dtype=torch.float64) y is tensor(5)
tensor(0.1229, dtype=torch.float64)
loss 80.53517150878906
mean tensor(8., dtype=torch.float64) y is tensor(8)
tensor(-0.0426, dtype=tor

loss 84.94961023330688
mean tensor(1., dtype=torch.float64) y is tensor(1)
tensor(0.0182, dtype=torch.float64)
loss 86.25194644927979
mean tensor(7.7000, dtype=torch.float64) y is tensor(8)
tensor(-0.0053, dtype=torch.float64)
loss 76.57853412628174
mean tensor(8., dtype=torch.float64) y is tensor(8)
tensor(-0.0121, dtype=torch.float64)
loss 86.58802127838135
mean tensor(2., dtype=torch.float64) y is tensor(2)
tensor(0.0246, dtype=torch.float64)
loss 76.93237972259521
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(0.0980, dtype=torch.float64)
loss 86.04988098144531
mean tensor(7., dtype=torch.float64) y is tensor(7)
tensor(0.0186, dtype=torch.float64)
loss 101.05918312072754
mean tensor(4.3500, dtype=torch.float64) y is tensor(4)
tensor(0.0787, dtype=torch.float64)
loss 85.93176746368408
mean tensor(5.9000, dtype=torch.float64) y is tensor(6)
tensor(-0.0129, dtype=torch.float64)
loss 100.19654750823975
mean tensor(4.8000, dtype=torch.float64) y is tensor(9)
tensor(-0.0533, 

loss 112.13638496398926
mean tensor(2.3000, dtype=torch.float64) y is tensor(2)
tensor(-0.0004, dtype=torch.float64)
loss 98.81140947341919
mean tensor(9., dtype=torch.float64) y is tensor(9)
tensor(-0.0873, dtype=torch.float64)
loss 71.30846881866455
mean tensor(3.7500, dtype=torch.float64) y is tensor(9)
tensor(-0.1004, dtype=torch.float64)
loss 71.753662109375
mean tensor(2., dtype=torch.float64) y is tensor(2)
tensor(-0.0559, dtype=torch.float64)
loss 71.01985359191895
mean tensor(1.5500, dtype=torch.float64) y is tensor(4)
tensor(-0.1518, dtype=torch.float64)
loss 93.7714958190918
mean tensor(5., dtype=torch.float64) y is tensor(5)
tensor(-0.1969, dtype=torch.float64)
loss 121.7280626296997
mean tensor(2.6500, dtype=torch.float64) y is tensor(8)
tensor(-0.1875, dtype=torch.float64)
loss 76.2218770980835
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(0.0270, dtype=torch.float64)
loss 77.85794639587402
mean tensor(9., dtype=torch.float64) y is tensor(8)
tensor(-0.0260, d

loss 91.78015899658203
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(-0.0248, dtype=torch.float64)
loss 70.12950325012207
mean tensor(8.7500, dtype=torch.float64) y is tensor(9)
tensor(0.1064, dtype=torch.float64)
loss 70.40699100494385
mean tensor(2., dtype=torch.float64) y is tensor(2)
tensor(0.1277, dtype=torch.float64)
loss 67.73260593414307
mean tensor(7.0500, dtype=torch.float64) y is tensor(7)
tensor(0.0342, dtype=torch.float64)
loss 121.36919403076172
mean tensor(5.7000, dtype=torch.float64) y is tensor(6)
tensor(0.0574, dtype=torch.float64)
loss 65.9134578704834
mean tensor(5.7000, dtype=torch.float64) y is tensor(4)
tensor(-0.0307, dtype=torch.float64)
loss 73.67630100250244
mean tensor(2.6000, dtype=torch.float64) y is tensor(2)
tensor(0.0109, dtype=torch.float64)
loss 81.60701847076416
mean tensor(1.5000, dtype=torch.float64) y is tensor(1)
tensor(-0.0043, dtype=torch.float64)
loss 83.67940711975098
mean tensor(5.6500, dtype=torch.float64) y is tensor(5)
tensor

loss 72.22289943695068
mean tensor(3., dtype=torch.float64) y is tensor(3)
tensor(-0.0879, dtype=torch.float64)
loss 86.64870548248291
mean tensor(2.3000, dtype=torch.float64) y is tensor(7)
tensor(-0.1980, dtype=torch.float64)
loss 87.3074402809143
mean tensor(6.4000, dtype=torch.float64) y is tensor(7)
tensor(-0.1549, dtype=torch.float64)
loss 121.67545509338379
mean tensor(5., dtype=torch.float64) y is tensor(5)
tensor(-0.0854, dtype=torch.float64)
loss 77.72179317474365
mean tensor(1., dtype=torch.float64) y is tensor(1)
tensor(0.0078, dtype=torch.float64)
loss 88.24528217315674
mean tensor(8.9500, dtype=torch.float64) y is tensor(9)
tensor(-0.0174, dtype=torch.float64)
loss 119.12331962585449
mean tensor(2.3000, dtype=torch.float64) y is tensor(5)
tensor(0.0436, dtype=torch.float64)
loss 70.97593879699707
mean tensor(4.8500, dtype=torch.float64) y is tensor(5)
tensor(-0.1184, dtype=torch.float64)
loss 71.4744520187378
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(0.12

loss 98.0734224319458
mean tensor(9., dtype=torch.float64) y is tensor(9)
tensor(0.0252, dtype=torch.float64)
loss 96.39553833007812
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(0.0971, dtype=torch.float64)
loss 100.35106754302979
mean tensor(5.0500, dtype=torch.float64) y is tensor(0)
tensor(0.0746, dtype=torch.float64)
loss 71.07358646392822
mean tensor(2., dtype=torch.float64) y is tensor(2)
tensor(0.0695, dtype=torch.float64)
loss 117.82815361022949
mean tensor(7., dtype=torch.float64) y is tensor(7)
tensor(0.0977, dtype=torch.float64)
loss 103.49953651428223
mean tensor(0.7500, dtype=torch.float64) y is tensor(0)
tensor(0.1209, dtype=torch.float64)
loss 89.4152889251709
mean tensor(9., dtype=torch.float64) y is tensor(9)
tensor(0.2340, dtype=torch.float64)
loss 88.73313999176025
mean tensor(4., dtype=torch.float64) y is tensor(4)
tensor(0.1074, dtype=torch.float64)
loss 84.14532279968262
mean tensor(7., dtype=torch.float64) y is tensor(7)
tensor(0.0400, dtype=torch.f

loss 91.66657543182373
mean tensor(7.1000, dtype=torch.float64) y is tensor(4)
tensor(-0.0457, dtype=torch.float64)
loss 67.47659206390381
mean tensor(5.8500, dtype=torch.float64) y is tensor(6)
tensor(0.0582, dtype=torch.float64)
loss 93.25187110900879
mean tensor(3., dtype=torch.float64) y is tensor(3)
tensor(0.0107, dtype=torch.float64)
loss 86.6978874206543
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(0.0229, dtype=torch.float64)
loss 84.44253206253052
mean tensor(4., dtype=torch.float64) y is tensor(4)
tensor(0.0361, dtype=torch.float64)
loss 63.15731716156006
mean tensor(5.5500, dtype=torch.float64) y is tensor(7)
tensor(-0.0076, dtype=torch.float64)
loss 71.33019685745239
mean tensor(6.8000, dtype=torch.float64) y is tensor(7)
tensor(-0.1125, dtype=torch.float64)
loss 85.48558044433594
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(-0.0459, dtype=torch.float64)
loss 85.24758815765381
mean tensor(9., dtype=torch.float64) y is tensor(9)
tensor(-0.0031, dt

loss 93.00843334197998
mean tensor(6.7500, dtype=torch.float64) y is tensor(7)
tensor(0.0311, dtype=torch.float64)
loss 81.57695817947388
mean tensor(2., dtype=torch.float64) y is tensor(0)
tensor(0.0262, dtype=torch.float64)
loss 104.92869567871094
mean tensor(7.6000, dtype=torch.float64) y is tensor(9)
tensor(0.0014, dtype=torch.float64)
loss 88.16720962524414
mean tensor(7., dtype=torch.float64) y is tensor(7)
tensor(0.1414, dtype=torch.float64)
loss 57.05706214904785
mean tensor(7., dtype=torch.float64) y is tensor(7)
tensor(0.0977, dtype=torch.float64)
loss 79.9291877746582
mean tensor(5.2500, dtype=torch.float64) y is tensor(4)
tensor(-0.0350, dtype=torch.float64)
loss 69.9978141784668
mean tensor(1., dtype=torch.float64) y is tensor(1)
tensor(0.0242, dtype=torch.float64)
loss 102.9609022140503
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(0.1766, dtype=torch.float64)
loss 118.85669708251953
mean tensor(8.9000, dtype=torch.float64) y is tensor(9)
tensor(-0.2100, dtyp

loss 111.80037593841553
mean tensor(2., dtype=torch.float64) y is tensor(2)
tensor(0.0615, dtype=torch.float64)
loss 78.57745361328125
mean tensor(4., dtype=torch.float64) y is tensor(4)
tensor(0.0633, dtype=torch.float64)
loss 89.80871677398682
mean tensor(1., dtype=torch.float64) y is tensor(1)
tensor(0.0074, dtype=torch.float64)
loss 85.62249851226807
mean tensor(3., dtype=torch.float64) y is tensor(3)
tensor(-0.1143, dtype=torch.float64)
loss 118.93721199035645
mean tensor(1., dtype=torch.float64) y is tensor(1)
tensor(-0.0957, dtype=torch.float64)
loss 95.19253253936768
mean tensor(7., dtype=torch.float64) y is tensor(7)
tensor(0.1447, dtype=torch.float64)
loss 103.21644115447998
mean tensor(2., dtype=torch.float64) y is tensor(2)
tensor(0.0533, dtype=torch.float64)
loss 72.72949886322021
mean tensor(6., dtype=torch.float64) y is tensor(6)
tensor(-0.0172, dtype=torch.float64)
loss 79.75090789794922
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(0.0068, dtype=torch.floa

loss 73.50930213928223
mean tensor(6., dtype=torch.float64) y is tensor(6)
tensor(-0.0961, dtype=torch.float64)
loss 81.6762228012085
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(0.0775, dtype=torch.float64)
loss 95.76903629302979
mean tensor(3.9500, dtype=torch.float64) y is tensor(3)
tensor(-0.0521, dtype=torch.float64)
loss 69.15899562835693
mean tensor(7.1000, dtype=torch.float64) y is tensor(7)
tensor(0.0555, dtype=torch.float64)
loss 90.80842113494873
mean tensor(7., dtype=torch.float64) y is tensor(8)
tensor(0.0400, dtype=torch.float64)
loss 108.1658067703247
mean tensor(7., dtype=torch.float64) y is tensor(7)
tensor(-0.0385, dtype=torch.float64)
loss 83.1600923538208
mean tensor(1., dtype=torch.float64) y is tensor(1)
tensor(0.0789, dtype=torch.float64)
loss 64.94384670257568
mean tensor(6., dtype=torch.float64) y is tensor(6)
tensor(-0.0256, dtype=torch.float64)
loss 63.856282234191895
mean tensor(3.2000, dtype=torch.float64) y is tensor(4)
tensor(-0.0570, dtype=

loss 85.97941303253174
mean tensor(1., dtype=torch.float64) y is tensor(1)
tensor(-0.0406, dtype=torch.float64)
loss 68.61635398864746
mean tensor(5., dtype=torch.float64) y is tensor(5)
tensor(0.0238, dtype=torch.float64)
loss 53.06521797180176
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(0.0994, dtype=torch.float64)
loss 117.80403995513916
mean tensor(8.1000, dtype=torch.float64) y is tensor(8)
tensor(0.1156, dtype=torch.float64)
loss 92.03231239318848
mean tensor(3., dtype=torch.float64) y is tensor(3)
tensor(0.1037, dtype=torch.float64)
loss 77.82761096954346
mean tensor(6., dtype=torch.float64) y is tensor(6)
tensor(0.0350, dtype=torch.float64)
loss 107.60990238189697
mean tensor(5., dtype=torch.float64) y is tensor(5)
tensor(0.0473, dtype=torch.float64)
loss 78.09743404388428
mean tensor(9., dtype=torch.float64) y is tensor(9)
tensor(0.0938, dtype=torch.float64)
loss 90.17822170257568
mean tensor(9., dtype=torch.float64) y is tensor(9)
tensor(0.1516, dtype=torch.flo

loss 96.48444938659668
mean tensor(7., dtype=torch.float64) y is tensor(7)
tensor(0.0232, dtype=torch.float64)
loss 97.61210298538208
mean tensor(5.9500, dtype=torch.float64) y is tensor(6)
tensor(-0.0225, dtype=torch.float64)
loss 71.82934761047363
mean tensor(1., dtype=torch.float64) y is tensor(1)
tensor(0.0137, dtype=torch.float64)
loss 70.43137454986572
mean tensor(5.2000, dtype=torch.float64) y is tensor(5)
tensor(-0.0270, dtype=torch.float64)
loss 82.36676979064941
mean tensor(7., dtype=torch.float64) y is tensor(7)
tensor(-0.0922, dtype=torch.float64)
loss 75.36786413192749
mean tensor(3.1500, dtype=torch.float64) y is tensor(3)
tensor(-0.0217, dtype=torch.float64)
loss 67.90523338317871
mean tensor(6.8500, dtype=torch.float64) y is tensor(8)
tensor(-0.0855, dtype=torch.float64)
loss 67.99691772460938
mean tensor(0., dtype=torch.float64) y is tensor(0)
tensor(-0.0592, dtype=torch.float64)
loss 86.08366775512695
mean tensor(4., dtype=torch.float64) y is tensor(4)
tensor(-0.0568,

loss 80.9532585144043
mean tensor(9., dtype=torch.float64) y is tensor(9)
tensor(-0.0910, dtype=torch.float64)
loss 90.22877025604248
mean tensor(3., dtype=torch.float64) y is tensor(9)
tensor(0.0275, dtype=torch.float64)
loss 69.94841384887695
mean tensor(7., dtype=torch.float64) y is tensor(1)
tensor(0.0816, dtype=torch.float64)
loss 70.73832988739014
mean tensor(3.1000, dtype=torch.float64) y is tensor(3)
tensor(-0.0725, dtype=torch.float64)
loss 81.4617280960083
mean tensor(6., dtype=torch.float64) y is tensor(6)
tensor(-0.0191, dtype=torch.float64)
loss 77.6117115020752
mean tensor(4.7500, dtype=torch.float64) y is tensor(5)
tensor(-0.2191, dtype=torch.float64)
loss 79.60945987701416
mean tensor(4.7500, dtype=torch.float64) y is tensor(5)
tensor(-0.1225, dtype=torch.float64)
loss 70.95444011688232
mean tensor(8.1500, dtype=torch.float64) y is tensor(9)
tensor(-0.1457, dtype=torch.float64)
loss 75.4791145324707
mean tensor(7.1000, dtype=torch.float64) y is tensor(7)
tensor(0.0184, 