In [1]:
from os import sys
# Path to workspace
sys.path.insert(0, '/workspace/3d-shapes-embeddings/contrib/sharp_features/')
sys.path.insert(0, '/workspace/dense-self-supervised-representation-learning-for-3D-shapes/')

import h5py
import torch
import numpy as np
from tqdm import tqdm
import k3d

In [2]:
import neptune.new as neptune
from workspace.utils.train_loop import *

params = {
    'name': 'Only point clouds',
    'dataset': 'abc',
    'batch_size': 10,
    'tau': 0.07,
    'n_output': 512,
    'result_dim': 128,
    'hidden_dim': 256,
    'total_epochs': 50,
    'lr': 5e-5,
    'weight_decay': 1e-5,
    'save_every': 50,
    'weights_root': '../weights/'
}

# tags
tags = ['abc']



logger = neptune.init(project='seals5454/crossmodal-exps-igor',
                      name=params['name'],
                      tags=tags,
                      api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmU'\
                                'uYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS'\
                                '5haSIsImFwaV9rZXkiOiI2NTIwODVkNC1hOTg5LTQ4NTAtY'\
                                'WRhNS0yMGY4MmQ1YzBmZWIifQ=='
                      )

logger['parameters'] = params

device1, device2, device = 'cuda:3', 'cuda:3', 'cuda:3'



https://app.neptune.ai/seals5454/crossmodal-exps-igor/e/IGOREXP-44
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#.stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


In [3]:
from torch.utils.data import DataLoader
from workspace.crossmodal.data.datasets import *
from workspace.datasets.transforms import *
from workspace.crossmodal.utils.collates import collate_clouds, collate_meshes, multicollate

collate = lambda data: multicollate(
    data,
#     lambda x: collate_meshes(x, device=device),
#     lambda x: collate_meshes(x, device=device),
    lambda x: collate_clouds(x, device=device),
    lambda x: collate_clouds(x, device=device),
)

pdataset_train = DoubleDataset(data_path='abc_train.hdf5', modality=Modality.POINT_CLOUD,
                            transform=Compose(
        PointCloudNormalize(),
        RandomRotation(low=-45, high=45, axis='xyz'),
        RandomJitter(std=0.01, clip_bound=0.05)
    ),)
mdataset_train = DoubleDataset(data_path='abc_train.hdf5', modality=Modality.MESH,
                            transform=Compose(
        MeshNetRandomRotation(low=-45, high=45, axis='xyz'),
        MeshNetRandomJitter(std=0.01, clip_bound=0.05)
    ),)

train = DoubleModalityDataset(mdataset_train, pdataset_train)


pdataset_test = DoubleDataset(data_path='abc_test.hdf5', modality=Modality.POINT_CLOUD,
                            transform=Compose(
        PointCloudNormalize(),
        RandomRotation(low=-45, high=45, axis='xyz'),
        RandomJitter(std=0.01, clip_bound=0.05)
    ),)
mdataset_test = DoubleDataset(data_path='abc_test.hdf5', modality=Modality.MESH,
                             transform=Compose(
        MeshNetRandomRotation(low=-45, high=45, axis='xyz'),
        MeshNetRandomJitter(std=0.01, clip_bound=0.05)
    ),)

test = DoubleModalityDataset(mdataset_test, pdataset_test)



train_loader = DataLoader(pdataset_train, batch_size=params['batch_size'], shuffle=True,
               collate_fn=collate
)

test_loader = DataLoader(pdataset_test, batch_size=params['batch_size'], shuffle=False,
              collate_fn=collate
)

In [4]:
class Transpose(torch.nn.Module):
    def __init__(self, *dims):
        super().__init__()
        self.dims = dims

    def forward(self, data):
        return data.transpose(*self.dims)
    

class MultiModalModel(torch.nn.Module):
    def __init__(self, model1, model2, model_output_dim, result_dim=128, hidden_dim=256):
        super().__init__()
        self.model1 = model1.to(device1)
        self.model2 = model2.to(device2)
        self.head1 = torch.nn.Sequential(
            Transpose(1, 2),
            torch.nn.Linear(model_output_dim, hidden_dim),
            Transpose(1, 2),
            torch.nn.BatchNorm1d(hidden_dim),
            torch.nn.ReLU(),
            Transpose(1, 2),
            torch.nn.Linear(hidden_dim, result_dim),
            Transpose(1, 2),
        ).to(device1)
        
        self.head2 = torch.nn.Sequential(
            Transpose(1, 2),
            torch.nn.Linear(model_output_dim, hidden_dim),
            Transpose(1, 2),
            torch.nn.BatchNorm1d(hidden_dim),
            torch.nn.ReLU(),
            Transpose(1, 2),
            torch.nn.Linear(hidden_dim, result_dim),
            Transpose(1, 2),
        ).to(device2)
        
        
    def forward(self, input1_1, input1_2, input2_1, input2_2):
        v1_1_emb = self.model1.forward_features(input1_1)
        v1_2_emb = self.model1.forward_features(input1_2)
        
        input2_1, face_indexes = input2_1
        input2_2, face_indexes = input2_2
        v2_1_emb = self.model2.forward_features(input2_1)
        v2_2_emb = self.model2.forward_features(input2_2)
        
        
        return (
            self.head1(v1_1_emb),
            self.head1(v1_2_emb),
            self.head2(v2_1_emb),
            self.head2(v2_2_emb),
            face_indexes
        )
    

    def get_embeddings(self, input1_1, input1_2, input2_1, input2_2):
        v1_1_emb = self.model1.forward_features(input1_1)
        v1_2_emb = self.model1.forward_features(input1_2)
        v2_1_emb = self.model2.forward_features(input2_1)
        v2_2_emb = self.model2.forward_features(input2_2)
        
        return v1_1_emb, v1_2_emb, v2_1_emb, v2_2_emb
    
    
class OneModalityModel(torch.nn.Module):
    def __init__(self, model, model_output_dim, result_dim=128, hidden_dim=256):
        super().__init__()
        self.model = model.to(device1)
        self.head = torch.nn.Sequential(
            Transpose(1, 2),
            torch.nn.Linear(model_output_dim, hidden_dim),
            Transpose(1, 2),
            torch.nn.BatchNorm1d(hidden_dim),
            torch.nn.ReLU(),
            Transpose(1, 2),
            torch.nn.Linear(hidden_dim, result_dim),
            Transpose(1, 2),
        ).to(device1)
        
    def forward(self, input1, input2):
        v1_emb = self.model.forward_features(input1)
        v2_emb = self.model.forward_features(input2)
        
        
        return (
            self.head(v1_emb),
            self.head(v2_emb),
        )
    

    def get_embeddings(self, input1_1, input1_2, input2_1, input2_2):
        v1_emb = self.model.forward_features(input1)
        v2_emb = self.model.forward_features(input2)
        
        return v1_emb, v2_emb

In [5]:
from workspace.models.meshnet import MeshNet
from workspace.models.dgcnn import DGCNN

#mnet = MeshNet(n_patches=5)
dgcnn = DGCNN(n_patches=5)
model = OneModalityModel(
    dgcnn,
    params['n_output'],
    result_dim=params['result_dim'],
    hidden_dim=params['hidden_dim']
)
#model = MultiModalModel(mnet, dgcnn, 512)

In [6]:
from copy import deepcopy
from workspace.crossmodal.utils.losses import *

def forward( 
    model,
    batch, # raw data from dataloader
    logger, # neptune run
    mode # 'train'/'val'
): # -> loss

    (data1, face_indexes), (data2, _) = batch
    
    max_faces = face_indexes.max() + 1 #data1_1[0].shape[-1]

    out1, out2 = model(data1, data2)
    
    pooled1, counts1 = get_patch_embeddings(out1, face_indexes, max_faces)
    pooled2, counts2 = get_patch_embeddings(out1, face_indexes, max_faces)
    
    
    #local inside figures
    pc_local_loss = patch_contrastive_loss(
        (pooled1, counts1),
        (pooled2, counts2),
        params
    )
    
    gout1 = out1.mean(-1)
    gout2 = out2.mean(-1)
    
    # model level
    pc_loss = contrastive_loss(gout1, gout2, params)
    
    
    return {
        'loss': pc_loss + pc_local_loss,
        'pc_loss': pc_loss,
        'pc_local_loss': pc_local_loss,
    }

In [7]:
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=params['lr'],
    weight_decay=params['weight_decay']
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader) * params['total_epochs'])

In [8]:
train_model(model, params, logger,  train_loader, test_loader, optimizer, scheduler, forward)

100%|██████████| 100/100 [01:34<00:00,  1.06it/s, Epoch=1, Loss=2.75]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.23it/s, Loss=5.16]
100%|██████████| 100/100 [01:32<00:00,  1.08it/s, Epoch=2, Loss=1.83]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.26it/s, Loss=5.06]
100%|██████████| 100/100 [01:32<00:00,  1.08it/s, Epoch=3, Loss=1.6]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.28it/s, Loss=4.74]
100%|██████████| 100/100 [01:31<00:00,  1.09it/s, Epoch=4, Loss=1.52]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.27it/s, Loss=4.55]
100%|██████████| 100/100 [01:32<00:00,  1.09it/s, Epoch=5, Loss=1.42]
Validation: 100%|██████████| 20/20 [00:16<00:00,  1.23it/s, Loss=4.37]
100%|██████████| 100/100 [01:33<00:00,  1.07it/s, Epoch=6, Loss=1.27]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.27it/s, Loss=4.39]
100%|██████████| 100/100 [01:31<00:00,  1.09it/s, Epoch=7, Loss=1.24]
Validation: 100%|██████████| 20/20 [00:15<00:00,  1.27it/s, Loss=4.3] 
100%|█████████

In [10]:
ls ../weights/

[0m[01;34mIGOREXP-10[0m/  [01;34mIGOREXP-16[0m/  [01;34mIGOREXP-22[0m/  [01;34mIGOREXP-28[0m/  [01;34mIGOREXP-34[0m/  [01;34mIGOREXP-40[0m/
[01;34mIGOREXP-11[0m/  [01;34mIGOREXP-17[0m/  [01;34mIGOREXP-23[0m/  [01;34mIGOREXP-29[0m/  [01;34mIGOREXP-35[0m/  [01;34mIGOREXP-41[0m/
[01;34mIGOREXP-12[0m/  [01;34mIGOREXP-18[0m/  [01;34mIGOREXP-24[0m/  [01;34mIGOREXP-30[0m/  [01;34mIGOREXP-36[0m/  [01;34mIGOREXP-42[0m/
[01;34mIGOREXP-13[0m/  [01;34mIGOREXP-19[0m/  [01;34mIGOREXP-25[0m/  [01;34mIGOREXP-31[0m/  [01;34mIGOREXP-37[0m/  [01;34mIGOREXP-43[0m/
[01;34mIGOREXP-14[0m/  [01;34mIGOREXP-20[0m/  [01;34mIGOREXP-26[0m/  [01;34mIGOREXP-32[0m/  [01;34mIGOREXP-38[0m/  [01;34mIGOREXP-44[0m/
[01;34mIGOREXP-15[0m/  [01;34mIGOREXP-21[0m/  [01;34mIGOREXP-27[0m/  [01;34mIGOREXP-33[0m/  [01;34mIGOREXP-39[0m/  [01;34mIGOREXP-9[0m/
