In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import sys
import h5py
sys.path.insert(0, '..')
from models import PointNet, DGCNNSegBackbone
from datasets import PointCloudNormalize, ABCDataset
from torch.utils.data import Dataset, DataLoader
from utils.training_routines import RunningMetrics
from tqdm import tqdm


## Parameters

In [2]:
device = 'cuda:0'
exp_id = 'z52pya7i'
n_epochs = 50
lr = 5e-4
weight_decay = 1e-5

In [3]:
class Regressor(nn.Module):
    def __init__(self, backbone, finetune_head=False):
        super().__init__()
        self.backbone = backbone
        if finetune_head:
            for p in self.backbone.parameters():
                p.requires_grad = False
                
        self.head = nn.Sequential(
            nn.Conv1d(self.backbone.n_output_point, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(True),
            nn.Conv1d(512, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(True),
            nn.Conv1d(256, 1, 1)
        )
        
    def forward(self, x):
        features = self.backbone.forward_features(x)
        return self.head(features).squeeze(1)
    
def compute_loss(gt, pred):
    log_gt = torch.log1p(gt)
    
    return (pred - log_gt).pow(2).mean()

def compute_mse(gt, pred):
    pred = torch.expm1(pred)
    return (pred - gt).pow(2).mean()

@torch.no_grad()
def validate(model, loader):
    model.eval()
    preds = []
    gts = []
    bar = tqdm(loader, desc='val')
    metrics = RunningMetrics()
    
    for x, gt in bar:
        pred = model(x.to(device))
        preds.append(pred.cpu())
        gts.append(gt)
        metrics.step({'loss': compute_loss(gt, pred.cpu())})
        bar.set_postfix(metrics.report())
    
    preds = torch.cat(preds, dim=0)
    gts = torch.cat(gts, dim=0)
    loss = compute_mse(gts, preds).item()
    print(metrics.report())
    
    print('val loss', loss)
    return loss

def train(model, train_loader, test_loader, optimizer, scheduler, n_epochs, val_every=1):
    val_loss_list = []
    for epoch in range(1, n_epochs + 1):
        bar = tqdm(train_loader)
        model.train()
        metrics = RunningMetrics()
        
        for x, gt in bar:
            optimizer.zero_grad()
            pred = model(x.to(device))
            loss = compute_loss(gt.to(device), pred) + 0.001 * model.backbone.reg
            loss.backward()
            optimizer.step()
            metrics.step({'loss': loss})
            report = metrics.report()
            report.update({'epoch': epoch})
            bar.set_postfix(report)
            
            scheduler.step()
        
        if epoch % val_every == 0:
            val_loss = validate(model, test_loader)
            val_loss_list.append(val_loss)
            
    return val_loss_list
            
def get_model(exp_id, n_epochs, finetune_head, lr, weight_decay):
    model = PointNet()
    # model = DGCNNSegBackbone()
    if exp_id is not None:
        checkpoint_path = f'../weights/simclr_run_{exp_id}_ckp_150.pt'
        state = torch.load(checkpoint_path, map_location='cpu')['model']
        model.load_state_dict(state)
        
    model = Regressor(model, finetune_head).to(device)
    optimizer = torch.optim.Adam(filter(lambda x: x.requires_grad,
                                        model.parameters()),
                                 lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs * len(train_loader), eta_min=0)
    return model, optimizer, scheduler

In [4]:
dataset_path = '../../datasets/hdfs/train_0.hdf5'

## Simple training

### Only head

In [5]:
train_ds = ABCDataset(dataset_path,
                      'train',
                      'distances',
                      transform=PointCloudNormalize('box'))
test_ds = ABCDataset(dataset_path,
                     'test',
                     'distances',
                     transform=PointCloudNormalize('box'))

train_loader = DataLoader(train_ds, shuffle=True, batch_size=50)
test_loader = DataLoader(test_ds, shuffle=False, batch_size=32)

In [None]:
run_results = []

for _ in range(3):
    finetune_head = True
    exp_id = None # 'z52pya7i'
    model, optimizer, scheduler = get_model(exp_id, n_epochs, finetune_head, lr, weight_decay)
    val_loss1 = train(model, train_loader, test_loader, optimizer, scheduler, n_epochs)
    run_results.append(val_loss1[-1])

100%|██████████| 230/230 [00:51<00:00,  4.45it/s, loss=0.0602, epoch=1]
val: 100%|██████████| 154/154 [00:15<00:00,  9.74it/s, loss=0.0429]


{'loss': 0.04290205746984874}
val loss 0.1440052084294357


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0378, epoch=2]
val: 100%|██████████| 154/154 [00:15<00:00,  9.89it/s, loss=0.0285]


{'loss': 0.028514715696312476}
val loss 0.07573836657483042


100%|██████████| 230/230 [00:51<00:00,  4.45it/s, loss=0.0361, epoch=3]
val: 100%|██████████| 154/154 [00:15<00:00,  9.69it/s, loss=0.0241]


{'loss': 0.024083847008247953}
val loss 0.05845543656325018


100%|██████████| 230/230 [00:51<00:00,  4.44it/s, loss=0.0348, epoch=4]
val: 100%|██████████| 154/154 [00:15<00:00, 10.05it/s, loss=0.0237]


{'loss': 0.02373604938469253}
val loss 0.056172036891923884


100%|██████████| 230/230 [00:51<00:00,  4.44it/s, loss=0.0342, epoch=5]
val: 100%|██████████| 154/154 [00:15<00:00,  9.90it/s, loss=0.0232]


{'loss': 0.02317105200309075}
val loss 0.05606002627333372


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0327, epoch=6]
val: 100%|██████████| 154/154 [00:15<00:00,  9.72it/s, loss=0.0218]


{'loss': 0.02179211543670263}
val loss 0.05216405704565748


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0327, epoch=7]
val: 100%|██████████| 154/154 [00:15<00:00,  9.68it/s, loss=0.0223]


{'loss': 0.022339994225786704}
val loss 0.05703467986033338


100%|██████████| 230/230 [00:52<00:00,  4.42it/s, loss=0.0322, epoch=8]
val: 100%|██████████| 154/154 [00:15<00:00,  9.94it/s, loss=0.0202]


{'loss': 0.020180170145031883}
val loss 0.04756004284959714


100%|██████████| 230/230 [00:51<00:00,  4.44it/s, loss=0.0317, epoch=9]
val: 100%|██████████| 154/154 [00:15<00:00,  9.84it/s, loss=0.0226]


{'loss': 0.02261083266566788}
val loss 0.05691591855244864


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0312, epoch=10]
val: 100%|██████████| 154/154 [00:15<00:00, 10.00it/s, loss=0.0199]


{'loss': 0.0199176932145804}
val loss 0.047248439634651286


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0313, epoch=11]
val: 100%|██████████| 154/154 [00:15<00:00,  9.72it/s, loss=0.022] 


{'loss': 0.021959709962310312}
val loss 0.051693054179567544


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0305, epoch=12]
val: 100%|██████████| 154/154 [00:15<00:00,  9.88it/s, loss=0.0204]


{'loss': 0.020423583918645294}
val loss 0.048828604536478144


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0305, epoch=13]
val: 100%|██████████| 154/154 [00:15<00:00,  9.67it/s, loss=0.0197]


{'loss': 0.019675404518204636}
val loss 0.04674584382822582


100%|██████████| 230/230 [00:51<00:00,  4.44it/s, loss=0.03, epoch=14]  
val: 100%|██████████| 154/154 [00:15<00:00,  9.79it/s, loss=0.0194]


{'loss': 0.01937881184628861}
val loss 0.045358671836784846


100%|██████████| 230/230 [00:51<00:00,  4.44it/s, loss=0.0298, epoch=15]
val: 100%|██████████| 154/154 [00:15<00:00,  9.87it/s, loss=0.0202]


{'loss': 0.02017727884133224}
val loss 0.04565814859259301


100%|██████████| 230/230 [00:51<00:00,  4.44it/s, loss=0.0299, epoch=16]
val: 100%|██████████| 154/154 [00:15<00:00,  9.85it/s, loss=0.0191]


{'loss': 0.019137559444403318}
val loss 0.04522634007229104


100%|██████████| 230/230 [00:51<00:00,  4.44it/s, loss=0.0295, epoch=17]
val: 100%|██████████| 154/154 [00:15<00:00, 10.16it/s, loss=0.0187]


{'loss': 0.018715980423143555}
val loss 0.043476503851420226


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0293, epoch=18]
val: 100%|██████████| 154/154 [00:15<00:00,  9.93it/s, loss=0.0201]


{'loss': 0.02013794497675849}
val loss 0.049023428268717285


100%|██████████| 230/230 [00:51<00:00,  4.42it/s, loss=0.0294, epoch=19]
val: 100%|██████████| 154/154 [00:15<00:00,  9.76it/s, loss=0.022] 


{'loss': 0.021964814896911874}
val loss 0.05436511780366882


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0291, epoch=20]
val: 100%|██████████| 154/154 [00:15<00:00,  9.72it/s, loss=0.0189]


{'loss': 0.01889996877764257}
val loss 0.04446957244219817


100%|██████████| 230/230 [00:51<00:00,  4.44it/s, loss=0.0284, epoch=21]
val: 100%|██████████| 154/154 [00:15<00:00,  9.69it/s, loss=0.0195]


{'loss': 0.019478324471858655}
val loss 0.04650225096875767


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0284, epoch=22]
val: 100%|██████████| 154/154 [00:15<00:00, 10.05it/s, loss=0.018] 


{'loss': 0.018012935766057142}
val loss 0.042033119246325266


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0285, epoch=23]
val: 100%|██████████| 154/154 [00:15<00:00, 10.17it/s, loss=0.0187]


{'loss': 0.018655908253033284}
val loss 0.042853706335661594


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0284, epoch=24]
val: 100%|██████████| 154/154 [00:15<00:00,  9.69it/s, loss=0.0181]


{'loss': 0.0180571437912031}
val loss 0.040743469012890686


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0279, epoch=25]
val: 100%|██████████| 154/154 [00:15<00:00,  9.69it/s, loss=0.0183]


{'loss': 0.018280951463073777}
val loss 0.04373123856671669


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0277, epoch=26]
val: 100%|██████████| 154/154 [00:15<00:00, 10.04it/s, loss=0.0178]


{'loss': 0.017819293858804132}
val loss 0.04094677555757151


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0275, epoch=27]
val: 100%|██████████| 154/154 [00:15<00:00, 10.08it/s, loss=0.0179]


{'loss': 0.017884624421544502}
val loss 0.04126559456216358


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0274, epoch=28]
val: 100%|██████████| 154/154 [00:15<00:00,  9.81it/s, loss=0.0179]


{'loss': 0.01790069111211753}
val loss 0.04097788594479704


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0269, epoch=29]
val: 100%|██████████| 154/154 [00:15<00:00,  9.67it/s, loss=0.0176]


{'loss': 0.01756811978578404}
val loss 0.04084003082686115


100%|██████████| 230/230 [00:52<00:00,  4.42it/s, loss=0.0268, epoch=30]
val: 100%|██████████| 154/154 [00:15<00:00,  9.80it/s, loss=0.0171]


{'loss': 0.017149314916433338}
val loss 0.03896723027713156


100%|██████████| 230/230 [00:52<00:00,  4.42it/s, loss=0.027, epoch=31] 
val: 100%|██████████| 154/154 [00:15<00:00, 10.06it/s, loss=0.0174]


{'loss': 0.017357869959372024}
val loss 0.03931403212444851


100%|██████████| 230/230 [00:51<00:00,  4.44it/s, loss=0.0269, epoch=32]
val: 100%|██████████| 154/154 [00:15<00:00,  9.88it/s, loss=0.0173]


{'loss': 0.017288265832497627}
val loss 0.039109755769899404


100%|██████████| 230/230 [00:52<00:00,  4.42it/s, loss=0.0264, epoch=33]
val: 100%|██████████| 154/154 [00:15<00:00,  9.71it/s, loss=0.0168]


{'loss': 0.0168370168437379}
val loss 0.03817643947446388


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0263, epoch=34]
val: 100%|██████████| 154/154 [00:15<00:00,  9.94it/s, loss=0.0167]


{'loss': 0.016686760343068218}
val loss 0.038242840168726616


100%|██████████| 230/230 [00:52<00:00,  4.42it/s, loss=0.0261, epoch=35]
val: 100%|██████████| 154/154 [00:15<00:00,  9.95it/s, loss=0.0174]


{'loss': 0.017436476393373312}
val loss 0.03924836792692475


100%|██████████| 230/230 [00:51<00:00,  4.44it/s, loss=0.0259, epoch=36]
val: 100%|██████████| 154/154 [00:15<00:00,  9.96it/s, loss=0.0161]


{'loss': 0.016121393775892015}
val loss 0.03635227680425539


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0257, epoch=37]
val: 100%|██████████| 154/154 [00:15<00:00, 10.09it/s, loss=0.0169]


{'loss': 0.016881654541944176}
val loss 0.038746004031992666


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0258, epoch=38]
val: 100%|██████████| 154/154 [00:15<00:00,  9.92it/s, loss=0.0163]


{'loss': 0.01632622152256567}
val loss 0.03745124870090619


100%|██████████| 230/230 [00:51<00:00,  4.42it/s, loss=0.0255, epoch=39]
val: 100%|██████████| 154/154 [00:15<00:00,  9.80it/s, loss=0.0167]


{'loss': 0.016742605525057082}
val loss 0.03768817023789165


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0255, epoch=40]
val: 100%|██████████| 154/154 [00:15<00:00,  9.67it/s, loss=0.0165]


{'loss': 0.01651209350704785}
val loss 0.037150785989732496


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0253, epoch=41]
val: 100%|██████████| 154/154 [00:15<00:00,  9.77it/s, loss=0.0162]


{'loss': 0.016207443263365318}
val loss 0.0366697920406039


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0253, epoch=42]
val: 100%|██████████| 154/154 [00:15<00:00,  9.75it/s, loss=0.0164]


{'loss': 0.016366029129950964}
val loss 0.037093024205571076


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.025, epoch=43] 
val: 100%|██████████| 154/154 [00:15<00:00,  9.78it/s, loss=0.0161]


{'loss': 0.016136524930000628}
val loss 0.036238428803410556


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0251, epoch=44]
val: 100%|██████████| 154/154 [00:15<00:00,  9.69it/s, loss=0.0159]


{'loss': 0.0159342439556472}
val loss 0.03603702937792538


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.025, epoch=45] 
val: 100%|██████████| 154/154 [00:15<00:00,  9.98it/s, loss=0.016] 


{'loss': 0.01596703342613678}
val loss 0.03597822785476616


100%|██████████| 230/230 [00:51<00:00,  4.43it/s, loss=0.0252, epoch=46]
val: 100%|██████████| 154/154 [00:15<00:00,  9.77it/s, loss=0.0162]


{'loss': 0.0162437491829253}
val loss 0.03660234335686549


 66%|██████▌   | 151/230 [00:34<00:17,  4.42it/s, loss=0.0251, epoch=47]

In [12]:
np.mean(run_results)

0.009279140784743602

In [None]:
val_loss1

In [None]:
finetune_head = True
exp_id = None
model, optimizer, scheduler = get_model(exp_id, n_epochs, finetune_head, lr, weight_decay)
val_loss2 = train(model, train_loader, test_loader, optimizer, scheduler, n_epochs)

In [None]:
val_loss2

### Finetuning

In [None]:
finetune_head = False
exp_id = 'z52pya7i'
model, optimizer, scheduler = get_model(exp_id, n_epochs, finetune_head, lr, weight_decay)
val_loss3 = train(model, train_loader, test_loader, optimizer, scheduler, n_epochs)

In [None]:
val_loss3

### From scratch

In [None]:
finetune_head = False
exp_id = None
model, optimizer, scheduler = get_model(exp_id, n_epochs, finetune_head, lr, weight_decay)
val_loss4 = train(model, train_loader, test_loader, optimizer, scheduler, 50)

In [None]:
val_loss4

## Semisupervised

In [None]:
test_ds = ABCDataset(dataset_path,
                     'test',
                     'distances',
                     transform=PointCloudNormalize('box'))

test_loader = DataLoader(test_ds, shuffle=False, batch_size=32)

In [None]:
finetune_head = False
exp_id = 'z52pya7i'
n_epochs = 200
run_results = []

for i, seed in enumerate([24234, 23214, 64645]):
    train_ds = ABCDataset(dataset_path,
                          'train',
                          'distances',
                          transform=PointCloudNormalize('box'), sample_frac=0.01, seed=seed)

    train_loader = DataLoader(train_ds, shuffle=True, batch_size=50)
    model, optimizer, scheduler = get_model(exp_id, n_epochs, finetune_head, lr, weight_decay)
    val_loss4 = train(model, train_loader, test_loader, optimizer, scheduler, n_epochs, val_every=50)
    print(f'Run {i}: {val_loss4[-1]}')
    run_results.append(val_loss4[-1])
    
np.mean(run_results), np.std(run_results, ddof=1)

In [None]:
np.mean(run_results), np.std(run_results, ddof=1)

In [None]:
finetune_head = False
exp_id = None
n_epochs = 200
run_results = []

for i, seed in enumerate([24234, 23214, 64645]):
    train_ds = ABCDataset(dataset_path,
                          'train',
                          'distances',
                          transform=PointCloudNormalize('box'), sample_frac=0.01, seed=seed)

    train_loader = DataLoader(train_ds, shuffle=True, batch_size=50)

    model, optimizer, scheduler = get_model(exp_id, n_epochs, finetune_head, lr, weight_decay)
    val_loss4 = train(model, train_loader, test_loader, optimizer, scheduler, n_epochs, val_every=50)
    print(f'Run {i}: {val_loss4[-1]}')
    run_results.append(val_loss4[-1])

In [None]:
np.mean(run_results), np.std(run_results, ddof=1)

In [None]:
finetune_head = False
exp_id = 'z52pya7i'
n_epochs = 200
run_results = []

for i, seed in enumerate([24234, 23214, 64645]):
    train_ds = ABCDataset(dataset_path,
                          'train',
                          'distances',
                          transform=PointCloudNormalize('box'), sample_frac=0.05, seed=seed)

    train_loader = DataLoader(train_ds, shuffle=True, batch_size=50)

    model, optimizer, scheduler = get_model(exp_id, n_epochs, finetune_head, lr, weight_decay)
    val_loss4 = train(model, train_loader, test_loader, optimizer, scheduler, n_epochs, val_every=50)
    print(f'Run {i}: {val_loss4[-1]}')
    run_results.append(val_loss4[-1])

In [None]:
np.mean(run_results), np.std(run_results, ddof=1)

In [None]:
finetune_head = False
exp_id = None
n_epochs = 200
run_results = []

for i, seed in enumerate([24234, 23214, 64645]):
    train_ds = ABCDataset(dataset_path,
                          'train',
                          'distances',
                          transform=PointCloudNormalize('box'), sample_frac=0.05, seed=seed)

    train_loader = DataLoader(train_ds, shuffle=True, batch_size=50)

    model, optimizer, scheduler = get_model(exp_id, n_epochs, finetune_head, lr, weight_decay)
    val_loss4 = train(model, train_loader, test_loader, optimizer, scheduler, n_epochs, val_every=50)
    print(f'Run {i}: {val_loss4[-1]}')
    run_results.append(val_loss4[-1])

In [None]:
np.mean(run_results), np.std(run_results, ddof=1)