In [1]:
from os import sys
# Path to workspace
sys.path.insert(0, '/workspace/dense-self-supervised-representation-learning-for-3D-shapes/')

import h5py
import torch
import numpy as np
from tqdm import tqdm
import k3d

In [2]:
import neptune.new as neptune
from workspace.utils.train_loop import *

params = {
    'name': 'meshnet_dgcnn_modelnet',
    'dataset': 'modelnet',
    'batch_size': 10,
    'tau': 0.07,
    'n_output': 512,
    'result_dim': 128,
    'hidden_dim': 256,
    'total_epochs': 100,
    'lr': 2e-4,
    'betas': (0.5, 0.999),
    'weight_decay': 1e-5,
    'save_every': 20,
    'weights_root': 'weights/'
}

# tags
tags = ['modelnet', 'meshnet', 'dgcnn', 'translation', 'mse']

logger = neptune.init(
    project="igor3661/crossmodal",
    name=params['name'],
    tags=tags,
    api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcG'\
              'lfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJiN'\
              'zcxMGNkOS04ZjU3LTRmNDMtOWFjMS1kNDNkZDZlNDI4YWYifQ==',
)  # your credentials


logger['parameters'] = params

device = 'cuda:2'



https://app.neptune.ai/igor3661/crossmodal/e/CROSS-51
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#.stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


In [3]:
from torch.utils.data import Dataset, DataLoader
from workspace.crossmodal.utils.meshnet_preprop import *

class MeshnetDataset(Dataset):
    def __init__(self, data_path, rotation=None, jitter=None):
        super().__init__()
        self.rotation = rotation
        self.jitter = jitter
        self.file = h5py.File(data_path, 'r')

    def __getitem__(self, index):
        faces = self.file['faces'][index][:].reshape(-1, 3)
        vertices = self.file['vertices'][index][:].reshape(-1, 3)
        
        if self.rotation is not None:
            vertices = self.rotation(vertices)
        
        features, neighbors = process_mesh(faces, vertices)
        

        features = torch.from_numpy(features).float()
        neighbors = torch.from_numpy(neighbors).long()

        features = torch.permute(features, (1, 0))
        centers, corners, normals = features[:3], features[3:12], features[12:]
        
        if self.jitter is not None:
            centers = self.jitter(centers).float()
        
        corners = corners - torch.cat([centers, centers, centers], 0).float()

        return centers, corners, normals, neighbors #, normals
        
    def __len__(self):
        return self.file['points'].shape[0]
    
    
class PointDataset(Dataset):
    def __init__(self, data_path, transform=None):
        super().__init__()
        self.transform = transform
        self.file = h5py.File(data_path, 'r')

    def __getitem__(self, index):
        points = self.file['points'][index][:]

        if self.transform is not None:
            points = self.transform(points)

        points = torch.from_numpy(points).float()
        points = torch.permute(points, (1, 0))
        return points
        
    def __len__(self):
        return self.file['points'].shape[0]
    

class DoubleDataset(Dataset):
    def __init__(self, kwargs):
        super().__init__()
        self.mesh = MeshnetDataset(**kwargs['mesh'])
        self.point = PointDataset(**kwargs['point'])

    def __getitem__(self, idx):
        face_index = torch.from_numpy(self.mesh.file['face_index'][idx][:]).long()
        return (*self.mesh.__getitem__(idx), self.point.__getitem__(idx), face_index)

    def __len__(self):
        return self.mesh.__len__()

In [4]:
from workspace.datasets.transforms import *


train_kwargs = {
    'mesh': {
        'data_path': 'modelnet/modelnet_train_1024.h5',
        'rotation': RandomRotation(low=-45, high=45, axis='xyz'),
        'jitter': RandomJitter(std=0.01, clip_bound=0.05)
    },
    'point': {
        'data_path': 'modelnet/modelnet_train_1024.h5',
        'transform': Compose(
            RandomRotation(low=-45, high=45, axis='xyz'),
            RandomJitter(std=0.01, clip_bound=0.05)
        )
    }
}

test_kwargs = {
    'mesh': {
        'data_path': 'modelnet/modelnet_test_1024.h5',
        'rotation': RandomRotation(low=-45, high=45, axis='xyz'),
        'jitter': RandomJitter(std=0.01, clip_bound=0.05)
    },
    'point': {
        'data_path': 'modelnet/modelnet_test_1024.h5',
        'transform': Compose(
            RandomRotation(low=-45, high=45, axis='xyz'),
            RandomJitter(std=0.01, clip_bound=0.05)
        )
    }
}


train_data = DoubleDataset(train_kwargs)
test_data = DoubleDataset(test_kwargs)

train_loader = DataLoader(
    train_data,
    batch_size=params['batch_size'],
    num_workers=10,
    shuffle=True
)
test_loader = DataLoader(
    test_data,
    batch_size=params['batch_size'],
    shuffle=False,
    num_workers=10
)

In [5]:
class Transpose(torch.nn.Module):
    def __init__(self, *dims):
        super().__init__()
        self.dims = dims

    def forward(self, data):
        return data.transpose(*self.dims)
    

class Model(torch.nn.Module):
    def __init__(self, model, model_output_dim, result_dim, hidden_dim):
        super().__init__()
        self.model = model
        self.head = torch.nn.Sequential(
            Transpose(1, 2),
            torch.nn.Linear(model_output_dim, hidden_dim),
            Transpose(1, 2),
            torch.nn.BatchNorm1d(hidden_dim),
            torch.nn.ReLU(),
            Transpose(1, 2),
            torch.nn.Linear(hidden_dim, result_dim),
            Transpose(1, 2),
        )
        
        
    def forward(self, data):
        return self.head(self.model.forward_features(data))

In [6]:
from workspace.models.meshnet import MeshNet
from workspace.models.dgcnn import DGCNN

meshnet = MeshNet(n_patches=5)
dgcnn = DGCNN(n_patches=5)

mesh_model = Model(
    meshnet,
    model_output_dim=params['n_output'],
    hidden_dim=params['hidden_dim'],
    result_dim=params['result_dim']
).to(device).eval()

point_model = Model(
    dgcnn,
    model_output_dim=params['n_output'],
    hidden_dim=params['hidden_dim'],
    result_dim=params['result_dim']
).to(device).eval()

mesh_model.load_state_dict(torch.load('weights/CROSS-32/100epoch.pt'))
point_model.load_state_dict(torch.load('weights/CROSS-33/100epoch.pt'))

<All keys matched successfully>

In [7]:
class Projector(torch.nn.Module):
    def __init__(self, model_output_dim):
        super().__init__()
        self.head = torch.nn.Sequential(
                Transpose(1, 2),
                torch.nn.Linear(model_output_dim, model_output_dim * 2),
                Transpose(1, 2),
                torch.nn.BatchNorm1d(model_output_dim * 2),
                torch.nn.ReLU(),
                Transpose(1, 2),
                torch.nn.Linear(model_output_dim * 2, model_output_dim),
                Transpose(1, 2),
            )
    
    def forward(self, data):
        return self.head(data)
    
class Discriminator(torch.nn.Module):
    def __init__(self, model_output_dim):
        super().__init__()
        self.head = torch.nn.Sequential(
                Transpose(1, 2),
                torch.nn.Linear(model_output_dim, model_output_dim // 2),
                Transpose(1, 2),
                torch.nn.BatchNorm1d(model_output_dim // 2),
                torch.nn.ReLU(),
                Transpose(1, 2),
                torch.nn.Linear(model_output_dim // 2, 1),
                Transpose(1, 2),
            )
    
    def forward(self, data):
        return self.head(data)

In [8]:
projMP = Projector(params['n_output']).to(device)
projPM = Projector(params['n_output']).to(device)
discM = Discriminator(params['n_output']).to(device)
discP = Discriminator(params['n_output']).to(device)

In [9]:
criterion_L1 = torch.nn.L1Loss()

def LOSS_D(real, fake):
    return (torch.mean((real - 1) ** 2) + torch.mean(fake ** 2))

def LOSS_G(fake):
    return  torch.mean((fake - 1) ** 2)

optimizer_MP = torch.optim.Adam(
    params=projMP.parameters(), lr=params['lr'], betas=params['betas'])
optimizer_PM = torch.optim.Adam(
    params=projPM.parameters(), lr=params['lr'], betas=params['betas'])
optimizer_DM = torch.optim.Adam(
    params=discM.parameters(), lr=params['lr'], betas=params['betas'])
optimizer_DP = torch.optim.Adam(
    params=discP.parameters(), lr=params['lr'], betas=params['betas'])

def linear_lambda_rule(epoch):
    lr_l = 1.0 - max(0, epoch - params['total_epochs']) / float(params['total_epochs'] + 1)#n_epochs_decay + 1)
    return lr_l

scheduler_MP = torch.optim.lr_scheduler.LambdaLR(optimizer_MP, lr_lambda=linear_lambda_rule)
scheduler_PM = torch.optim.lr_scheduler.LambdaLR(optimizer_PM, lr_lambda=linear_lambda_rule)
scheduler_DM = torch.optim.lr_scheduler.LambdaLR(optimizer_DM, lr_lambda=linear_lambda_rule)
scheduler_DP = torch.optim.lr_scheduler.LambdaLR(optimizer_DP, lr_lambda=linear_lambda_rule)

In [13]:
def move_to_device(data, device='cpu'):
    if isinstance(data, list):
        return [item.to(device) for item in data]
    else:
        return data.to(device)
    
lambda_M = 10
lambda_P = 10
lambda_Idt = 0.5


def forward( 
    models,
    batch, # raw data from dataloader
    logger, # neptune run
    mode # 'train'/'val'
): # -> loss
    
    projPM, projMP, discM, discP = models

    batch = move_to_device(batch, device)
    
    meshes = batch[0:4]
    points = batch[4]
    face_index = batch[5]
    
    mout = mesh_model.model.forward_features(meshes).detach()
    pout = point_model.model.forward_features(points).detach()
    
    real_meshes = torch.gather(mout, 2, face_index.unsqueeze(1).expand((-1, 512, -1)))
    real_points = pout
    
    fake_meshes = projPM(real_points)
    fake_points = projMP(real_meshes)
    
    rec_meshes = projPM(fake_points)
    rec_points = projMP(fake_meshes)
    
    disc_loss_M = LOSS_D(discM(real_meshes), discM(fake_meshes)) * 0.5
    disc_loss_P = LOSS_D(discP(real_meshes), discP(fake_meshes)) * 0.5
    
    fool_disc_loss_M2P = LOSS_G(discP(fake_points))
    fool_disc_loss_P2M = LOSS_G(discM(fake_meshes))
    
    cycle_loss_M = criterion_L1(rec_meshes, real_meshes) * lambda_M
    cycle_loss_P = criterion_L1(rec_points, real_points) * lambda_P

    id_loss_M2P = criterion_L1(projPM(real_meshes), real_meshes) * lambda_M * lambda_Idt
    id_loss_P2M = criterion_L1(projMP(real_points), real_points) * lambda_P * lambda_Idt
    
    rec_M = criterion_L1(real_meshes, fake_meshes)
    rec_P = criterion_L1(real_points, fake_points)
    
    gen_loss = fool_disc_loss_M2P + fool_disc_loss_P2M + cycle_loss_M + cycle_loss_P + id_loss_M2P + id_loss_P2M
        
    return {
        'loss': gen_loss + disc_loss_M + disc_loss_P + rec_M + rec_P,
        'rec_loss': rec_M + rec_P,
        'gen_loss': gen_loss,
        'discM': disc_loss_M,
        'discP': disc_loss_P
    }

In [11]:
from typing import *
from tqdm import tqdm
from pathlib import Path
from collections import defaultdict


def log_dict(dict_loss, logger, mode):
    for k, v in dict_loss.items():
        logger[mode + '/' + k].log(v)


@torch.no_grad()
def calc_val_loss(models, loader, logger, forward):
    (model.train() for model in models)
    progress_bar = tqdm(loader, leave=True, position=0, desc='Validation')
    dict_loss = defaultdict(int)
    cur_iter = 0
    
    
    for batch in progress_bar:        
        loss = forward(models, batch, logger, 'val')

        for k, v in loss.items():
            dict_loss[k] += v.item()
        cur_iter += 1

        progress_bar.set_postfix({
            'Loss': dict_loss['loss'] / cur_iter
        })

    return {k: v / cur_iter for k, v in dict_loss.items()}


def train_model(
    models: torch.nn.Module,
    params: Dict[str, Any],
    logger: Any,
    train_loader: torch.utils.data.DataLoader,
    val_loader: torch.utils.data.DataLoader,
    optimizers: torch.optim.Optimizer,
    schedulers: Any,
    forward: Callable[[torch.nn.Module, Any, Any, str], Dict[str, torch.Tensor]],
):
    '''
    :param model: torch.nn.Module model
    :param params: experiment parameters
    :param logger: logger, neptune instance in our case
    :param train_loader: train loader
    :param val_loader: val loader
    :param optimizer: optimizer
    :param scheduler: scheduler
    :param forward: forward(model, batch, logger, 'val'/'train') -> loss
    loss is dict with Tensors, .backward() is called on 'loss' key, other keys are only logged
    '''
    exp_id = logger['sys/id'].fetch()
    save_dir = Path('{}/{}'.format(params['weights_root'], exp_id))
    save_dir.mkdir(parents=True, exist_ok=True)
    
    best_val_loss = float('inf')

    for epoch in range(1, params['total_epochs'] + 1):
        progress_bar = tqdm(train_loader, leave=True, position=0)

        dict_loss = defaultdict(int)
        cur_iter = 0


        (model.train() for model in models)
        for batch in progress_bar:
            (optimizer.zero_grad() for optimizer in optimizers)

            loss = forward(models, batch, logger, 'train')
            
            for k, v in loss.items():
                dict_loss[k] += v.item()
        
            cur_iter += 1

            loss['loss'].backward()
            (optimizer.step() for optimizer in optimizers)
            (scheduler.step() for scheduler in schedulers)
            progress_bar.set_postfix({
                'Epoch': epoch,
                'Loss': dict_loss['loss'] / cur_iter
            }) 

        dict_loss = {k: v / cur_iter for k, v in dict_loss.items()}

        log_dict(dict_loss, logger, 'train')
        
        val_dict_loss = calc_val_loss(models, val_loader, logger, forward)
        log_dict(val_dict_loss, logger, 'val')

        if val_dict_loss['loss'] < best_val_loss:
            best_val_loss = val_dict_loss['loss']
            (torch.save(model.state_dict(), save_dir / f'{i}_{epoch}val_best.pt')
             for i, model in enumerate(models))
            
                    
        if epoch % params['save_every'] == 0:
            (torch.save(model.state_dict(), save_dir / f'{i}_{epoch}epoch.pt') for i, model in enumerate(models))
    
    if params['total_epochs'] % params['save_every'] != 0:
        (torch.save(model.state_dict(), save_dir / f'{i}_{epoch}epoch.pt') for i, model in enumerate(models))

In [14]:
train_model(
    (projPM, projMP, discM, discP),
     params,
     logger,
     train_loader,
     test_loader,
     (optimizer_MP, optimizer_PM, optimizer_DM, optimizer_DP),
     (scheduler_MP, scheduler_PM, scheduler_DM, scheduler_DP),
     forward
)

100%|██████████| 985/985 [02:28<00:00,  6.63it/s, Epoch=1, Loss=76.7]
Validation: 100%|██████████| 247/247 [00:29<00:00,  8.52it/s, Loss=76.3]
100%|██████████| 985/985 [02:29<00:00,  6.60it/s, Epoch=2, Loss=76.7]
Validation: 100%|██████████| 247/247 [00:28<00:00,  8.62it/s, Loss=76.3]
100%|██████████| 985/985 [02:28<00:00,  6.62it/s, Epoch=3, Loss=76.7]
Validation: 100%|██████████| 247/247 [00:28<00:00,  8.59it/s, Loss=76.2]
100%|██████████| 985/985 [02:28<00:00,  6.62it/s, Epoch=4, Loss=76.7]
Validation: 100%|██████████| 247/247 [00:28<00:00,  8.58it/s, Loss=76.2]
100%|██████████| 985/985 [02:29<00:00,  6.61it/s, Epoch=5, Loss=76.7]
Validation: 100%|██████████| 247/247 [00:28<00:00,  8.58it/s, Loss=76.3]
100%|██████████| 985/985 [02:28<00:00,  6.61it/s, Epoch=6, Loss=76.7]
Validation: 100%|██████████| 247/247 [00:28<00:00,  8.53it/s, Loss=76.3]
100%|██████████| 985/985 [02:29<00:00,  6.61it/s, Epoch=7, Loss=76.7]
Validation: 100%|██████████| 247/247 [00:28<00:00,  8.56it/s, Loss=76.2]

KeyboardInterrupt: 