In [1]:
%load_ext autoreload
%autoreload 2

from os import sys
sys.path.insert(0, '..')

import torch
import os
import k3d
from tqdm.notebook import tqdm
from models.diffusion import *
from datasets.shapenet import *
from torch.utils.data import DataLoader
from models.latent_cond.models import *

In [2]:
import numpy as np
np.random.seed(0)
# torch.use_deterministic_algorithms(True)
torch.manual_seed(34533)

<torch._C.Generator at 0x7f3030104790>

In [3]:
device = 'cuda:0'
model = DiffusionModel(NoisePredictor(3, residual=True),
                       1000,
                       time_embedding_dim=3).to(device)
state = torch.load('model_299.pt')
model.load_state_dict(state['model'])
model.eval()

DiffusionModel(
  (extractor): NoisePredictor(
    (encoder): PointNetEncoder(
      (conv1): Conv1d(3, 128, kernel_size=(1,), stride=(1,))
      (conv2): Conv1d(128, 128, kernel_size=(1,), stride=(1,))
      (conv3): Conv1d(128, 256, kernel_size=(1,), stride=(1,))
      (conv4): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
      (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (fc1_m): Linear(in_features=512, out_features=256, bias=True)
      (fc2_m): Linear(in_features=256, out_features=128, bias=True)
      (fc3_m): Linear(in_features=128, out_features=512, bias=True)
      (fc_bn1_m): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stat

In [4]:
val_dataset = Dataset('../../datasets', dataset_name='shapenetpart', class_choice='airplane', split='val', segmentation=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64)

In [5]:
batch = next(iter(val_loader))
x = batch[0].transpose(2, 1).to(device)
labels = batch[2]
timesteps = [50, 100, 300, 500, 700, 800]
features, coords = model.get_features(x, timesteps)

100%|██████████| 6/6 [00:04<00:00,  1.43it/s]


In [5]:
def interpolate(x, y, y_features):
    dists = (
        y.pow(2).sum(dim=1, keepdim=True) -
        2 * torch.bmm(x.transpose(2, 1), y)
        + x.pow(2).sum(dim=1).unsqueeze(2)
    )
    weights, idx = torch.topk(dists.pow(2), 3, largest=False, sorted=False, dim=2)
    weights = 1 / (weights + 1e-8)
    weights /= weights.sum(dim=2, keepdim=True)

            # idx: bs x n x 3
    bs, _, n_points = x.shape
    channels = y_features.size(1)
    interpolated = torch.gather(y_features, 2, idx.view(bs, 1, -1).expand(-1, channels, -1))
    interpolated = interpolated.view(bs, channels, n_points, 3) * weights.unsqueeze(1)
    
    return interpolated.sum(dim=3)

def combine_features(x, features, centroids):
    features_list = [features[0]]
    
    for i in range(1, 4):
        interpolated = interpolate(x, centroids[i-1], features[i])

        features_list.append(interpolated)
        
    return torch.cat(features_list, dim=1)

def center(points):
    max_p = points.max(dim=2)[0]
    min_p = points.min(dim=2)[0]
    shift = (max_p + min_p) / 2
    
    points = points - shift.unsqueeze(2)
    y = points[:, 1, :].clone()
    points[:, 1, :] = points[:, 2, :].clone()
    points[:, 2, :] = y
    
    return points.cpu().transpose(2, 1)

In [7]:
agg_features = []

for t in timesteps:
    agg_features.append(combine_features(x, features[t], coords[t]))

In [8]:
from sklearn.cluster import KMeans

In [9]:
def train_kmeans(features, n_clusters):
    batch_size, dim, n_pts = features.shape
    train_dataset = features.transpose(2, 1).reshape(-1, dim).cpu()
    labels = KMeans(n_clusters=n_clusters).fit_predict(train_dataset).reshape(batch_size, n_pts)
    
    return labels

In [10]:
labels_100_3 = train_kmeans(agg_features[1], 3)
labels_300_3 = train_kmeans(agg_features[2], 3)
labels_500_3 = train_kmeans(agg_features[2], 3)

labels_100_5 = train_kmeans(agg_features[1], 5)
labels_300_5 = train_kmeans(agg_features[2], 5)
labels_500_5 = train_kmeans(agg_features[2], 5)

In [11]:
x_centered = center(x)

In [17]:
k3d.points(x_centered[0], point_size=0.05, attribute=labels_500_5[0])

Output()

### FewShort Learning

In [6]:
import torch.nn as nn
from sklearn.metrics import jaccard_score


class MLPClassifier(nn.Module):
    def __init__(self, in_channels, n_classes):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_channels, 128),
            nn.ReLU(True),
            nn.BatchNorm1d(num_features=128),
            nn.Linear(128, 32),
            nn.ReLU(True),
            nn.BatchNorm1d(num_features=32),
            nn.Linear(32, n_classes)
        )
        
    def forward(self, x):
        return self.layers(x)

def prepare_train_dataset(dataset, n_samples, timesteps):
    idx = torch.randperm(len(dataset))[:n_samples].tolist()
    points = []
    labels = []
    
    for i in idx:
        sample = dataset[i]
        points.append(sample[0].t().unsqueeze(0))
        labels.append(sample[2].unsqueeze(0))
        
    points = torch.cat(points, dim=0).to(device)
    labels = torch.cat(labels, dim=0).to(device)
    
    features, coords = model.get_features(points, timesteps)
    agg_features = {}
    
    for t in timesteps:
        agg_features[t] = combine_features(points, features[t], coords[t])
        agg_features[t] = agg_features[t].transpose(2, 1).flatten(start_dim=0, end_dim=1)
    
    data = {
        'features': agg_features,
        'labels': labels.flatten(start_dim=0, end_dim=1)
    }
    torch.save(data, f'dataset_{n_samples}')
    
    return data

@torch.no_grad()
def validate(model, data, timestep, batch_size=128 * 2048):
    loss = nn.CrossEntropyLoss()
    running_loss = 0
    model.eval()
    
    batches = data['features'][timestep].reshape(-1, 2048, 1152)
    labels = data['labels'].reshape(-1, 2048)
    intersections = {0: 0, 1: 0, 2: 0, 3: 0}
    unions = {0: 0, 1: 0, 2: 0, 3: 0}
    
    all_preds = []
    
    for x, l in tqdm(zip(batches, labels)):
        logits = model(x)
        running_loss += loss(logits, l)
        preds = logits.argmax(dim=1)
        all_preds.append(preds)
        for i in range(4):
            intersections[i] += ((preds == i) & (l == i)).sum()
            unions[i] += ((preds == i) | (l == i)).sum()
            
    ious = []
    for i in range(4):
        iou = intersections[i] / (1e-8 + unions[i])
        ious.append(iou)
    
    return torch.tensor(ious).mean(), ious, torch.stack(all_preds, dim=0), running_loss.item() / len(batches)

def train_model(model, data, timestep, epoch_num, batch_size=128):
    features = data['features'][timestep]
    labels = data['labels']
    
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)
    loss = nn.CrossEntropyLoss()
    running_loss = 0
    
    for epoch in range(epoch_num):
        perm = torch.randperm(len(features))
        batches = features[perm].split(batch_size, dim=0)
        l_batches = labels[perm].split(batch_size, dim=0)
        bar = tqdm(enumerate(zip(batches, l_batches)), total=len(batches))
        
        for i, (x, l) in bar:
            if x.shape[0] == 1:
                continue
                
            optimizer.zero_grad()
            logits = model(x)
            loss_t = loss(logits, l)
            loss_t.backward()
            running_loss += loss_t.item() 
            optimizer.step()
            
            bar.set_postfix({
                'Epoch': epoch,
                'Loss': running_loss / (i+1)
            })
            
        running_loss = 0
            
        scheduler.step()

In [7]:
train_dataset = Dataset('../../datasets', dataset_name='shapenetpart',
                        class_choice='airplane', split='train', segmentation=True)

In [8]:
data_3samples = prepare_train_dataset(train_dataset, 3, [100, 300, 500])
data_10samples = prepare_train_dataset(train_dataset, 10, [100, 300, 500])
data_20samples = prepare_train_dataset(train_dataset, 20, [100, 300, 500])

100%|██████████| 3/3 [00:01<00:00,  1.67it/s]
100%|██████████| 3/3 [00:01<00:00,  2.07it/s]
100%|██████████| 3/3 [00:01<00:00,  1.89it/s]


In [9]:
clf3_100 = MLPClassifier(1152, 4).to(device)
train_model(clf3_100, data_3samples, 100, 100)

clf3_300 = MLPClassifier(1152, 4).to(device)
train_model(clf3_300, data_3samples, 300, 100)

clf3_500 = MLPClassifier(1152, 4).to(device)
train_model(clf3_500, data_3samples, 500, 100)

100%|██████████| 48/48 [00:00<00:00, 140.01it/s, Epoch=0, Loss=0.63] 
100%|██████████| 48/48 [00:00<00:00, 146.53it/s, Epoch=1, Loss=0.389]
100%|██████████| 48/48 [00:00<00:00, 140.21it/s, Epoch=2, Loss=0.319]
100%|██████████| 48/48 [00:00<00:00, 136.86it/s, Epoch=3, Loss=0.275]
100%|██████████| 48/48 [00:00<00:00, 139.20it/s, Epoch=4, Loss=0.249]
100%|██████████| 48/48 [00:00<00:00, 160.02it/s, Epoch=5, Loss=0.233]
100%|██████████| 48/48 [00:00<00:00, 148.03it/s, Epoch=6, Loss=0.226]
100%|██████████| 48/48 [00:00<00:00, 205.62it/s, Epoch=7, Loss=0.217]
100%|██████████| 48/48 [00:00<00:00, 161.00it/s, Epoch=8, Loss=0.205]
100%|██████████| 48/48 [00:00<00:00, 160.61it/s, Epoch=9, Loss=0.186]
100%|██████████| 48/48 [00:00<00:00, 147.80it/s, Epoch=10, Loss=0.183]
100%|██████████| 48/48 [00:00<00:00, 143.57it/s, Epoch=11, Loss=0.181]
100%|██████████| 48/48 [00:00<00:00, 138.17it/s, Epoch=12, Loss=0.171]
100%|██████████| 48/48 [00:00<00:00, 173.32it/s, Epoch=13, Loss=0.17] 
100%|██████████|

In [10]:
clf10_100 = MLPClassifier(1152, 4).to(device)
train_model(clf10_100, data_10samples, 100, 100)

clf10_300 = MLPClassifier(1152, 4).to(device)
train_model(clf10_300, data_10samples, 300, 100)

clf10_500 = MLPClassifier(1152, 4).to(device)
train_model(clf10_500, data_10samples, 500, 100)

100%|██████████| 160/160 [00:00<00:00, 161.41it/s, Epoch=0, Loss=0.475]
100%|██████████| 160/160 [00:01<00:00, 147.21it/s, Epoch=1, Loss=0.283]
100%|██████████| 160/160 [00:01<00:00, 148.97it/s, Epoch=2, Loss=0.24] 
100%|██████████| 160/160 [00:01<00:00, 158.27it/s, Epoch=3, Loss=0.214]
100%|██████████| 160/160 [00:01<00:00, 138.45it/s, Epoch=4, Loss=0.201]
100%|██████████| 160/160 [00:01<00:00, 157.77it/s, Epoch=5, Loss=0.184]
100%|██████████| 160/160 [00:01<00:00, 155.23it/s, Epoch=6, Loss=0.182]
100%|██████████| 160/160 [00:01<00:00, 138.71it/s, Epoch=7, Loss=0.172]
100%|██████████| 160/160 [00:01<00:00, 143.88it/s, Epoch=8, Loss=0.166]
100%|██████████| 160/160 [00:00<00:00, 171.59it/s, Epoch=9, Loss=0.159]
100%|██████████| 160/160 [00:01<00:00, 141.74it/s, Epoch=10, Loss=0.159]
100%|██████████| 160/160 [00:01<00:00, 155.86it/s, Epoch=11, Loss=0.153]
100%|██████████| 160/160 [00:01<00:00, 153.51it/s, Epoch=12, Loss=0.148]
100%|██████████| 160/160 [00:01<00:00, 141.26it/s, Epoch=13, 

In [11]:
clf20_100 = MLPClassifier(1152, 4).to(device)
train_model(clf20_100, data_20samples, 100, 100)

clf20_300 = MLPClassifier(1152, 4).to(device)
train_model(clf20_300, data_20samples, 300, 100)

clf20_500 = MLPClassifier(1152, 4).to(device)
train_model(clf20_500, data_20samples, 500, 100)

100%|██████████| 320/320 [00:02<00:00, 149.52it/s, Epoch=0, Loss=0.444]
100%|██████████| 320/320 [00:02<00:00, 156.73it/s, Epoch=1, Loss=0.276]
100%|██████████| 320/320 [00:02<00:00, 147.88it/s, Epoch=2, Loss=0.24] 
100%|██████████| 320/320 [00:02<00:00, 156.12it/s, Epoch=3, Loss=0.218]
100%|██████████| 320/320 [00:02<00:00, 144.23it/s, Epoch=4, Loss=0.204]
100%|██████████| 320/320 [00:02<00:00, 154.13it/s, Epoch=5, Loss=0.195]
100%|██████████| 320/320 [00:02<00:00, 148.99it/s, Epoch=6, Loss=0.185]
100%|██████████| 320/320 [00:02<00:00, 152.05it/s, Epoch=7, Loss=0.178]
100%|██████████| 320/320 [00:02<00:00, 146.41it/s, Epoch=8, Loss=0.173]
100%|██████████| 320/320 [00:02<00:00, 143.28it/s, Epoch=9, Loss=0.166]
100%|██████████| 320/320 [00:02<00:00, 140.43it/s, Epoch=10, Loss=0.163]
100%|██████████| 320/320 [00:02<00:00, 147.32it/s, Epoch=11, Loss=0.162]
100%|██████████| 320/320 [00:02<00:00, 149.45it/s, Epoch=12, Loss=0.157]
100%|██████████| 320/320 [00:02<00:00, 149.96it/s, Epoch=13, 

In [12]:
val_dataset = Dataset('../../datasets', dataset_name='shapenetpart',
                      class_choice='airplane', split='val', segmentation=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128, drop_last=True)

In [13]:
timesteps = [100, 300, 500]
agg_features = {100: [], 300: [], 500: []}
all_labels = []

for batch in val_loader:
    
    x = batch[0].transpose(2, 1).to(device)
    labels = batch[2]
    features, coords = model.get_features(x, timesteps)
    for t in timesteps:
        f = combine_features(x, features[t], coords[t])
        agg_features[t].append(f.transpose(2, 1).flatten(start_dim=0, end_dim=1))
        
    all_labels.append(labels.flatten(start_dim=0, end_dim=1))

for t in timesteps:
    agg_features[t] = torch.cat(agg_features[t], dim=0)
    
val_data = {
    'features': agg_features,
    'labels': torch.cat(all_labels).to(device)
}

100%|██████████| 3/3 [00:03<00:00,  1.23s/it]
100%|██████████| 3/3 [00:03<00:00,  1.21s/it]
100%|██████████| 3/3 [00:03<00:00,  1.21s/it]


In [19]:
models = {
    3: [clf3_100, clf3_300, clf3_500],
    10: [clf10_100, clf10_300, clf10_500],
    20: [clf20_100, clf20_300, clf20_500]
}
timesteps = [100, 300, 500]

table = []
predictions = {
    3: [],
    10: [],
    20: []
}
for num_samples, clfs in models.items():
    for clf, t in zip(clfs, timesteps):
        mIoU, ious, preds, loss = validate(clf, val_data, t)
        table.append({'num_samples': num_samples,
                      'timestep': t,
                      'mIOU': mIoU.item(),
                      'part1': ious[0].item(),
                      'part2': ious[1].item(),
                      'part3': ious[2].item(),
                      'part4': ious[3].item()
                     })
        predictions[num_samples].append(preds)

384it [00:00, 737.20it/s]
384it [00:00, 700.58it/s]
384it [00:00, 712.85it/s]
384it [00:00, 642.43it/s]
384it [00:00, 701.38it/s]
384it [00:00, 693.99it/s]
384it [00:00, 718.98it/s]
384it [00:00, 677.20it/s]
384it [00:00, 767.36it/s]


In [17]:
! pip install pandas

Collecting pandas
  Downloading pandas-1.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.3 MB)
[K     |████████████████████████████████| 11.3 MB 1.4 MB/s eta 0:00:01
Installing collected packages: pandas
Successfully installed pandas-1.3.5


In [21]:
import pandas as pd
pd.set_option("display.precision", 3)
pd.DataFrame(table)

Unnamed: 0,num_samples,timestep,mIOU,part1,part2,part3,part4
0,3,100,0.613,0.781,0.686,0.649,0.338
1,3,300,0.538,0.721,0.564,0.626,0.241
2,3,500,0.477,0.65,0.487,0.569,0.202
3,10,100,0.682,0.804,0.724,0.73,0.471
4,10,300,0.591,0.748,0.653,0.633,0.331
5,10,500,0.561,0.732,0.596,0.59,0.326
6,20,100,0.692,0.8,0.726,0.734,0.508
7,20,300,0.634,0.765,0.656,0.701,0.413
8,20,500,0.605,0.743,0.621,0.687,0.367


In [23]:
predictions[3][0].shape

torch.Size([384, 2048])

In [73]:
k3d.points(center(val_dataset[25][0].t().unsqueeze(0))[0], point_size=0.05,
           attribute=predictions[20][0][25].cpu())

Output()