In [108]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [187]:
class SphereDatasetAnchor(Dataset):
    def __init__(self, extra_modes, num_points, transform_fn=None):
        self.n = num_points
        self.m = extra_modes + 1
        self.points = np.random.uniform(-1,1,size=(self.n, 3))
        mags = np.linalg.norm(self.points, axis=-1)
        self.points /= mags[...,None]
        self.disks = np.array([np.random.uniform(-1,1,size=(3)) for i in range(extra_modes)])
        self.disks = np.array([disk / np.linalg.norm(disk) for disk in self.disks])
        self.disk_points = torch.tensor([self.project_points(self.points, disk) for disk in self.disks]).float()
        self.points = torch.tensor(self.points).float()
        self.disks = torch.tensor(self.disks).float()
        self.transform_fn = transform_fn
        
    def project_points(self, points, disk):
        output_points = []
        for point in points:
            arrow = point*1/np.dot(point, disk) - disk
            arrow /= np.linalg.norm(arrow)
            output_points.append(arrow)
        return np.array(output_points)
    
    def __len__(self):
        return self.m-1
    
    def __getitem__(self, idx):
        e_points = self.disk_points[idx]
        if self.transform_fn is not None:
            return self.transform_fn(self.points), self.transform_fn(e_points)
        return self.points, e_points
    
    def all_points(self):
        return [self.points] + list(self.disk_points)

In [188]:
class Mapper(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(3, 128)
    
    def forward(self, X):
        return self.linear(X)

mapper = Mapper()

In [189]:
dataset = SphereDatasetAnchor(3, 1000, transform_fn=mapper)

In [195]:
from torch.nn.functional import cross_entropy
class ContrastiveModel(nn.Module):
    def __init__(self, modes):
        super().__init__()
        self.linears = nn.ModuleList([nn.Linear(128, 3) for i in range(modes)])
    
    def loss(self, corr_mat):
        losses_row = cross_entropy(corr_mat, torch.arange(0, len(corr_mat)).long())
        losses_col = cross_entropy(torch.transpose(corr_mat,0,1), torch.arange(0, len(corr_mat)).long())
        return torch.mean(losses_row) + torch.mean(losses_col)
    
    def forward(self, a, b, a_points, b_points):
        a_points = self.linears[a](a_points)
        b_points = self.linears[b](b_points)
        corr_mat = torch.sum(a_points[:,None] * b_points[None,:], dim=-1)
        return corr_mat
    
    def run_evaluate(self, i, points):
        return self.linears[i](points)

In [196]:
model = ContrastiveModel(4)

In [197]:
from torch.optim import Adam
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
optim = Adam(model.parameters(), lr=0.01)

In [198]:
EPOCH = 1000
for epoch in range(EPOCH):
    avg_loss = 0
    ptrs1, ptrs2 = dataset[0]
    optim.zero_grad()
    corr_mat = model(0, 1, ptrs1, ptrs2)
    print(corr_mat)
    loss = model.loss(corr_mat)
    print(loss)
    loss.backward()
    optim.step()
    avg_loss += loss.item()/len(dataloader)
    if epoch % 10 == 0:
        print(f'epoch: {epoch}, loss: {avg_loss}')

tensor([[-0.2147, -0.2226, -0.1209,  ..., -0.2117, -0.2361, -0.2359],
        [-0.2432, -0.2517, -0.1367,  ..., -0.2398, -0.2672, -0.2671],
        [-0.0854, -0.0601, -0.1434,  ..., -0.0877, -0.0579, -0.0656],
        ...,
        [-0.2057, -0.2108, -0.1224,  ..., -0.2031, -0.2234, -0.2240],
        [ 0.1946,  0.2656, -0.1119,  ...,  0.1838,  0.2961,  0.2788],
        [-0.1187, -0.0873, -0.1648,  ..., -0.1207, -0.0889, -0.0993]],
       grad_fn=<SumBackward1>)
tensor(13.8340, grad_fn=<AddBackward0>)
epoch: 0, loss: 4.611321449279785
tensor([[ 0.1925,  0.1165,  0.0604,  ...,  0.1903,  0.1618,  0.1941],
        [ 0.2416,  0.1518,  0.0885,  ...,  0.2391,  0.2048,  0.2428],
        [ 0.0992,  0.0571,  0.0102,  ...,  0.0975,  0.0849,  0.1032],
        ...,
        [ 0.1794,  0.1075,  0.0541,  ...,  0.1773,  0.1504,  0.1810],
        [-0.2039, -0.1505, -0.1602,  ..., -0.2038, -0.1740, -0.1952],
        [ 0.0710,  0.0360,  0.0068,  ...,  0.0699,  0.0574,  0.0724]],
       grad_fn=<SumBackward

tensor([[ 0.0977,  0.0272,  0.0096,  ...,  0.0966,  0.0633,  0.0922],
        [ 0.1223,  0.0368,  0.0136,  ...,  0.1210,  0.0810,  0.1161],
        [ 0.0198, -0.0021,  0.0040,  ...,  0.0198,  0.0072,  0.0158],
        ...,
        [ 0.0831,  0.0230,  0.0058,  ...,  0.0822,  0.0542,  0.0789],
        [-0.1650, -0.0747, -0.0139,  ..., -0.1625, -0.1275, -0.1656],
        [-0.0748, -0.0198, -0.0381,  ..., -0.0749, -0.0426, -0.0642]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.0890,  0.0256, -0.0010,  ...,  0.0878,  0.0599,  0.0863],
        [ 0.1112,  0.0354,  0.0019,  ...,  0.1097,  0.0767,  0.1083],
        [ 0.0193, -0.0035,  0.0005,  ...,  0.0193,  0.0066,  0.0156],
        ...,
        [ 0.0760,  0.0211, -0.0036,  ...,  0.0748,  0.0511,  0.0740],
        [-0.1465, -0.0762, -0.0046,  ..., -0.1438, -0.1214, -0.1518],
        [-0.0659, -0.0245, -0.0335,  ..., -0.0659, -0.0424, -0.0589]],
       grad_fn=<SumBackward1>)
tensor(13.8099, grad_fn=<AddBa

tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1004,  0.0525, -0.0017,  ...,  0.0984,  0.0842,  0.1051],
        [ 0.1233,  0.0581,  0.0046,  ...,  0.1213,  0.0978,  0.1257],
        [ 0.0172,  0.0102, -0.0049,  ...,  0.0167,  0.0161,  0.0193],
        ...,
        [ 0.0869,  0.0489, -0.0054,  ...,  0.0850,  0.0760,  0.0929],
        [-0.1698, -0.0660, -0.0289,  ..., -0.1679, -0.1211, -0.1640],
        [-0.0608,  0.0076, -0.0430,  ..., -0.0618, -0.0160, -0.0420]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1040,  0.0528, -0.0019,  ...,  0.1019,  0.0861,  0.1084],
        [ 0.1279,  0.0584,  0.0050,  ...,  0.1257,  0.1001,  0.1297],
        [ 0.0181,  0.0105, -0.0059,  ...,  0.0175,  0.0168,  0.0204],
        ...,
        [ 0.0899,  0.0493, -0.0059,  ...,  0.0879,  0.0777,  0.0957],
        [-0.1757, -0.0660, -0.0326,  ..., -0.1739, -0.1233, -0.1684],
        [-0.0640,  0.0082, -0.0476,  ..., -0.0651, -0.0163, -0.0437]],
       grad_fn=<SumBa

tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1033,  0.0439,  0.0048,  ...,  0.1017,  0.0785,  0.1036],
        [ 0.1254,  0.0515,  0.0088,  ...,  0.1235,  0.0935,  0.1245],
        [ 0.0184,  0.0085, -0.0047,  ...,  0.0179,  0.0154,  0.0198],
        ...,
        [ 0.0898,  0.0401,  0.0014,  ...,  0.0883,  0.0700,  0.0912],
        [-0.1678, -0.0652, -0.0291,  ..., -0.1659, -0.1196, -0.1620],
        [-0.0591, -0.0016, -0.0360,  ..., -0.0597, -0.0228, -0.0449]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1048,  0.0445,  0.0052,  ...,  0.1032,  0.0795,  0.1050],
        [ 0.1273,  0.0520,  0.0094,  ...,  0.1254,  0.0946,  0.1262],
        [ 0.0186,  0.0087, -0.0049,  ...,  0.0181,  0.0156,  0.0200],
        ...,
        [ 0.0911,  0.0407,  0.0017,  ...,  0.0895,  0.0710,  0.0924],
        [-0.1707, -0.0656, -0.0304,  ..., -0.1688, -0.1209, -0.1643],
        [-0.0604, -0.0008, -0.0376,  ..., -0.0610, -0.0226, -0.0455]],
       grad_fn=<SumBa

tensor([[ 0.1046,  0.0446,  0.0052,  ...,  0.1030,  0.0795,  0.1049],
        [ 0.1268,  0.0515,  0.0102,  ...,  0.1250,  0.0939,  0.1255],
        [ 0.0192,  0.0080, -0.0040,  ...,  0.0187,  0.0153,  0.0202],
        ...,
        [ 0.0911,  0.0408,  0.0018,  ...,  0.0895,  0.0711,  0.0925],
        [-0.1681, -0.0648, -0.0314,  ..., -0.1664, -0.1191, -0.1617],
        [-0.0583, -0.0016, -0.0349,  ..., -0.0589, -0.0226, -0.0445]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1038,  0.0446,  0.0043,  ...,  0.1022,  0.0793,  0.1043],
        [ 0.1260,  0.0514,  0.0092,  ...,  0.1241,  0.0937,  0.1249],
        [ 0.0190,  0.0080, -0.0042,  ...,  0.0186,  0.0153,  0.0200],
        ...,
        [ 0.0903,  0.0408,  0.0010,  ...,  0.0888,  0.0708,  0.0920],
        [-0.1673, -0.0646, -0.0302,  ..., -0.1655, -0.1187, -0.1611],
        [-0.0583, -0.0015, -0.0350,  ..., -0.0588, -0.0226, -0.0444]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBa

tensor([[ 0.1047,  0.0450,  0.0052,  ...,  0.1031,  0.0798,  0.1051],
        [ 0.1272,  0.0522,  0.0100,  ...,  0.1253,  0.0946,  0.1260],
        [ 0.0193,  0.0086, -0.0041,  ...,  0.0188,  0.0158,  0.0204],
        ...,
        [ 0.0911,  0.0412,  0.0019,  ...,  0.0896,  0.0713,  0.0926],
        [-0.1688, -0.0652, -0.0309,  ..., -0.1670, -0.1198, -0.1625],
        [-0.0586, -0.0009, -0.0353,  ..., -0.0592, -0.0222, -0.0444]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1043,  0.0450,  0.0050,  ...,  0.1026,  0.0797,  0.1047],
        [ 0.1266,  0.0522,  0.0097,  ...,  0.1248,  0.0944,  0.1257],
        [ 0.0191,  0.0086, -0.0041,  ...,  0.0187,  0.0157,  0.0203],
        ...,
        [ 0.0907,  0.0412,  0.0017,  ...,  0.0891,  0.0712,  0.0923],
        [-0.1684, -0.0654, -0.0305,  ..., -0.1666, -0.1197, -0.1623],
        [-0.0586, -0.0010, -0.0353,  ..., -0.0592, -0.0223, -0.0445]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBa

epoch: 130, loss: 4.603254318237305
tensor([[ 0.1044,  0.0453,  0.0049,  ...,  0.1028,  0.0800,  0.1050],
        [ 0.1268,  0.0524,  0.0097,  ...,  0.1250,  0.0946,  0.1259],
        [ 0.0191,  0.0086, -0.0041,  ...,  0.0187,  0.0157,  0.0203],
        ...,
        [ 0.0909,  0.0415,  0.0015,  ...,  0.0893,  0.0715,  0.0926],
        [-0.1686, -0.0651, -0.0308,  ..., -0.1668, -0.1196, -0.1623],
        [-0.0586, -0.0005, -0.0356,  ..., -0.0592, -0.0219, -0.0442]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1044,  0.0454,  0.0048,  ...,  0.1027,  0.0800,  0.1050],
        [ 0.1268,  0.0524,  0.0096,  ...,  0.1249,  0.0946,  0.1258],
        [ 0.0191,  0.0086, -0.0041,  ...,  0.0187,  0.0157,  0.0203],
        ...,
        [ 0.0908,  0.0416,  0.0014,  ...,  0.0892,  0.0715,  0.0926],
        [-0.1686, -0.0652, -0.0307,  ..., -0.1668, -0.1197, -0.1623],
        [-0.0586, -0.0005, -0.0356,  ..., -0.0593, -0.0219, -0.0443]],
       grad_fn=<SumBackwa

tensor([[ 0.1040,  0.0449,  0.0040,  ...,  0.1023,  0.0796,  0.1046],
        [ 0.1263,  0.0519,  0.0087,  ...,  0.1244,  0.0942,  0.1254],
        [ 0.0191,  0.0085, -0.0043,  ...,  0.0186,  0.0157,  0.0203],
        ...,
        [ 0.0905,  0.0412,  0.0008,  ...,  0.0889,  0.0712,  0.0923],
        [-0.1678, -0.0646, -0.0294,  ..., -0.1659, -0.1191, -0.1617],
        [-0.0583, -0.0003, -0.0352,  ..., -0.0590, -0.0217, -0.0440]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1049,  0.0458,  0.0058,  ...,  0.1033,  0.0804,  0.1054],
        [ 0.1274,  0.0531,  0.0109,  ...,  0.1256,  0.0952,  0.1264],
        [ 0.0192,  0.0087, -0.0040,  ...,  0.0187,  0.0158,  0.0204],
        ...,
        [ 0.0913,  0.0420,  0.0023,  ...,  0.0897,  0.0719,  0.0929],
        [-0.1695, -0.0662, -0.0324,  ..., -0.1677, -0.1205, -0.1631],
        [-0.0591, -0.0010, -0.0364,  ..., -0.0597, -0.0223, -0.0447]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBa

tensor(13.8098, grad_fn=<AddBackward0>)
epoch: 170, loss: 4.603262901306152
tensor([[ 0.1039,  0.0422, -0.0014,  ...,  0.1021,  0.0786,  0.1047],
        [ 0.1253,  0.0476,  0.0012,  ...,  0.1233,  0.0920,  0.1247],
        [ 0.0215,  0.0108, -0.0026,  ...,  0.0210,  0.0181,  0.0228],
        ...,
        [ 0.0909,  0.0394, -0.0033,  ...,  0.0892,  0.0709,  0.0929],
        [-0.1595, -0.0505, -0.0117,  ..., -0.1575, -0.1083, -0.1534],
        [-0.0519,  0.0090, -0.0248,  ..., -0.0525, -0.0139, -0.0375]],
       grad_fn=<SumBackward1>)
tensor(13.8099, grad_fn=<AddBackward0>)
tensor([[ 0.1081,  0.0525,  0.0120,  ...,  0.1065,  0.0856,  0.1092],
        [ 0.1308,  0.0606,  0.0179,  ...,  0.1290,  0.1008,  0.1304],
        [ 0.0217,  0.0120, -0.0010,  ...,  0.0213,  0.0188,  0.0231],
        ...,
        [ 0.0944,  0.0482,  0.0080,  ...,  0.0929,  0.0767,  0.0966],
        [-0.1685, -0.0701, -0.0368,  ..., -0.1668, -0.1221, -0.1627],
        [-0.0565,  0.0002, -0.0359,  ..., -0.0572, -0.02

tensor([[ 0.1128,  0.0544,  0.0133,  ...,  0.1111,  0.0888,  0.1136],
        [ 0.1349,  0.0614,  0.0180,  ...,  0.1331,  0.1033,  0.1343],
        [ 0.0274,  0.0170,  0.0040,  ...,  0.0270,  0.0241,  0.0287],
        ...,
        [ 0.0993,  0.0506,  0.0099,  ...,  0.0977,  0.0804,  0.1012],
        [-0.1596, -0.0573, -0.0223,  ..., -0.1578, -0.1113, -0.1536],
        [-0.0498,  0.0082, -0.0272,  ..., -0.0504, -0.0131, -0.0354]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
epoch: 190, loss: 4.603254000345866
tensor([[ 0.1138,  0.0561,  0.0160,  ...,  0.1121,  0.0900,  0.1144],
        [ 0.1361,  0.0635,  0.0213,  ...,  0.1343,  0.1048,  0.1353],
        [ 0.0276,  0.0173,  0.0044,  ...,  0.0272,  0.0243,  0.0289],
        ...,
        [ 0.1001,  0.0520,  0.0122,  ...,  0.0985,  0.0813,  0.1019],
        [-0.1613, -0.0603, -0.0274,  ..., -0.1595, -0.1134, -0.1550],
        [-0.0506,  0.0068, -0.0294,  ..., -0.0512, -0.0141, -0.0361]],
       grad_fn=<SumBackwa

tensor([[ 0.1142,  0.0552,  0.0142,  ...,  0.1125,  0.0899,  0.1148],
        [ 0.1364,  0.0621,  0.0188,  ...,  0.1345,  0.1044,  0.1356],
        [ 0.0293,  0.0188,  0.0060,  ...,  0.0289,  0.0259,  0.0305],
        ...,
        [ 0.1007,  0.0514,  0.0110,  ...,  0.0991,  0.0814,  0.1025],
        [-0.1572, -0.0541, -0.0187,  ..., -0.1553, -0.1086, -0.1511],
        [-0.0479,  0.0101, -0.0247,  ..., -0.0485, -0.0113, -0.0336]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
epoch: 210, loss: 4.603254318237305
tensor([[ 0.1146,  0.0559,  0.0155,  ...,  0.1130,  0.0903,  0.1152],
        [ 0.1370,  0.0630,  0.0205,  ...,  0.1351,  0.1050,  0.1360],
        [ 0.0293,  0.0188,  0.0061,  ...,  0.0289,  0.0259,  0.0305],
        ...,
        [ 0.1010,  0.0520,  0.0121,  ...,  0.0995,  0.0818,  0.1028],
        [-0.1582, -0.0556, -0.0213,  ..., -0.1565, -0.1097, -0.1521],
        [-0.0485,  0.0094, -0.0258,  ..., -0.0491, -0.0119, -0.0341]],
       grad_fn=<SumBackwa

tensor([[ 0.1146,  0.0555,  0.0148,  ...,  0.1129,  0.0902,  0.1152],
        [ 0.1369,  0.0626,  0.0195,  ...,  0.1350,  0.1048,  0.1360],
        [ 0.0297,  0.0191,  0.0064,  ...,  0.0292,  0.0263,  0.0309],
        ...,
        [ 0.1010,  0.0517,  0.0115,  ...,  0.0995,  0.0818,  0.1028],
        [-0.1573, -0.0542, -0.0191,  ..., -0.1555, -0.1086, -0.1512],
        [-0.0479,  0.0101, -0.0247,  ..., -0.0485, -0.0113, -0.0336]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
epoch: 230, loss: 4.603254318237305
tensor([[ 0.1148,  0.0558,  0.0152,  ...,  0.1132,  0.0904,  0.1154],
        [ 0.1372,  0.0629,  0.0200,  ...,  0.1353,  0.1051,  0.1363],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0292,  0.0263,  0.0309],
        ...,
        [ 0.1012,  0.0520,  0.0118,  ...,  0.0997,  0.0819,  0.1030],
        [-0.1578, -0.0547, -0.0199,  ..., -0.1559, -0.1090, -0.1516],
        [-0.0481,  0.0099, -0.0251,  ..., -0.0487, -0.0114, -0.0338]],
       grad_fn=<SumBackwa

tensor([[ 0.1148,  0.0558,  0.0151,  ...,  0.1132,  0.0904,  0.1154],
        [ 0.1372,  0.0628,  0.0199,  ...,  0.1353,  0.1051,  0.1363],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0292,  0.0263,  0.0309],
        ...,
        [ 0.1013,  0.0520,  0.0118,  ...,  0.0997,  0.0820,  0.1030],
        [-0.1577, -0.0545, -0.0197,  ..., -0.1559, -0.1089, -0.1515],
        [-0.0480,  0.0101, -0.0250,  ..., -0.0486, -0.0113, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
epoch: 250, loss: 4.603254000345866
tensor([[ 0.1149,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0629,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0119,  ...,  0.0998,  0.0820,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1561, -0.1091, -0.1517],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackwa

tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0629,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0119,  ...,  0.0998,  0.0820,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
epoch: 270, loss: 4.603254318237305
tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1156],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0547, -0.0201,  ..., -0.1561, -0.1091, -0.1517],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackwa

tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1091, -0.1517],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
epoch: 290, loss: 4.603254000345866
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0547, -0.0201,  ..., -0.1561, -0.1091, -0.1517],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackwa

tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1561, -0.1091, -0.1517],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
epoch: 310, loss: 4.603254318237305
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1091, -0.1517],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackwa

tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
epoch: 330, loss: 4.603254000345866
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1091, -0.1517],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackwa

tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1091, -0.1517],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
epoch: 350, loss: 4.603254000345866
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1091, -0.1517],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackwa

tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1091, -0.1517],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
epoch: 370, loss: 4.603254318237305
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1091, -0.1517],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackwa

tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1091, -0.1517],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
epoch: 390, loss: 4.603254318237305
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1517],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackwa

tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
epoch: 410, loss: 4.603254318237305
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackwa

tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBa

tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBa

tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBa

tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBa

tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBa

tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBa

tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1579, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBa

tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBa

epoch: 580, loss: 4.603254318237305
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackwa

tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
epoch: 600, loss: 4.603254318237305
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackwa

tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0201,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBa

tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBa

tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBa

tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBa

tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBa

tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0309],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBa

tensor(13.8098, grad_fn=<AddBackward0>)
epoch: 730, loss: 4.603254318237305
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0905,  0.1155],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.01

tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0906,  0.1155],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
epoch: 750, loss: 4.603254318237305
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0906,  0.1155],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackwa

tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0906,  0.1155],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0153,  ...,  0.1133,  0.0906,  0.1155],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBa

tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1155],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1155],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0114, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBa

tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1156],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0113, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1156],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0113, -0.0337]],
       grad_fn=<SumBa

tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1156],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0113, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1156],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0297,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0113, -0.0337]],
       grad_fn=<SumBa

tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1156],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0298,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0113, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1156],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0298,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0113, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBa

tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1156],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0298,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0113, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1156],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0298,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0113, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBa

tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1156],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0298,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0113, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1156],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0298,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0100, -0.0251,  ..., -0.0487, -0.0113, -0.0337]],
       grad_fn=<SumBa

tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1156],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0298,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0101, -0.0251,  ..., -0.0487, -0.0113, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1156],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0298,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0101, -0.0251,  ..., -0.0487, -0.0113, -0.0337]],
       grad_fn=<SumBa

tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1156],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0298,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0101, -0.0251,  ..., -0.0487, -0.0113, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1156],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0298,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0101, -0.0251,  ..., -0.0487, -0.0113, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBa

tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1156],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0298,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0101, -0.0251,  ..., -0.0487, -0.0113, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1156],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0298,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0101, -0.0251,  ..., -0.0487, -0.0113, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBa

tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1156],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0298,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0101, -0.0251,  ..., -0.0487, -0.0113, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1156],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0298,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0101, -0.0251,  ..., -0.0487, -0.0113, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBa

tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1156],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0298,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0101, -0.0251,  ..., -0.0487, -0.0113, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBackward0>)
tensor([[ 0.1150,  0.0559,  0.0154,  ...,  0.1133,  0.0906,  0.1156],
        [ 0.1373,  0.0630,  0.0202,  ...,  0.1355,  0.1052,  0.1364],
        [ 0.0298,  0.0192,  0.0064,  ...,  0.0293,  0.0263,  0.0310],
        ...,
        [ 0.1014,  0.0521,  0.0120,  ...,  0.0998,  0.0821,  0.1031],
        [-0.1578, -0.0546, -0.0200,  ..., -0.1560, -0.1090, -0.1516],
        [-0.0481,  0.0101, -0.0251,  ..., -0.0487, -0.0113, -0.0337]],
       grad_fn=<SumBackward1>)
tensor(13.8098, grad_fn=<AddBa

In [199]:
orig_points = dataset.all_points()

In [200]:
import plotly.express as px
import pandas as pd

In [201]:
df = pd.DataFrame(np.concatenate(orig_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = ['0']*1000 + ['1']*1000 + ['2']*1000 + ['3']*1000
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type')
fig.show()

In [202]:
pred_points = [model.run_evaluate(i, mapper(points)).detach().numpy() for i, points in enumerate(orig_points)]

In [203]:
df = pd.DataFrame(np.concatenate(pred_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = ['0']*1000 + ['1']*1000 + ['2']*1000 + ['3']*1000
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type')
fig.show()