In [1]:
import pickle
import numpy as np
import os
import torch
import json
import shlex
import pickle
import time
import subprocess
import numpy as np
import torch.utils.data as data
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import datasets.data_utils as d_utils
import pandas as pd

from sklearn.metrics import roc_auc_score,confusion_matrix
from sklearn.metrics import accuracy_score
from utils.util import AverageMeter, shapenetpart_metrics
from models.backbones import ResNet
from models.heads import ClassifierResNet, MultiPartSegHeadResNet, SceneSegHeadResNet
from models.losses import LabelSmoothingCrossEntropyLoss, MultiShapeCrossEntropy, MaskedCrossEntropy
from utils.config import config, update_config
from utils.lr_scheduler import get_scheduler
from torchvision import transforms
from sklearn.metrics import roc_curve, auc
from IPython.display import clear_output
import pickle

def config_seting(cfg = 'cfgs/brain/brain.yaml'):
    update_config(cfg)
    return config


# Usefull functional for dataprocesing, train and validation

In [2]:

class BrainDataSeg():
    def __init__(self, data_type = 'train', num_points = 2048, 
                 transforms=None,
                 data_post = '',
                 datafolder = 'BrainData'):
        self.num_points = num_points
        self.transforms = transforms
        if data_type == 'test':
            filename = f'data/{datafolder}/test_data{data_post}.pkl'
        if data_type == 'train':
            filename = f'data/{datafolder}/trainval_data{data_post}.pkl'
        with open(filename, 'rb') as f:
            self.points, self.points_labels, self.labels = pickle.load(f)
        print(f"{filename} loaded successfully")

    def __getitem__(self, idx):
        current_points = self.points[idx]
        current_points_labels = self.points_labels[idx]
        cur_num_points = current_points.shape[0]
        if cur_num_points >= self.num_points:
            choice = np.random.choice(cur_num_points, self.num_points)
            current_points = current_points[choice, :]
            current_points_labels = current_points_labels[choice]
            mask = torch.ones(self.num_points).type(torch.int32)
        else:
            padding_num = self.num_points - cur_num_points
            shuffle_choice = np.random.permutation(np.arange(cur_num_points))
            padding_choice = np.random.choice(cur_num_points, padding_num)
            choice = np.hstack([shuffle_choice, padding_choice])
            current_points = current_points[choice, :]
            current_points_labels = current_points_labels[choice]
            mask = torch.cat([torch.ones(cur_num_points), torch.zeros(padding_num)]).type(torch.int32)
        if self.transforms is not None:
            current_points = self.transforms(current_points)
        label = torch.from_numpy(self.labels[idx]).type(torch.int64)
        current_points_labels = torch.from_numpy(current_points_labels).type(torch.int64)

        return current_points, mask, current_points_labels, label

    def __len__(self):
        return len(self.points)
def get_loader(num_points,batch_size = 16,data_post = '', datafolder = 'BrainData'):
    trans = transforms.Compose([d_utils.PointcloudToTensor()])
    train_dataset = BrainDataSeg(num_points=num_points,
                                 data_type = 'train',
                                 transforms = trans, 
                                 data_post = data_post,
                                 datafolder = datafolder
                                )
    test_dataset = BrainDataSeg(num_points=num_points,
                                data_type = 'test',
                                transforms = trans, 
                                data_post = data_post,
                               datafolder = datafolder)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=False,
                                               pin_memory=True,
                                               drop_last=True)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              drop_last=False)

    return train_loader, test_loader

class MultiPartSegmentationModel(nn.Module):
    def __init__(self, config, backbone, head, num_classes, num_parts,
                 input_features_dim, radius, sampleDl, nsamples, npoints,
                 width=144, depth=2, bottleneck_ratio=2):
        super(MultiPartSegmentationModel, self).__init__()
        if backbone == 'resnet':
            self.backbone = ResNet(config, input_features_dim, radius, sampleDl, nsamples, npoints,
                                   width=width, depth=depth, bottleneck_ratio=bottleneck_ratio)
        else:
            raise NotImplementedError(f"Backbone {backbone} not implemented in Multi-Part Segmentation Model")

        if head == 'resnet_part_seg':
            self.segmentation_head = MultiPartSegHeadResNet(num_classes, width, radius, nsamples, num_parts)
        else:
            raise NotImplementedError(f"Head {backbone} not implemented in Multi-Part Segmentation Model")

    def forward(self, xyz, mask, features):
        end_points = self.backbone(xyz, mask, features)
        return self.segmentation_head(end_points)

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
                torch.nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    torch.nn.init.zeros_(m.bias)
                    
def build_multi_part_segmentation(config, weights = None):
    model = MultiPartSegmentationModel(config, config.backbone, config.head, config.num_classes, config.num_parts,
                                       config.input_features_dim,
                                       config.radius, config.sampleDl, config.nsamples, config.npoints,
                                       config.width, config.depth, config.bottleneck_ratio)
    criterion = MultiShapeCrossEntropy(config.num_classes, weights)
    return model, criterion
def Find_Optimal_Cutoff(target, predicted):
    """ Find the optimal probability cutoff point for a classification model related to event rate  """
    fpr, tpr, threshold = roc_curve(target, predicted)
    i = np.arange(len(tpr)) 
    roc = pd.DataFrame({'tf' : pd.Series(tpr-(1-fpr), index=i), 'threshold' : pd.Series(threshold, index=i)})
    roc_t = roc.iloc[(roc.tf-0).abs().argsort()[:1]]

    return list(roc_t['threshold'])

def train(epoch, train_loader, model, criterion, optimizer, scheduler, config):
    """ One epoch training """
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    loss_meter = AverageMeter()
    end = time.time()
    pred_soft_flats = []
    points_labels_flats = []
    for idx, (points, mask, points_labels, shape_labels) in enumerate(train_loader):
        data_time.update(time.time() - end)
        bsz = points.size(0)
        # forward
        features = points
        features = features.transpose(1, 2).contiguous()

        points = points[:,:,:3].cuda(non_blocking=True)
        mask = mask.cuda(non_blocking=True)
        features = features.cuda(non_blocking=True)
        points_labels = points_labels.cuda(non_blocking=True)
        shape_labels = shape_labels.cuda(non_blocking=True)
        pred = model(points, mask, features)
        loss = criterion(pred, points_labels, shape_labels)
        
        m = torch.nn.Softmax(dim=1)
        pred_soft_flats += list(np.array(m(pred[0])[:,1,:].reshape(-1).detach().cpu()))
        points_labels_flats += list(np.array(points_labels.reshape(-1).detach().cpu()))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        # update meters
        loss_meter.update(loss.item(), bsz)
        batch_time.update(time.time() - end)
        end = time.time()

    print('roc train',roc_auc_score(points_labels_flats,pred_soft_flats))
    opt = Find_Optimal_Cutoff(points_labels_flats, pred_soft_flats)
    print('optimal cut',opt)

    return loss_meter.avg, opt

def validate(epoch, test_loader, model, criterion, config, num_votes=10, is_conf = False):
    """ One epoch validating """
    batch_time = AverageMeter()
    losses = AverageMeter()

    model.eval()
    with torch.no_grad():
        all_logits = []
        all_points_labels = []
        all_shape_labels = []
        all_masks = []
        end = time.time()
        TS = d_utils.BatchPointcloudScaleAndJitter(scale_low=config.scale_low,
                                                   scale_high=config.scale_high,
                                                   std=config.noise_std,
                                                   clip=config.noise_clip)
        pred_soft_flats = []
        points_labels_flats = []
        for idx, (points_orig, mask, points_labels, shape_labels) in enumerate(test_loader):
            vote_logits = None
            vote_points_labels = None
            vote_shape_labels = None
            vote_masks = None
            for v in range(num_votes):
                batch_logits = []
                batch_points_labels = []
                batch_shape_labels = []
                batch_masks = []
                # augment for voting
                if v > 0:
                    points = TS(points_orig)
                else:
                    points = points_orig
                # forward
                features = points
                features = features.transpose(1, 2).contiguous()
                points = points[:,:,:3].cuda(non_blocking=True)
                mask = mask.cuda(non_blocking=True)
                features = features.cuda(non_blocking=True)
                points_labels = points_labels.cuda(non_blocking=True)
                shape_labels = shape_labels.cuda(non_blocking=True)

                pred = model(points, mask, features)
                loss = criterion(pred, points_labels, shape_labels)
                losses.update(loss.item(), points.size(0))
                
                m = torch.nn.Softmax(dim=1)
                pred_soft_flats += list(np.array(m(pred[0])[:,1,:].reshape(-1).detach().cpu()))
                points_labels_flats += list(np.array(points_labels.reshape(-1).detach().cpu()))
                

                # collect
                bsz = points.shape[0]
                for ib in range(bsz):
                    sl = shape_labels[ib]
                    logits = pred[sl][ib]
                    pl = points_labels[ib]
                    pmk = mask[ib]
                    batch_logits.append(logits.cpu().numpy())
                    batch_points_labels.append(pl.cpu().numpy())
                    batch_shape_labels.append(sl.cpu().numpy())
                    batch_masks.append(pmk.cpu().numpy().astype(np.bool))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if vote_logits is None:
                    vote_logits = batch_logits
                    vote_points_labels = batch_points_labels
                    vote_shape_labels = batch_shape_labels
                    vote_masks = batch_masks
                else:
                    for i in range(len(vote_logits)):
                        vote_logits[i] = vote_logits[i] + (batch_logits[i] - vote_logits[i]) / (v + 1)

            all_logits += vote_logits
            all_points_labels += vote_points_labels
            all_shape_labels += vote_shape_labels
            all_masks += vote_masks
            if idx % config.print_freq == 0:
                print(
                    f'V{num_votes} Test: [{idx}/{len(test_loader)}]\t'
                    f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    f'Loss {losses.val:.4f} ({losses.avg:.4f})')
                

        print('roc test',roc_auc_score(points_labels_flats,pred_soft_flats))
            
        if is_conf:
            acc, shape_ious, msIoU, mIoU, confs = shapenetpart_metrics(config.num_classes,
                                                            config.num_parts,
                                                            all_shape_labels,
                                                            all_logits,
                                                            all_points_labels,
                                                            all_masks, is_conf = is_conf)
        else:
            acc, shape_ious, msIoU, mIoU = shapenetpart_metrics(config.num_classes,
                                                            config.num_parts,
                                                            all_shape_labels,
                                                            all_logits,
                                                            all_points_labels,
                                                            all_masks, is_conf = is_conf)
            
    if is_conf:
        return losses.avg,acc, msIoU, mIoU, confs
    else:
        return losses.avg,acc, msIoU, mIoU

# Hyperparameters

In [3]:
WEIGHTS = [1,167]
CFG = 'cfgs/brain/brain4exp.yaml'
IS_CONF = True
DATAFOLDER = 'BrainData'

# Data processing, train and validation

In [4]:
config = config_seting(CFG)

  exp_config = edict(yaml.load(f))


In [5]:
model_, criterion = build_multi_part_segmentation(config, WEIGHTS)
criterion.cuda()

MultiShapeCrossEntropy()

In [11]:
torch.cuda.empty_cache()

In [8]:
res_dict = {}
for e in range(15):
    DATA_POSTFIX = f'_4exp_{e}'
    train_loader, test_loader = get_loader(num_points = config.num_points, 
                                       data_post = DATA_POSTFIX, 
                                       datafolder = DATAFOLDER )
    
    
    n_data = len(train_loader.dataset)
    print(f"length of training dataset: {n_data}")
    n_data = len(test_loader.dataset)
    print(f"length of testing dataset: {n_data}")
    
    model = model_
    model.cuda()
    if config.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=config.batch_size * dist.get_world_size() / 16 * config.base_learning_rate,
                                    momentum=config.momentum,
                                    weight_decay=config.weight_decay)
    elif config.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=config.base_learning_rate,
                                     weight_decay=config.weight_decay)
    elif config.optimizer == 'adamW':
        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=config.base_learning_rate,
                                      weight_decay=config.weight_decay)
    else:
        raise NotImplementedError(f"Optimizer {config.optimizer} not supported")
    scheduler = get_scheduler(optimizer, len(train_loader), config)
    
    res_dict[e] = {}
    train_losses = []
    test_iou = []
    test_losses = []

    for epoch in range(config.start_epoch, config.epochs + 1):

        tic = time.time()
        loss, opt = train(epoch, train_loader, model, criterion, optimizer, scheduler, config)

        print('epoch {}, total time {:.2f}, lr {:.5f}'.format(epoch,
                                                                    (time.time() - tic),
                                                                    optimizer.param_groups[0]['lr']))
        tmp = validate(epoch, test_loader, model, criterion, config, num_votes=1, is_conf = IS_CONF)
        if IS_CONF:
            loss_test,acc, msIoU, mIoU,confs = tmp
        else:
            loss_test,acc, msIoU, mIoU = tmp


        print('test_loss', loss_test)
        print('ins_loss', loss)
        print('test_acc', acc)
        print('msIoU', msIoU)
        print('mIoU', mIoU)
        if IS_CONF:
            print(confs)
        test_losses.append(loss_test)
        train_losses.append(loss)
        test_iou.append(msIoU)

        res_dict[e] = {'test_loss':test_losses,'train_loss':train_losses,'iou':test_iou}
        with open('4_exp_output.pkl', 'wb') as handle:
            pickle.dump(res_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

        with open('4_exp_output.pkl', 'rb') as handle:
            res_dict = pickle.load(handle)
        
        if epoch > 1:
            if res_dict[e]['test_loss'][-1] < np.min(res_dict[e]['test_loss'][:-1]):
                torch.save(model.state_dict(), f'4_exp_{e}_model.pkl')

data/BrainData/trainval_data_4exp_0.pkl loaded successfully
data/BrainData/test_data_4exp_0.pkl loaded successfully
length of training dataset: 140
length of testing dataset: 10
roc train 0.8618589526294252
optimal cut [0.5462836027145386]
epoch 1, total time 27.72, lr 0.00199
V1 Test: [0/1]	Time 2.031 (2.031)	Loss 0.6985 (0.6985)
roc test 0.7905033026965896
test_loss 0.6984949707984924
ins_loss 0.5162891522049904
test_acc 0.006298828125
msIoU 0.006298828124692439
mIoU 0.006298828124692439
[[     0 203510]
 [     0   1290]]
roc train 0.9314778995217297
optimal cut [0.5272753238677979]
epoch 2, total time 27.36, lr 0.00196
V1 Test: [0/1]	Time 1.977 (1.977)	Loss 0.6857 (0.6857)
roc test 0.892115275621501
test_loss 0.6857215166091919
ins_loss 0.3486314043402672
test_acc 0.006201171875
msIoU 0.006201171874697208
mIoU 0.006201171874697208
[[     0 203530]
 [     0   1270]]
roc train 0.9525862244266041
optimal cut [0.5705159306526184]
epoch 3, total time 30.09, lr 0.00190
V1 Test: [0/1]	Time

test_loss 1.3914514780044556
ins_loss 0.18140088580548763
test_acc 0.9830322265625
msIoU 0.003940362083130604
mIoU 0.003940362083130604
[[201288   2402]
 [  1073     37]]
roc train 0.9725846620725492
optimal cut [0.5335146188735962]
epoch 4, total time 28.03, lr 0.00183
V1 Test: [0/1]	Time 3.263 (3.263)	Loss 2.0807 (2.0807)
roc test 0.8945453026092864
test_loss 2.0807361602783203
ins_loss 0.19241540879011154
test_acc 0.98462890625
msIoU 0.0
mIoU 0.0
[[201652   2002]
 [  1146      0]]
roc train 0.9764379727020285
optimal cut [0.5027206540107727]
epoch 5, total time 27.55, lr 0.00173
V1 Test: [0/1]	Time 1.962 (1.962)	Loss 1.4088 (1.4088)
roc test 0.9150443237595184
test_loss 1.4088284969329834
ins_loss 0.17121003940701485
test_acc 0.975048828125
msIoU 0.007578676937971306
mIoU 0.007578676937971306
[[199572   4053]
 [  1057    118]]
roc train 0.9791428363842218
optimal cut [0.5334902405738831]
epoch 6, total time 27.26, lr 0.00162
V1 Test: [0/1]	Time 1.944 (1.944)	Loss 1.1659 (1.1659)
roc

test_loss 1.197222113609314
ins_loss 0.16214179247617722
test_acc 0.9559814453125
msIoU 0.01958521495061618
mIoU 0.01958521495061618
[[195492   7813]
 [  1202    293]]
roc train 0.9840954518221131
optimal cut [0.4708273410797119]
epoch 7, total time 39.83, lr 0.00150
V1 Test: [0/1]	Time 1.935 (1.935)	Loss 1.0355 (1.0355)
roc test 0.8920473384646281
test_loss 1.0355199575424194
ins_loss 0.14775159023702145
test_acc 0.9287109375
msIoU 0.022688957820034848
mIoU 0.022688957820034848
[[189769  13584]
 [  1016    431]]
roc train 0.9861577629808581
optimal cut [0.5013614892959595]
epoch 8, total time 32.76, lr 0.00137
V1 Test: [0/1]	Time 1.909 (1.909)	Loss 1.4250 (1.4250)
roc test 0.8602334693483485
test_loss 1.425013780593872
ins_loss 0.13843812234699726
test_acc 0.93861328125
msIoU 0.004416877632536514
mIoU 0.004416877632536514
[[192166  11225]
 [  1347     62]]
roc train 0.9884366378176178
optimal cut [0.5396597981452942]
epoch 9, total time 31.19, lr 0.00122
V1 Test: [0/1]	Time 1.915 (1.9

roc train 0.9880664779527214
optimal cut [0.5548218488693237]
epoch 10, total time 32.44, lr 0.00108
V1 Test: [0/1]	Time 1.963 (1.963)	Loss 0.2302 (0.2302)
roc test 0.8839763259657264
test_loss 0.23015077412128448
ins_loss 0.11912655271589756
test_acc 0.937314453125
msIoU 0.0015340894500274834
mIoU 0.0015340894500274834
[[191939  12786]
 [    52     23]]
roc train 0.991092174475103
optimal cut [0.599894106388092]
epoch 11, total time 30.07, lr 0.00093
V1 Test: [0/1]	Time 2.029 (2.029)	Loss 0.2362 (0.2362)
roc test 0.8989489186877854
test_loss 0.23623883724212646
ins_loss 0.10084963869303465
test_acc 0.9698583984375
msIoU 0.0006652946344297646
mIoU 0.0006652946344297646
[[198620   6103]
 [    70      7]]
roc train 0.9915477903337974
optimal cut [0.5641732811927795]
epoch 12, total time 27.36, lr 0.00078
V1 Test: [0/1]	Time 2.007 (2.007)	Loss 0.2309 (0.2309)
roc test 0.8980576067330203
test_loss 0.23093119263648987
ins_loss 0.10012912284582853
test_acc 0.9594921875
msIoU 0.00034364261138

roc train 0.9923243487384897
optimal cut [0.5673220753669739]
epoch 14, total time 28.38, lr 0.00050
V1 Test: [0/1]	Time 1.913 (1.913)	Loss 0.2494 (0.2494)
roc test 0.8268707203907204
test_loss 0.24937529861927032
ins_loss 0.09123507514595985
test_acc 0.9766748046875
msIoU 0.0
mIoU 0.0
[[200023   4727]
 [    50      0]]
roc train 0.9937527047626954
optimal cut [0.6530836820602417]
epoch 15, total time 26.60, lr 0.00038
V1 Test: [0/1]	Time 1.903 (1.903)	Loss 0.3124 (0.3124)
roc test 0.7966068581562203
test_loss 0.31237202882766724
ins_loss 0.07879805937409401
test_acc 0.9758837890625
msIoU 0.0
mIoU 0.0
[[199861   4880]
 [    59      0]]
roc train 0.9941518972569019
optimal cut [0.6341311931610107]
epoch 16, total time 26.46, lr 0.00027
V1 Test: [0/1]	Time 1.854 (1.854)	Loss 0.2800 (0.2800)
roc test 0.7514468864468864
test_loss 0.27997514605522156
ins_loss 0.07731112651526928
test_acc 0.97865234375
msIoU 0.0
mIoU 0.0
[[200428   4322]
 [    50      0]]
roc train 0.9932749295469437
optimal

V1 Test: [0/1]	Time 1.886 (1.886)	Loss 0.1288 (0.1288)
roc test 0.9877283408236431
test_loss 0.1288488358259201
ins_loss 0.07460298109799623
test_acc 0.921630859375
msIoU 0.06492923574767968
mIoU 0.06492923574767968
[[187683  16037]
 [    13   1067]]
roc train 0.9957753955562614
optimal cut [0.5923251509666443]
epoch 18, total time 27.02, lr 0.00010
V1 Test: [0/1]	Time 1.870 (1.870)	Loss 0.1549 (0.1549)
roc test 0.9818826953532562
test_loss 0.15494288504123688
ins_loss 0.06714335829019547
test_acc 0.9231005859375
msIoU 0.06430662725908756
mIoU 0.06430662725908756
[[187996  15703]
 [    46   1055]]
roc train 0.9947724802162353
optimal cut [0.5579079985618591]
epoch 19, total time 26.63, lr 0.00005
V1 Test: [0/1]	Time 1.873 (1.873)	Loss 0.1446 (0.1446)
roc test 0.9823469314704958
test_loss 0.14457659423351288
ins_loss 0.07102936552837491
test_acc 0.9191064453125
msIoU 0.06470484253216925
mIoU 0.06470484253216925
[[187140  16554]
 [    13   1093]]
roc train 0.9949511102356853
optimal cut 

optimal cut [0.6068517565727234]
epoch 20, total time 26.33, lr 0.00001
V1 Test: [0/1]	Time 1.869 (1.869)	Loss 0.0246 (0.0246)
roc test 0.9994240322302859
test_loss 0.02462790347635746
ins_loss 0.06705457530915737
test_acc 0.9940185546875
msIoU 0.45014379625148704
mIoU 0.45014379625148704
[[202996   1224]
 [     1    579]]
data/BrainData/trainval_data_4exp_7.pkl loaded successfully
data/BrainData/test_data_4exp_7.pkl loaded successfully
length of training dataset: 140
length of testing dataset: 10
roc train 0.9917319392568889
optimal cut [0.4917180836200714]
epoch 1, total time 26.79, lr 0.00199
V1 Test: [0/1]	Time 1.950 (1.950)	Loss 0.0469 (0.0469)
roc test 0.9981597379997419
test_loss 0.04685313627123833
ins_loss 0.09113972540944815
test_acc 0.977490234375
msIoU 0.12652248112233946
mIoU 0.12652248112233946
[[199637   4610]
 [     0    553]]
roc train 0.9923682566717489
optimal cut [0.4780583679676056]
epoch 2, total time 27.23, lr 0.00196
V1 Test: [0/1]	Time 1.948 (1.948)	Loss 0.0904

KeyboardInterrupt: 