In [None]:
from torch_scatter import scatter_mean
import torch.nn.functional as F
from torch_geometric.nn import global_max_pool


import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from glob import glob
import random

import os.path as osp
import pandas as pd
from plyfile import PlyData, PlyElement
import numpy as np
from numpy.typing import NDArray

from tqdm import tqdm
import pickle

data_dir = f"./data/FWF_flat/10.0X10.0"
osp.exists(data_dir)

In [11]:
import yaml
with open(osp.join(data_dir,'../schema.yaml'), "r") as f:
    schema = yaml.safe_load(f)

material_names = np.unique([v.split("-")[-1] for v in schema.values()]).tolist()
old2new = []
for k, v in schema.items():
    old2new.append([
        k, material_names.index(v.split('-')[-1])
    ])
old2new = np.array(old2new)
old2new = torch.tensor(old2new,device='cuda')
old2new.amax(dim=0)

tensor([88, 17], device='cuda:0')

In [None]:
# generate train_test split
all_names = list(glob(osp.join(data_dir,'ply','*.ply')))

# only train
train_names = [n for n in all_names if 'Area_06' not in osp.basename(n)]
train_names

## split train/test
# train_ratio = 0.95
# N = len(all_names)
# k = int(train_ratio * N)
# train_i = random.sample(range(N), k) 

# train_mask = np.zeros(len(all_names)).astype(bool)
# train_mask[train_i] = True


# all_names = np.array(all_names)
# train_names = all_names[train_mask].tolist()
# test_names = all_names[~train_mask].tolist()


In [None]:

def forwardstar_to_idx(first_edg:NDArray, adj_verts:NDArray)->NDArray:
    sources = []
    targets = []

    V = len(first_edg) - 1
    for u in range(V):
        start, end = first_edg[u], first_edg[u+1]
        src = np.full(end - start, u)
        tgt = adj_verts[start:end]
        sources.append(src)
        targets.append(tgt)
    sources = np.concat(sources)
    targets = np.concat(targets)

    return np.concat([sources[None,:], targets[None,:]],axis=0)



In [5]:
from typing import List, Tuple
from torch_scatter import scatter_add

def make_directed(edge_index: torch.Tensor) -> torch.Tensor:
    rev = edge_index[[1, 0], :]
    return torch.cat([edge_index, rev], dim=1)

def labels_pt2reg(y:torch.Tensor, assignment:torch.Tensor, num_classes:int,dim_size:int|None=None)->torch.Tensor:
    assert torch.all(y>=0)
    y_oh = torch.eye(num_classes)
    y_sum = scatter_add(y_oh[y], assignment, dim=0,dim_size=dim_size)
    y_reg = torch.argmax(y_sum,dim=1)
    return y_reg

class ChunkedDataset(Dataset):
    def __init__(self,
                 data_dir:str,
                 names:List[str],
                 num_classes:int,
                 ):
        
        self.data_dir = data_dir
        self.data = list()
        self.feat_names = list()
        self.names = names
        self.num_classes = num_classes

        for pcd_fp in tqdm(glob(osp.join(data_dir,'ply','*.ply'))):
            fn = osp.basename(pcd_fp).replace('.ply','')
            if fn not in self.names:
                continue
            graph_fp = osp.join(osp.split(pcd_fp)[0],'..','graph',f"{fn}.pkl")

            # load data
            pcd = pd.DataFrame(PlyData.read(pcd_fp).elements[0].data)
            pos = pcd[['x','y','z']].to_numpy(dtype=np.float32)
            feat_names = [n for n in pcd.columns if n not in ['x','y','z','labels']]
            if not len(self.feat_names): self.feat_names = feat_names
            feats = pcd[feat_names].to_numpy(dtype=np.float32)
            labels = pcd['labels'].to_numpy(dtype=np.int32)

            
            with open(graph_fp,'rb') as f:
                graph = pickle.load(f)

            
            point_feats = graph['point_feats']

            # edges
            efs = graph['edges_forwardstar']                       
            edge_index = forwardstar_to_idx(efs[0],efs[1])
            edge_index = torch.as_tensor(edge_index,dtype=torch.int64)
            edge_index = make_directed(edge_index)

            # majority vote label for each superpoint
            labels = torch.as_tensor(labels,dtype=torch.int64)
            superpoint_idx = torch.as_tensor(graph['superpoint_idx'],dtype=torch.int64)
            labels = labels_pt2reg(labels,assignment=superpoint_idx, num_classes=self.num_classes)

            self.data.append(dict(
                fn = fn,
                pos = torch.as_tensor(pos),
                feats = torch.as_tensor(feats),
                labels = labels,
                superpoint_idx = superpoint_idx,
                edge_index = edge_index,
                point_feats = torch.as_tensor(point_feats),
            ))

    def __getitem__(self,idx):
        return self.data[idx]
    
    def __len__(self):
        return len(self.data)

        

            


train_ds = ChunkedDataset(data_dir,[osp.basename(n).replace('.ply','') for n in train_names],num_classes=89) # type: ignore
# test_ds = ChunkedDataset(data_dir,[osp.basename(n).replace('.ply','') for n in test_names],num_classes=89) # type: ignore


100%|██████████| 288/288 [00:28<00:00, 10.29it/s]


In [6]:
class PointNetEncoder(nn.Module):
    def __init__(self,
                 in_dim:int,
                 latent_dim:int,
                 hidden: int = 64,
                 dropout: float=0.1
                 ):
        super().__init__()
        
        self.lin1   = nn.Linear(in_dim, hidden)
        self.bn1    = nn.LayerNorm(hidden)
        

        self.lin2   = nn.Linear(hidden, latent_dim)
        self.bn2    = nn.LayerNorm(latent_dim)

        self.drop   = nn.Dropout(dropout)

    def forward(self, 
                pos: torch.Tensor,
                point_feats: torch.Tensor,
                super_idx: torch.Tensor                
                ):

        # prepare

        x = torch.cat([pos, point_feats], dim=-1)
        
        # layer 1
        x = self.drop(F.relu(self.bn1(self.lin1(x))))
        # layer 2
        x = self.drop(F.relu(self.bn2(self.lin2(x))))

        # pool
        super_embedding = global_max_pool(x, super_idx)

        return super_embedding
    
class SPConv(nn.Module):
    def __init__(self,
                 d:int,
                 d_e:int):
        super().__init__()
        
        # edge encoder
        self.phi_e = nn.Sequential(
            nn.Linear(3, d_e), nn.ReLU(),
            nn.Linear(d_e, d_e), nn.ReLU()
        )

        # message MLP
        self.phi_m = nn.Sequential(
            nn.Linear(2*d + d_e, d), nn.ReLU(),
            nn.Linear(d,d)
        )

        # update MLP
        self.phi_u = nn.Sequential(
            nn.Linear(2*d,d) , nn.ReLU()
        )

    def forward(self, 
                z,          # [R,d]
                centroids,  # [R,3]
                edge_index  # [2,E]
                ):
        src, dst = edge_index 

        # edge feats
        delta = centroids[dst] - centroids[src] # [E,3]
        e = self.phi_e(delta)                   # [E, d_3]

        # messages
        m = self.phi_m(torch.cat([z[src], z[dst], e], dim=-1))  # [E,d]

        # aggregation
        M = scatter_mean(m, dst, dim=0, dim_size=z.size(0))     # [R,d]

        # update
        out = self.phi_u(torch.cat([z,M], dim=-1))  # [R,d]
        return out





class SimpleModel(nn.Module):
    def __init__(self,
                 in_dim:int,
                 num_classes:int,

                 pn_latent_dim:int = 128,
                 pn_hidden: int = 128,
                 pn_dropout: float=0.1,

                 super_hidden:int = 128,
                 super_latent_dim:int = 32,
                 super_dropout: float = 0.1,

                 ):
        
        super().__init__()
        self.encoder = PointNetEncoder(in_dim,pn_latent_dim,pn_hidden,pn_dropout)

        self.lin1   = nn.Linear(pn_latent_dim,super_hidden)
        self.bn1    = nn.LayerNorm(super_hidden)

        self.lin2   = nn.Linear(super_hidden, super_latent_dim)
        self.bn2    = nn.LayerNorm(super_latent_dim)

        self.dropout = nn.Dropout(super_dropout)

        self.classifier = nn.Linear(super_latent_dim, num_classes)

    def forward(self,
                pos: torch.Tensor,
                point_feats: torch.Tensor,
                super_idx: torch.Tensor                
                ):
        # prepare inputs
        super_centroids = scatter_mean(pos, super_idx, dim=0)
        pos_local = pos - super_centroids[super_idx]

        # run model
        x = self.encoder(pos_local, point_feats, super_idx)

        x = self.dropout(F.relu(self.bn1(self.lin1(x))))
        z = self.dropout(F.relu(self.bn2(self.lin2(x))))

        return self.classifier(z), z


class SPConvModel(nn.Module):
    def __init__(
            self,
            in_dim:int,
            num_classes:int,

            pn_hidden: int = 128,
            pn_dropout: float=0.1,

            super_hidden:int=128,
            super_dropout:float=0.2,
                
            edge_dim:int = 64,
            super_num_conv:int=3,
                 ):
        super().__init__()
        self.encoder = PointNetEncoder(in_dim,super_hidden,pn_hidden,pn_dropout)

        self.conv_blocks = nn.ModuleList([
            SPConv(super_hidden, edge_dim) for _ in range(super_num_conv)
        ])
        self.ln_blocks = nn.ModuleList([
            nn.LayerNorm(super_hidden) for _ in range(super_num_conv)
        ])
        
        self.classifier = nn.Linear(super_hidden,num_classes)
        self.dropout = nn.Dropout(super_dropout)

    def forward(self,
                pos: torch.Tensor,
                point_feats: torch.Tensor,
                super_idx: torch.Tensor,
                edge_idx: torch.Tensor,                
                ):
        # prepare inputs
        super_centroids = scatter_mean(pos, super_idx, dim=0)
        pos_local = pos - super_centroids[super_idx]

        # run model
        x = self.encoder(pos_local, point_feats, super_idx)
        z = x
        for conv, ln in zip(self.conv_blocks, self.ln_blocks):
            x = conv(x, super_centroids, edge_idx)
            x = ln(x + z)
            x = F.relu(x)
            x = self.dropout(x)
            z = x

        x = self.classifier(x)

        return x, z







    

In [None]:
# contrastive loss spconv

config_max_epochs = 360
config_device = 'cuda'
config_margin = 1.0
config_lambda_c = 0.1

train_dl = DataLoader(train_ds,batch_size=1,shuffle=True, collate_fn=lambda x: x[0])
# test_dl = DataLoader(test_ds,batch_size=1, collate_fn=lambda x: x[0])

model = SimpleModel(
    in_dim = 3+12,
    num_classes=89,
    pn_hidden=64,
    super_hidden=64,
    pn_dropout=0.1,
    super_dropout=0.2,

).to(device=config_device)

optim = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

for epoch in range(1,config_max_epochs+1):
    model.train()
    total_loss = 0.
    for d in tqdm(train_dl,total=len(train_ds), desc=f"Epoch {epoch:03d}"):
        pos         = d['pos'].to(device=config_device)
        point_feats = d['point_feats'].to(device=config_device)
        labels      = d['labels'].to(device=config_device)
        super_idx   = d['superpoint_idx'].to(device=config_device)


        S = super_idx.amax() + 1 
        logits, z = model(pos, point_feats, super_idx)
        L_class = F.cross_entropy(logits, labels)


        # total loss
        loss = L_class 
        optim.zero_grad()
        loss.backward()
        optim.step()
        total_loss += loss.item()

    print(f"Epoch {epoch:02d}  average loss: {total_loss/len(train_dl):.4f}")

torch.save(model, './experiments/simple_model.pth')

In [None]:
# contrastive loss spconv

config_max_epochs = 300
config_device = 'cuda'
config_margin = 1.0
config_lambda_c = 0.1

train_dl = DataLoader(train_ds,batch_size=1,shuffle=True, collate_fn=lambda x: x[0])
# test_dl = DataLoader(test_ds,batch_size=1, collate_fn=lambda x: x[0])

model = SPConvModel(
    in_dim = 3+12,
    # num_classes=89,
    num_classes=18,
    pn_hidden=64,
    super_hidden=64,
    pn_dropout=0.1,
    super_dropout=0.2,
    edge_dim=32,
    super_num_conv=3
).to(device=config_device)

optim = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

for epoch in range(1,config_max_epochs+1):
    model.train()
    total_loss = 0.
    for d in tqdm(train_dl,total=len(train_ds), desc=f"Epoch {epoch:03d}"):
        pos         = d['pos'].to(device=config_device)
        point_feats = d['point_feats'].to(device=config_device)
        labels      = d['labels'].to(device=config_device)
        super_idx   = d['superpoint_idx'].to(device=config_device)
        edge_idx    = d['edge_index'].to(device=config_device)

        labels = old2new[labels][:,1]
        
        S = super_idx.amax() + 1 
        logits, z = model(pos, point_feats, super_idx, edge_idx)
        L_class = F.cross_entropy(logits, labels)

        # contrastive loss
        src, dst = edge_idx
        zi, zj = z[src], z[dst]
        dist2 = (zi - zj).pow(2).sum(dim=1)
        same = (labels[src] == labels[dst]).float() # [E]
        # pull
        loss_pos = same * dist2
        # push
        loss_neg = (1-same) * F.relu(config_margin - dist2)

        L_contrast = (loss_pos + loss_neg).mean()

        # total loss
        loss = L_class + config_lambda_c * L_contrast
        optim.zero_grad()
        loss.backward()
        optim.step()
        total_loss += loss.item()

    print(f"Epoch {epoch:02d}  average loss: {total_loss/len(train_dl):.4f}")

torch.save(model, './experiments/spconv_contrastive_L0-1_matOnly.pth')

Epoch 001: 100%|██████████| 283/283 [00:12<00:00, 22.82it/s]


Epoch 01  average loss: 3.4632


Epoch 002: 100%|██████████| 283/283 [00:11<00:00, 24.08it/s]


Epoch 02  average loss: 2.3886


Epoch 003: 100%|██████████| 283/283 [00:12<00:00, 23.05it/s]


Epoch 03  average loss: 2.1930


Epoch 004: 100%|██████████| 283/283 [00:13<00:00, 21.76it/s]


Epoch 04  average loss: 2.0543


Epoch 005: 100%|██████████| 283/283 [00:12<00:00, 22.17it/s]


Epoch 05  average loss: 1.8255


Epoch 006: 100%|██████████| 283/283 [00:11<00:00, 23.63it/s]


Epoch 06  average loss: 1.6814


Epoch 007: 100%|██████████| 283/283 [00:12<00:00, 22.67it/s]


Epoch 07  average loss: 1.6005


Epoch 008: 100%|██████████| 283/283 [00:12<00:00, 22.77it/s]


Epoch 08  average loss: 1.5249


Epoch 009: 100%|██████████| 283/283 [00:10<00:00, 25.82it/s]


Epoch 09  average loss: 1.4691


Epoch 010: 100%|██████████| 283/283 [00:12<00:00, 22.41it/s]


Epoch 10  average loss: 1.4221


Epoch 011: 100%|██████████| 283/283 [00:12<00:00, 22.55it/s]


Epoch 11  average loss: 1.3921


Epoch 012: 100%|██████████| 283/283 [00:12<00:00, 22.61it/s]


Epoch 12  average loss: 1.3516


Epoch 013: 100%|██████████| 283/283 [00:12<00:00, 22.41it/s]


Epoch 13  average loss: 1.3265


Epoch 014: 100%|██████████| 283/283 [00:12<00:00, 23.01it/s]


Epoch 14  average loss: 1.3071


Epoch 015: 100%|██████████| 283/283 [00:11<00:00, 23.84it/s]


Epoch 15  average loss: 1.2927


Epoch 016: 100%|██████████| 283/283 [00:12<00:00, 23.21it/s]


Epoch 16  average loss: 1.2627


Epoch 017: 100%|██████████| 283/283 [00:12<00:00, 22.46it/s]


Epoch 17  average loss: 1.2467


Epoch 018: 100%|██████████| 283/283 [00:12<00:00, 22.21it/s]


Epoch 18  average loss: 1.2166


Epoch 019: 100%|██████████| 283/283 [00:12<00:00, 22.69it/s]


Epoch 19  average loss: 1.2057


Epoch 020: 100%|██████████| 283/283 [00:12<00:00, 22.48it/s]


Epoch 20  average loss: 1.1888


Epoch 021: 100%|██████████| 283/283 [00:12<00:00, 22.55it/s]


Epoch 21  average loss: 1.1855


Epoch 022: 100%|██████████| 283/283 [00:11<00:00, 24.18it/s]


Epoch 22  average loss: 1.1632


Epoch 023: 100%|██████████| 283/283 [00:12<00:00, 23.08it/s]


Epoch 23  average loss: 1.1640


Epoch 024: 100%|██████████| 283/283 [00:12<00:00, 22.00it/s]


Epoch 24  average loss: 1.1488


Epoch 025: 100%|██████████| 283/283 [00:11<00:00, 23.76it/s]


Epoch 25  average loss: 1.1181


Epoch 026: 100%|██████████| 283/283 [00:11<00:00, 25.65it/s]


Epoch 26  average loss: 1.1106


Epoch 027: 100%|██████████| 283/283 [00:12<00:00, 22.90it/s]


Epoch 27  average loss: 1.1056


Epoch 028: 100%|██████████| 283/283 [00:11<00:00, 24.50it/s]


Epoch 28  average loss: 1.0977


Epoch 029: 100%|██████████| 283/283 [00:12<00:00, 23.38it/s]


Epoch 29  average loss: 1.0871


Epoch 030: 100%|██████████| 283/283 [00:12<00:00, 22.48it/s]


Epoch 30  average loss: 1.0791


Epoch 031: 100%|██████████| 283/283 [00:12<00:00, 23.56it/s]


Epoch 31  average loss: 1.0679


Epoch 032: 100%|██████████| 283/283 [00:11<00:00, 25.19it/s]


Epoch 32  average loss: 1.0545


Epoch 033: 100%|██████████| 283/283 [00:11<00:00, 24.32it/s]


Epoch 33  average loss: 1.0530


Epoch 034: 100%|██████████| 283/283 [00:12<00:00, 23.55it/s]


Epoch 34  average loss: 1.0466


Epoch 035: 100%|██████████| 283/283 [00:11<00:00, 23.66it/s]


Epoch 35  average loss: 1.0434


Epoch 036: 100%|██████████| 283/283 [00:11<00:00, 25.20it/s]


Epoch 36  average loss: 1.0291


Epoch 037: 100%|██████████| 283/283 [00:12<00:00, 23.05it/s]


Epoch 37  average loss: 1.0230


Epoch 038: 100%|██████████| 283/283 [00:12<00:00, 23.37it/s]


Epoch 38  average loss: 1.0034


Epoch 039: 100%|██████████| 283/283 [00:11<00:00, 23.69it/s]


Epoch 39  average loss: 0.9948


Epoch 040: 100%|██████████| 283/283 [00:11<00:00, 25.39it/s]


Epoch 40  average loss: 0.9948


Epoch 041: 100%|██████████| 283/283 [00:12<00:00, 22.97it/s]


Epoch 41  average loss: 0.9937


Epoch 042: 100%|██████████| 283/283 [00:11<00:00, 25.38it/s]


Epoch 42  average loss: 0.9753


Epoch 043: 100%|██████████| 283/283 [00:11<00:00, 25.53it/s]


Epoch 43  average loss: 0.9703


Epoch 044: 100%|██████████| 283/283 [00:12<00:00, 22.18it/s]


Epoch 44  average loss: 0.9580


Epoch 045: 100%|██████████| 283/283 [00:12<00:00, 22.38it/s]


Epoch 45  average loss: 0.9633


Epoch 046: 100%|██████████| 283/283 [00:12<00:00, 22.42it/s]


Epoch 46  average loss: 0.9534


Epoch 047: 100%|██████████| 283/283 [00:12<00:00, 22.41it/s]


Epoch 47  average loss: 0.9488


Epoch 048: 100%|██████████| 283/283 [00:12<00:00, 22.76it/s]


Epoch 48  average loss: 0.9296


Epoch 049: 100%|██████████| 283/283 [00:12<00:00, 22.26it/s]


Epoch 49  average loss: 0.9217


Epoch 050: 100%|██████████| 283/283 [00:12<00:00, 22.36it/s]


Epoch 50  average loss: 0.9195


Epoch 051: 100%|██████████| 283/283 [00:12<00:00, 22.96it/s]


Epoch 51  average loss: 0.9109


Epoch 052: 100%|██████████| 283/283 [00:11<00:00, 23.70it/s]


Epoch 52  average loss: 0.9094


Epoch 053: 100%|██████████| 283/283 [00:12<00:00, 23.28it/s]


Epoch 53  average loss: 0.9131


Epoch 054: 100%|██████████| 283/283 [00:11<00:00, 24.31it/s]


Epoch 54  average loss: 0.9006


Epoch 055: 100%|██████████| 283/283 [00:11<00:00, 23.97it/s]


Epoch 55  average loss: 0.8773


Epoch 056: 100%|██████████| 283/283 [00:12<00:00, 21.98it/s]


Epoch 56  average loss: 0.8752


Epoch 057: 100%|██████████| 283/283 [00:12<00:00, 22.27it/s]


Epoch 57  average loss: 0.8763


Epoch 058: 100%|██████████| 283/283 [00:12<00:00, 22.19it/s]


Epoch 58  average loss: 0.8710


Epoch 059: 100%|██████████| 283/283 [00:11<00:00, 25.12it/s]


Epoch 59  average loss: 0.8671


Epoch 060: 100%|██████████| 283/283 [00:12<00:00, 22.87it/s]


Epoch 60  average loss: 0.8573


Epoch 061: 100%|██████████| 283/283 [00:11<00:00, 25.13it/s]


Epoch 61  average loss: 0.8476


Epoch 062: 100%|██████████| 283/283 [00:12<00:00, 22.92it/s]


Epoch 62  average loss: 0.8531


Epoch 063: 100%|██████████| 283/283 [00:11<00:00, 23.87it/s]


Epoch 63  average loss: 0.8464


Epoch 064: 100%|██████████| 283/283 [00:12<00:00, 21.84it/s]


Epoch 64  average loss: 0.8272


Epoch 065: 100%|██████████| 283/283 [00:11<00:00, 24.41it/s]


Epoch 65  average loss: 0.8439


Epoch 066: 100%|██████████| 283/283 [00:11<00:00, 25.62it/s]


Epoch 66  average loss: 0.8383


Epoch 067: 100%|██████████| 283/283 [00:11<00:00, 23.85it/s]


Epoch 67  average loss: 0.8239


Epoch 068: 100%|██████████| 283/283 [00:11<00:00, 23.74it/s]


Epoch 68  average loss: 0.8175


Epoch 069: 100%|██████████| 283/283 [00:12<00:00, 22.47it/s]


Epoch 69  average loss: 0.8025


Epoch 070: 100%|██████████| 283/283 [00:12<00:00, 23.30it/s]


Epoch 70  average loss: 0.8073


Epoch 071: 100%|██████████| 283/283 [00:11<00:00, 25.15it/s]


Epoch 71  average loss: 0.7987


Epoch 072: 100%|██████████| 283/283 [00:12<00:00, 22.84it/s]


Epoch 72  average loss: 0.7921


Epoch 073: 100%|██████████| 283/283 [00:12<00:00, 23.18it/s]


Epoch 73  average loss: 0.7796


Epoch 074: 100%|██████████| 283/283 [00:12<00:00, 22.89it/s]


Epoch 74  average loss: 0.8019


Epoch 075: 100%|██████████| 283/283 [00:12<00:00, 22.70it/s]


Epoch 75  average loss: 0.7837


Epoch 076: 100%|██████████| 283/283 [00:11<00:00, 25.37it/s]


Epoch 76  average loss: 0.7803


Epoch 077: 100%|██████████| 283/283 [00:10<00:00, 25.78it/s]


Epoch 77  average loss: 0.7678


Epoch 078: 100%|██████████| 283/283 [00:11<00:00, 25.06it/s]


Epoch 78  average loss: 0.7650


Epoch 079: 100%|██████████| 283/283 [00:12<00:00, 22.44it/s]


Epoch 79  average loss: 0.7592


Epoch 080: 100%|██████████| 283/283 [00:11<00:00, 25.61it/s]


Epoch 80  average loss: 0.7775


Epoch 081: 100%|██████████| 283/283 [00:12<00:00, 22.65it/s]


Epoch 81  average loss: 0.7593


Epoch 082: 100%|██████████| 283/283 [00:12<00:00, 22.63it/s]


Epoch 82  average loss: 0.7619


Epoch 083: 100%|██████████| 283/283 [00:12<00:00, 22.69it/s]


Epoch 83  average loss: 0.7521


Epoch 084: 100%|██████████| 283/283 [00:11<00:00, 24.36it/s]


Epoch 84  average loss: 0.7434


Epoch 085: 100%|██████████| 283/283 [00:11<00:00, 23.88it/s]


Epoch 85  average loss: 0.7431


Epoch 086: 100%|██████████| 283/283 [00:13<00:00, 21.73it/s]


Epoch 86  average loss: 0.7268


Epoch 087: 100%|██████████| 283/283 [00:13<00:00, 21.50it/s]


Epoch 87  average loss: 0.7602


Epoch 088: 100%|██████████| 283/283 [00:11<00:00, 23.90it/s]


Epoch 88  average loss: 0.7243


Epoch 089: 100%|██████████| 283/283 [00:12<00:00, 21.83it/s]


Epoch 89  average loss: 0.7245


Epoch 090: 100%|██████████| 283/283 [00:11<00:00, 23.81it/s]


Epoch 90  average loss: 0.7241


Epoch 091: 100%|██████████| 283/283 [00:12<00:00, 22.95it/s]


Epoch 91  average loss: 0.7176


Epoch 092: 100%|██████████| 283/283 [00:12<00:00, 23.31it/s]


Epoch 92  average loss: 0.7316


Epoch 093: 100%|██████████| 283/283 [00:11<00:00, 24.54it/s]


Epoch 93  average loss: 0.7010


Epoch 094: 100%|██████████| 283/283 [00:12<00:00, 22.24it/s]


Epoch 94  average loss: 0.6920


Epoch 095: 100%|██████████| 283/283 [00:12<00:00, 22.57it/s]


Epoch 95  average loss: 0.6999


Epoch 096: 100%|██████████| 283/283 [00:12<00:00, 21.77it/s]


Epoch 96  average loss: 0.7062


Epoch 097: 100%|██████████| 283/283 [00:12<00:00, 23.31it/s]


Epoch 97  average loss: 0.6936


Epoch 098:  64%|██████▍   | 181/283 [00:08<00:04, 25.14it/s]

In [None]:
# contrastive loss spconv

config_max_epochs = 200
config_device = 'cuda'
config_margin = 1.0
config_lambda_c = 0.01

train_dl = DataLoader(train_ds,batch_size=1,shuffle=True, collate_fn=lambda x: x[0])
# test_dl = DataLoader(test_ds,batch_size=1, collate_fn=lambda x: x[0])

model = SPConvModel(
    in_dim = 3+12,
    # num_classes=89,
    num_classes = 18,
    pn_hidden=64,
    super_hidden=64,
    pn_dropout=0.1,
    super_dropout=0.2,
    edge_dim=32,
    super_num_conv=3
).to(device=config_device)

optim = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

for epoch in range(1,config_max_epochs+1):
    model.train()
    total_loss = 0.
    for d in tqdm(train_dl,total=len(train_ds), desc=f"Epoch {epoch:03d}"):
        pos         = d['pos'].to(device=config_device)
        point_feats = d['point_feats'].to(device=config_device)
        labels      = d['labels'].to(device=config_device)
        super_idx   = d['superpoint_idx'].to(device=config_device)
        edge_idx    = d['edge_index'].to(device=config_device)

        labels = old2new[labels][:,1]

        S = super_idx.amax() + 1 
        logits, z = model(pos, point_feats, super_idx, edge_idx)
        L_class = F.cross_entropy(logits, labels)

        # contrastive loss
        src, dst = edge_idx
        zi, zj = z[src], z[dst]
        dist2 = (zi - zj).pow(2).sum(dim=1)
        same = (labels[src] == labels[dst]).float() # [E]
        # pull
        loss_pos = same * dist2
        # push
        loss_neg = (1-same) * F.relu(config_margin - dist2)

        L_contrast = (loss_pos + loss_neg).mean()

        # total loss
        loss = L_class + config_lambda_c * L_contrast
        optim.zero_grad()
        loss.backward()
        optim.step()
        total_loss += loss.item()

    print(f"Epoch {epoch:02d}  average loss: {total_loss/len(train_dl):.4f}")

torch.save(model, './experiments/spconv_contrastive_L0-01_matOnly.pth')

In [None]:
# contrastive loss spconv

config_max_epochs = 200
config_device = 'cuda'
config_margin = 1.0
config_lambda_c = 1.0

train_dl = DataLoader(train_ds,batch_size=1,shuffle=True, collate_fn=lambda x: x[0])
# test_dl = DataLoader(test_ds,batch_size=1, collate_fn=lambda x: x[0])

model = SPConvModel(
    in_dim = 3+12,
    # num_classes=89,
    num_classes = 18,
    pn_hidden=64,
    super_hidden=64,
    pn_dropout=0.1,
    super_dropout=0.2,
    edge_dim=32,
    super_num_conv=3
).to(device=config_device)

optim = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

for epoch in range(1,config_max_epochs+1):
    model.train()
    total_loss = 0.
    for d in tqdm(train_dl,total=len(train_ds), desc=f"Epoch {epoch:03d}"):
        pos         = d['pos'].to(device=config_device)
        point_feats = d['point_feats'].to(device=config_device)
        labels      = d['labels'].to(device=config_device)
        super_idx   = d['superpoint_idx'].to(device=config_device)
        edge_idx    = d['edge_index'].to(device=config_device)

        labels = old2new[labels][:,1]

        S = super_idx.amax() + 1 
        logits, z = model(pos, point_feats, super_idx, edge_idx)
        L_class = F.cross_entropy(logits, labels)

        # contrastive loss
        src, dst = edge_idx
        zi, zj = z[src], z[dst]
        dist2 = (zi - zj).pow(2).sum(dim=1)
        same = (labels[src] == labels[dst]).float() # [E]
        # pull
        loss_pos = same * dist2
        # push
        loss_neg = (1-same) * F.relu(config_margin - dist2)

        L_contrast = (loss_pos + loss_neg).mean()

        # total loss
        loss = L_class + config_lambda_c * L_contrast
        optim.zero_grad()
        loss.backward()
        optim.step()
        total_loss += loss.item()

    print(f"Epoch {epoch:02d}  average loss: {total_loss/len(train_dl):.4f}")

torch.save(model, './experiments/spconv_contrastive_L1-00_matOnly.pth')