In [1]:
from os import sys
# Path to workspace
sys.path.insert(0, '/workspace/dense-self-supervised-representation-learning-for-3D-shapes/')

import h5py
import torch
import numpy as np
from tqdm import tqdm
import k3d

In [2]:
import neptune.new as neptune
from workspace.utils.train_loop import *

params = {
    'name': 'modelnet_cls_linprobe',
    'dataset': 'modelnet',
    'batch_size': 15,
    'tau': 0.07,
    'n_output': 512,
    'result_dim': 3,
    'hidden_dim': 256,
    'total_epochs': 100,
    'lr': 0.001,
    'weight_decay': 1e-5,
    'save_every': 50,
    'weights_root': 'weights/'
}

# tags
# tags = ['modelnet', 'classification', 'linprobe']

# logger = neptune.init(
#     project="igor3661/crossmodal",
#     name=params['name'],
#     tags=tags,
#     api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcG'\
#               'lfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJiN'\
#               'zcxMGNkOS04ZjU3LTRmNDMtOWFjMS1kNDNkZDZlNDI4YWYifQ==',
# )  # your credentials


# logger['parameters'] = params

device = 'cuda:2'

In [3]:
from torch.utils.data import Dataset, DataLoader
from workspace.crossmodal.utils.meshnet_preprop import *


class PointDataset(Dataset):
    def __init__(self, data_path, transform=None):
        super().__init__()
        self.transform = transform
        self.file = h5py.File(data_path, 'r')

    def __getitem__(self, index):        
        points = self.file['points'][index][:]

        if self.transform is not None:
            points = self.transform(points)

        points = torch.from_numpy(points).float()
        points = torch.permute(points, (1, 0))
        return points

        
    def __len__(self):
        return self.file['points'].shape[0]
    
class MeshnetDataset(Dataset):
    def __init__(self, data_path, rotation=None, jitter=None):
        super().__init__()
        self.rotation = rotation
        self.jitter = jitter
        self.file = h5py.File(data_path, 'r')

    def __getitem__(self, index):
        faces = self.file['faces'][index][:].reshape(-1, 3)
        vertices = self.file['vertices'][index][:].reshape(-1, 3)
        
        if self.rotation is not None:
            vertices = self.rotation(vertices)
        
        features, neighbors = process_mesh(faces, vertices)
        

        features = torch.from_numpy(features).float()
        neighbors = torch.from_numpy(neighbors).long()

        features = torch.permute(features, (1, 0))
        centers, corners, normals = features[:3], features[3:12], features[12:]
        
        if self.jitter is not None:
            centers = self.jitter(centers).float()
        
        corners = corners - torch.cat([centers, centers, centers], 0).float()

        return centers, corners, normals, neighbors
        
    def __len__(self):
        return self.file['points'].shape[0]
    

class DoubleDataset(Dataset):
    def __init__(self, kwargs):
        super().__init__()
        self.mesh = MeshnetDataset(**kwargs['mesh'])
        self.point = PointDataset(**kwargs['point'])

    def __getitem__(self, idx):
        label = self.mesh.file['labels'][idx]
        face_index = torch.from_numpy(self.mesh.file['face_index'][idx][:]).long()
        return (*self.mesh.__getitem__(idx), self.point.__getitem__(idx), label, face_index)

    def __len__(self):
        return self.mesh.__len__()

In [4]:
train_kwargs = {
    'mesh': {
        'data_path': 'modelnet/modelnet_train_1024.h5',
    },
    'point': {
        'data_path': 'modelnet/modelnet_train_1024.h5',
    }
}

test_kwargs = {
    'mesh': {
        'data_path': 'modelnet/modelnet_test_1024.h5',
    },
    'point': {
        'data_path': 'modelnet/modelnet_test_1024.h5',
    }
}


train_data = DoubleDataset(train_kwargs)
test_data = DoubleDataset(test_kwargs)

train_loader = DataLoader(
    train_data,
    batch_size=params['batch_size'],
    num_workers=10,
    shuffle=False
)
test_loader = DataLoader(
    test_data,
    batch_size=params['batch_size'],
    shuffle=False,
    num_workers=10
)

In [5]:
class Transpose(torch.nn.Module):
    def __init__(self, *dims):
        super().__init__()
        self.dims = dims

    def forward(self, data):
        return data.transpose(*self.dims)
    

class NormalsModel(torch.nn.Module):
    def __init__(self, model, model_output_dim, result_dim, hidden_dim):
        super().__init__()
        self.model = model
        self.head = torch.nn.Sequential(
            Transpose(1, 2),
            torch.nn.Linear(model_output_dim, hidden_dim),
            Transpose(1, 2),
            torch.nn.BatchNorm1d(hidden_dim),
            torch.nn.ReLU(),
            Transpose(1, 2),
            torch.nn.Linear(hidden_dim, 128),
            Transpose(1, 2),
        )
        
    def forward(self, data):
        return self.head(self.model.forward_features(data))

In [6]:
from workspace.models.dgcnn import DGCNN
from workspace.models.meshnet import MeshNet

meshnet = MeshNet(n_patches=5)
dgcnn = DGCNN(n_patches=5)

mesh_model = NormalsModel(
    meshnet,
    model_output_dim=params['n_output'],
    hidden_dim=params['hidden_dim'],
    result_dim=params['result_dim']
).to(device).eval()
# point_model = NormalsModel(
#     dgcnn,
#     model_output_dim=params['n_output'],
#     hidden_dim=params['hidden_dim'],
#     result_dim=params['result_dim']
# ).to(device).eval()

mesh_model.load_state_dict(torch.load('weights/CROSS-57/100epoch.pt'))
# point_model.load_state_dict(torch.load('weights/CROSS-56/100epoch.pt'))

<All keys matched successfully>

In [7]:
class Projector(torch.nn.Module):
    def __init__(self, model_output_dim):
        super().__init__()
        self.head = torch.nn.Sequential(
                #Transpose(1, 2),
                torch.nn.Linear(model_output_dim, model_output_dim * 2),
                #Transpose(1, 2),
                torch.nn.BatchNorm1d(model_output_dim * 2),
                torch.nn.ReLU(),
                #Transpose(1, 2),
                torch.nn.Linear(model_output_dim * 2, model_output_dim),
                #Transpose(1, 2),
            )
    
    def forward(self, data):
        return self.head(data)

proj = Projector(params['n_output']).to(device).eval()
proj.load_state_dict(torch.load('weights/CROSS-79/100epoch.pt'))

<All keys matched successfully>

In [8]:
from tqdm import tqdm

def move_to_device(data, device='cpu'):
    if isinstance(data, list):
        return [item.to(device) for item in data]
    else:
        return data.to(device)
    
    
def collect_embeddings(loader):
    all_embeddings = []
    all_targets = []
    
    for centers, corners, normals, neighbors, points, labels, face_index in tqdm(loader):
        mesh_data = move_to_device([centers, corners, normals, neighbors], device)
        points = points.to(device)
        labels = labels.to(device)
        
        point_embs = point_model.model.forward_features(points).detach().cpu().mean(-1)
        mesh_embs = mesh_model.model.forward_features(mesh_data).detach()
        
        projected = proj(mesh_embs).detach().cpu()
        gathered = torch.gather(projected, 2, face_index.unsqueeze(1).expand((-1, 512, -1))).mean(-1)
        
        
        all_embeddings.append(gathered)
        all_targets.append(labels.cpu())
        
    all_embeddings = torch.cat(all_embeddings, dim=0)
    all_targets = torch.cat(all_targets)
        
    return all_embeddings, all_targets

def collect_embeddings_coords(loader):
    all_embeddings = []
    all_targets = []
    
    for centers, corners, normals, neighbors, points, labels, face_index in tqdm(loader):
        mesh_data = move_to_device([centers, corners, normals, neighbors], device)
        centers = mesh_data[0]
        points = points.to(device)
        labels = labels.to(device)
        face_index = face_index.to(device)
        
        #point_embs = point_model.model.forward_features(points).detach().cpu().mean(-1)
        mesh_embs = mesh_model.model.forward_features(mesh_data).detach()
        
        #concated = torch.cat([mesh_embs, centers], dim=1)
        
        gathered = torch.gather(mesh_embs, 2, face_index.unsqueeze(1).expand((-1, 512, -1)))
        concated = torch.cat([gathered, points], dim=1)

        projected = proj(concated).detach().cpu().mean(-1)
        
        all_embeddings.append(projected)
        all_targets.append(labels.cpu())
        
    all_embeddings = torch.cat(all_embeddings, dim=0)
    all_targets = torch.cat(all_targets)
        
    return all_embeddings, all_targets

def collect_embeddings_global(loader):
    all_embeddings = []
    all_targets = []
    
    for centers, corners, normals, neighbors, points, labels, face_index in tqdm(loader):
        mesh_data = move_to_device([centers, corners, normals, neighbors], device)
        centers = mesh_data[0]
        points = points.to(device)
        labels = labels.to(device)
        face_index = face_index.to(device)
        
        #point_embs = point_model.model.forward_features(points).detach().cpu().mean(-1)
        mesh_embs = mesh_model.model.forward_features(mesh_data).detach().mean(-1)

        projected = proj(mesh_embs).detach().cpu()
        
        all_embeddings.append(projected)
        all_targets.append(labels.cpu())
        
    all_embeddings = torch.cat(all_embeddings, dim=0)
    all_targets = torch.cat(all_targets)
        
    return all_embeddings, all_targets

In [9]:
test_embeddings, test_labels = collect_embeddings_global(test_loader)
train_embeddings, train_labels = collect_embeddings_global(train_loader)

100%|██████████| 165/165 [00:08<00:00, 20.39it/s]
100%|██████████| 657/657 [00:26<00:00, 24.68it/s]


In [10]:
from sklearn.svm import SVC
from sklearn.metrics import classification_report

def train_svm(X, y):
    svm = SVC(kernel='linear')
    perm = np.random.permutation(X.shape[0])
    svm.fit(X[perm], y[perm])
    return svm


def train_eval(X_train, y_train, X_test, y_test):
    print('X_train size:', X_train.shape[0], 'X_test size:', X_test.shape[0], 'dim:', X_train.shape[1])
    svm = train_svm(X_train, y_train)
    y_pred = svm.predict(X_test)
    return y_test, y_pred

In [11]:
y_test, y_pred = train_eval(train_embeddings, train_labels, test_embeddings, test_labels)

X_train size: 9842 X_test size: 2468 dim: 512


In [12]:
report = classification_report(y_test, y_pred, digits=4)
print(report)

              precision    recall  f1-score   support

           0     0.8761    0.9900    0.9296       100
           1     0.3333    0.0800    0.1290        50
           2     0.3433    0.4600    0.3932       100
           3     0.3529    0.3000    0.3243        20
           4     0.3690    0.6900    0.4808       100
           5     0.8810    0.7400    0.8043       100
           6     0.8000    0.4000    0.5333        20
           7     0.7407    0.6000    0.6630       100
           8     0.4709    0.8100    0.5956       100
           9     0.8500    0.8500    0.8500        20
          10     0.6250    0.2500    0.3571        20
          11     0.3500    0.3500    0.3500        20
          12     0.3030    0.1163    0.1681        86
          13     0.3600    0.4500    0.4000        20
          14     0.4815    0.4535    0.4671        86
          15     0.0000    0.0000    0.0000        20
          16     0.7670    0.7900    0.7783       100
          17     0.9579    