[![Binder](https://mybinder.org/badge_logo.svg)](https://lab.mlpack.org/v2/gh/mlpack/examples/master?urlpath=lab%2Ftree%2Fcontact_tracing_clustering_with_dbscan%2F.ipynb)

In [1]:
/**
 * @file contact-tracing-dbscan-cpp.ipynb
 *
 * A simple contact tracing method using DBSCAN.
 * 
 * Once a person is tested positive for the virus,
 * it is very important to identify others who may
 * have been infected by the diagnosed patients.
 * To identify the infected people, a process called
 * contact tracing is often used. In this example, we
 * applied DBSCAN to perform pseudo location-based
 * contact tracing using GPS.
 */

In [2]:
!wget -q https://lab.mlpack.org/data/contact-tracing.csv

In [3]:
#include <mlpack/xeus-cling.hpp>

#include <mlpack/core.hpp>
#include <mlpack/methods/dbscan/dbscan.hpp>

#include <sstream>

In [4]:
// Header files to create and show the plot.
#define WITHOUT_NUMPY 1
#include "matplotlibcpp.h"
#include "xwidgets/ximage.hpp"

#include "plot3d.hpp"

namespace plt = matplotlibcpp;

In [5]:
using namespace mlpack;

In [6]:
using namespace mlpack::dbscan;

In [7]:
using namespace mlpack::data;

In [8]:
// Load the pseudo location-based dataset for the contact tracing.
// The dataset has 4 columns: timestamp, latitude, longitude, id.
arma::mat input;
DatasetInfo info;
data::Load("contact-tracing.csv", input, info);

In [9]:
// Print the first ten columns of the input data.
std::cout << "timestamp\t"
          << "latitude\t"
          << "longitude\t"
          << "id\t" << std::endl;
std::cout << input.cols(0, 10).t() << std::endl;

timestamp	latitude	longitude	id	
   1.5934e+09   1.2880e+01   7.7785e+01            0
   1.5934e+09   1.2993e+01   7.7597e+01            0
   1.5934e+09   1.2976e+01   7.7464e+01            0
   1.5934e+09   1.2975e+01   7.7615e+01            0
   1.5934e+09   1.2998e+01   7.7706e+01            0
   1.5934e+09   1.3021e+01   7.7511e+01            0
   1.5934e+09   1.2993e+01   7.7647e+01            0
   1.5934e+09   1.3032e+01   7.7568e+01            0
   1.5934e+09   1.2940e+01   7.7641e+01            0
   1.5934e+09   1.2910e+01   7.7649e+01            0
   1.5934e+09   1.2984e+01   7.7455e+01            0



In [10]:
// Helper function to generate the data for the 3D plot.
void Data3DPlot(std::stringstream& xData,
                std::stringstream& yData,
                std::stringstream& time,
                std::stringstream& label,
                const std::vector<int>& filter)
{
    xData.clear();
    yData.clear();
    time.clear();
    label.clear();
    
    for (size_t i = 0; i < info.NumMappings(3); ++i)
    {
        if (filter.size() != 0 &&
            std::find(filter.begin(), filter.end(), i) == filter.end())
            continue;

        // Get the indices for the current label.
        arma::mat dataset = input.cols(arma::find(input.row(3) == (double) i));

        // Get the data for the indices.
        std::vector<double> t = arma::conv_to<std::vector<double>>::from(dataset.row(0));
        std::vector<double> x = arma::conv_to<std::vector<double>>::from(dataset.row(1));
        std::vector<double> y = arma::conv_to<std::vector<double>>::from(dataset.row(2));

        // Build the strings for the plot.
        label << info.UnmapString(i, 3);
        for (size_t j = 0; j < t.size(); ++j)
        {
            xData << x[j] << ";";
            yData << y[j] << ";";
            // Scale time to make the plot easier to read.
            time << t[j] / 1000 << ";";
        }

        // Prepare for the next row.
        xData << "\n";
        yData << "\n";
        time << "\n";
        label << "\n";
    }
}

In [11]:
// Plot ids with their latitudes and longitudes across the x-axis and y-axis respectively.
std::stringstream xData, yData, time, label;

std::vector<int> filter;
// Uncomment the lines below to filter for id 0 and 3.
// filter.push_back(0);
// filter.push_back(3);

Data3DPlot(xData, yData, time, label, filter);

Plot3D(xData.str(),
       yData.str(),
       time.str(),
       label.str(),
       "x",
       "y",
       "time",
       2, // Mode: 0 = line, 1 = scatter, 2 = line + scatter.
       "output.png",
       10, // Plot width.
       10); // Plot height.

auto im = xw::image_from_file("output.png").finalize();
im

A Jupyter widget

Plotting all ids can be confusing, so it might be useful to only plot certain ids.
See the comment above to filter and plot certain ids.

In [12]:
// Generate clusters, and identify the infections by filtering the data in the clusters.

// Radial distance of 6 feet in kilometers.
const double epsilon = 0.0018288;

// Perform Density-Based Spatial Clustering of Applications with Noise
// (DBSCAN).
//
// For more information checkout https://mlpack.org/doc/mlpack-git/doxygen/classmlpack_1_1dbscan_1_1DBSCAN.html
// or uncomment the line below.
// ?DBSCAN<>
DBSCAN<> model(epsilon, 2 /* Minimum number of points for each cluster. */);

// We only use the latitude and longitude attribute.
const arma::mat points = input.submat(
    1, 0, input.n_rows - 2 , input.n_cols - 1);

// Perform clustering using DBSCAN, an return the number of clusters. 
arma::Row<size_t> assignments;
const size_t numCluster = model.Cluster(points, assignments);

In [13]:
// The model was able to generate 29 clusters, out of which cluster
// 0 to cluster 29 represents data points with neighboring nodes.
std::cout << "Number of clusters: " << numCluster << std::endl;

Number of clusters: 29


In [14]:
// Plot cluster with their latitudes and longitudes across the x-axis and y-axis respectively.
plt::figure_size(800, 800);

for (size_t i = 0; i < numCluster; ++i)
{
    // Get the indices for the current label.
    arma::mat dataset = input.cols(arma::find(assignments == i));
    
    // Get the data for the indices.
    std::vector<double> x = arma::conv_to<std::vector<double>>::from(dataset.row(1));
    std::vector<double> y = arma::conv_to<std::vector<double>>::from(dataset.row(2));
    
    // Set the label for the legend.
    std::map<std::string, std::string> m;
    m.insert(std::pair<std::string, std::string>("label", std::to_string(i)));
    
    plt::scatter(x, y, 10, m);
}

plt::xlabel("X");
plt::ylabel("y");
plt::title("ids with their latitudes and longitudes");
plt::legend();

plt::save("./plot.png");
auto im = xw::image_from_file("plot.png").finalize();
im

A Jupyter widget

In [15]:
// Check for people who had been in contact with the infected patient.
void PrintInfected(const std::string& infected /* Infected id e.g. Judy. */,
                   DatasetInfo& info /* The dataset info object to map between ids and names. */,
                   const arma::Row<size_t>& assignments /* The generated cluster. */,
                   const size_t numCluster /* The number of found cluster. */)
{
    // Get id from name.
    double infectedId = info.MapString<double>(infected, 3);
    
    // Get infected clusters.
    arma::Mat<size_t> assignmentsTemp = assignments;
    arma::Mat<size_t> cluster = assignmentsTemp.cols(
        arma::find(input.row(3) == infectedId));
    
    // Filter out noise cluster.
    cluster = cluster.cols(arma::find(cluster <= numCluster));
    
    std::cout << "Infected: " << infected << std::endl;
    
    // Find all names that are in the same infected cluster.
    for (size_t c = 0; c < cluster.n_elem; ++c)
    {       
        arma::mat infectedIdsFromCluster = input.cols(
            arma::find(assignments == cluster(c)));

        if (infectedIdsFromCluster.n_cols <= 0)
            std::cout << "No people in the same cluster." << std::endl;
        else
            std::cout << "Maybe infected others in the cluster: ";
        
        for (size_t n = 0, g = 0; n < infectedIdsFromCluster.n_cols; ++n)
        {
            size_t id = infectedIdsFromCluster.col(n)(3);
            
            // Skip the name if it's the same as the infected person.
            if (info.UnmapString(id, 3) == infected)
                continue;

            if (g == 0)
                std::cout << info.UnmapString(id, 3);
            else
                std::cout << "," << info.UnmapString(id, 3);
            
            g++;
        }
        
        std::cout << std::endl;
    }
}

In [16]:
// Check for the people who might be potentially infected from the patient.
PrintInfected("Heidi", info, assignments, numCluster)

Infected: Heidi
Maybe infected others in the cluster: David
Maybe infected others in the cluster: Judy


In [17]:
// Plot the data for Carol, Frank and Grace, to check the contact over time.
std::vector<int> filterHeidiDavidJudy;
filterHeidiDavidJudy.push_back((int) info.MapString<double>("Heidi", 3));
filterHeidiDavidJudy.push_back((int) info.MapString<double>("David", 3));
filterHeidiDavidJudy.push_back((int) info.MapString<double>("Judy", 3));

std::stringstream xData, yData, time, label;
Data3DPlot(xData, yData, time, label, filterHeidiDavidJudy);

Plot3D(xData.str(),
       yData.str(),
       time.str(),
       label.str(),
       "x",
       "y",
       "time",
       2, // Mode: 0 = line, 1 = scatter, 2 = line + scatter.
       "contact-heidi-david-judy.png",
       10, // Plot width.
       10); // Plot height.

auto im = xw::image_from_file("contact-heidi-david-judy.png").finalize();
im

A Jupyter widget

In [18]:
// Check for the people who might be potentially infected from the patient.
PrintInfected("Alice", info, assignments, numCluster)

Infected: Alice
Maybe infected others in the cluster: Judy


In [19]:
// Plot the data for Alice and Judy, to check the contact over time.
std::vector<int> filterAliceJudy;
filterAliceJudy.push_back((int) info.MapString<double>("Alice", 3));
filterAliceJudy.push_back((int) info.MapString<double>("Judy", 3));

std::stringstream xData, yData, time, label;
Data3DPlot(xData, yData, time, label, filterAliceJudy);

Plot3D(xData.str(),
       yData.str(),
       time.str(),
       label.str(),
       "x",
       "y",
       "time",
       2, // Mode: 0 = line, 1 = scatter, 2 = line + scatter.
       "contact-alice-judy.png",
       10, // Plot width.
       10); // Plot height.

auto im = xw::image_from_file("contact-alice-judy.png").finalize();
im

A Jupyter widget

In [20]:
// Check for the people who might be potentially infected from the patient.
PrintInfected("David", info, assignments, numCluster)

Infected: David
Maybe infected others in the cluster: Heidi


In [21]:
// Check for the people who might be potentially infected from the patient.
PrintInfected("Judy", info, assignments, numCluster)

Infected: Judy
Maybe infected others in the cluster: Heidi
Maybe infected others in the cluster: Alice


In [22]:
// Check for the people who might be potentially infected from the patient.
PrintInfected("Carol", info, assignments, numCluster)

Infected: Carol
Maybe infected others in the cluster: Frank,Grace
