In [2]:
import torch
from torch import nn
from torch.nn import functional as F
import pytorch_lightning as pl
import numpy as np
from sklearn.metrics import recall_score, accuracy_score

from DGCNN import DGCNN_reg
from utils.data_depth import DepthDataset
from utils.dataloader import DataLoader
from utils.plot import plot_3d_pointcloud

### Datasets and dataloaders:

In [3]:
train_set    = DepthDataset(partition='train', outputs=['edge_label','edge_index'], normalize=True, 
                            pc_mean=0.5, repeat=1, shuffle_pixels=False, preload=True)
train_loader = DataLoader(train_set, batch_size=None, shuffle=True, drop_last=True, pin_memory=True)
valid_set    = DepthDataset(partition='valid', outputs=['edge_label','edge_index'], normalize=True, 
                            pc_mean=0.5, repeat=1, shuffle_pixels=False)
valid_loader = DataLoader(valid_set, batch_size=None, shuffle=False, drop_last=True, pin_memory=True)

### Models and pytorch lightning system:

In [8]:
from torch_geometric.nn import GAE
from torch_geometric.nn import GCNConv

class GCNEncoder(nn.Module):
    def __init__(self, in_channels=2, hidden_channels=8, out_channels=8):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.conv4 = GCNConv(hidden_channels, out_channels)
        self.act = nn.LeakyReLU(negative_slope=0.1)
        
    def forward(self, x, edge_index):
        x = self.act(self.conv1(x, edge_index))
        x = self.act(self.conv2(x, edge_index))
        x = self.act(self.conv3(x, edge_index))
        return self.conv4(x, edge_index)
    
class MLPDecoder(nn.Module):
    def __init__(self, in_channels=8, out_channels=2):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        
    def forward(self, x):
        x = self.conv1(x)
        return x.squeeze(0)
    
class EdgeClassifier(nn.Module):
    def __init__(self, in_channels=2, hidden_gcn=8, emb_dims=8, out_channels=2):
        super().__init__()
        self.encoder = GCNEncoder(in_channels, hidden_gcn, emb_dims)
        self.decoder = MLPDecoder(2*emb_dims, out_channels)

    def forward(self, x, edge_index):
        xp = x.permute(1, 0) # GCNConv needs node features to be [N,C]
        x = self.encoder(xp, edge_index)
        
        N, C = x.size()
        K = round(edge_index.size(1) / N)
        
        x1 = x.permute(1,0).unsqueeze(2).repeat(repeats=(1, 1, K))
        x2 = x[edge_index[1,:],:].reshape(N, K, C).permute(2, 0, 1)
        
        x = torch.cat((x1, x2), dim=0)
        x = self.decoder(x.unsqueeze(0))
        return x

class TopologyGCN(pl.LightningModule):
    def __init__(self, in_channels=2, hidden_gcn=8, emb_dims=8, out_channels=2, weight=torch.tensor([0.5, 0.5])):
        super().__init__()
        self.model = EdgeClassifier(in_channels, hidden_gcn, emb_dims, out_channels)
        self.criterion = nn.CrossEntropyLoss(weight=weight)
        
    def forward(self, batch):
        pixels, edge_label, edge_index = batch
        edge_label_pred = self.model(pixels, edge_index)
        return edge_label_pred
    
    def training_step(self, batch, batch_idx):
        pixels, edge_label, edge_index = batch
        edge_label_pred = self(batch)
        loss = self.criterion(edge_label_pred.permute(1,0,2), edge_label)
        self.log('loss', {'train': loss}, on_step=False, on_epoch=True)
        return {'loss': loss, 'edge_label': edge_label, 'edge_label_pred': edge_label_pred.argmax(0)}

    def validation_step(self, batch, batch_idx):
        pixels, edge_label, edge_index = batch
        edge_label_pred = self(batch)
        loss = self.criterion(edge_label_pred.permute(1,0,2), edge_label)
        self.log('loss', {'valid': loss}, on_step=False, on_epoch=True)
        return {'edge_label': edge_label, 'edge_label_pred': edge_label_pred.argmax(0)}
    
    def training_epoch_end(self, training_step_outputs):
        labels = torch.stack([x['edge_label'].cpu() for x in training_step_outputs]).flatten()
        labels_pred = torch.stack([x['edge_label_pred'].cpu() for x in training_step_outputs]).flatten()
        self.log('metrics', {'acc_train': accuracy_score(labels, labels_pred),
                             'tpr_train': recall_score(labels, labels_pred, pos_label=1),
                             'tnr_train': recall_score(labels, labels_pred, pos_label=0)}, on_step=False, on_epoch=True)
        
    def validation_epoch_end(self, validation_step_outputs):
        labels = torch.stack([x['edge_label'].cpu().flatten() for x in validation_step_outputs]).flatten()
        labels_pred = torch.stack([x['edge_label_pred'].cpu().flatten() for x in validation_step_outputs]).flatten()
        self.log('metrics', {'acc_val': accuracy_score(labels, labels_pred),
                             'tpr_val': recall_score(labels, labels_pred, pos_label=1),
                             'tnr_val': recall_score(labels, labels_pred, pos_label=0)}, on_step=False, on_epoch=True)
        
    def test_step(self, batch, batch_idx):
        pixels, edge_label, edge_index = batch
        edge_label_pred = self(batch)
        loss = self.criterion(edge_label_pred, edge_label)
        self.log('loss', {'valid': loss}, on_step=False, on_epoch=True)
        return {'edge_label': edge_label, 'edge_label_pred': edge_label_pred.argmax(1)}
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        return optimizer

### Init tensorboard

In [2]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


### Init and train the model

In [13]:
import time

t1 = time.perf_counter()

n,p = (0,0)
for _, edge_label, _ in train_loader:
    p += edge_label.sum()
    n += edge_label.numel()
    
print(time.perf_counter() - t1)

0.4712864000000039


In [None]:
lr = 1e-3
epochs = 4000
pos_weight = 1.0

n,p = (0,0)
for _, edge_label, _ in train_loader:
    p += edge_label.sum()
    n += edge_label.numel()
    
weight = torch.Tensor([1/(1-p/n), pos_weight/(p/n)])
weight /= weight.sum()
print(weight)

tgcn = TopologyGCN(in_channels=2, hidden_gcn=64, emb_dims=64, out_channels=2, weight=weight)

trainer = pl.Trainer(gpus=1, max_epochs=epochs, precision=32)

trainer.fit(tgcn, train_loader)#, valid_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | EdgeClassifier   | 12.9 K
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
12.9 K    Trainable params
0         Non-trainable params
12.9 K    Total params
0.052     Total estimated model params size (MB)


tensor([0.0513, 0.9487])


Training: 0it [00:00, ?it/s]

### Test model

In [10]:
# Create test loader
test_loader = DataLoader(train_set, num_workers=1, batch_size=1, shuffle=False, drop_last=False, pin_memory=True)
test_loader_iter = iter(test_loader)

# Send model to device
device = torch.device("cuda")
dgcnn.to(device);

#trainer.test(test_dataloaders=val)

In [16]:
# Get next pointcloud and send it to device
pixels, depth, fn = next(test_loader_iter)
pixels, fn = pixels.to(device), fn.to(device)

# Predict
depth_pred = dgcnn((pixels, depth, fn))

# Send prediction to cpu and plot it
pixels = pixels.detach().cpu().squeeze().numpy()
depth_pred = depth_pred.detach().cpu().squeeze().numpy()
plot_3d_pointcloud(train_set.denormalize(pixels), depth_pred, train_set.im_shape)

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0079492…

### Load model

In [None]:
checkpoint_dir = 'lightning_logs/version_xx/checkpoints/'

filename = listdir(checkpoint_dir)[0]
model = LitModel.load_from_checkpoint(checkpoint_dir + filename)