In [8]:
import numpy as np
from scipy.sparse import rand as sprand
import torch

# Make up some random explicit feedback ratings
# and convert to a numpy array
n_users = 1_000
n_items = 1_000
ratings = sprand(n_users, n_items, density=0.01, format="csr")
ratings.data = np.random.randint(1, 5, size=ratings.nnz).astype(np.float64)
ratings = ratings.toarray()
print(ratings.data)

<memory at 0x7ffeb420a330>


In [2]:
class MatrixFactorization(torch.nn.Module):
    def __init__(self, n_users, n_items, n_factors=20):
        super().__init__()
        self.user_factors = torch.nn.Embedding(n_users, n_factors, sparse=True)
        self.item_factors = torch.nn.Embedding(n_items, n_factors, sparse=True)

    def forward(self, user, item):
        return (self.user_factors(user) * self.item_factors(item)).sum(1)


In [9]:
model = MatrixFactorization(n_users, n_items, n_factors=20)
print(model)

MatrixFactorization(
  (user_factors): Embedding(1000, 20, sparse=True)
  (item_factors): Embedding(1000, 20, sparse=True)
)


In [10]:
loss_func = torch.nn.MSELoss()
print(loss_func)

MSELoss()


In [5]:
optimizer = torch.optim.SGD(model.parameters(), lr=1e-6)  # learning rate

In [12]:
# Sort our data
rows, cols = ratings.nonzero()
p = np.random.permutation(len(rows))
rows, cols = rows[p], cols[p]

for row, col in zip(*(rows, cols)):
    # Set gradients to zero
    optimizer.zero_grad()
    
    # Turn data into tensors
    rating = torch.FloatTensor([ratings[row, col]])
    row = torch.LongTensor([row])
    col = torch.LongTensor([col])

    # Predict and calculate loss
    prediction = model(row, col)
    loss = loss_func(prediction, rating)
    
    # Backpropagate
    loss.backward()

    # Update the parameters
    optimizer.step()
    
    print(loss)


tensor(67.9551, grad_fn=<MseLossBackward0>)
tensor(145.9759, grad_fn=<MseLossBackward0>)
tensor(4.5712, grad_fn=<MseLossBackward0>)
tensor(80.7261, grad_fn=<MseLossBackward0>)
tensor(36.3909, grad_fn=<MseLossBackward0>)
tensor(33.3313, grad_fn=<MseLossBackward0>)
tensor(11.7467, grad_fn=<MseLossBackward0>)
tensor(11.5460, grad_fn=<MseLossBackward0>)
tensor(6.4448, grad_fn=<MseLossBackward0>)
tensor(22.4857, grad_fn=<MseLossBackward0>)
tensor(1.7918, grad_fn=<MseLossBackward0>)
tensor(4.8370, grad_fn=<MseLossBackward0>)
tensor(25.2307, grad_fn=<MseLossBackward0>)
tensor(15.0787, grad_fn=<MseLossBackward0>)
tensor(46.5298, grad_fn=<MseLossBackward0>)
tensor(40.6491, grad_fn=<MseLossBackward0>)
tensor(2.3588, grad_fn=<MseLossBackward0>)
tensor(1.4067, grad_fn=<MseLossBackward0>)
tensor(0.6612, grad_fn=<MseLossBackward0>)
tensor(0.7031, grad_fn=<MseLossBackward0>)
tensor(3.2516, grad_fn=<MseLossBackward0>)
tensor(9.7676, grad_fn=<MseLossBackward0>)
tensor(18.3991, grad_fn=<MseLossBackward0

tensor(1.1514, grad_fn=<MseLossBackward0>)
tensor(53.0814, grad_fn=<MseLossBackward0>)
tensor(1.7508, grad_fn=<MseLossBackward0>)
tensor(139.5892, grad_fn=<MseLossBackward0>)
tensor(2.0743, grad_fn=<MseLossBackward0>)
tensor(17.0744, grad_fn=<MseLossBackward0>)
tensor(54.1711, grad_fn=<MseLossBackward0>)
tensor(3.7961, grad_fn=<MseLossBackward0>)
tensor(21.5180, grad_fn=<MseLossBackward0>)
tensor(1.7483, grad_fn=<MseLossBackward0>)
tensor(51.9576, grad_fn=<MseLossBackward0>)
tensor(69.3541, grad_fn=<MseLossBackward0>)
tensor(6.5722, grad_fn=<MseLossBackward0>)
tensor(11.0528, grad_fn=<MseLossBackward0>)
tensor(14.5260, grad_fn=<MseLossBackward0>)
tensor(66.2259, grad_fn=<MseLossBackward0>)
tensor(134.0064, grad_fn=<MseLossBackward0>)
tensor(5.8471, grad_fn=<MseLossBackward0>)
tensor(13.1153, grad_fn=<MseLossBackward0>)
tensor(19.6417, grad_fn=<MseLossBackward0>)
tensor(0.1861, grad_fn=<MseLossBackward0>)
tensor(11.3742, grad_fn=<MseLossBackward0>)
tensor(7.8517, grad_fn=<MseLossBackwar

tensor(82.2528, grad_fn=<MseLossBackward0>)
tensor(15.1465, grad_fn=<MseLossBackward0>)
tensor(52.2624, grad_fn=<MseLossBackward0>)
tensor(102.5791, grad_fn=<MseLossBackward0>)
tensor(87.7643, grad_fn=<MseLossBackward0>)
tensor(6.7966, grad_fn=<MseLossBackward0>)
tensor(14.4544, grad_fn=<MseLossBackward0>)
tensor(2.3462, grad_fn=<MseLossBackward0>)
tensor(0.0892, grad_fn=<MseLossBackward0>)
tensor(5.0853, grad_fn=<MseLossBackward0>)
tensor(23.3332, grad_fn=<MseLossBackward0>)
tensor(70.4764, grad_fn=<MseLossBackward0>)
tensor(0.4556, grad_fn=<MseLossBackward0>)
tensor(3.7240, grad_fn=<MseLossBackward0>)
tensor(18.6469, grad_fn=<MseLossBackward0>)
tensor(48.9545, grad_fn=<MseLossBackward0>)
tensor(6.6233, grad_fn=<MseLossBackward0>)
tensor(57.8871, grad_fn=<MseLossBackward0>)
tensor(1.0980, grad_fn=<MseLossBackward0>)
tensor(0.3017, grad_fn=<MseLossBackward0>)
tensor(51.0291, grad_fn=<MseLossBackward0>)
tensor(34.9041, grad_fn=<MseLossBackward0>)
tensor(30.2566, grad_fn=<MseLossBackward

tensor(0.1204, grad_fn=<MseLossBackward0>)
tensor(15.3634, grad_fn=<MseLossBackward0>)
tensor(114.6341, grad_fn=<MseLossBackward0>)
tensor(12.0905, grad_fn=<MseLossBackward0>)
tensor(50.1005, grad_fn=<MseLossBackward0>)
tensor(44.9202, grad_fn=<MseLossBackward0>)
tensor(0.0403, grad_fn=<MseLossBackward0>)
tensor(14.9864, grad_fn=<MseLossBackward0>)
tensor(4.2309, grad_fn=<MseLossBackward0>)
tensor(53.5223, grad_fn=<MseLossBackward0>)
tensor(144.4804, grad_fn=<MseLossBackward0>)
tensor(27.9882, grad_fn=<MseLossBackward0>)
tensor(32.4339, grad_fn=<MseLossBackward0>)
tensor(15.7681, grad_fn=<MseLossBackward0>)
tensor(10.3989, grad_fn=<MseLossBackward0>)
tensor(2.3140, grad_fn=<MseLossBackward0>)
tensor(76.8702, grad_fn=<MseLossBackward0>)
tensor(58.8127, grad_fn=<MseLossBackward0>)
tensor(29.3640, grad_fn=<MseLossBackward0>)
tensor(20.2756, grad_fn=<MseLossBackward0>)
tensor(6.5688, grad_fn=<MseLossBackward0>)
tensor(39.2631, grad_fn=<MseLossBackward0>)
tensor(5.9588, grad_fn=<MseLossBack

tensor(193.0443, grad_fn=<MseLossBackward0>)
tensor(49.6970, grad_fn=<MseLossBackward0>)
tensor(21.6890, grad_fn=<MseLossBackward0>)
tensor(26.6622, grad_fn=<MseLossBackward0>)
tensor(18.0587, grad_fn=<MseLossBackward0>)
tensor(34.5333, grad_fn=<MseLossBackward0>)
tensor(0.0143, grad_fn=<MseLossBackward0>)
tensor(16.6358, grad_fn=<MseLossBackward0>)
tensor(13.6399, grad_fn=<MseLossBackward0>)
tensor(40.1785, grad_fn=<MseLossBackward0>)
tensor(12.8095, grad_fn=<MseLossBackward0>)
tensor(6.9877, grad_fn=<MseLossBackward0>)
tensor(12.2058, grad_fn=<MseLossBackward0>)
tensor(39.8788, grad_fn=<MseLossBackward0>)
tensor(24.6978, grad_fn=<MseLossBackward0>)
tensor(47.1926, grad_fn=<MseLossBackward0>)
tensor(54.6135, grad_fn=<MseLossBackward0>)
tensor(60.5435, grad_fn=<MseLossBackward0>)
tensor(0.2972, grad_fn=<MseLossBackward0>)
tensor(123.8064, grad_fn=<MseLossBackward0>)
tensor(20.3119, grad_fn=<MseLossBackward0>)
tensor(19.0148, grad_fn=<MseLossBackward0>)
tensor(11.3568, grad_fn=<MseLossB

tensor(71.3567, grad_fn=<MseLossBackward0>)
tensor(30.2790, grad_fn=<MseLossBackward0>)
tensor(5.5400, grad_fn=<MseLossBackward0>)
tensor(18.6327, grad_fn=<MseLossBackward0>)
tensor(11.7158, grad_fn=<MseLossBackward0>)
tensor(31.6824, grad_fn=<MseLossBackward0>)
tensor(10.5815, grad_fn=<MseLossBackward0>)
tensor(41.2834, grad_fn=<MseLossBackward0>)
tensor(26.7948, grad_fn=<MseLossBackward0>)
tensor(58.8139, grad_fn=<MseLossBackward0>)
tensor(12.9934, grad_fn=<MseLossBackward0>)
tensor(5.9510, grad_fn=<MseLossBackward0>)
tensor(5.5396, grad_fn=<MseLossBackward0>)
tensor(32.7854, grad_fn=<MseLossBackward0>)
tensor(1.8680, grad_fn=<MseLossBackward0>)
tensor(28.2097, grad_fn=<MseLossBackward0>)
tensor(36.8840, grad_fn=<MseLossBackward0>)
tensor(36.7505, grad_fn=<MseLossBackward0>)
tensor(5.1797, grad_fn=<MseLossBackward0>)
tensor(2.5634, grad_fn=<MseLossBackward0>)
tensor(1.0611, grad_fn=<MseLossBackward0>)
tensor(22.1075, grad_fn=<MseLossBackward0>)
tensor(20.6207, grad_fn=<MseLossBackwar

tensor(3.2124, grad_fn=<MseLossBackward0>)
tensor(6.0272, grad_fn=<MseLossBackward0>)
tensor(11.8111, grad_fn=<MseLossBackward0>)
tensor(24.6086, grad_fn=<MseLossBackward0>)
tensor(14.3964, grad_fn=<MseLossBackward0>)
tensor(19.3004, grad_fn=<MseLossBackward0>)
tensor(21.0464, grad_fn=<MseLossBackward0>)
tensor(17.1829, grad_fn=<MseLossBackward0>)
tensor(3.9437, grad_fn=<MseLossBackward0>)
tensor(13.0465, grad_fn=<MseLossBackward0>)
tensor(183.4137, grad_fn=<MseLossBackward0>)
tensor(0.6079, grad_fn=<MseLossBackward0>)
tensor(0.6322, grad_fn=<MseLossBackward0>)
tensor(0.0074, grad_fn=<MseLossBackward0>)
tensor(6.0319, grad_fn=<MseLossBackward0>)
tensor(17.3068, grad_fn=<MseLossBackward0>)
tensor(31.8804, grad_fn=<MseLossBackward0>)
tensor(4.7254, grad_fn=<MseLossBackward0>)
tensor(1.1960, grad_fn=<MseLossBackward0>)
tensor(92.4682, grad_fn=<MseLossBackward0>)
tensor(23.7630, grad_fn=<MseLossBackward0>)
tensor(169.5923, grad_fn=<MseLossBackward0>)
tensor(64.2079, grad_fn=<MseLossBackwar

tensor(1.5822, grad_fn=<MseLossBackward0>)
tensor(42.3308, grad_fn=<MseLossBackward0>)
tensor(22.2968, grad_fn=<MseLossBackward0>)
tensor(31.8013, grad_fn=<MseLossBackward0>)
tensor(111.0336, grad_fn=<MseLossBackward0>)
tensor(16.8167, grad_fn=<MseLossBackward0>)
tensor(4.2202, grad_fn=<MseLossBackward0>)
tensor(7.8338, grad_fn=<MseLossBackward0>)
tensor(10.1008, grad_fn=<MseLossBackward0>)
tensor(176.3932, grad_fn=<MseLossBackward0>)
tensor(30.8566, grad_fn=<MseLossBackward0>)
tensor(10.4134, grad_fn=<MseLossBackward0>)
tensor(8.3411, grad_fn=<MseLossBackward0>)
tensor(11.9392, grad_fn=<MseLossBackward0>)
tensor(6.5342, grad_fn=<MseLossBackward0>)
tensor(0.0349, grad_fn=<MseLossBackward0>)
tensor(121.7197, grad_fn=<MseLossBackward0>)
tensor(5.1057, grad_fn=<MseLossBackward0>)
tensor(0.3064, grad_fn=<MseLossBackward0>)
tensor(2.3757, grad_fn=<MseLossBackward0>)
tensor(8.0425, grad_fn=<MseLossBackward0>)
tensor(0.1776, grad_fn=<MseLossBackward0>)
tensor(0.3585, grad_fn=<MseLossBackward0

tensor(0.0038, grad_fn=<MseLossBackward0>)
tensor(37.3783, grad_fn=<MseLossBackward0>)
tensor(11.9165, grad_fn=<MseLossBackward0>)
tensor(0.3358, grad_fn=<MseLossBackward0>)
tensor(21.4014, grad_fn=<MseLossBackward0>)
tensor(22.5213, grad_fn=<MseLossBackward0>)
tensor(32.5348, grad_fn=<MseLossBackward0>)
tensor(13.5212, grad_fn=<MseLossBackward0>)
tensor(35.8678, grad_fn=<MseLossBackward0>)
tensor(64.1717, grad_fn=<MseLossBackward0>)
tensor(100.7392, grad_fn=<MseLossBackward0>)
tensor(9.2725, grad_fn=<MseLossBackward0>)
tensor(76.9320, grad_fn=<MseLossBackward0>)
tensor(2.9236, grad_fn=<MseLossBackward0>)
tensor(10.0528, grad_fn=<MseLossBackward0>)
tensor(0.5060, grad_fn=<MseLossBackward0>)
tensor(1.8750, grad_fn=<MseLossBackward0>)
tensor(42.2083, grad_fn=<MseLossBackward0>)
tensor(0.0927, grad_fn=<MseLossBackward0>)
tensor(73.8835, grad_fn=<MseLossBackward0>)
tensor(6.5389, grad_fn=<MseLossBackward0>)
tensor(90.4332, grad_fn=<MseLossBackward0>)
tensor(7.7362, grad_fn=<MseLossBackward

tensor(97.3260, grad_fn=<MseLossBackward0>)
tensor(44.4827, grad_fn=<MseLossBackward0>)
tensor(0.0025, grad_fn=<MseLossBackward0>)
tensor(3.7944, grad_fn=<MseLossBackward0>)
tensor(0.2149, grad_fn=<MseLossBackward0>)
tensor(7.7252, grad_fn=<MseLossBackward0>)
tensor(75.1537, grad_fn=<MseLossBackward0>)
tensor(31.7590, grad_fn=<MseLossBackward0>)
tensor(106.1367, grad_fn=<MseLossBackward0>)
tensor(27.4329, grad_fn=<MseLossBackward0>)
tensor(0.4578, grad_fn=<MseLossBackward0>)
tensor(11.5035, grad_fn=<MseLossBackward0>)
tensor(18.2292, grad_fn=<MseLossBackward0>)
tensor(9.3345, grad_fn=<MseLossBackward0>)
tensor(10.5850, grad_fn=<MseLossBackward0>)
tensor(13.3134, grad_fn=<MseLossBackward0>)
tensor(8.6339, grad_fn=<MseLossBackward0>)
tensor(201.1362, grad_fn=<MseLossBackward0>)
tensor(110.5344, grad_fn=<MseLossBackward0>)
tensor(33.7601, grad_fn=<MseLossBackward0>)
tensor(4.3425, grad_fn=<MseLossBackward0>)
tensor(42.2661, grad_fn=<MseLossBackward0>)
tensor(0.3691, grad_fn=<MseLossBackwa

tensor(0.2422, grad_fn=<MseLossBackward0>)
tensor(6.2656, grad_fn=<MseLossBackward0>)
tensor(4.3000, grad_fn=<MseLossBackward0>)
tensor(48.5238, grad_fn=<MseLossBackward0>)
tensor(35.7283, grad_fn=<MseLossBackward0>)
tensor(30.2194, grad_fn=<MseLossBackward0>)
tensor(25.3213, grad_fn=<MseLossBackward0>)
tensor(85.0151, grad_fn=<MseLossBackward0>)
tensor(0.3480, grad_fn=<MseLossBackward0>)
tensor(52.8575, grad_fn=<MseLossBackward0>)
tensor(174.0555, grad_fn=<MseLossBackward0>)
tensor(107.8873, grad_fn=<MseLossBackward0>)
tensor(163.9895, grad_fn=<MseLossBackward0>)
tensor(2.8180, grad_fn=<MseLossBackward0>)
tensor(75.9817, grad_fn=<MseLossBackward0>)
tensor(1.7249, grad_fn=<MseLossBackward0>)
tensor(17.3649, grad_fn=<MseLossBackward0>)
tensor(124.1537, grad_fn=<MseLossBackward0>)
tensor(60.9111, grad_fn=<MseLossBackward0>)
tensor(55.9939, grad_fn=<MseLossBackward0>)
tensor(40.5879, grad_fn=<MseLossBackward0>)
tensor(58.1097, grad_fn=<MseLossBackward0>)
tensor(25.3880, grad_fn=<MseLossBa

tensor(92.7347, grad_fn=<MseLossBackward0>)
tensor(0.5442, grad_fn=<MseLossBackward0>)
tensor(16.6205, grad_fn=<MseLossBackward0>)
tensor(16.6294, grad_fn=<MseLossBackward0>)
tensor(26.9138, grad_fn=<MseLossBackward0>)
tensor(18.7207, grad_fn=<MseLossBackward0>)
tensor(0.2408, grad_fn=<MseLossBackward0>)
tensor(9.7665, grad_fn=<MseLossBackward0>)
tensor(65.8477, grad_fn=<MseLossBackward0>)
tensor(5.3251, grad_fn=<MseLossBackward0>)
tensor(5.8185, grad_fn=<MseLossBackward0>)
tensor(0.4201, grad_fn=<MseLossBackward0>)
tensor(11.0867, grad_fn=<MseLossBackward0>)
tensor(67.5781, grad_fn=<MseLossBackward0>)
tensor(0.0468, grad_fn=<MseLossBackward0>)
tensor(3.8609, grad_fn=<MseLossBackward0>)
tensor(12.9739, grad_fn=<MseLossBackward0>)
tensor(1.6309, grad_fn=<MseLossBackward0>)
tensor(15.4262, grad_fn=<MseLossBackward0>)
tensor(81.7072, grad_fn=<MseLossBackward0>)
tensor(41.0340, grad_fn=<MseLossBackward0>)
tensor(0.0693, grad_fn=<MseLossBackward0>)
tensor(0.9987, grad_fn=<MseLossBackward0>)

tensor(14.4521, grad_fn=<MseLossBackward0>)
tensor(23.6071, grad_fn=<MseLossBackward0>)
tensor(3.2597, grad_fn=<MseLossBackward0>)
tensor(14.4284, grad_fn=<MseLossBackward0>)
tensor(68.3327, grad_fn=<MseLossBackward0>)
tensor(27.7717, grad_fn=<MseLossBackward0>)
tensor(0.1501, grad_fn=<MseLossBackward0>)
tensor(0.2713, grad_fn=<MseLossBackward0>)
tensor(2.8417, grad_fn=<MseLossBackward0>)
tensor(17.4431, grad_fn=<MseLossBackward0>)
tensor(12.3836, grad_fn=<MseLossBackward0>)
tensor(32.2852, grad_fn=<MseLossBackward0>)
tensor(1.9620, grad_fn=<MseLossBackward0>)
tensor(0.4216, grad_fn=<MseLossBackward0>)
tensor(0.2690, grad_fn=<MseLossBackward0>)
tensor(0.9283, grad_fn=<MseLossBackward0>)
tensor(95.3672, grad_fn=<MseLossBackward0>)
tensor(11.6965, grad_fn=<MseLossBackward0>)
tensor(37.1575, grad_fn=<MseLossBackward0>)
tensor(2.3494, grad_fn=<MseLossBackward0>)
tensor(4.2804, grad_fn=<MseLossBackward0>)
tensor(30.5943, grad_fn=<MseLossBackward0>)
tensor(3.0361, grad_fn=<MseLossBackward0>)

tensor(5.0394, grad_fn=<MseLossBackward0>)
tensor(3.9485, grad_fn=<MseLossBackward0>)
tensor(7.4743, grad_fn=<MseLossBackward0>)
tensor(98.0048, grad_fn=<MseLossBackward0>)
tensor(193.3756, grad_fn=<MseLossBackward0>)
tensor(17.7899, grad_fn=<MseLossBackward0>)
tensor(3.0657, grad_fn=<MseLossBackward0>)
tensor(174.0266, grad_fn=<MseLossBackward0>)
tensor(50.4245, grad_fn=<MseLossBackward0>)
tensor(2.6347, grad_fn=<MseLossBackward0>)
tensor(62.5725, grad_fn=<MseLossBackward0>)
tensor(1.9294, grad_fn=<MseLossBackward0>)
tensor(0.0382, grad_fn=<MseLossBackward0>)
tensor(0.2675, grad_fn=<MseLossBackward0>)
tensor(5.9448, grad_fn=<MseLossBackward0>)
tensor(33.0840, grad_fn=<MseLossBackward0>)
tensor(120.3036, grad_fn=<MseLossBackward0>)
tensor(3.7597, grad_fn=<MseLossBackward0>)
tensor(45.6877, grad_fn=<MseLossBackward0>)
tensor(0.0185, grad_fn=<MseLossBackward0>)
tensor(0.0051, grad_fn=<MseLossBackward0>)
tensor(27.8141, grad_fn=<MseLossBackward0>)
tensor(0.8647, grad_fn=<MseLossBackward0>

tensor(1.7695, grad_fn=<MseLossBackward0>)
tensor(12.5448, grad_fn=<MseLossBackward0>)
tensor(29.6079, grad_fn=<MseLossBackward0>)
tensor(10.8641, grad_fn=<MseLossBackward0>)
tensor(5.5421, grad_fn=<MseLossBackward0>)
tensor(4.1171, grad_fn=<MseLossBackward0>)
tensor(0.0110, grad_fn=<MseLossBackward0>)
tensor(17.5113, grad_fn=<MseLossBackward0>)
tensor(14.0049, grad_fn=<MseLossBackward0>)
tensor(0.4256, grad_fn=<MseLossBackward0>)
tensor(0.1530, grad_fn=<MseLossBackward0>)
tensor(47.5491, grad_fn=<MseLossBackward0>)
tensor(18.3403, grad_fn=<MseLossBackward0>)
tensor(19.2363, grad_fn=<MseLossBackward0>)
tensor(10.6867, grad_fn=<MseLossBackward0>)
tensor(9.8577, grad_fn=<MseLossBackward0>)
tensor(57.4610, grad_fn=<MseLossBackward0>)
tensor(2.2978, grad_fn=<MseLossBackward0>)
tensor(3.9462, grad_fn=<MseLossBackward0>)
tensor(67.4466, grad_fn=<MseLossBackward0>)
tensor(52.7336, grad_fn=<MseLossBackward0>)
tensor(16.4606, grad_fn=<MseLossBackward0>)
tensor(5.0538, grad_fn=<MseLossBackward0>

tensor(0.3331, grad_fn=<MseLossBackward0>)
tensor(16.3646, grad_fn=<MseLossBackward0>)
tensor(26.2130, grad_fn=<MseLossBackward0>)
tensor(1.2470, grad_fn=<MseLossBackward0>)
tensor(20.2146, grad_fn=<MseLossBackward0>)
tensor(6.6455, grad_fn=<MseLossBackward0>)
tensor(19.2004, grad_fn=<MseLossBackward0>)
tensor(25.5511, grad_fn=<MseLossBackward0>)
tensor(63.1829, grad_fn=<MseLossBackward0>)
tensor(3.3069, grad_fn=<MseLossBackward0>)
tensor(2.9489, grad_fn=<MseLossBackward0>)
tensor(6.6265, grad_fn=<MseLossBackward0>)
tensor(51.6511, grad_fn=<MseLossBackward0>)
tensor(0.0079, grad_fn=<MseLossBackward0>)
tensor(27.2967, grad_fn=<MseLossBackward0>)
tensor(78.0746, grad_fn=<MseLossBackward0>)
tensor(7.6638, grad_fn=<MseLossBackward0>)
tensor(92.1220, grad_fn=<MseLossBackward0>)
tensor(0.0098, grad_fn=<MseLossBackward0>)
tensor(1.6879, grad_fn=<MseLossBackward0>)
tensor(1.0835, grad_fn=<MseLossBackward0>)
tensor(45.7108, grad_fn=<MseLossBackward0>)
tensor(146.1382, grad_fn=<MseLossBackward0>

tensor(19.5172, grad_fn=<MseLossBackward0>)
tensor(9.9445, grad_fn=<MseLossBackward0>)
tensor(6.0673, grad_fn=<MseLossBackward0>)
tensor(59.2418, grad_fn=<MseLossBackward0>)
tensor(17.6534, grad_fn=<MseLossBackward0>)
tensor(1.7772, grad_fn=<MseLossBackward0>)
tensor(1.0911, grad_fn=<MseLossBackward0>)
tensor(12.8575, grad_fn=<MseLossBackward0>)
tensor(52.3967, grad_fn=<MseLossBackward0>)
tensor(24.0052, grad_fn=<MseLossBackward0>)
tensor(0.1215, grad_fn=<MseLossBackward0>)
tensor(78.8136, grad_fn=<MseLossBackward0>)
tensor(1.0407, grad_fn=<MseLossBackward0>)
tensor(29.7625, grad_fn=<MseLossBackward0>)
tensor(58.8361, grad_fn=<MseLossBackward0>)
tensor(0.0999, grad_fn=<MseLossBackward0>)
tensor(24.2514, grad_fn=<MseLossBackward0>)
tensor(6.3016, grad_fn=<MseLossBackward0>)
tensor(20.9463, grad_fn=<MseLossBackward0>)
tensor(99.1576, grad_fn=<MseLossBackward0>)
tensor(19.4351, grad_fn=<MseLossBackward0>)
tensor(51.2106, grad_fn=<MseLossBackward0>)
tensor(42.4835, grad_fn=<MseLossBackward

tensor(42.8997, grad_fn=<MseLossBackward0>)
tensor(5.9432, grad_fn=<MseLossBackward0>)
tensor(0.9115, grad_fn=<MseLossBackward0>)
tensor(43.9109, grad_fn=<MseLossBackward0>)
tensor(7.7037, grad_fn=<MseLossBackward0>)
tensor(70.0102, grad_fn=<MseLossBackward0>)
tensor(28.4221, grad_fn=<MseLossBackward0>)
tensor(26.8177, grad_fn=<MseLossBackward0>)
tensor(0.1381, grad_fn=<MseLossBackward0>)
tensor(2.0235, grad_fn=<MseLossBackward0>)
tensor(27.0388, grad_fn=<MseLossBackward0>)
tensor(9.5326e-05, grad_fn=<MseLossBackward0>)
tensor(21.0865, grad_fn=<MseLossBackward0>)
tensor(2.0934, grad_fn=<MseLossBackward0>)
tensor(79.6799, grad_fn=<MseLossBackward0>)
tensor(2.7181, grad_fn=<MseLossBackward0>)
tensor(3.7776, grad_fn=<MseLossBackward0>)
tensor(32.5342, grad_fn=<MseLossBackward0>)
tensor(11.7195, grad_fn=<MseLossBackward0>)
tensor(17.9235, grad_fn=<MseLossBackward0>)
tensor(57.7397, grad_fn=<MseLossBackward0>)
tensor(0.3111, grad_fn=<MseLossBackward0>)
tensor(13.6033, grad_fn=<MseLossBackwa

tensor(23.3153, grad_fn=<MseLossBackward0>)
tensor(80.1667, grad_fn=<MseLossBackward0>)
tensor(68.9977, grad_fn=<MseLossBackward0>)
tensor(0.2416, grad_fn=<MseLossBackward0>)
tensor(11.2558, grad_fn=<MseLossBackward0>)
tensor(39.0098, grad_fn=<MseLossBackward0>)
tensor(57.2153, grad_fn=<MseLossBackward0>)
tensor(158.9219, grad_fn=<MseLossBackward0>)
tensor(6.0612, grad_fn=<MseLossBackward0>)
tensor(12.7560, grad_fn=<MseLossBackward0>)
tensor(32.1253, grad_fn=<MseLossBackward0>)
tensor(16.3109, grad_fn=<MseLossBackward0>)
tensor(35.6311, grad_fn=<MseLossBackward0>)
tensor(35.4324, grad_fn=<MseLossBackward0>)
tensor(1.1957, grad_fn=<MseLossBackward0>)
tensor(19.0952, grad_fn=<MseLossBackward0>)
tensor(3.2989, grad_fn=<MseLossBackward0>)
tensor(63.2111, grad_fn=<MseLossBackward0>)
tensor(36.4894, grad_fn=<MseLossBackward0>)
tensor(0.0404, grad_fn=<MseLossBackward0>)
tensor(17.2296, grad_fn=<MseLossBackward0>)
tensor(113.3850, grad_fn=<MseLossBackward0>)
tensor(1.6861, grad_fn=<MseLossBack

tensor(56.3456, grad_fn=<MseLossBackward0>)
tensor(1.6324, grad_fn=<MseLossBackward0>)
tensor(34.9214, grad_fn=<MseLossBackward0>)
tensor(30.0785, grad_fn=<MseLossBackward0>)
tensor(34.0635, grad_fn=<MseLossBackward0>)
tensor(0.0657, grad_fn=<MseLossBackward0>)
tensor(138.3155, grad_fn=<MseLossBackward0>)
tensor(12.9167, grad_fn=<MseLossBackward0>)
tensor(6.7409, grad_fn=<MseLossBackward0>)
tensor(0.4786, grad_fn=<MseLossBackward0>)
tensor(5.4786, grad_fn=<MseLossBackward0>)
tensor(3.0660, grad_fn=<MseLossBackward0>)
tensor(60.8412, grad_fn=<MseLossBackward0>)
tensor(33.9542, grad_fn=<MseLossBackward0>)
tensor(9.2301, grad_fn=<MseLossBackward0>)
tensor(36.9418, grad_fn=<MseLossBackward0>)
tensor(25.2967, grad_fn=<MseLossBackward0>)
tensor(8.6494, grad_fn=<MseLossBackward0>)
tensor(60.1326, grad_fn=<MseLossBackward0>)
tensor(48.4318, grad_fn=<MseLossBackward0>)
tensor(57.6464, grad_fn=<MseLossBackward0>)
tensor(0.3189, grad_fn=<MseLossBackward0>)
tensor(32.7258, grad_fn=<MseLossBackward

tensor(9.4925, grad_fn=<MseLossBackward0>)
tensor(5.3558, grad_fn=<MseLossBackward0>)
tensor(1.2671, grad_fn=<MseLossBackward0>)
tensor(79.1498, grad_fn=<MseLossBackward0>)
tensor(12.6716, grad_fn=<MseLossBackward0>)
tensor(35.4664, grad_fn=<MseLossBackward0>)
tensor(23.6029, grad_fn=<MseLossBackward0>)
tensor(2.8356, grad_fn=<MseLossBackward0>)
tensor(27.5378, grad_fn=<MseLossBackward0>)
tensor(12.0541, grad_fn=<MseLossBackward0>)
tensor(0.3751, grad_fn=<MseLossBackward0>)
tensor(12.2490, grad_fn=<MseLossBackward0>)
tensor(0.1905, grad_fn=<MseLossBackward0>)
tensor(6.8815, grad_fn=<MseLossBackward0>)
tensor(35.1926, grad_fn=<MseLossBackward0>)
tensor(30.5935, grad_fn=<MseLossBackward0>)
tensor(36.2530, grad_fn=<MseLossBackward0>)
tensor(69.3398, grad_fn=<MseLossBackward0>)
tensor(4.6642, grad_fn=<MseLossBackward0>)
tensor(16.9497, grad_fn=<MseLossBackward0>)
tensor(4.5402, grad_fn=<MseLossBackward0>)
tensor(1.3734, grad_fn=<MseLossBackward0>)
tensor(1.2742, grad_fn=<MseLossBackward0>)

tensor(3.4340, grad_fn=<MseLossBackward0>)
tensor(4.7507, grad_fn=<MseLossBackward0>)
tensor(42.2732, grad_fn=<MseLossBackward0>)
tensor(4.1817, grad_fn=<MseLossBackward0>)
tensor(136.0182, grad_fn=<MseLossBackward0>)
tensor(165.3613, grad_fn=<MseLossBackward0>)
tensor(38.1646, grad_fn=<MseLossBackward0>)
tensor(2.4807, grad_fn=<MseLossBackward0>)
tensor(213.0913, grad_fn=<MseLossBackward0>)
tensor(5.8119, grad_fn=<MseLossBackward0>)
tensor(8.4690, grad_fn=<MseLossBackward0>)
tensor(12.8557, grad_fn=<MseLossBackward0>)
tensor(13.9147, grad_fn=<MseLossBackward0>)
tensor(3.6635, grad_fn=<MseLossBackward0>)
tensor(33.9375, grad_fn=<MseLossBackward0>)
tensor(29.4235, grad_fn=<MseLossBackward0>)
tensor(1.9244, grad_fn=<MseLossBackward0>)
tensor(4.0647, grad_fn=<MseLossBackward0>)
tensor(17.1617, grad_fn=<MseLossBackward0>)
tensor(0.4327, grad_fn=<MseLossBackward0>)
tensor(1.2385, grad_fn=<MseLossBackward0>)
tensor(0.2528, grad_fn=<MseLossBackward0>)
tensor(23.8324, grad_fn=<MseLossBackward0

tensor(12.3126, grad_fn=<MseLossBackward0>)
tensor(21.1403, grad_fn=<MseLossBackward0>)
tensor(2.7449, grad_fn=<MseLossBackward0>)
tensor(21.3853, grad_fn=<MseLossBackward0>)
tensor(33.0142, grad_fn=<MseLossBackward0>)
tensor(0.6882, grad_fn=<MseLossBackward0>)
tensor(0.1306, grad_fn=<MseLossBackward0>)
tensor(115.0601, grad_fn=<MseLossBackward0>)
tensor(0.3263, grad_fn=<MseLossBackward0>)
tensor(14.2130, grad_fn=<MseLossBackward0>)
tensor(92.0408, grad_fn=<MseLossBackward0>)
tensor(53.7495, grad_fn=<MseLossBackward0>)
tensor(3.1526, grad_fn=<MseLossBackward0>)
tensor(0.0557, grad_fn=<MseLossBackward0>)
tensor(0.9222, grad_fn=<MseLossBackward0>)
tensor(0.3532, grad_fn=<MseLossBackward0>)
tensor(1.6460, grad_fn=<MseLossBackward0>)
tensor(30.3417, grad_fn=<MseLossBackward0>)
tensor(0.0197, grad_fn=<MseLossBackward0>)
tensor(74.6050, grad_fn=<MseLossBackward0>)
tensor(9.1193, grad_fn=<MseLossBackward0>)
tensor(25.1928, grad_fn=<MseLossBackward0>)
tensor(42.8760, grad_fn=<MseLossBackward0>

tensor(11.3346, grad_fn=<MseLossBackward0>)
tensor(1.8049, grad_fn=<MseLossBackward0>)
tensor(11.9853, grad_fn=<MseLossBackward0>)
tensor(33.9245, grad_fn=<MseLossBackward0>)
tensor(25.3602, grad_fn=<MseLossBackward0>)
tensor(2.1103, grad_fn=<MseLossBackward0>)
tensor(0.1218, grad_fn=<MseLossBackward0>)
tensor(70.6119, grad_fn=<MseLossBackward0>)
tensor(20.5058, grad_fn=<MseLossBackward0>)
tensor(32.9792, grad_fn=<MseLossBackward0>)
tensor(3.6053, grad_fn=<MseLossBackward0>)
tensor(24.0632, grad_fn=<MseLossBackward0>)
tensor(4.2592, grad_fn=<MseLossBackward0>)
tensor(6.3765, grad_fn=<MseLossBackward0>)
tensor(58.7433, grad_fn=<MseLossBackward0>)
tensor(8.5295, grad_fn=<MseLossBackward0>)
tensor(1.3971, grad_fn=<MseLossBackward0>)
tensor(2.5161, grad_fn=<MseLossBackward0>)
tensor(120.3008, grad_fn=<MseLossBackward0>)
tensor(1.2903, grad_fn=<MseLossBackward0>)
tensor(14.5179, grad_fn=<MseLossBackward0>)
tensor(24.0781, grad_fn=<MseLossBackward0>)
tensor(19.3643, grad_fn=<MseLossBackward0

tensor(114.9867, grad_fn=<MseLossBackward0>)
tensor(32.0056, grad_fn=<MseLossBackward0>)
tensor(7.8338, grad_fn=<MseLossBackward0>)
tensor(8.5070, grad_fn=<MseLossBackward0>)
tensor(47.6139, grad_fn=<MseLossBackward0>)
tensor(4.8782, grad_fn=<MseLossBackward0>)
tensor(156.3526, grad_fn=<MseLossBackward0>)
tensor(13.9844, grad_fn=<MseLossBackward0>)
tensor(0.7525, grad_fn=<MseLossBackward0>)
tensor(56.4775, grad_fn=<MseLossBackward0>)
tensor(105.9069, grad_fn=<MseLossBackward0>)
tensor(11.5639, grad_fn=<MseLossBackward0>)
tensor(54.5277, grad_fn=<MseLossBackward0>)
tensor(8.8292, grad_fn=<MseLossBackward0>)
tensor(0.0981, grad_fn=<MseLossBackward0>)
tensor(1.4173, grad_fn=<MseLossBackward0>)
tensor(12.1485, grad_fn=<MseLossBackward0>)
tensor(1.7815, grad_fn=<MseLossBackward0>)
tensor(57.5867, grad_fn=<MseLossBackward0>)
tensor(31.1428, grad_fn=<MseLossBackward0>)
tensor(20.0759, grad_fn=<MseLossBackward0>)
tensor(5.1655, grad_fn=<MseLossBackward0>)
tensor(132.6742, grad_fn=<MseLossBackw

tensor(1.5441, grad_fn=<MseLossBackward0>)
tensor(56.0491, grad_fn=<MseLossBackward0>)
tensor(6.7871, grad_fn=<MseLossBackward0>)
tensor(3.5334, grad_fn=<MseLossBackward0>)
tensor(9.3507, grad_fn=<MseLossBackward0>)
tensor(19.6010, grad_fn=<MseLossBackward0>)
tensor(46.9683, grad_fn=<MseLossBackward0>)
tensor(20.4427, grad_fn=<MseLossBackward0>)
tensor(140.2402, grad_fn=<MseLossBackward0>)
tensor(15.7653, grad_fn=<MseLossBackward0>)
tensor(20.5933, grad_fn=<MseLossBackward0>)
tensor(2.2507, grad_fn=<MseLossBackward0>)
tensor(1.9460, grad_fn=<MseLossBackward0>)
tensor(0.4010, grad_fn=<MseLossBackward0>)
tensor(0.4504, grad_fn=<MseLossBackward0>)
tensor(22.4418, grad_fn=<MseLossBackward0>)
tensor(2.1779, grad_fn=<MseLossBackward0>)
tensor(40.3746, grad_fn=<MseLossBackward0>)
tensor(85.9667, grad_fn=<MseLossBackward0>)
tensor(20.4500, grad_fn=<MseLossBackward0>)
tensor(10.0865, grad_fn=<MseLossBackward0>)
tensor(0.6332, grad_fn=<MseLossBackward0>)
tensor(2.4421, grad_fn=<MseLossBackward0>

tensor(74.9508, grad_fn=<MseLossBackward0>)
tensor(1.0215, grad_fn=<MseLossBackward0>)
tensor(16.8300, grad_fn=<MseLossBackward0>)
tensor(0.2814, grad_fn=<MseLossBackward0>)
tensor(70.0185, grad_fn=<MseLossBackward0>)
tensor(19.8629, grad_fn=<MseLossBackward0>)
tensor(49.1724, grad_fn=<MseLossBackward0>)
tensor(15.3725, grad_fn=<MseLossBackward0>)
tensor(29.6144, grad_fn=<MseLossBackward0>)
tensor(13.4246, grad_fn=<MseLossBackward0>)
tensor(1.8265, grad_fn=<MseLossBackward0>)
tensor(55.7498, grad_fn=<MseLossBackward0>)
tensor(9.9750, grad_fn=<MseLossBackward0>)
tensor(1.4956, grad_fn=<MseLossBackward0>)
tensor(9.3972, grad_fn=<MseLossBackward0>)
tensor(15.8865, grad_fn=<MseLossBackward0>)
tensor(45.7286, grad_fn=<MseLossBackward0>)
tensor(23.3690, grad_fn=<MseLossBackward0>)
tensor(0.6517, grad_fn=<MseLossBackward0>)
tensor(47.7871, grad_fn=<MseLossBackward0>)
tensor(24.9481, grad_fn=<MseLossBackward0>)
tensor(5.2030, grad_fn=<MseLossBackward0>)
tensor(7.3752, grad_fn=<MseLossBackward0

tensor(122.8831, grad_fn=<MseLossBackward0>)
tensor(0.7227, grad_fn=<MseLossBackward0>)
tensor(0.0318, grad_fn=<MseLossBackward0>)
tensor(0.4712, grad_fn=<MseLossBackward0>)
tensor(0.0130, grad_fn=<MseLossBackward0>)
tensor(8.2868, grad_fn=<MseLossBackward0>)
tensor(33.5225, grad_fn=<MseLossBackward0>)
tensor(214.1796, grad_fn=<MseLossBackward0>)
tensor(5.6002, grad_fn=<MseLossBackward0>)
tensor(97.2495, grad_fn=<MseLossBackward0>)
tensor(5.1923, grad_fn=<MseLossBackward0>)
tensor(105.6685, grad_fn=<MseLossBackward0>)
tensor(0.1595, grad_fn=<MseLossBackward0>)
tensor(105.4074, grad_fn=<MseLossBackward0>)
tensor(2.3545, grad_fn=<MseLossBackward0>)
tensor(3.1703, grad_fn=<MseLossBackward0>)
tensor(81.4070, grad_fn=<MseLossBackward0>)
tensor(26.0209, grad_fn=<MseLossBackward0>)
tensor(58.8792, grad_fn=<MseLossBackward0>)
tensor(9.4876, grad_fn=<MseLossBackward0>)
tensor(69.5343, grad_fn=<MseLossBackward0>)
tensor(7.3379, grad_fn=<MseLossBackward0>)
tensor(209.6597, grad_fn=<MseLossBackwar