In [9]:
%load_ext autoreload
%autoreload 2

In [1]:
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, WeightedRandomSampler
import numpy as np
import pandas as pd

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '6'

from tqdm import tqdm_notebook as tqdm

import sys 
sys.path.append('..')

from models import AD_SDF
from backend.datasets import SDFItemDataset
from backend.datasets.utils import get_weights

import warnings
warnings.filterwarnings('ignore')

In [2]:
batch_size = 16384
train_steps = 400
val_steps = 200

# random points for validation
val_fraction = 0.35

num_training_shapes = 200

In [3]:
model = AD_SDF(data_shape=num_training_shapes).cuda()

criterion = nn.MSELoss()

lr = 1e-3
optimizer = optim.Adam(model.parameters(), lr=lr)

In [4]:
# to clamp sdf values for both targets and model outputs
def clamp(delta, x):
    return torch.clamp(x, min=-delta, max=delta)

def mse(outputs, targets):
    return ((outputs - targets) ** 2).sum() #taking sum just to track the progress

In [20]:
class SDFTrainer:
    def __init__(self, model, criterion, optimizer, delta=0.1, checkpoints_dir='checkpoints'):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.delta = delta
        os.makedirs(checkpoints_dir, exist_ok=True)
        self.checkpoints_dir=checkpoints_dir
        
    def fit(self, train_loader, val_loader, epochs, vector_id, save=False):
        best_val_mse = 1e6
        for epoch in range(epochs):
            self._train(train_loader, epoch, vector_id)
            val_loss, val_mse = self._validate(val_loader, vector_id)
            if val_mse < best_val_mse:
                best_val_mse = val_mse
                if save:
                    self.save_weights() 
                    
    def save_weights(self, name='model.pth'):
        torch.save(self.model.state_dict(), os.path.join(self.checkpoints_dir, name))

    def load_weights(self, weights_path):
        self.model.load_state_dict(weights_path)
        
    def _validate(self, loader, vector_id):
        self.model.eval()
        running_loss = []
        running_mse = []
        
        for inputs, targets in loader:
            inputs = inputs.float().cuda()
            targets = targets.float().cuda()

            with torch.set_grad_enabled(False):
                outputs = self.model(vector_id, inputs)
                print('targets:', targets)
                print('outputs:', outputs)
                loss = criterion(clamp(self.delta, outputs), clamp(self.delta, targets))
                
                running_loss.append(loss.item())
                running_mse.append(mse(outputs, targets).detach().cpu().numpy())
        
        mean_loss = np.mean(running_loss)
        mean_mse = np.mean(running_mse)
        
        print(f'val loss: {mean_loss:.5f}, val mse: {mean_mse:.5f}')
        return mean_loss, mean_mse
        
    def _train(self, loader, epoch, vector_id):
        self.model.train()
        running_loss = []
        running_mse = []
        tq = tqdm(total=len(loader))
        tq.set_description('Epoch {}'.format(epoch))
        
        for inputs, targets in loader:
            inputs = inputs.float().cuda()
            targets = targets.float().cuda()

            self.optimizer.zero_grad()
            
            with torch.set_grad_enabled(True):
                outputs = self.model(vector_id, inputs)
                loss = self.criterion(clamp(delta, outputs), clamp(delta, targets))

            loss.backward()
            self.optimizer.step()

            running_loss.append(loss.item())
            running_mse.append(mse(outputs, targets).detach().cpu().numpy())

            mean_loss = np.mean(running_loss)
            mean_mse = np.mean(running_mse)
                
            tq.update()
            tq.set_postfix(loss='{:.3f}'.format(mean_loss), mse = '{:.3f}'.format(mean_mse))            

In [21]:
CSV_PATH = '../../data/abc_data_sigma_1.0.csv'
samples = pd.read_csv(CSV_PATH)
training_shapes = np.random.choice(range(len(samples)), num_training_shapes, replace=False)
latent_vectors_map = {i:k for k, i in enumerate(training_shapes)}

In [22]:
epochs = 8
delta = 2.5

sdf_trainer = SDFTrainer(model, criterion, optimizer, delta)

In [23]:
for index_shape in training_shapes:
    mesh = np.load(samples.iloc[index_shape, 1])
    sdf = np.load(samples.iloc[index_shape, 2])
    val_mask = np.zeros((mesh.shape[0]), dtype=np.bool)
    val_ind = np.random.choice(range(mesh.shape[0]), int(val_fraction*mesh.shape[0]))
    val_mask[val_ind] = 1

    train_dataset = SDFItemDataset(mesh[~val_mask], sdf[~val_mask])
    val_dataset = SDFItemDataset(mesh[val_mask], sdf[val_mask])

    # balanced sampling: 1:1 positive:negative 
    weights_train = get_weights(sdf[~val_mask])
    weights_val = get_weights(sdf[val_mask])

    train_sampler = WeightedRandomSampler(weights_train, batch_size*train_steps)
    val_sampler = WeightedRandomSampler(weights_val, batch_size*val_steps)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, sampler=val_sampler, num_workers=4)
    sdf_trainer.fit(train_loader, val_loader, epochs, vector_id=latent_vectors_map[index_shape], save=True)

HBox(children=(IntProgress(value=0, max=400), HTML(value='')))

Exception in thread Thread-4:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/threading.py", line 916, in _bootstrap_inner
    self.run()
  File "/opt/conda/lib/python3.6/site-packages/tqdm/_monitor.py", line 62, in run
    for instance in self.tqdm_cls._instances:
  File "/opt/conda/lib/python3.6/_weakrefset.py", line 60, in __iter__
    for itemref in self.data:
RuntimeError: Set changed size during iteration



targets: tensor([ 0.4245,  0.0647, -0.6281,  ...,  0.2694, -0.8806,  0.5036],
       device='cuda:0')
outputs: tensor([[-0.1553],
        [-0.1552],
        [-0.1554],
        ...,
        [-0.1554],
        [-0.1550],
        [-0.1555]], device='cuda:0')
targets: tensor([ 0.8505,  0.5129,  0.6852,  ...,  0.7760, -0.9040, -0.4758],
       device='cuda:0')
outputs: tensor([[-0.1548],
        [-0.1555],
        [-0.1553],
        ...,
        [-0.1554],
        [-0.1556],
        [-0.1552]], device='cuda:0')
targets: tensor([-1.4850, -0.8603, -0.4823,  ...,  0.0484,  0.5643, -1.4432],
       device='cuda:0')
outputs: tensor([[-0.1558],
        [-0.1560],
        [-0.1553],
        ...,
        [-0.1551],
        [-0.1551],
        [-0.1554]], device='cuda:0')
targets: tensor([-0.6947,  0.4730, -0.4870,  ..., -0.5598, -0.8744,  0.7572],
       device='cuda:0')
outputs: tensor([[-0.1554],
        [-0.1550],
        [-0.1553],
        ...,
        [-0.1551],
        [-0.1552],
        [-0.1

HBox(children=(IntProgress(value=0, max=400), HTML(value='')))

KeyboardInterrupt: 