In [21]:
from mesh_dataset_utils import *
from MLP import MLP
import torch
from torch import nn

In [3]:
# load the texture sample as well as the blank (sphere/ cube/ simple untextured shape
texture = trimesh.load('/Users/maxperozek/GINR-texture/Spikeyball1.stl')
blank = get_sphere(size=3)
print(texture.vertices.shape[0])
print(blank.vertices.shape[0])

42787
40962


In [36]:
blank.show()

In [37]:
blank.vertices.shape

(40962, 3)

In [4]:
texture.show()

In [7]:
dataset, smooth = build_offset_dataset(texture, smooth_iter=1000)

KeyboardInterrupt: 

In [None]:
smooth.show()

In [None]:
DIRECTORY_NAME = 'spikey_1'
np.savez(f'npz_train/{DIRECTORY_NAME}', fourier=dataset['fourier'], points=dataset['points'], target=dataset['target'])

In [8]:
dataset = np.load('/Users/maxperozek/GINR-texture/npz_train/spikey_1.npz')

In [14]:
X = dataset['fourier']
y = dataset['target']

In [15]:
X.shape

(42787, 100)

In [16]:
y.shape

(42787, 3)

In [105]:
batch_size = 1000
train_data = torch.utils.data.TensorDataset(torch.tensor(X).float(), torch.tensor(y).float())
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

In [106]:
model = MLP(
            input_dim=X.shape[1],
            output_dim=y.shape[1],
            hidden_dim=512,
            n_layers=6,
            geometric_init=True,
            beta=True,
            sine=True,
            all_sine=True,
            skip=True,
            bn=True,
            dropout=0.0
            )

In [107]:
lr=1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.MSELoss()
epochs = 60

In [108]:
for i in range(epochs):
    best_model = model.state_dict()
    best_loss = 1e10
    losses=[]
    for batch in train_loader:
        optimizer.zero_grad()
        
        batchX = batch[0]
        batchY = batch[1]
        
        preds = model(batchX)
        
        loss = loss_fn(preds, batchY)
        losses.append(loss.detach().numpy())
        loss.backward()
        optimizer.step()
    
    mean_loss = np.array(losses).sum() / len(losses) # technically not quite right cuz partial batches will be weighted slightly higher (we will survive)
    if mean_loss < best_loss:
        torch.save(model.state_dict(), 'best-model-parameters.pt')
    if i % 5 == 0:
        print(f'epoch {i} loss: {mean_loss}')
    

epoch 0 loss: 0.8412933349609375
epoch 5 loss: 0.0737447738647461
epoch 10 loss: 0.002808394646921823
epoch 15 loss: 0.0019111610775770142
epoch 20 loss: 0.0017478375933891119
epoch 25 loss: 0.0014887509304423664
epoch 30 loss: 0.0015225289172904437
epoch 35 loss: 0.0013408848019533378
epoch 40 loss: 0.0012928737803947093
epoch 45 loss: 0.0011390601826268573
epoch 50 loss: 0.00117141840069793
epoch 55 loss: 0.0011468099993328716


In [109]:
inf_dataset = np.load('/Users/maxperozek/GINR-texture/npz_train/blank_sphere.npz')
inference_X = inf_dataset['fourier']

In [110]:
model = MLP(
            input_dim=X.shape[1],
            output_dim=y.shape[1],
            hidden_dim=512,
            n_layers=6,
            geometric_init=True,
            beta=True,
            sine=True,
            all_sine=True,
            skip=True,
            bn=True,
            dropout=0.0
            )

In [111]:
model.load_state_dict(torch.load('best-model-parameters.pt'))

<All keys matched successfully>

In [112]:
preds = model(torch.tensor(inference_X).float())

In [113]:
preds.shape

torch.Size([40962, 3])

In [114]:
preds = preds.detach().numpy()

In [115]:
# new_mesh = generate_pred_mesh(preds, blank)

In [116]:
pred_verts =  blank.vertices + preds
inference = trimesh.Trimesh(vertices=pred_verts, faces=blank.faces)
# inference = trimesh.repair.fill_holes(inference)


In [117]:
inference.show()