In [1]:
from os import sys
# Path to workspace
sys.path.insert(0, '/workspace/dense-self-supervised-representation-learning-for-3D-shapes/')

import h5py
import torch
import numpy as np
from tqdm import tqdm
import k3d

In [2]:
import neptune.new as neptune
from workspace.utils.train_loop import *

params = {
    'name': 'meshnet_modelnet_unsupervised',
    'dataset': 'modelnet',
    'batch_size': 8,
    'tau': 0.07,
    'n_output': 512,
    'result_dim': 128,
    'hidden_dim': 256,
    'total_epochs': 100,
    'lr': 0.001,
    'weight_decay': 1e-5,
    'save_every': 20,
    'weights_root': 'weights/'
}

# tags
tags = ['modelnet', 'meshnet', 'unsupervised', 'local']

logger = neptune.init(
    project="igor3661/crossmodal",
    name=params['name'],
    tags=tags,
    api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcG'\
              'lfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJiN'\
              'zcxMGNkOS04ZjU3LTRmNDMtOWFjMS1kNDNkZDZlNDI4YWYifQ==',
)  # your credentials


logger['parameters'] = params

device = 'cuda:2'



https://app.neptune.ai/igor3661/crossmodal/e/CROSS-64
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#.stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


In [3]:
from torch.utils.data import Dataset, DataLoader
from workspace.crossmodal.utils.meshnet_preprop import *

class MeshnetDataset(Dataset):
    def __init__(self, data_path, rotation=None, jitter=None):
        super().__init__()
        self.rotation = rotation
        self.jitter = jitter
        self.file = h5py.File(data_path, 'r')

    def __getitem__(self, index):
        faces = self.file['faces'][index][:].reshape(-1, 3)
        vertices = self.file['vertices'][index][:].reshape(-1, 3)
        
        if self.rotation is not None:
            vertices = self.rotation(vertices)
        
        features, neighbors = process_mesh(faces, vertices)
        

        features = torch.from_numpy(features).float()
        neighbors = torch.from_numpy(neighbors).long()

        features = torch.permute(features, (1, 0))
        centers, corners, normals = features[:3], features[3:12], features[12:]
        
        if self.jitter is not None:
            centers = self.jitter(centers).float()
        
        corners = corners - torch.cat([centers, centers, centers], 0).float()

        return centers, corners, normals, neighbors #, normals
        
    def __len__(self):
        return self.file['points'].shape[0]
    

class DoubleDataset(MeshnetDataset):
    def __init__(self, **multimodal_dataset_kwargs):
        super().__init__(**multimodal_dataset_kwargs)

    def __getitem__(self, idx):
        return (*super().__getitem__(idx), *super().__getitem__(idx))

    def __len__(self):
        return super().__len__()

In [4]:
from workspace.datasets.transforms import *

train_data = DoubleDataset(
    data_path='modelnet/modelnet_train_1024.h5',
    rotation=RandomRotation(low=-45, high=45, axis='xyz'),
    jitter=RandomJitter(std=0.01, clip_bound=0.05),
)
test_data = DoubleDataset(
    data_path='modelnet/modelnet_test_1024.h5',
    rotation=RandomRotation(low=-45, high=45, axis='xyz'),
    jitter=RandomJitter(std=0.01, clip_bound=0.05),
)


train_loader = DataLoader(
    train_data,
    batch_size=params['batch_size'],
    num_workers=5,
    shuffle=True
)
test_loader = DataLoader(
    test_data,
    batch_size=params['batch_size'],
    shuffle=False,
    num_workers=5
)

In [5]:
class Transpose(torch.nn.Module):
    def __init__(self, *dims):
        super().__init__()
        self.dims = dims

    def forward(self, data):
        return data.transpose(*self.dims)
    

class Model(torch.nn.Module):
    def __init__(self, model, model_output_dim, result_dim, hidden_dim):
        super().__init__()
        self.model = model
        self.head = torch.nn.Sequential(
            Transpose(1, 2),
            torch.nn.Linear(model_output_dim, hidden_dim),
            Transpose(1, 2),
            torch.nn.BatchNorm1d(hidden_dim),
            torch.nn.ReLU(),
            Transpose(1, 2),
            torch.nn.Linear(hidden_dim, result_dim),
            Transpose(1, 2),
        )
        
        
    def forward(self, data):
        return self.head(self.model.forward_features(data))

In [6]:
from workspace.models.meshnet import MeshNet

meshnet = MeshNet(n_patches=5)
model = Model(
    meshnet,
    model_output_dim=params['n_output'],
    hidden_dim=params['hidden_dim'],
    result_dim=params['result_dim']
).to(device)

In [7]:
from workspace.crossmodal.utils.losses import *

def move_to_device(data, device='cpu'):
    if isinstance(data, list):
        return [item.to(device) for item in data]
    else:
        return data.to(device)


def forward( 
    model,
    batch, # raw data from dataloader
    logger, # neptune run
    mode # 'train'/'val'
): # -> loss

    batch = move_to_device(batch, device)
    
    data1 = batch[:4]
    data2 = batch[4:]
    
    out1 = model(data1)
    out2 = model(data2)
    
    gout1 = out1.mean(-1)
    gout2 = out2.mean(-1)
    
    counts = torch.ones(out1.shape[0], out1.shape[2]).to(device)
    
    local_loss = patch_contrastive_loss(
        (out1, counts),
        (out2, counts),
        params
    ) * 0.1
    
    
    global_loss = contrastive_loss(gout1, gout2, params)
    
    
    return {
        'loss': local_loss + global_loss,
        'local_loss': local_loss,
        'global_loss': global_loss
    }

In [8]:
def get_warmup_schedule(
        optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1
):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) /\
                   float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + np.cos(np.pi * float(num_cycles) * 2.0 * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)

optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=params['lr'],
    weight_decay=params['weight_decay']
)


scheduler = get_warmup_schedule(optimizer, 4 * len(train_loader), params['total_epochs'] * len(train_loader))

In [None]:
train_model(model, params, logger,  train_loader, test_loader, optimizer, scheduler, forward)

100%|██████████| 1231/1231 [01:47<00:00, 11.47it/s, Epoch=1, Loss=0.796]
Validation: 100%|██████████| 309/309 [00:19<00:00, 15.55it/s, Loss=0.816]
100%|██████████| 1231/1231 [01:45<00:00, 11.63it/s, Epoch=2, Loss=0.326]
Validation: 100%|██████████| 309/309 [00:19<00:00, 15.75it/s, Loss=0.765]
100%|██████████| 1231/1231 [01:46<00:00, 11.58it/s, Epoch=3, Loss=0.207]
Validation: 100%|██████████| 309/309 [00:20<00:00, 15.16it/s, Loss=0.332]
100%|██████████| 1231/1231 [01:47<00:00, 11.48it/s, Epoch=4, Loss=0.165]
Validation: 100%|██████████| 309/309 [00:20<00:00, 15.40it/s, Loss=0.279]
100%|██████████| 1231/1231 [01:47<00:00, 11.44it/s, Epoch=5, Loss=0.0772]
Validation: 100%|██████████| 309/309 [00:20<00:00, 15.35it/s, Loss=0.125]
100%|██████████| 1231/1231 [01:47<00:00, 11.46it/s, Epoch=6, Loss=0.0623]
Validation: 100%|██████████| 309/309 [00:20<00:00, 15.39it/s, Loss=0.25] 
100%|██████████| 1231/1231 [01:47<00:00, 11.42it/s, Epoch=7, Loss=0.0169]
Validation: 100%|██████████| 309/309 [00:2