In [2]:
import numpy as np
import tqdm
import torch 
import torch.nn.functional as func

import misc.faust as faust
import models

FAUST = "../../Downloads/Mesh-Datasets/MyFaustDataset"

dataset = faust.FaustDataset(FAUST)
D = dataset.downscale_matrices
E = dataset.downscaled_edges
F = dataset.downscaled_faces
num_classes = 10
lastlayernodes = D[-1].shape[0] #NOTE this is the number of nodes after the last convolutional layer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = models.ChebyNet(lastlayernodes, num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()

In [16]:
import plotly
import plotly.io as pio
import plotly.graph_objects as go

def visualize(mesh, D=None, F=None, scale_index=0,title=""):
    """
    visualize input mesh.
    
    The input lists D and F are respectively the downscale matrices and the faces of the downscaled
    mesh. The scale_index parameter indicates which downscaling matrix should be used, if scale_index=0 then
    no downscaling is applied (scale_index must be less then or equal to the length of D and F).
    
    NOTE: if scale_index is 0 then the parameters F and D are not necessary and can simply have value to None.
    """
    
    pos = mesh.pos
    if scale_index == 0:
        f = mesh.face
    else:
        f = F[scale_index-1]
        for j in range(scale_index):
            pos = torch.sparse.mm(D[j], pos) #transform graph signal (vertex positions) to downscaled graph

    x,y,z = pos[:,0], pos[:,1], pos[:,2]
    i,j,k = f[0,:], f[1,:], f[2,:]
    mesh = go.Mesh3d(x=x, y=z, z=y,
              color='lightpink', 
              opacity=1,
              i=i, j=j, k=k,
              showscale=True)
    layout = go.Layout(scene=go.layout.Scene(aspectmode="data")) 

    #pio.renderers.default="plotly_mimetype"
    fig = go.Figure(data=[mesh],
                   layout=layout)
    fig.update_layout(
        autosize=True,
        margin=dict(l=20, r=20, t=20, b=20),
        paper_bgcolor="LightSteelBlue")
    fig.show()
    

Training of the model

In [17]:
# train module
model.train()

# meters
loss_values = []

#Train process
for epoch in range(1):
    print("epoch number: "+str(epoch))
    for data in tqdm.tqdm(dataset):
        visualize(data, F=F, D=D,scale_index=1)
        optimizer.zero_grad()
        out = model(data, E, D)
        out = out.view(1,-1)
        
        
        loss = criterion(out, data.y)
        loss_values.append(loss.item())

        loss.backward()
        optimizer.step()

epoch number: 0


  0%|                                                                                          | 0/100 [00:02<?, ?it/s]


TypeError: invalid Figure property: title