# Deep Learning Project

## **1.** Environment setup

In [None]:
!pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu116
!pip install pyspark
!pip install sparktorch 
!pip install gdown 
!pip install torchvision
!pip install pyarrow

In [None]:
import os
import time
from itertools import product, chain

import torch
import torch.optim as optim
from torch.nn import TripletMarginLoss
from torch.optim.lr_scheduler import MultiStepLR

import pyspark
import pyspark.sql.functions as F

from pyspark.sql import *
from pyspark.sql.types import *
from pyspark.sql.functions import *
from pyspark import SparkContext, SparkConf

from sparktorch import (SparkTorch, serialize_torch_obj,
                        serialize_torch_obj_lazy)

from models.utils import *
from models.loss import PointNetLoss
from models.transformation import pointnet_train_transforms, pointnet_default_transform

from sklearn.metrics import roc_curve, auc, roc_auc_score

### **1.1** Parameters

In [None]:
# Device
USE_GPU = True

# Hyperparameters
LEARNING_RATE = 0.00025
WEIGHT_DECAY = 0.001
NUM_POINTS = 2048
NUM_EPOCHS = 200
BATCH_SIZE = 32
NUM_CLASSES = 10

# Reproducibility
RANDOM_SEED = 42

# Spark 
SPARK_MAX_RECORDS_PER_BATCH = 1e3
SPARK_MAX_PARTITION_BYTES = 1e8
SPARK_NUM_CORES = 4

# Dataset
DATASET_FOLDER = "data"

# Model
USE_TRAINED_MODEL = True

In [None]:
device = torch.device(f'cuda:0' if USE_GPU and torch.cuda.is_available() else 'cpu')

In [None]:
print('Using device:', device)
print()

# Clear cache
torch.cuda.empty_cache()

#Additional Info when using cuda
if device.type == 'cuda':
    print('Device:', torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', torch.cuda.memory_allocated(0)/1024**3, 'GB')
    print('Cached:   ', torch.cuda.memory_reserved(0)/1024**3, 'GB')

### **1.2** Reproducibility

In [None]:
# Set riproducibility
set_deterministic()
set_seed(RANDOM_SEED)

### **1.3** Create Spark context

In [None]:
# create the session
conf = SparkConf() \
    .set("spark.ui.port", "4050") \
    .set('spark.executor.memory', '10G') \
    .set('spark.driver.memory', '10G') \
    .set('spark.driver.maxResultSize', '10G') \
    .set("spark.sql.execution.arrow.enabled", True) \
    .set("spark.sql.execution.arrow.maxRecordsPerBatch", int(SPARK_MAX_RECORDS_PER_BATCH)) \
    .set("spark.sql.files.maxPartitionBytes", int(SPARK_MAX_PARTITION_BYTES))

# create the context
sc = pyspark.SparkContext(conf=conf)
sc.setLogLevel("ERROR")

# create spark 
spark = SparkSession.builder.master("local[{}]".format(SPARK_NUM_CORES)).getOrCreate()

In [None]:
spark

In [None]:
sc._conf.getAll()

### **1.3** Data retrieval

In [None]:
print(f"Downloading dataset into {DATASET_FOLDER} folder...")
download_dataset(DATASET_FOLDER)

In [None]:
df = get_dataset(spark)

# balance the dataset
df = undersample(df)

In [None]:
train_set = PointCloudData(df, num_classes=NUM_CLASSES, split='train', transform=pointnet_train_transforms())
test_set = PointCloudData(df, num_classes=NUM_CLASSES, split='test', transform=pointnet_default_transform())
val_set = PointCloudData(df, num_classes=NUM_CLASSES, split='val', transform=pointnet_default_transform())

train_loader = DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
test_loader = DataLoader(dataset=test_set, batch_size=BATCH_SIZE, num_workers=0, pin_memory=True)
val_loader = DataLoader(dataset=val_set, batch_size=BATCH_SIZE,  num_workers=0, pin_memory=True)

In [None]:
print("No. of training samples:", len(train_loader.dataset))
print("No. of testing samples:", len(test_loader.dataset))
print("No. of val samples:", len(val_loader.dataset))

### **1.4** Setup network

In [None]:
model = PointNet(len(train_loader.dataset.classes))
model = model.to(device)

In [None]:
# create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

## **2.** Train PointNet

In [None]:
def plot_training_loss(minibatch_loss_list, num_epochs, iter_per_epoch,
                       results_dir=None, averaging_iterations=100):

    plt.figure()
    ax1 = plt.subplot(1, 1, 1)
    ax1.plot(range(len(minibatch_loss_list)),
             (minibatch_loss_list), label='Minibatch Loss')

    if len(minibatch_loss_list) > 1000:
        ax1.set_ylim([
            0, np.max(minibatch_loss_list[1000:])*1.5
            ])
    ax1.set_xlabel('Iterations')
    ax1.set_ylabel('Loss')

    ax1.plot(np.convolve(minibatch_loss_list,
                         np.ones(averaging_iterations,)/averaging_iterations,
                         mode='valid'),
             label='Running Average')
    ax1.legend()

    ###################
    # Set scond x-axis
    ax2 = ax1.twiny()
    newlabel = list(range(num_epochs+1))

    newpos = [e*iter_per_epoch for e in newlabel]

    ax2.set_xticks(newpos[::50])
    ax2.set_xticklabels(newlabel[::50])

    ax2.xaxis.set_ticks_position('bottom')
    ax2.xaxis.set_label_position('bottom')
    ax2.spines['bottom'].set_position(('outward', 45))
    ax2.set_xlabel('Epochs')
    ax2.set_xlim(ax1.get_xlim())
    ###################

    plt.tight_layout()

    if results_dir is not None:
        image_path = os.path.join(results_dir, 'plot_training_loss.pdf')
        plt.savefig(image_path)

In [None]:
def pointnetloss(outputs, labels, m3x3, m64x64, alpha = 0.0001):
    criterion = torch.nn.NLLLoss()
    bs=outputs.size(0)
    id3x3 = torch.eye(3, requires_grad=True).repeat(bs,1,1)
    id64x64 = torch.eye(64, requires_grad=True).repeat(bs,1,1)
    if outputs.is_cuda:
        id3x3=id3x3.cuda()
        id64x64=id64x64.cuda()
    diff3x3 = id3x3-torch.bmm(m3x3,m3x3.transpose(1,2))
    diff64x64 = id64x64-torch.bmm(m64x64,m64x64.transpose(1,2))
    crit_loss = criterion(outputs, labels)
    return crit_loss + alpha * (torch.norm(diff3x3)+torch.norm(diff64x64)) / float(bs)

In [None]:
def compute_accuracy(model, data_loader, device):
    model.eval()
    with torch.no_grad():
        correct_pred, num_examples = 0, 0
        for batch_idx, (features, labels) in enumerate(data_loader):
            features, labels = features.to(device).float(), labels.to(device)
            outputs, __, __ = model(features.transpose(1,2))
            _, predicted_labels = torch.max(outputs.data, 1)
            num_examples += labels.size(0)
            correct_pred += (predicted_labels == labels).sum()
    return correct_pred.float()/num_examples * 100

In [None]:
def compute_epoch_loss_classifier(model, data_loader, device):
    model.eval()
    curr_loss, num_examples = 0., 0
    with torch.no_grad():
        for batch_idx, (features, labels) in enumerate(data_loader):
            features, labels = features.to(device).float(), labels.to(device)
            outputs, m3x3, m64x64 = model(features.transpose(1,2))
            loss = pointnetloss(outputs, labels, m3x3, m64x64)
            num_examples += labels.size(0)
            curr_loss += loss

        curr_loss = curr_loss / len(data_loader)
        return curr_loss

In [None]:
def train_classifier(num_epochs, model, optimizer, device, 
                     train_loader, valid_loader=None, 
                     loss_fn=None, logging_interval=100, 
                     skip_epoch_stats=False):
    
    log_dict = {'train_loss_per_batch': [],
                'train_acc_per_epoch': [],
                'train_loss_per_epoch': [],
                'valid_acc_per_epoch': [],
                'valid_loss_per_epoch': []}
    
    start_time = time.time()
    for epoch in range(num_epochs):

        model.train()
        for batch_idx, (features, labels) in enumerate(train_loader):
            optimizer.zero_grad()

            features, labels = features.to(device).float(), labels.to(device)
            outputs, m3x3, m64x64 = model(features.transpose(1,2))
            loss = pointnetloss(outputs, labels, m3x3, m64x64)
            loss.backward()

            # UPDATE MODEL PARAMETERS
            optimizer.step()

            # LOGGING
            log_dict['train_loss_per_batch'].append(loss.item())
            
            if not batch_idx % logging_interval:
                print('Epoch: %03d/%03d | Batch %04d/%04d | Loss: %.4f'
                      % (epoch+1, num_epochs, batch_idx,
                          len(train_loader), loss))

        if not skip_epoch_stats:
            model.eval()
            with torch.set_grad_enabled(False):  # save memory during inference
                # compute accuracy and lose
                train_loss = compute_epoch_loss_classifier(model, train_loader, device)
                train_acc = compute_accuracy(model, train_loader, device)
                log_dict['train_loss_per_epoch'].append(train_loss.item())
                log_dict['train_acc_per_epoch'].append(train_acc.item())

                print('***Epoch: %03d/%03d | Train. Acc.: %.3f%% | Loss: %.3f' % (epoch+1, num_epochs, train_acc, train_loss))

                if valid_loader is not None:
                    valid_loss = compute_epoch_loss_classifier(model, valid_loader, device)
                    log_dict['valid_loss_per_epoch'].append(valid_loss.item())

                    valid_acc = compute_accuracy(model, valid_loader, device)
                    log_dict['valid_acc_per_epoch'].append(valid_acc.item())

                    print('***Epoch: %03d/%03d | Valid. Acc.: %.3f%% | Loss: %.3f' % (epoch+1, num_epochs, valid_acc, valid_loss))

        # save state
        num_classes = len(train_loader.dataset.classes)
        torch.save(model.state_dict(), f"state/pointnet_model_{num_classes}c.pt")
        torch.save(optimizer.state_dict(),f"state/poitnet_optimizer_{num_classes}c.pt")

        with open(f"state/log_dict_pointnet_{num_classes}c.json", "w") as f:
            json.dump(log_dict, f)
        
        print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))

    print('Total training time: %.2f min' % ((time.time() - start_time)/60))
    return log_dict

In [None]:
if not USE_TRAINED_MODEL:
    log_dict = train_classifier(NUM_EPOCHS, model=model, optimizer=optimizer, device=device, 
                            train_loader=train_loader, valid_loader=val_loader, 
                            logging_interval=5, skip_epoch_stats=False)
else:
    log_dict = load_pointnet_state(model, NUM_CLASSES, optimizer)

## **3** Evaluate Model

### **3.1.** Loss

In [None]:
model.eval()
plot_training_loss(minibatch_loss_list=log_dict['train_loss_per_batch'],
                num_epochs=NUM_EPOCHS,
                iter_per_epoch=len(train_loader),
                results_dir="output",
                averaging_iterations=len(train_loader))
plt.savefig('output/plot_pointnet_training_loss.png')
plt.show()

### **3.2.** Accuracy

In [None]:
def plot_accuracy(train_acc_list, valid_acc_list, results_dir):
    num_epochs = len(train_acc_list)

    plt.plot(np.arange(1, num_epochs+1),
             train_acc_list, label='Training')
    plt.plot(np.arange(1, num_epochs+1),
             valid_acc_list, label='Validation')

    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.tight_layout()

    if results_dir is not None:
        image_path = os.path.join(
            results_dir, 'plot_acc_training_validation.pdf')
        plt.savefig(image_path)

plot_accuracy(train_acc_list=log_dict["train_acc_per_epoch"],
              valid_acc_list=log_dict["valid_acc_per_epoch"],
              results_dir=None)
plt.ylim([60, 100])
plt.savefig('output/pointnet_accuracy.png')
plt.show()

In [None]:
compute_accuracy(model, test_loader, device=torch.device(device))

### **3.3.** Confusion Matrix

In [None]:
def compute_confusion_matrix(model, data_loader, device):
    from itertools import product

    all_targets, all_predictions = [], []
    with torch.no_grad():
        for batch_idx, (features, label) in enumerate(data_loader):
            inputs, labels = features.to(device).float(), label.to(device)
            outputs, __, __ = model(inputs.transpose(1,2))
            _, predicted_labels = torch.max(outputs.data, 1)    
            all_targets.extend(labels.to('cpu'))
            all_predictions.extend(predicted_labels.to('cpu'))
    all_predictions = all_predictions
    all_predictions = np.array(all_predictions)
    all_targets = np.array(all_targets)
        
    class_labels = np.unique(np.concatenate((all_targets, all_predictions)))
    if class_labels.shape[0] == 1:
        if class_labels[0] != 0:
            class_labels = np.array([0, class_labels[0]])
        else:
            class_labels = np.array([class_labels[0], 1])
    n_labels = class_labels.shape[0]
    lst = []
    z = list(zip(all_targets, all_predictions))
    for combi in product(class_labels, repeat=2):
        lst.append(z.count(combi))
    mat = np.asarray(lst)[:, None].reshape(n_labels, n_labels)
    return mat

def plot_confusion_matrix(conf_mat,
                          hide_spines=False,
                          hide_ticks=False,
                          figsize=None,
                          cmap=None,
                          colorbar=False,
                          show_absolute=True,
                          show_normed=False,
                          class_names=None):

    if not (show_absolute or show_normed):
        raise AssertionError('Both show_absolute and show_normed are False')
    if class_names is not None and len(class_names) != len(conf_mat):
        raise AssertionError('len(class_names) should be equal to number of'
                             'classes in the dataset')

    total_samples = conf_mat.sum(axis=1)[:, np.newaxis]
    normed_conf_mat = conf_mat.astype('float') / total_samples

    fig, ax = plt.subplots(figsize=figsize)
    ax.grid(False)
    if cmap is None:
        cmap = plt.cm.Blues

    if figsize is None:
        figsize = (len(conf_mat)*1.25, len(conf_mat)*1.25)

    if show_normed:
        matshow = ax.matshow(normed_conf_mat, cmap=cmap)
    else:
        matshow = ax.matshow(conf_mat, cmap=cmap)

    if colorbar:
        fig.colorbar(matshow)

    for i in range(conf_mat.shape[0]):
        for j in range(conf_mat.shape[1]):
            cell_text = ""
            if show_absolute:
                cell_text += format(conf_mat[i, j], 'd')
                if show_normed:
                    cell_text += "\n" + '('
                    cell_text += format(normed_conf_mat[i, j], '.2f') + ')'
            else:
                cell_text += format(normed_conf_mat[i, j], '.2f')
            ax.text(x=j,
                    y=i,
                    s=cell_text,
                    va='center',
                    ha='center',
                    color="white" if normed_conf_mat[i, j] > 0.5 else "black")
    
    if class_names is not None:
        tick_marks = np.arange(len(class_names))
        plt.xticks(tick_marks, class_names, rotation=90)
        plt.yticks(tick_marks, class_names)
        
    if hide_spines:
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    if hide_ticks:
        ax.axes.get_yaxis().set_ticks([])
        ax.axes.get_xaxis().set_ticks([])

    plt.xlabel('predicted label')
    plt.ylabel('true label')
    return fig, ax

mat = compute_confusion_matrix(model=model, data_loader=test_loader, device=torch.device(device))
plot_confusion_matrix(mat, class_names=test_loader.dataset.id2label.values())
plt.savefig('output/plot_confusion_matrix.png')
plt.show()

### **3.3.** ROC Curve

In [None]:
def plot_roc_curve(model, data_loader):
    model.eval()

    with torch.no_grad():
        all_proba = []
        all_labels = []

        for batch_idx, (features, labels) in enumerate(data_loader):
            features = features.to(device).float()
            outputs, __, __ = model(features.transpose(1,2))
            _, predicted_labels = torch.max(outputs.data, 1)
            all_proba.append(outputs.cpu())
            all_labels.append(labels)
        
        y_test = np.concatenate(all_labels)
        y_score = np.concatenate(all_proba)

        tpr,fpr,roc_auc = ([[]]*NUM_CLASSES for _ in range(3))
        
        f,ax = plt.subplots()
        for i in range(NUM_CLASSES):
            fpr[i], tpr[i], th = roc_curve(y_test == i, y_score[:, i])
            roc_auc[i] = auc(fpr[i], tpr[i])
            ax.plot(fpr[i],tpr[i])
        
        plt.legend([ f"Class {d} (area = {roc_auc[d]:.4f})" for d in range(NUM_CLASSES)])
        plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('FPR')
        plt.ylabel('TPR')

plot_roc_curve(model, test_loader)


### **3.3.** Contrastive Learning

In [None]:
from torchvision import datasets, models, transforms
from models.transformation import ToTensor, FromFlattenToPointcloud, Normalize

class ApplyContrastiveLearning(object):
    def __init__(self, autoencoder, device):
        self.autoencoder = autoencoder
        self.device = device

    def __call__(self, obj):
        features = obj.unsqueeze(dim=0).to(self.device).float()
        autoencoded_obj = self.autoencoder(features.permute(0, 2, 1))
        return autoencoded_obj[0].cpu()

def autoencoder_transformer(device, contrastive_learning):
    autoencoder = PointcloudAutoencoder(NUM_POINTS)
    autoencoder.to(device)

    # load autoencoder state
    load_autoencoder_state(autoencoder, num_classes=NUM_CLASSES, contrastive=contrastive_learning, device=device)
    autoencoder.eval()

    return transforms.Compose([
        FromFlattenToPointcloud(),
        Normalize(),
        ToTensor(),
        ApplyContrastiveLearning(autoencoder, device)
    ])

# load dataset by the applied method
contrastive_dataset = PointCloudData(df, num_classes=NUM_CLASSES, split='test', transform=autoencoder_transformer(device, True))
chamfer_dataset = PointCloudData(df, num_classes=NUM_CLASSES, split='test', transform=autoencoder_transformer(device, False))

# setup data loader
contrastive_data_loader = DataLoader(dataset=contrastive_dataset, batch_size=BATCH_SIZE, num_workers=0, pin_memory=True)
chamfer_data_loader = DataLoader(dataset=chamfer_dataset, batch_size=BATCH_SIZE, num_workers=0, pin_memory=True)
normal_data_loader = test_loader

In [None]:
model.eval()
print(compute_accuracy(model, normal_data_loader, device=torch.device(device)))
print(compute_accuracy(model, chamfer_data_loader, device=torch.device(device)))
print(compute_accuracy(model, contrastive_data_loader, device=torch.device(device)))

In [None]:
mat = compute_confusion_matrix(model=model, data_loader=normal_data_loader, device=torch.device(device))
plot_confusion_matrix(mat, class_names=normal_data_loader.dataset.id2label.values())
plt.show()

mat = compute_confusion_matrix(model=model, data_loader=chamfer_data_loader, device=torch.device(device))
plot_confusion_matrix(mat, class_names=chamfer_data_loader.dataset.id2label.values())
plt.show()

mat = compute_confusion_matrix(model=model, data_loader=contrastive_data_loader, device=torch.device(device))
plot_confusion_matrix(mat, class_names=contrastive_data_loader.dataset.id2label.values())
plt.show()

In [None]:
plot_roc_curve(model, normal_data_loader)
plot_roc_curve(model, chamfer_data_loader)
plot_roc_curve(model, contrastive_data_loader)


## **4.** SparkTorch Training

### **4.1.** Vectorize Features Column

In [None]:
seqAsVector = udf(lambda x: Vectors.dense(x), VectorUDT())
df = df.select(*df.columns, seqAsVector(F.col("features")).alias("vectorized_features"))

### **4.2.** Build the PyTorch object

In [None]:
# create torch object
torch_obj = serialize_torch_obj_lazy(
    model=PointNet,
    criterion=PointNetLoss,
    optimizer=torch.optim.Adam,
    optimizer_params={'lr': LEARNING_RATE },
    model_parameters={ 'classes': len(train_loader.dataset.classes) }
)

In [None]:
# setup features
vector_assembler = VectorAssembler(inputCols=["vectorized_features"], outputCol="assembler_features")

In [None]:
# create spark model
spark_model = SparkTorch(
    inputCol='assembler_features',
    labelCol='class',
    predictionCol='predictions',
    torchObj=torch_obj,
    iters=10,
    verbose=1,
    miniBatch=32,
    partitions=SPARK_NUM_CORES,
    earlyStopPatience=20,
    validationPct=0,
    useVectorOut=True
)

In [None]:
# filter dataset
dataset = df.filter(df['split'] == 'train')

# embed class in df
mapping_expr = create_map([lit(x) for x in chain(*train_loader.dataset.cat2label.items())])

#lookup and replace 
dataset = dataset.withColumn('class', mapping_expr[df['label']])

# create dataset for training
spark_dataset = vector_assembler.transform(dataset)
spark_dataset = spark_dataset.select("assembler_features", "class")
spark_dataset.cache()
spark_dataset.show()

In [None]:
spark_dataset.groupBy("class").count().show()

In [None]:
pymodel = spark_model.fit(spark_dataset).getPytorchModel()

In [None]:
first = spark_dataset.first()
input = np.array(first.assembler_features.toArray()).reshape(2048, 3)
pcshow(*input.T)

In [None]:
pymodel.eval()
outputs, __, __ = pymodel(torch.from_numpy(np.array([ np.array(first.assembler_features) ])).to("cpu").float())
_, predicted_labels = torch.max(outputs.data, 1)
label = train_loader.dataset.id2label[predicted_labels[0].item()]

label