In [20]:
import os
import time
import torch
import numpy as np

import numpy as np
import matplotlib.pyplot as plt
import os
import vedo
import time
import torch
import torch.nn.functional as F
import napari

from collections import OrderedDict
from modules.click_attention import ClickAttention
from modules.decoder import Decoder

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
def load_state_dict(model, state_dict):
    try:
        model.load_state_dict(state_dict)
    except RuntimeError:
        # train on multiple gpus and test on a single gpu. remove 'module' prefix form variable names
        dict_wo_module_prefix = OrderedDict([(".".join(k.split(".")[1:]), v) for k, v in state_dict.items()])
        model.load_state_dict(dict_wo_module_prefix)

def save_loss(loss, dir, name=None):
    plt.figure()
    plt.plot(loss)
    plt.yscale('log')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')

    # Ensure the directory exists
    os.makedirs(dir, exist_ok=True)

    # Save the figure
    if name is not None:
        plt.title(name+' over time')
        plt.savefig(os.path.join(dir, name+'.jpg'))
        plt.close()
    else:
        plt.title('Loss over time')
        plt.savefig(os.path.join(dir, 'loss.jpg'))
        plt.close()


def load_features(encoder_f_path):
    # load features
    pred_f = torch.load(encoder_f_path)
    return pred_f


def masked_bce_loss(y_pred, y_true, mask):
    # Compute the raw BCE loss term-wise
    bce = F.binary_cross_entropy(y_pred, F.one_hot(y_true.long(), num_classes=2).squeeze(1).permute(0, 3, 1, 2).float(), reduction='none')
    mask = torch.cat((mask.unsqueeze(1), mask.unsqueeze(1)), dim=1)
    # Apply the mask
    masked_bce = bce * mask
    
    # Compute the mean of the masked BCE values
    loss = masked_bce.sum() / mask.sum()
    return loss

def fix_state_dict(state_dict):
    """Fixes the state dictionary by removing 'module.' prefix."""
    new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
    return new_state_dict

In [13]:
def test_decoder(test_index,
                 depth,
                 width,
                 device,
                 positional_encoding,
                 sigma,
                 save_dir,
                 model_name,
                 encoder_f_path,
                 feature_dim=256,
                 use_attention_k=True,
                 use_attention_q=True,
                 use_attention_v=True,
                 residual_attention=True,
                 scale_attention=True):
    print('test decoder')

    # Create an instance of ClickAttention with the new keyword arguments
    attention = ClickAttention(
        feature_dim=feature_dim,
        use_attention_q=use_attention_q,
        use_attention_k=use_attention_k,
        use_attention_v=use_attention_v,
        residual_attention=residual_attention,
        scale_attention=scale_attention
    ).to(device)

    mlp = Decoder(depth=depth, width=[512] + [256] * depth, out_dim=2, input_dim=512, positional_encoding=positional_encoding,
                  sigma=sigma).to(device)
    
    save_path = os.path.join(save_dir, model_name)
    print('decoder checkpoint path: %s' % save_path)
    
    checkpoint = torch.load(save_path, map_location=device)
    attention.load_state_dict(fix_state_dict(checkpoint['attention_state_dict']))
    mlp.load_state_dict(fix_state_dict(checkpoint['model_state_dict']))

    # Read learned 3D features
    pred_f = load_features(encoder_f_path)

    if len(test_index.shape) == 1:
        test_index = test_index.unsqueeze(0)
    
    batch_size = test_index.shape[0]
    
    feature_field_batch = pred_f.unsqueeze(0).expand(batch_size, -1, -1)
    weighted_vals = attention(feature_field_batch, test_index)
    input_tensor = torch.cat((feature_field_batch, weighted_vals), dim=-1)
            
    # MLP
    prob_tensor = mlp(input_tensor)

    # Save vertex probability
    return prob_tensor.detach().cpu().numpy().astype(np.float32)

# Append to the existing files or create them if they don't exist
def save_or_append(filename, data):
    if os.path.exists(filename):
        existing_data = torch.load(filename, map_location=device)
        if len(data.shape)<len(existing_data.shape):
            combined_data = torch.cat([existing_data, data.unsqueeze(0)])
        else:
            combined_data = torch.cat([existing_data, data])
        torch.save(combined_data, filename)
    else:
        torch.save(data, filename)


In [17]:
# Define parameters directly
seed = 0
obj_path = './meshes/hammer.obj'
encoder_f_path = './demo/hammer/encoder/pred_f.pth'
decoder_data_dir = './data/hammer/decoder_data'
save_dir = './demo/hammer/decoder/'
model_name = 'decoder_checkpoint.pth'
use_positive_click = 0
use_negative_click = 0
name = 'hammer'
data_percentage = 1.0
views_per_vert = 100
background = [1., 1., 1.]
n_views = 1
frontview_std = 4
frontview_center = [3.14, 0.]
render_res = 224
use_attention_q = 1
use_attention_k = 1
use_attention_v = 1
redsidual_attention = 0
scale_attention = 1
continue_train = 0
depth = 14
width = 256
n_classes = 256  # 256 channels for SAM embedding feature
positional_encoding = False
sigma = 5.0
batch_size = 16
learning_rate = 0.0001
num_epochs = 5
save_interval = 100
return_original = 0
use_data_parallel = 0
mode = 'test'
select_vertices = [5000]
show = 0
base_color = [180, 180, 180]
show_seg = 1
seg_color = [60, 160, 250]
show_spheres = 0
sphere_radius = 0.025
pos_color = [0, 255, 0]
neg_color = [255, 0, 0]

# Load mesh object
mesh = vedo.load(obj_path)

# Create decoder directory if it does not exist
if not os.path.exists(save_dir):
    os.makedirs(save_dir, exist_ok=True)
    
test_index = torch.tensor(select_vertices)
probabilities = test_decoder(
    test_index=test_index,
    depth=depth,
    width=width,
    device=device,
    positional_encoding=positional_encoding,
    sigma=sigma,
    use_data_parallel=use_data_parallel,
    save_dir=save_dir,
    model_name=model_name,
    encoder_f_path=encoder_f_path,
    feature_dim=n_classes,
    use_attention_k=use_attention_k,
    use_attention_q=use_attention_q,
    use_attention_v=use_attention_v,
    residual_attention=redsidual_attention,
    scale_attention=scale_attention
)
    

test decoder
ModuleList(
  (0): Linear(in_features=512, out_features=512, bias=True)
  (1): ReLU()
  (2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (3): Linear(in_features=512, out_features=256, bias=True)
  (4): ReLU()
  (5): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (6): Linear(in_features=256, out_features=256, bias=True)
  (7): ReLU()
  (8): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (9): Linear(in_features=256, out_features=256, bias=True)
  (10): ReLU()
  (11): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (12): Linear(in_features=256, out_features=256, bias=True)
  (13): ReLU()
  (14): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (15): Linear(in_features=256, out_features=256, bias=True)
  (16): ReLU()
  (17): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (18): Linear(in_features=256, out_features=256, bias=True)
  (19): ReLU()
  (20): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (21): Linear(in_f

In [7]:
viewer = napari.Viewer()

In [9]:
viewer.add_surface((mesh.vertices, np.asarray(mesh.cells)))

<Surface layer 'Surface' at 0x1c502f405b0>

In [18]:
probabilities.shape

(1, 11595, 2)

In [12]:
len(mesh.cells)

23128

In [16]:
mesh.vertices.shape

(69384, 3)

In [23]:
len(mesh.edges)

69384