In [1]:
import numpy as np
import trimesh
import scipy

In [6]:
# create test and train set 
# save it as txt files
import glob
import random
def get_random_train_list():
    files = glob.glob('data/xyz/*.xyz')
    train = random.sample(files, 120)
    train_set = set(train)
    train_list = []
    test_list = []
    for file in files:
        name = file.split('/')[-1].split('.')[0]
        if file in train_set:
            train_list.append(name)
        else:
            test_list.append(name)

    return train_list, test_list
def write_list_to_file(filename,my_list):
    with open(filename, 'w') as f:
        for item in my_list:
            f.write("%s\n" % item)
train_list, test_list = get_random_train_list()
write_list_to_file('data/trainingset.txt', train_list)
write_list_to_file('data/testset.txt', test_list)

In [6]:
from scipy.sparse.linalg import spsolve
from scipy.sparse import coo_matrix, eye


from trimesh import triangles


def filter_humphrey(mesh,
                    alpha=0.1,
                    beta=0.5,
                    iterations=10,
                    laplacian_operator=None):
    """
    Smooth a mesh in-place using laplacian smoothing
    and Humphrey filtering.
    Articles
    "Improved Laplacian Smoothing of Noisy Surface Meshes"
    J. Vollmer, R. Mencl, and H. Muller
    Parameters
    ------------
    mesh : trimesh.Trimesh
      Mesh to be smoothed in place
    alpha : float
      Controls shrinkage, range is 0.0 - 1.0
      If 0.0, not considered
      If 1.0, no smoothing
    beta : float
      Controls how aggressive smoothing is
      If 0.0, no smoothing
      If 1.0, full aggressiveness
    iterations : int
      Number of passes to run filter
    laplacian_operator : None or scipy.sparse.coo.coo_matrix
      Sparse matrix laplacian operator
      Will be autogenerated if None
    """
    # if the laplacian operator was not passed create it here
    if laplacian_operator is None:
        laplacian_operator = laplacian_calculation(mesh)

    # get mesh vertices as vanilla numpy array
    vertices = mesh.vertices.copy().view(np.ndarray)
    # save original unmodified vertices
    original = vertices.copy()

    # run through iterations of filter
    for _index in range(iterations):
        vert_q = vertices.copy()
        vertices = laplacian_operator.dot(vertices)
        vert_b = vertices - (alpha * original + (1.0 - alpha) * vert_q)
        vertices -= (beta * vert_b + (1.0 - beta) *
                     laplacian_operator.dot(vert_b))

    # assign modified vertices back to mesh
    mesh.vertices = vertices
    return mesh, laplacian_operator
def laplacian_calculation(mesh, equal_weight=True):
    """
    Calculate a sparse matrix for laplacian operations.
    Parameters
    -------------
    mesh : trimesh.Trimesh
      Input geometry
    equal_weight : bool
      If True, all neighbors will be considered equally
      If False, all neightbors will be weighted by inverse distance
    Returns
    ----------
    laplacian : scipy.sparse.coo.coo_matrix
      Laplacian operator
    """
    # get the vertex neighbors from the cache
    neighbors = mesh.vertex_neighbors

    # avoid hitting crc checks in loops
    vertices = mesh.vertices.view(np.ndarray)


    # stack neighbors to 1D arrays
    col = np.concatenate(neighbors)
    row = np.concatenate([[i] * len(n)
                          for i, n in enumerate(neighbors)])

    if equal_weight:
        # equal weights for each neighbor
        data = np.concatenate([[1.0 / len(n)] * len(n)
                               for n in neighbors])
    else:
        # umbrella weights, distance-weighted
        # use dot product of ones to replace array.sum(axis=1)
        ones = np.ones(3)
        # the distance from verticesex to neighbors
        norms = [1.0 / np.sqrt(np.dot((vertices[i] - vertices[n]) ** 2, ones))
                 for i, n in enumerate(neighbors)]
        # normalize group and stack into single array
        data = np.concatenate([i / i.sum() for i in norms])

    # create the sparse matrix
    matrix = coo_matrix((data, (row, col)),
                        shape=[len(vertices)] * 2)
    
    # return matrix,data
    return matrix

# mesh = trimesh.load_mesh('../2plane.obj')
# matrix,data = laplacian_calculation(mesh,equal_weight=False)
# print('laplacian matrix: ', matrix.shape,'\n',matrix)

In [5]:
from random import sample
import glob
import trimesh
def save_xyz(pts, file_name):
    # print(pts)
    s = trimesh.util.array_to_string(pts)
    with open(file_name, 'w') as f:
        f.write("%s\n" % s)
def sample_k(matrix,k):
    # matrix = matrix.toarray()
    result = []
    for i,row in enumerate(matrix):
        row = list(row)
        if len(row) >= k:
            k_nearest = sample(row,k)
        else:
            while len(row) < k:
                row.append(0.0)
            k_nearest = row
        result.append(k_nearest)
    return result
def save_vertices():
    files = glob.glob('data/noisy/*.obj')
    for file in files:
        print(file)
        mesh = trimesh.load_mesh(file)
        vertices = mesh.vertices
        dest_name = 'data/xyz/' +file.split('/')[-1].split('.')[0] +'.xyz'
        save_xyz(vertices, dest_name)
        
def create_target():
    files = glob.glob('data/noisy/*.obj')
    for file in files:
        print(file)
        mesh = trimesh.load_mesh(file)
        matrix = laplacian_calculation(mesh, equal_weight=False)
        result = sample_k(matrix, 5)
    
        dest_name = 'data/laplacian/' +file.split('/')[-1].split('.')[0] +'.laplacian'
        save_xyz(result, dest_name)

save_vertices()


data/noisy/meshes_march2_mesh_0120.obj
data/noisy/meshes_squat2_mesh_0140.obj
data/noisy/meshes_handstand_mesh_0020.obj
data/noisy/meshes_jumping_mesh_0080.obj
data/noisy/meshes_march2_mesh_0170.obj
data/noisy/meshes_march2_mesh_0020.obj
data/noisy/meshes_handstand_mesh_0060.obj
data/noisy/meshes_bouncing_mesh_0120.obj
data/noisy/meshes_squat2_mesh_0020.obj
data/noisy/meshes_march2_mesh_0080.obj
data/noisy/meshes_handstand_mesh_0170.obj
data/noisy/meshes_squat1_mesh_0000.obj
data/noisy/meshes_march1_mesh_0100.obj
data/noisy/meshes_march2_mesh_0150.obj
data/noisy/meshes_march2_mesh_0140.obj
data/noisy/meshes_march1_mesh_0030.obj
data/noisy/meshes_jumping_mesh_0020.obj
data/noisy/meshes_march1_mesh_0070.obj
data/noisy/meshes_crane_mesh_0150.obj
data/noisy/meshes_march2_mesh_0230.obj
data/noisy/meshes_squat2_mesh_0110.obj
data/noisy/meshes_march1_mesh_0050.obj
data/noisy/meshes_crane_mesh_0020.obj
data/noisy/meshes_jumping_mesh_0050.obj
data/noisy/meshes_squat1_mesh_0070.obj
data/noisy/me

In [115]:

def make_smooth(mesh, matrix= None):
    mesh = trimesh.load_mesh(mesh)
    if matrix != None:
        matrix = matrix
    else:
        matrix = trimesh.smoothing.laplacian_calculation(mesh, equal_weight= False)
        print(matrix)
    smoothed_mesh = trimesh.smoothing.filter_taubin(mesh,
                    lamb=0.4,
                    nu=0.5,
                    iterations=10,
                    laplacian_operator=matrix)
    return smoothed_mesh
# smoothed_mesh =  make_smooth('data/noisy/meshes_bouncing_mesh_0020.obj')
# output = smoothed_mesh.export('output/meshes_bouncing_mesh_0020_smooth_t.obj')
# mesh = trimesh.load_mesh('output/meshes_bouncing_mesh_0020_smooth_f.obj')
# mesh.show()


In [33]:
import torch
import numpy as np
from sklearn.neighbors import NearestNeighbors
from chamferdist import ChamferDistance
def calculate_loss(mesh1, mesh2):
    loss = ChamferDistance()
    mse_loss = torch.nn.MSELoss()
    mesh1 = trimesh.load_mesh(mesh1)
    mesh2 = trimesh.load_mesh(mesh2)
    v1, v2 = mesh1.vertices.view(np.ndarray),mesh2.vertices.view(np.ndarray)
    v1, v2  = torch.tensor(v1).float().unsqueeze(0), torch.tensor(v2).float().unsqueeze(0)
    
    chamfer_loss = loss(v1, v2)
    mse_loss = mse_loss(v1, v2)
    return chamfer_loss, mse_loss

mesh1 = 'data/smooth/meshes_bouncing_mesh_0020.obj'
mesh2 = 'data/noisy/meshes_bouncing_mesh_0020.obj'

chamfer, mse = calculate_loss(mesh1, mesh2)
print('===================================')
print('loss(smooth_gt,noisy_gt)')
print('ChamferDistance: ', chamfer)
print('MSE: ', mse)
print('===================================')
'''
===================================
between Equal weight False & True: loss(smooth_gt,smooth_true)
ChamferDistance:  tensor(0.0306)
MSE:  tensor(0.1742)

===================================

loss(smooth_gt,smooth_pred_true)
ChamferDistance:  tensor(0.0948)
MSE:  tensor(0.1706)
===================================

===================================
loss(smooth_gt,smooth_pred_false)
ChamferDistance:  tensor(0.0399)
MSE:  tensor(0.1719)
===================================

===================================
loss(smooth_gt,noisy_gt)
ChamferDistance:  tensor(0.0296)
MSE:  tensor(0.1707)
===================================
'''

loss(smooth_gt,noisy_gt)
ChamferDistance:  tensor(0.0296)
MSE:  tensor(0.1707)




In [234]:
# Replace the Trimesh Laplacian with the cotangent Laplacian
# Have a function that computes the cotangent weights, given a center point and its k neighbors
# Compute k nearest neighbors
# For each of these k nearest neighbors,
# check if there is an edge in the between the center point and this neighbor. If not, return 0
# If yes, return cotagent weight

# Aijk = np.sqrt(s(s−lik)(s−lkj)(s−lij))
# s = (lik +lkj +lij)/2

# wij = lij + lik + ljk/(8*Aijk) + lij + lih + ljh/(8*Aijh)
import heapq
import numpy as np

get_distance = lambda p1, p2: np.sqrt(np.sum((p1 - p2) ** 2, axis=0))
def get_k_nearest_neighbors(vertices):
    result = {}
    for i,p1 in enumerate(vertices):
        h = []
        for j, p2 in enumerate(vertices):
            if i == j:
                continue
            dist = (get_distance(p1,p2))
            heapq.heappush(h,(float(dist), j))
        k_nearest = []
        for k in range(4): 
            k_nearest.append(heapq.heappop(h))
        result[i] = k_nearest
    return result

def laplacian_cotangent_knn(mesh, k_nearest):
  
    # get the vertex neighbors from the cache
    neighbors = mesh.vertex_neighbors

    # avoid hitting crc checks in loops
    vertices = mesh.vertices.view(np.ndarray)

    # stack neighbors to 1D arrays
    col = np.concatenate(neighbors)
    row = np.concatenate([[i] * len(n)
                          for i, n in enumerate(neighbors)])

    # print('row: ',row)
    # print('col: ', col)

    data = []
    # data = np.concatenate([[1.0 / len(n)] * len(n)
    #                            for n in neighbors])
    for i in k_nearest:
        neighbor = set(neighbors[i])
        temp = [0] * len(neighbor)
        # print(temp)
        for nearest in k_nearest[i]:
            lij, j = nearest
            #check if j exist in neighbor
            if j not in neighbor:
                continue
            k = None
            h = None
            # print(nearest)
            # print(neighbors[j])
            for third_point in neighbors[j]:
                if third_point in neighbor and third_point != i:
                    k = third_point
                    for fourth_point in neighbors[k]:
                        if h is not None:
                            break
                        if fourth_point == i or fourth_point ==j:
                            continue
                        for fifth_point in neighbors[fourth_point]:
                            if fifth_point == i:
                                h = fourth_point
                                break
            # wij = lij + lik + ljk/(8*Aijk) + lij + lih + ljh/(8*Aijh)
            # Aijk = np.sqrt(s(s−lik)(s−lkj)(s−lij))
            # s = (lik +lkj +lij)/2
            pi,pj, pk = vertices[i],vertices[j],vertices[k]
            lik, lkj = get_distance(pi,pk),get_distance(pk,pj)
            s = (lik +lkj +lij)/2

            #----------Aijk--------------
            Aijk = np.sqrt(s* (s - lik) * (s - lkj) * (s - lij))

            #---------------------------
            pi,pj, ph = vertices[i],vertices[j],vertices[h]
            lih, ljh = get_distance(pi,ph),get_distance(pj,ph)
            s = (lih+ ljh +lij)/2

            #----------Aijh--------------
            Aijh = np.sqrt(s* (s - lih) * (s - ljh) * (s - lij))
            #---------- wij --------------
            wij = lij + lik + lkj/(8*Aijk) + lij + lih + ljh/(8*Aijh)

            #find index for the wij
            index = (neighbors[i]).index(j)

            temp[index] = wij
            print(wij)
        data.extend(temp)
        



    # create the sparse matrix
    matrix = coo_matrix((data, (row, col)),
                        shape=[len(vertices)] * 2)

    return matrix
def laplacian_cotangent(mesh):
       # get the vertex neighbors from the cache
    neighbors = mesh.vertex_neighbors

    # avoid hitting crc checks in loops
    vertices = mesh.vertices.view(np.ndarray)

    # stack neighbors to 1D arrays
    col = np.concatenate(neighbors)
    row = np.concatenate([[i] * len(n)
                          for i, n in enumerate(neighbors)])

    # print('row: ',row)
    # print('col: ', col)

    data = []
    # data = np.concatenate([[1.0 / len(n)] * len(n)
    #                            for n in neighbors])
    for i in k_nearest:
        neighbor = set(neighbors[i])
        temp = [0] * len(neighbor)
        # print(temp)
        for nearest in k_nearest[i]:
            lij, j = nearest
            #check if j exist in neighbor
            if j not in neighbor:
                continue
            k = None
            h = None
            # print(nearest)
            # print(neighbors[j])
            for third_point in neighbors[j]:
                if third_point in neighbor and third_point != i:
                    k = third_point
                    for fourth_point in neighbors[k]:
                        if h is not None:
                            break
                        if fourth_point == i or fourth_point ==j:
                            continue
                        for fifth_point in neighbors[fourth_point]:
                            if fifth_point == i:
                                h = fourth_point
                                break
            # wij = lij + lik + ljk/(8*Aijk) + lij + lih + ljh/(8*Aijh)
            # Aijk = np.sqrt(s(s−lik)(s−lkj)(s−lij))
            # s = (lik +lkj +lij)/2
            pi,pj, pk = vertices[i],vertices[j],vertices[k]
            lik, lkj = get_distance(pi,pk),get_distance(pk,pj)
            s = (lik +lkj +lij)/2

            #----------Aijk--------------
            Aijk = np.sqrt(s* (s - lik) * (s - lkj) * (s - lij))

            #---------------------------
            pi,pj, ph = vertices[i],vertices[j],vertices[h]
            lih, ljh = get_distance(pi,ph),get_distance(pj,ph)
            s = (lih+ ljh +lij)/2

            #----------Aijh--------------
            Aijh = np.sqrt(s* (s - lih) * (s - ljh) * (s - lij))
            #---------- wij --------------
            wij = lij + lik + lkj/(8*Aijk) + lij + lih + ljh/(8*Aijh)

            #find index for the wij
            index = (neighbors[i]).index(j)

            temp[index] = wij
        data.extend(temp)
        



    # create the sparse matrix
    matrix = coo_matrix((data, (row, col)),
                        shape=[len(vertices)] * 2)

    return matrix
mesh = trimesh.load_mesh('../2plane.obj')
vertices = mesh.vertices
k_nearest = get_k_nearest_neighbors(vertices)
        
matrix = laplacian_cotangent(mesh)
# print(matrix)

# k_nearest



    


In [146]:
def laplacian_calculation(mesh, equal_weight=True):
    """
    Calculate a sparse matrix for laplacian operations.
    Parameters
    -------------
    mesh : trimesh.Trimesh
      Input geometry
    equal_weight : bool
      If True, all neighbors will be considered equally
      If False, all neightbors will be weighted by inverse distance
    Returns
    ----------
    laplacian : scipy.sparse.coo.coo_matrix
      Laplacian operator
    """
    # get the vertex neighbors from the cache
    neighbors = mesh.vertex_neighbors
    # avoid hitting crc checks in loops
    vertices = mesh.vertices.view(np.ndarray)

    # stack neighbors to 1D arrays
    col = np.concatenate(neighbors)
    row = np.concatenate([[i] * len(n)
                          for i, n in enumerate(neighbors)])


    ones = np.ones(3)
    # the distance from verticesex to neighbors
    norms = [1.0 / np.sqrt(np.dot((vertices[i] - vertices[n]) ** 2, ones))
                for i, n in enumerate(neighbors)]
    # normalize group and stack into single array
    # data = np.concatenate([i / i.sum() for i in norms])
    data = [i / i.sum() for i in norms]

    # create the sparse matrix
    # matrix = coo_matrix((data, (row, col)),
    #                     shape=[len(vertices)] * 2)

    return data
    



In [151]:
mesh = trimesh.load_mesh('data/noisy/meshes_crane_mesh_0020.obj')
# neighbors = mesh.vertex_neighbors
# vertices = mesh.vertices.view(np.ndarray)
# norms = [1.0/ np.sqrt(np.dot((vertices[i] - vertices[n]) ** 2, ones))
#                 for i, n in enumerate(neighbors)]
data = laplacian_calculation(mesh)
print(data[0][0])
# for weight in data:
#     print(ty)
#     break


0.2714832390382398


In [138]:
from scipy.sparse.linalg import spsolve
from scipy.sparse import coo_matrix, eye

def get_weight(i,neighbor,neighbors,norms):
    def get_k_h(i,j):
        list1, list2 = neighbors[i], neighbors[j]
        #if there is morethan 2 intersection choose the first 2 points
        intersection = list(set(list1).intersection(list2))
        try:
            k,h = intersection
        except:
            k,h = intersection[:2]
            # print(intersection)
            # print(i,j)
            # print('k,h', )

        return k,h
    def get_distance(i,j,k,h):
        j_idx = neighbors[i].index(j)
        lij = norms[i][j_idx]

        k_idx = neighbors[j].index(k)
        ljk = norms[j][k_idx]

        i_idx = neighbors[k].index(i)
        lki = norms[k][i_idx]

        h_idx = neighbors[j].index(h)
        ljh = norms[j][h_idx]

        i_idx = neighbors[h].index(i)
        lhi = norms[h][i_idx]

        return lij, ljk, lki, ljh, lhi 

    result = []
    for j in neighbor:
        #find k * h
        k,h  = get_k_h(i,j)
        lij, ljk, lki, ljh, lhi = get_distance(i,j,k,h)

        s_ijk = (lij + ljk + lki)/2
        A_ijk = 8 *  np.sqrt(s_ijk * ( s_ijk - lij) * ( s_ijk- ljk) * ( s_ijk - lki))

        s_ijh = (lij + ljh + lhi)/2
        A_ijh = 8 *  np.sqrt(s_ijh * ( s_ijh - lij) * ( s_ijh- ljh) * ( s_ijh - lhi))

        
        wij = ((-lij**2 + ljk ** 2 + lki**2)/  A_ijk)  + ((-lij**2 + ljh ** 2 + lhi**2)/  A_ijh)
        if wij == np.nan:
            print('here')
        
        result.append(wij)
        
    return result


def laplacian(mesh):
    neighbors = mesh.vertex_neighbors
    
    vertices = mesh.vertices.view(np.ndarray)

    ones = np.ones(3)
    norms = [np.sqrt(np.dot((vertices[i] - vertices[n]) ** 2, ones))
                for i, n in enumerate(neighbors)]
    # norms = [i / i.sum() for i in norms]
    data = []

    for i, neighbor in enumerate(neighbors):
        weight = get_weight(i,neighbor, neighbors,norms)
        data.append(weight)
        # create the sparse matrix
    
    col = np.concatenate(neighbors)
    row = np.concatenate([[i] * len(n)
                          for i, n in enumerate(neighbors)])
    
    # data = np.array(data)
    data = np.concatenate([i / np.array(i).sum() for i in data])
    
    matrix = coo_matrix((data, (row, col)),
                        shape=[len(vertices)] * 2)
    return matrix
matrix = laplacian(mesh)
# data


In [12]:
# get knn
from sklearn.neighbors import KDTree
import matplotlib.pyplot as plt
import numpy as np
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
# x,y = X[:0], X[:1]
# plt.scatter(x, y)
# plt.show()
kdt = KDTree(X, leaf_size=30, metric='euclidean')

In [15]:
import torch
target = torch.tensor([[0.1507, 0.2624, 0.1219, 0.0839, 0.2729],
        [0.1845, 0.1116, 0.2147, 0.2645, 0.2247],
        [0.2977, 0.0787, 0.2527, 0.0848, 0.1858],
        [0.2211, 0.0887, 0.1111, 0.0980, 0.1386],
        [0.3101, 0.2052, 0.1550, 0.3297, 0.0000],
        [0.1003, 0.3132, 0.1688, 0.1474, 0.1515],
        [0.1002, 0.2885, 0.1028, 0.2110, 0.1480],
        [0.1642, 0.1276, 0.0839, 0.1270, 0.1507],
        [0.1058, 0.1090, 0.1528, 0.1606, 0.0834],
        [0.1702, 0.1621, 0.1830, 0.1600, 0.3247],
        [0.1583, 0.2402, 0.2243, 0.1403, 0.2369],
        [0.0896, 0.1902, 0.1166, 0.1189, 0.1923],
        [0.1557, 0.1819, 0.1875, 0.1722, 0.1896],
        [0.2272, 0.1031, 0.2018, 0.2397, 0.1240],
        [0.1172, 0.1472, 0.0950, 0.0532, 0.1626],
        [0.2169, 0.3171, 0.1367, 0.1714, 0.1578],
        [0.1123, 0.0991, 0.2772, 0.0577, 0.0780],
        [0.1613, 0.2062, 0.1676, 0.3505, 0.1144],
        [0.1999, 0.1279, 0.2586, 0.0922, 0.1319],
        [0.1320, 0.1451, 0.3292, 0.1106, 0.1085],
        [0.1874, 0.1744, 0.3241, 0.1687, 0.1454],
        [0.1406, 0.1657, 0.2110, 0.1437, 0.1951],
        [0.0682, 0.2098, 0.1366, 0.1762, 0.1918],
        [0.2659, 0.1418, 0.2054, 0.1958, 0.1911],
        [0.2375, 0.1258, 0.1450, 0.1762, 0.1187],
        [0.1725, 0.1634, 0.1018, 0.1801, 0.1083],
        [0.2312, 0.1378, 0.0962, 0.1327, 0.2103],
        [0.1454, 0.1623, 0.1971, 0.1146, 0.1504],
        [0.2250, 0.1498, 0.1460, 0.1235, 0.1800],
        [0.2267, 0.3127, 0.1072, 0.1188, 0.1003],
        [0.1287, 0.0871, 0.1207, 0.1051, 0.2606],
        [0.3030, 0.1704, 0.1946, 0.1436, 0.1884],
        [0.1814, 0.2420, 0.0978, 0.1106, 0.3682],
        [0.1381, 0.0974, 0.0846, 0.2256, 0.3404],
        [0.1127, 0.4504, 0.1686, 0.1418, 0.1265],
        [0.0502, 0.1222, 0.0522, 0.1076, 0.2877],
        [0.1186, 0.1785, 0.1714, 0.1754, 0.1103],
        [0.0875, 0.2572, 0.2849, 0.2406, 0.1297],
        [0.1553, 0.2777, 0.1975, 0.1707, 0.1989],
        [0.0989, 0.0670, 0.1823, 0.2600, 0.1861],
        [0.1377, 0.2162, 0.3693, 0.1594, 0.1173],
        [0.3099, 0.1409, 0.1590, 0.1166, 0.1410],
        [0.1488, 0.0999, 0.1417, 0.1614, 0.1375],
        [0.3036, 0.1175, 0.1292, 0.2286, 0.2210],
        [0.1368, 0.1411, 0.2302, 0.0854, 0.2198],
        [0.1958, 0.2177, 0.1531, 0.2073, 0.1257],
        [0.1382, 0.1400, 0.1160, 0.0999, 0.1557],
        [0.2631, 0.2698, 0.1370, 0.1474, 0.1827],
        [0.1080, 0.1183, 0.0919, 0.0943, 0.0994],
        [0.2156, 0.3064, 0.1528, 0.1340, 0.1911],
        [0.1283, 0.1294, 0.0957, 0.1727, 0.1348],
        [0.0523, 0.3450, 0.1308, 0.2579, 0.1015],
        [0.1270, 0.1219, 0.3287, 0.0736, 0.0669],
        [0.1240, 0.0692, 0.0593, 0.3089, 0.1270],
        [0.2709, 0.2368, 0.2429, 0.1701, 0.0793],
        [0.1123, 0.1172, 0.1995, 0.1560, 0.2275],
        [0.1132, 0.2108, 0.2334, 0.1533, 0.1565],
        [0.1708, 0.1993, 0.1256, 0.1403, 0.1033],
        [0.2674, 0.1729, 0.1915, 0.1522, 0.2161],
        [0.2358, 0.2054, 0.1175, 0.2675, 0.1738],
        [0.2493, 0.1959, 0.1423, 0.2087, 0.2037],
        [0.2142, 0.1350, 0.1205, 0.1332, 0.1400],
        [0.0739, 0.1181, 0.2427, 0.0986, 0.0908],
        [0.1997, 0.1636, 0.3237, 0.1646, 0.1485]])

In [24]:
o_pred = torch.tensor([[ 0.2027, -0.0210,  0.0329],
        [-0.2031, -0.6453, -0.4977],
        [-0.6620, -0.0070, -0.2923],
        [-0.3473, -0.8854, -0.1800],
        [-0.4675,  0.2644,  0.0139],
        [-0.6349,  0.3274,  0.1666],
        [ 0.2172, -0.6724,  0.7244],
        [-0.1738, -0.6387,  0.3055],
        [ 0.0320, -0.8430, -0.2320],
        [-0.3310, -0.4930,  0.4650],
        [-0.2975, -0.6533,  0.3916],
        [ 0.1652, -0.7021,  0.4155],
        [ 0.5166,  0.3405, -0.1173],
        [-0.2124, -0.5991,  0.4536],
        [ 0.1014,  0.3320, -0.2192],
        [ 0.4256, -0.3802,  0.3118],
        [-0.2252, -0.6854, -0.3811],
        [-0.0722,  0.3057,  0.6788],
        [-0.0064, -0.4139,  0.9157],
        [ 0.1010, -0.5218,  0.4027],
        [-0.1031,  0.6556,  0.6502],
        [ 0.3284, -0.0844,  1.0147],
        [-0.5973, -1.2255,  0.2655],
        [ 0.5319, -1.0806,  0.3422],
        [-0.4849, -0.0875,  1.0492],
        [-0.1082, -0.7225,  0.7758],
        [-0.1104,  0.4969,  0.3577],
        [ 0.1114, -0.5584, -0.5236],
        [-0.1116, -0.1735, -0.2867],
        [ 0.0572, -0.3471,  0.2970],
        [-0.0948,  0.5722, -0.0161],
        [-0.7849, -1.2891,  0.4726],
        [-0.0319,  0.0611,  0.0241],
        [-0.6461,  0.1272,  0.2651],
        [-0.2338,  0.6676,  0.3022],
        [-0.7054, -0.4605,  0.1620],
        [ 1.1693, -0.6133,  1.1239],
        [-0.2697, -0.8238,  0.4291],
        [ 0.2554, -0.0892,  0.2605],
        [-0.3016, -0.4674, -0.5664],
        [-0.1995, -0.9387, -0.2751],
        [-0.0077, -0.6622,  0.1146],
        [-1.7754, -1.1141,  0.6418],
        [ 0.0556, -0.1813, -0.2745],
        [-0.3515, -1.1075,  0.8043],
        [-0.0606, -0.6252, -0.0918],
        [-0.2915, -0.7684,  0.1968],
        [-0.4240, -0.2706, -0.2895],
        [ 0.2630, -0.6591, -0.5250],
        [-0.6247, -0.2048,  0.3532],
        [-0.3409,  0.4907, -0.3040],
        [-0.1497, -1.1450,  0.8079],
        [ 0.9814, -0.3750,  0.2471],
        [ 0.0550, -0.8866,  0.4151],
        [-0.3728, -0.1157, -0.1688],
        [ 0.1023, -0.7606,  1.0742],
        [ 0.9699,  0.2054, -0.0493],
        [ 0.4119, -0.7608,  0.2472],
        [ 0.0869, -0.0955,  0.0740],
        [ 0.5423, -0.7980, -0.0446],
        [-0.0939, -0.6629,  0.6792],
        [ 0.2981, -0.2288, -0.7055],
        [-0.1444, -0.5643, -0.3630],
        [ 0.4677, -0.9841,  0.2417]], device='cuda:0')

In [46]:
o_target = target[0]
o_target.shape

torch.Size([5])

In [27]:
pred= torch.tensor([[ 2.0272e-01, -2.0987e-02,  3.2920e-02,  7.7162e-02, -1.4826e-02],
        [-2.0305e-01, -6.4533e-01, -4.9770e-01, -6.0329e-01, -6.7306e-01],
        [-6.6202e-01, -7.0299e-03, -2.9227e-01, -7.2969e-02, -1.5013e-01],
        [-3.4730e-01, -8.8540e-01, -1.7998e-01,  2.2393e-01, -3.8027e-01],
        [-4.6752e-01,  2.6436e-01,  1.3936e-02,  5.1311e-03,  2.5015e-01],
        [-6.3486e-01,  3.2740e-01,  1.6656e-01,  3.1997e-01,  4.3261e-01],
        [ 2.1717e-01, -6.7238e-01,  7.2435e-01, -5.7573e-01, -5.2131e-02],
        [-1.7376e-01, -6.3865e-01,  3.0547e-01, -1.1933e+00, -2.0626e-01],
        [ 3.2048e-02, -8.4302e-01, -2.3200e-01, -4.9804e-01, -1.1634e+00],
        [-3.3101e-01, -4.9297e-01,  4.6498e-01, -5.5347e-01,  8.8753e-01],
        [-2.9748e-01, -6.5326e-01,  3.9155e-01, -6.2939e-01, -1.3625e-01],
        [ 1.6521e-01, -7.0205e-01,  4.1547e-01, -9.6165e-02,  1.2866e-01],
        [ 5.1657e-01,  3.4045e-01, -1.1730e-01, -6.8295e-02,  1.6492e-01],
        [-2.1244e-01, -5.9905e-01,  4.5358e-01, -6.6637e-01, -8.5312e-01],
        [ 1.0139e-01,  3.3198e-01, -2.1919e-01, -8.1215e-01,  3.6336e-01],
        [ 4.2561e-01, -3.8025e-01,  3.1181e-01, -2.3854e-01, -3.2680e-01],
        [-2.2522e-01, -6.8542e-01, -3.8110e-01,  5.2929e-01,  1.8744e-01],
        [-7.2218e-02,  3.0568e-01,  6.7884e-01,  1.7120e-01, -5.4198e-01],
        [-6.4437e-03, -4.1392e-01,  9.1571e-01,  1.6743e-01,  1.1055e+00],
        [ 1.0097e-01, -5.2183e-01,  4.0273e-01,  3.5366e-01, -5.0834e-02],
        [-1.0311e-01,  6.5561e-01,  6.5019e-01, -5.3076e-01,  4.0606e-01],
        [ 3.2840e-01, -8.4431e-02,  1.0147e+00, -4.2752e-01,  6.7277e-01],
        [-5.9725e-01, -1.2255e+00,  2.6546e-01,  1.5877e-01,  2.2268e-02],
        [ 5.3190e-01, -1.0806e+00,  3.4223e-01, -8.0640e-01,  9.4801e-02],
        [-4.8492e-01, -8.7516e-02,  1.0492e+00,  6.8642e-01,  4.4158e-01],
        [-1.0822e-01, -7.2246e-01,  7.7577e-01, -5.5466e-02,  1.0109e-01],
        [-1.1040e-01,  4.9692e-01,  3.5772e-01,  8.2783e-02, -6.4687e-01],
        [ 1.1140e-01, -5.5843e-01, -5.2361e-01, -3.2096e-01, -3.5379e-01],
        [-1.1157e-01, -1.7351e-01, -2.8675e-01,  7.0544e-04, -2.8903e-01],
        [ 5.7186e-02, -3.4713e-01,  2.9697e-01, -8.4602e-01, -1.2777e-01],
        [-9.4835e-02,  5.7224e-01, -1.6079e-02, -2.8392e-02,  3.1948e-02],
        [-7.8494e-01, -1.2891e+00,  4.7264e-01, -2.3375e-01, -6.1836e-01],
        [-3.1873e-02,  6.1129e-02,  2.4096e-02,  3.3469e-01, -3.6369e-01],
        [-6.4613e-01,  1.2719e-01,  2.6514e-01, -5.8018e-01,  1.4411e-01],
        [-2.3376e-01,  6.6757e-01,  3.0220e-01,  8.8645e-02, -8.6534e-02],
        [-7.0540e-01, -4.6050e-01,  1.6202e-01, -3.7877e-01, -5.3171e-01],
        [ 1.1693e+00, -6.1333e-01,  1.1239e+00, -2.5604e-01, -4.3976e-01],
        [-2.6968e-01, -8.2380e-01,  4.2911e-01, -5.1997e-01, -7.5634e-02],
        [ 2.5540e-01, -8.9241e-02,  2.6047e-01, -5.3104e-01, -9.1707e-02],
        [-3.0161e-01, -4.6737e-01, -5.6640e-01, -4.5480e-01, -2.5400e-01],
        [-1.9955e-01, -9.3874e-01, -2.7505e-01,  4.0701e-01, -4.5153e-01],
        [-7.7428e-03, -6.6222e-01,  1.1462e-01,  2.5166e-01, -2.1921e-02],
        [-1.7754e+00, -1.1141e+00,  6.4183e-01, -8.8608e-01, -5.5761e-01],
        [ 5.5608e-02, -1.8134e-01, -2.7451e-01, -1.0199e+00, -3.8010e-02],
        [-3.5148e-01, -1.1075e+00,  8.0432e-01,  7.6659e-02, -5.7889e-01],
        [-6.0644e-02, -6.2516e-01, -9.1757e-02, -1.1595e-01,  1.5238e-01],
        [-2.9153e-01, -7.6836e-01,  1.9684e-01,  2.4666e-01,  2.7772e-01],
        [-4.2400e-01, -2.7063e-01, -2.8952e-01, -4.6929e-01, -4.7530e-01],
        [ 2.6301e-01, -6.5911e-01, -5.2503e-01, -5.8072e-01,  2.7260e-01],
        [-6.2466e-01, -2.0485e-01,  3.5320e-01, -5.9942e-02,  4.7029e-01],
        [-3.4091e-01,  4.9074e-01, -3.0400e-01,  4.0595e-01, -4.9785e-01],
        [-1.4967e-01, -1.1450e+00,  8.0790e-01,  5.5015e-01, -1.9482e-02],
        [ 9.8137e-01, -3.7500e-01,  2.4707e-01, -8.1400e-01,  2.3974e-01],
        [ 5.5003e-02, -8.8658e-01,  4.1505e-01,  2.1661e-01, -5.6283e-01],
        [-3.7283e-01, -1.1574e-01, -1.6878e-01, -4.1446e-01, -5.0753e-01],
        [ 1.0229e-01, -7.6057e-01,  1.0742e+00, -7.1775e-01,  8.2314e-02],
        [ 9.6987e-01,  2.0543e-01, -4.9310e-02,  1.1066e-01,  1.8075e-03],
        [ 4.1195e-01, -7.6082e-01,  2.4720e-01, -4.6309e-01, -1.9986e-02],
        [ 8.6937e-02, -9.5484e-02,  7.4039e-02, -1.0545e+00,  8.6431e-01],
        [ 5.4226e-01, -7.9796e-01, -4.4592e-02,  1.6077e-01, -1.6574e-01],
        [-9.3858e-02, -6.6292e-01,  6.7916e-01, -7.7548e-01, -6.4124e-01],
        [ 2.9814e-01, -2.2876e-01, -7.0547e-01, -1.0396e+00,  9.5195e-01],
        [-1.4441e-01, -5.6428e-01, -3.6303e-01,  1.8420e-01,  6.1630e-02],
        [ 4.6765e-01, -9.8414e-01,  2.4171e-01, -1.2552e-01, -8.1047e-02]],
       device='cuda:0')

In [48]:
pred = pred.to('cuda:0')
target = target.to('cuda:0')

In [42]:
output_pred_ind = [0]
oi = 0
o_pred = pred[:, output_pred_ind[oi]:output_pred_ind[oi]+5]

In [49]:
loss = torch.nn.MSELoss()
loss(pred, target)

tensor(0.3293, device='cuda:0')