In [1]:
from os import sys
# Path to workspace
#sys.path.insert(0, '/workspace/3d-shapes-embeddings/contrib/sharp_features/')
sys.path.insert(0, '/workspace/dense-self-supervised-representation-learning-for-3D-shapes/')

import h5py
import torch
import numpy as np
from tqdm import tqdm
import k3d

In [2]:
import neptune.new as neptune
from workspace.utils.train_loop import *

params = {
    'name': 'meshcnn global pretrained',
    'dataset': 'abc',
    'batch_size': 5,
    'tau': 0.07,
    'n_output': 512,
    'result_dim': 128,
    'hidden_dim': 256,
    'total_epochs': 300,
    'lr': 5e-4,
    'weight_decay': 1e-5,
    'save_every': 100,
    'weights_root': '../weights/'
}

# tags
tags = ['abc']

logger = neptune.init(
    project="igor3661/crossmodal",
    name=params['name'],
    tags=tags,
    api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcG'\
              'lfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJiN'\
              'zcxMGNkOS04ZjU3LTRmNDMtOWFjMS1kNDNkZDZlNDI4YWYifQ==',
)  # your credentials


logger['parameters'] = params

device = 'cuda:1'



https://app.neptune.ai/igor3661/crossmodal/e/CROSS-12
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#.stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


In [3]:
from torch.utils.data import DataLoader
from workspace.crossmodal.data.datasets import *
from workspace.datasets.transforms import *
from workspace.crossmodal.utils.collates import collate_clouds, collate_meshcnn, multicollate

In [4]:
pdataset_train = DoubleDataset(
    data_path='abc_train.hdf5',
    modality=Modality.POINT_CLOUD,
    transform=Compose(
        PointCloudNormalize(),
        RandomRotation(low=-45, high=45, axis='xyz'),
        RandomJitter(std=0.01, clip_bound=0.05)
    )
)


mdataset_train = DoubleDataset(
    data_path='abc_train.hdf5',
    modality=Modality.MESHCNN,
    meshcnn_opt = AttrDict({
        'normalize': True,
        'num_aug': 1,
        'scale_verts': True,
        'slide_verts': 0.3,
        'flip_edges': 0.3,
        'is_train': True,
        'ninput_edges': 700
    })
)

train = DoubleModalityDataset(mdataset_train, pdataset_train)

pdataset_test = DoubleDataset(
    data_path='abc_test.hdf5', 
    modality=Modality.POINT_CLOUD,
    transform=Compose(
        PointCloudNormalize(),
        RandomRotation(low=-45, high=45, axis='xyz'),
        RandomJitter(std=0.01, clip_bound=0.05)
    )
)

mdataset_test = DoubleDataset(
    data_path='abc_test.hdf5',
    modality=Modality.MESHCNN,
    meshcnn_opt = AttrDict({
        'normalize': True,
        'num_aug': 1,
        'scale_verts': True,
        'slide_verts': 0.3,
        'flip_edges': 0,
        'is_train': True,
        'ninput_edges': 700
    })
)

test = DoubleModalityDataset(mdataset_test, pdataset_test)


CrossmodalDataset.__len__ = lambda x:2

collate = lambda data:multicollate(
    data,
    lambda x: collate_meshcnn(x, device=device),
    lambda x: collate_meshcnn(x, device=device),
    lambda x: collate_clouds(x, device=device),
    lambda x: collate_clouds(x, device=device),
)

train_loader = DataLoader(train, batch_size=params['batch_size'], shuffle=False,
               collate_fn=collate
)

val_loader = DataLoader(test, batch_size=params['batch_size'], shuffle=False,
              collate_fn=collate
)

In [5]:
class Transpose(torch.nn.Module):
    def __init__(self, *dims):
        super().__init__()
        self.dims = dims

    def forward(self, data):
        return data.transpose(*self.dims)
    

class MultiModalModel(torch.nn.Module):
    def __init__(self, model1, model2, model_output_dim=512, result_dim=64, hidden_dim=1024):
        super().__init__()
        self.model1 = model1
        self.model2 = model2
        self.head1 = torch.nn.Sequential(
            Transpose(1, 2),
            torch.nn.Linear(64, 256),
            Transpose(1, 2),
            torch.nn.BatchNorm1d(256),
            torch.nn.ReLU(),
            Transpose(1, 2),
            torch.nn.Linear(256, 64),
            Transpose(1, 2),
        )
        
        self.head2 = torch.nn.Sequential(
            Transpose(1, 2),
            torch.nn.Linear(model_output_dim, hidden_dim),
            Transpose(1, 2),
            torch.nn.BatchNorm1d(hidden_dim),
            torch.nn.ReLU(),
            Transpose(1, 2),
            torch.nn.Linear(hidden_dim, result_dim),
            Transpose(1, 2),
        )
        
        
    def forward(self, input1_1, input1_2, input2_1, input2_2):
        v1_1_emb = self.model1(*input1_1)
        v1_2_emb = self.model1(*input1_2)
        v2_1_emb = self.model2.forward_features(input2_1)
        v2_2_emb = self.model2.forward_features(input2_2)
        
        
        return (
            self.head1(v1_1_emb),
            self.head1(v1_2_emb),
            self.head2(v2_1_emb),
            self.head2(v2_2_emb),
        )
    

    def get_embeddings(self, input1_1, input1_2, input2_1, input2_2):
        v1_1_emb = self.model1(*input1_1)
        v1_2_emb = self.model1(*input1_2)
        v2_1_emb = self.model2.forward_features(input2_1)
        v2_2_emb = self.model2.forward_features(input2_2)
        
        return v1_1_emb, v1_2_emb, v2_1_emb, v2_2_emb

In [6]:
from workspace.crossmodal.models.meshcnn.networks import *
from workspace.models.dgcnn import DGCNN


meshcnn = MeshEncoderDecoder(
    pools=[0, 700, 600, 500, 400, 200],
    down_convs=[5, 64, 128, 256, 512, 1024],
    up_convs=[1024, 1024, 512, 256, 128, 64],
    blocks=1
)

dgcnn = DGCNN(n_patches=5)

model = MultiModalModel(meshcnn, dgcnn).to(device)

In [7]:
from workspace.crossmodal.utils.losses import *

def forward( 
    model,
    batch, # raw data from dataloader
    logger, # neptune run
    mode # 'train'/'val'
): # -> loss

    mesh1, fe1, mesh2, fe2 = batch[0]['mesh'], batch[0]['edge_features'],\
                             batch[1]['mesh'], batch[1]['edge_features']
    
    
        
    out1_1, out1_2, out2_1, out2_2 = model(
        (batch[0]['edge_features'], batch[0]['mesh']),
        (batch[1]['edge_features'], batch[1]['mesh']),
        batch[2],
        batch[3]
    )
    

    
    gout1_1 = out1_1.mean(-1)
    gout1_2 = out1_2.mean(-1)
    gout2_1 = out2_1.mean(-1)
    gout2_2 = out2_2.mean(-1)
    # crossmodal
    crossmodal_loss = contrastive_loss(gout1_1, gout2_1, params) +\
           contrastive_loss(gout1_2, gout2_2, params) +\
           contrastive_loss(gout1_1, gout2_2, params) +\
           contrastive_loss(gout1_2, gout2_1, params)
    
    # model level
    pc_loss = contrastive_loss(gout1_1, gout1_2, params)
    mesh_loss = contrastive_loss(gout2_1, gout2_2, params)
    
    
    return {
        'loss': (0.25 * crossmodal_loss + pc_loss + mesh_loss),
        'pc_loss': pc_loss,
        'mesh_loss': mesh_loss,
        'crossmodal_loss:': 0.25 * crossmodal_loss
    }

In [8]:
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=params['lr'],
    weight_decay=params['weight_decay']
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader) * params['total_epochs'])

In [9]:
train_model(model, params, logger, train_loader, val_loader, optimizer, scheduler, forward)

100%|██████████| 1/1 [00:01<00:00,  1.59s/it, Epoch=1, Loss=13]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s, Loss=13.2]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=2, Loss=4.23]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.28it/s, Loss=11.7]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=3, Loss=2.58]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.28it/s, Loss=10.6]
100%|██████████| 1/1 [00:01<00:00,  1.09s/it, Epoch=4, Loss=1.73]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s, Loss=8.62]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=5, Loss=1.58]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s, Loss=7.47]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=6, Loss=2.08]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.28it/s, Loss=6.8]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=7, Loss=1.85]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.27it/s, Loss=5.8]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=8, Loss

100%|██████████| 1/1 [00:01<00:00,  1.10s/it, Epoch=61, Loss=0.00537]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.28it/s, Loss=1.78]
100%|██████████| 1/1 [00:01<00:00,  1.09s/it, Epoch=62, Loss=0.000463]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.28it/s, Loss=3.03]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=63, Loss=0.408]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.30it/s, Loss=4]
100%|██████████| 1/1 [00:01<00:00,  1.09s/it, Epoch=64, Loss=0.823]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s, Loss=1.39]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=65, Loss=0.0749]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s, Loss=0.508]
100%|██████████| 1/1 [00:01<00:00,  1.09s/it, Epoch=66, Loss=0.0132]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.28it/s, Loss=0.434]
100%|██████████| 1/1 [00:01<00:00,  1.09s/it, Epoch=67, Loss=0.00783]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.28it/s, Loss=2.63]
100%|██████████| 1/1 [00:01<00:00

100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=120, Loss=0.00077]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.30it/s, Loss=1.9]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=121, Loss=0.199]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.30it/s, Loss=1.32]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=122, Loss=0.0136]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s, Loss=0.306]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=123, Loss=0.00133]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s, Loss=2.42]
100%|██████████| 1/1 [00:01<00:00,  1.09s/it, Epoch=124, Loss=0.00103]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s, Loss=1.9]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=125, Loss=0.000426]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.27it/s, Loss=3.7]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=126, Loss=0.000436]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.33it/s, Loss=3.1]
100%|██████████| 1/1 [00

100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=178, Loss=0.000362]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.30it/s, Loss=0.0814]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=179, Loss=0.0107]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.28it/s, Loss=0.845]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=180, Loss=0.000913]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.20it/s, Loss=0.0159]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=181, Loss=0.202]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s, Loss=0.0551]
100%|██████████| 1/1 [00:01<00:00,  1.09s/it, Epoch=182, Loss=0.000411]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.26it/s, Loss=0.0401]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=183, Loss=0.000416]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s, Loss=0.631]
100%|██████████| 1/1 [00:01<00:00,  1.09s/it, Epoch=184, Loss=0.00226]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s, Loss=0.415]
100%|███

100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=236, Loss=0.000186]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.30it/s, Loss=1.97]
100%|██████████| 1/1 [00:01<00:00,  1.09s/it, Epoch=237, Loss=0.000219]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s, Loss=0.0873]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=238, Loss=0.000212]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.30it/s, Loss=0.767]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=239, Loss=0.544]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s, Loss=0.35]
100%|██████████| 1/1 [00:01<00:00,  1.09s/it, Epoch=240, Loss=0.000408]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s, Loss=7.11]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=241, Loss=0.0022]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s, Loss=0.0403]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=242, Loss=0.000196]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s, Loss=0.211]
100%|███████

100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=294, Loss=0.000257]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.32it/s, Loss=6.69]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=295, Loss=0.000171]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.32it/s, Loss=0.883]
100%|██████████| 1/1 [00:01<00:00,  1.09s/it, Epoch=296, Loss=0.00021]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.33it/s, Loss=3.91]
100%|██████████| 1/1 [00:01<00:00,  1.09s/it, Epoch=297, Loss=0.000261]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.28it/s, Loss=5.38]
100%|██████████| 1/1 [00:01<00:00,  1.09s/it, Epoch=298, Loss=0.00289]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.29it/s, Loss=0.279]
100%|██████████| 1/1 [00:01<00:00,  1.09s/it, Epoch=299, Loss=0.149]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.32it/s, Loss=1.24]
100%|██████████| 1/1 [00:01<00:00,  1.08s/it, Epoch=300, Loss=0.000186]
Validation: 100%|██████████| 1/1 [00:00<00:00,  1.27it/s, Loss=0.408]


In [10]:
batch = next(iter(train_loader))
with torch.no_grad():
    out1_1, out1_2, out2_1, out2_2 = model(
        (batch[0]['edge_features'], batch[0]['mesh']),
        (batch[1]['edge_features'], batch[1]['mesh']),
        batch[2],
        batch[3]
    )
    fm1 = out1_1.mean(-1).detach().cpu()
    fm2 = out1_2.mean(-1).detach().cpu()
    fp1 = out2_1.mean(-1).detach().cpu()
    fp2 = out2_2.mean(-1).detach().cpu()

In [26]:
batch[1]['mesh'][1].edges.shape

(37, 2)

In [29]:
fm1 = F.normalize(fm1, dim=-1)
fm2 = F.normalize(fm2, dim=-1)
fp1 = F.normalize(fp1, dim=-1)
fp2 = F.normalize(fp2, dim=-1)

In [31]:
print(fp1 @ fp2.T)
print()
print(fm1 @ fm2.T)

tensor([[0.9996, 0.7278],
        [0.2787, 0.8076]])

tensor([[ 0.9097, -0.2041],
        [-0.0129,  0.7249]])


In [30]:
print(fp1 @ fm1.T)
print()
print(fp1 @ fm2.T)
print()
print(fp2 @ fm1.T)
print()
print(fp2 @ fm2.T)

tensor([[ 0.8658,  0.3005],
        [-0.0173,  0.6856]])

tensor([[ 0.8362,  0.0696],
        [-0.0327,  0.7769]])

tensor([[0.8637, 0.3025],
        [0.4774, 0.5768]])

tensor([[0.8334, 0.0727],
        [0.4559, 0.5078]])


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(15, 15))
sns.heatmap(res, annot=True, fmt="f", )