In [None]:
from os import sys
# Path to workspace
sys.path.insert(0, '/workspace/dense-self-supervised-representation-learning-for-3D-shapes/')

import h5py
import torch
import numpy as np
from tqdm import tqdm
import k3d

In [None]:
import neptune.new as neptune
from workspace.utils.train_loop import *

params = {
    'name': 'dgcnn_modelnet_unsupervised',
    'dataset': 'modelnet',
    'batch_size': 8,
    'tau': 0.07,
    'n_output': 512,
    'result_dim': 128,
    'hidden_dim': 256,
    'total_epochs': 100,
    'lr': 0.001,
    'weight_decay': 1e-5,
    'save_every': 20,
    'weights_root': 'weights/'
}

# tags
tags = ['modelnet', 'unsupervised', 'dgcnn', 'local']

logger = neptune.init(
    project="igor3661/crossmodal",
    name=params['name'],
    tags=tags,
    api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcG'\
              'lfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJiN'\
              'zcxMGNkOS04ZjU3LTRmNDMtOWFjMS1kNDNkZDZlNDI4YWYifQ==',
)  # your credentials


logger['parameters'] = params

device = 'cuda:0'

In [None]:
from torch.utils.data import Dataset, DataLoader
from workspace.crossmodal.utils.meshnet_preprop import *

class PointDataset(Dataset):
    def __init__(self, data_path, transform=None):
        super().__init__()
        self.transform = transform
        self.file = h5py.File(data_path, 'r')

    def __getitem__(self, index):
        points = self.file['points'][index][:]

        if self.transform is not None:
            points = self.transform(points)

        points = torch.from_numpy(points).float()
        points = torch.permute(points, (1, 0))
        return points
        
    def __len__(self):
        return self.file['points'].shape[0]
    

class DoubleDataset(PointDataset):
    def __init__(self, **multimodal_dataset_kwargs):
        super().__init__(**multimodal_dataset_kwargs)

    def __getitem__(self, idx):
        face_index = torch.from_numpy(self.file['face_index'][idx][:]).long()
        return (super().__getitem__(idx), super().__getitem__(idx), face_index)

    def __len__(self):
        return super().__len__()

In [None]:
from workspace.datasets.transforms import *

train_data = DoubleDataset(
    data_path='modelnet/modelnet_train_1024.h5',
    transform=Compose(
        RandomRotation(low=-45, high=45, axis='xyz'),
        RandomJitter(std=0.01, clip_bound=0.05)
    )
)
test_data = DoubleDataset(
    data_path='modelnet/modelnet_test_1024.h5',
    transform=Compose(
        RandomRotation(low=-45, high=45, axis='xyz'),
        RandomJitter(std=0.01, clip_bound=0.05)
    )
)


train_loader = DataLoader(
    train_data,
    batch_size=params['batch_size'],
    num_workers=5,
    shuffle=True
)
test_loader = DataLoader(
    test_data,
    batch_size=params['batch_size'],
    shuffle=False,
    num_workers=5
)

In [None]:
class Transpose(torch.nn.Module):
    def __init__(self, *dims):
        super().__init__()
        self.dims = dims

    def forward(self, data):
        return data.transpose(*self.dims)
    

class Model(torch.nn.Module):
    def __init__(self, model, model_output_dim, result_dim, hidden_dim):
        super().__init__()
        self.model = model
        self.head = torch.nn.Sequential(
            Transpose(1, 2),
            torch.nn.Linear(model_output_dim, hidden_dim),
            Transpose(1, 2),
            torch.nn.BatchNorm1d(hidden_dim),
            torch.nn.ReLU(),
            Transpose(1, 2),
            torch.nn.Linear(hidden_dim, result_dim),
            Transpose(1, 2),
        )
        
        
    def forward(self, data):
        return self.head(self.model.forward_features(data))

In [None]:
from workspace.models.dgcnn import DGCNN

dgcnn = DGCNN(n_patches=5)
model = Model(
    dgcnn,
    model_output_dim=params['n_output'],
    hidden_dim=params['hidden_dim'],
    result_dim=params['result_dim']
).to(device)

In [None]:
from workspace.crossmodal.utils.losses import *

def move_to_device(data, device='cpu'):
    if isinstance(data, list):
        return [item.to(device) for item in data]
    else:
        return data.to(device)


def forward( 
    model,
    batch, # raw data from dataloader
    logger, # neptune run
    mode # 'train'/'val'
): # -> loss

    batch = move_to_device(batch, device)
    
    data1 = batch[0]
    data2 = batch[1]
    face_indexes = batch[2]
    max_faces = face_indexes.max() + 1
    
    out1 = model(data1)
    out2 = model(data2)
    
    pooled1, counts1 = get_patch_embeddings(out1, face_indexes, max_faces)
    pooled2, counts2 = get_patch_embeddings(out2, face_indexes, max_faces)
    
    local_loss = patch_contrastive_loss(
        (pooled1, counts1),
        (pooled2, counts2),
        params
    ) * 0.1
    
    gout1 = out1.mean(-1)
    gout2 = out2.mean(-1)
    
    
    global_loss = contrastive_loss(gout1, gout2, params)
    
    
    return {
        'loss': local_loss + global_loss,
        'local_loss': local_loss,
        'global_loss': global_loss
    }

In [None]:
def get_warmup_schedule(
        optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1
):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) /\
                   float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + np.cos(np.pi * float(num_cycles) * 2.0 * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)

optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=params['lr'],
    weight_decay=params['weight_decay']
)


scheduler = get_warmup_schedule(optimizer, 4 * len(train_loader), params['total_epochs'] * len(train_loader))

In [None]:
train_model(model, params, logger, train_loader, test_loader, optimizer, scheduler, forward)