In [None]:
from os import sys
# Path to workspace
sys.path.insert(0, '/workspace/dense-self-supervised-representation-learning-for-3D-shapes/')

import h5py
import torch
import numpy as np
from tqdm import tqdm
import k3d

In [None]:
import neptune.new as neptune
from workspace.utils.train_loop import *

params = {
    'name': 'modelnet_cls_linprobe',
    'dataset': 'modelnet',
    'batch_size': 30,
    'tau': 0.07,
    'n_output': 512,
    'result_dim': 3,
    'hidden_dim': 256,
    'total_epochs': 100,
    'lr': 0.001,
    'weight_decay': 1e-5,
    'save_every': 50,
    'weights_root': 'weights/'
}

# tags
# tags = ['modelnet', 'classification', 'linprobe']

# logger = neptune.init(
#     project="igor3661/crossmodal",
#     name=params['name'],
#     tags=tags,
#     api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcG'\
#               'lfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJiN'\
#               'zcxMGNkOS04ZjU3LTRmNDMtOWFjMS1kNDNkZDZlNDI4YWYifQ==',
# )  # your credentials


# logger['parameters'] = params

device = 'cuda:0'

In [None]:
from torch.utils.data import Dataset, DataLoader
from workspace.crossmodal.utils.meshnet_preprop import *


class PointDataset(Dataset):
    def __init__(self, data_path, transform=None):
        super().__init__()
        self.transform = transform
        self.file = h5py.File(data_path, 'r')

    def __getitem__(self, index):        
        points = self.file['points'][index][:]
        label = self.file['labels'][index]

        if self.transform is not None:
            points = self.transform(points)

        points = torch.from_numpy(points).float()
        points = torch.permute(points, (1, 0))
        return points, label

        
    def __len__(self):
        return self.file['points'].shape[0]
    
class MeshnetDataset(Dataset):
    def __init__(self, data_path, rotation=None, jitter=None):
        super().__init__()
        self.rotation = rotation
        self.jitter = jitter
        self.file = h5py.File(data_path, 'r')

    def __getitem__(self, index):
        faces = self.file['faces'][index][:].reshape(-1, 3)
        vertices = self.file['vertices'][index][:].reshape(-1, 3)
        label = self.file['labels'][index]
        
        if self.rotation is not None:
            vertices = self.rotation(vertices)
        
        features, neighbors = process_mesh(faces, vertices)
        

        features = torch.from_numpy(features).float()
        neighbors = torch.from_numpy(neighbors).long()

        features = torch.permute(features, (1, 0))
        centers, corners, normals = features[:3], features[3:12], features[12:]
        
        if self.jitter is not None:
            centers = self.jitter(centers).float()
        
        corners = corners - torch.cat([centers, centers, centers], 0).float()

        return centers, corners, normals, neighbors, label
        
    def __len__(self):
        return self.file['points'].shape[0]

In [None]:
train_data = PointDataset('modelnet/modelnet_train_1024.h5')
test_data = PointDataset('modelnet/modelnet_test_1024.h5')


train_loader = DataLoader(
    train_data,
    shuffle=True,
    batch_size=params['batch_size'],
    num_workers=5,
)
test_loader = DataLoader(
    test_data,
    shuffle=False,
    batch_size=params['batch_size'],
    num_workers=5
)

In [None]:
class Transpose(torch.nn.Module):
    def __init__(self, *dims):
        super().__init__()
        self.dims = dims

    def forward(self, data):
        return data.transpose(*self.dims)
    

class NormalsModel(torch.nn.Module):
    def __init__(self, model, model_output_dim, result_dim, hidden_dim):
        super().__init__()
        self.model = model
        self.head = torch.nn.Sequential(
            Transpose(1, 2),
            torch.nn.Linear(model_output_dim, hidden_dim),
            Transpose(1, 2),
            torch.nn.BatchNorm1d(hidden_dim),
            torch.nn.ReLU(),
            Transpose(1, 2),
            torch.nn.Linear(hidden_dim, 128),
            Transpose(1, 2),
        )
        
    def forward(self, data):
        return self.head(self.model.forward_features(data))

In [None]:
from workspace.models.dgcnn import DGCNN
from workspace.models.meshnet import MeshNet

dgcnn = DGCNN(n_patches=5)

model = NormalsModel(
    dgcnn,
    model_output_dim=params['n_output'],
    hidden_dim=params['hidden_dim'],
    result_dim=params['result_dim']
).to(device).eval()

model.load_state_dict(torch.load('weights/CROSS-63/100epoch.pt'))

In [None]:
from tqdm import tqdm

def move_to_device(data, device='cpu'):
    if isinstance(data, list):
        return [item.to(device) for item in data]
    else:
        return data.to(device)


def collect_embeddings_points(loader):
    all_embeddings = []
    all_targets = []
    
    for points, labels in tqdm(loader):
        points = points.to(device)
        labels = labels.to(device)
    
        all_embeddings.append(model.model.forward_features(points).detach().cpu().mean(-1))
        all_targets.append(labels.cpu())
        
    all_embeddings = torch.cat(all_embeddings, dim=0)
    all_targets = torch.cat(all_targets)
        
    return all_embeddings, all_targets

def collect_embeddings_meshnet(loader):
    all_embeddings = []
    all_targets = []
    
    for centers, corners, normals, neighbors, labels in tqdm(loader):
        data = move_to_device([centers, corners, normals, neighbors], device)
        labels = labels.to(device)
    
        all_embeddings.append(model.model.forward_features(data).detach().cpu().mean(-1))
        all_targets.append(labels.cpu())
        
    all_embeddings = torch.cat(all_embeddings, dim=0)
    all_targets = torch.cat(all_targets)
        
    return all_embeddings, all_targets

In [None]:
test_embeddings, test_labels = collect_embeddings_points(test_loader)
train_embeddings, train_labels = collect_embeddings_points(train_loader)

In [None]:
from sklearn.svm import SVC
from sklearn.metrics import classification_report

def train_svm(X, y):
    svm = SVC(kernel='linear')
    perm = np.random.permutation(X.shape[0])
    svm.fit(X[perm], y[perm])
    return svm


def train_eval(X_train, y_train, X_test, y_test):
    print('X_train size:', X_train.shape[0], 'X_test size:', X_test.shape[0], 'dim:', X_train.shape[1])
    svm = train_svm(X_train, y_train)
    y_pred = svm.predict(X_test)
    return y_test, y_pred

In [None]:
y_test, y_pred = train_eval(train_embeddings, train_labels, test_embeddings, test_labels)

In [None]:
report = classification_report(y_test, y_pred, digits=4)
print(report)