Skip to content

Commit

Permalink
Add dataset domain identification protocols
Browse files Browse the repository at this point in the history
  • Loading branch information
jmccrae committed Mar 23, 2018
1 parent 17104c7 commit be79468
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 5 deletions.
19 changes: 19 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Expand Up @@ -10,3 +10,4 @@ serde = "*"
serde_json = "*"
serde_derive = "*"
htmlescape = "*"
noisy_float = "*"
9 changes: 8 additions & 1 deletion src/data.rs
Expand Up @@ -38,7 +38,8 @@ pub struct Dataset {
pub links : Vec<Link>,
pub identifier : String,
pub domain : String,
pub triples : IntLike
pub triples : IntLike,
pub keywords : Vec<String>
}

/// A link from a dataset to a target dataset
Expand All @@ -51,6 +52,12 @@ pub struct Link {
#[derive(Debug,Clone)]
pub struct IntLike(Option<i32>);

impl From<i32> for IntLike {
fn from(x : i32) -> Self {
IntLike(Some(x))
}
}

impl IntLike {
pub fn get(&self) -> i32 {
self.0.unwrap_or(0).clone()
Expand Down
20 changes: 17 additions & 3 deletions src/graph.rs
@@ -1,6 +1,6 @@
//! The graph is a set of vertices and links between these vertices
use data::Dataset;
use std::collections::HashMap;
use std::collections::{HashSet, HashMap};

/// The parameters of the model
#[derive(Default)]
Expand Down Expand Up @@ -302,14 +302,28 @@ fn relu(x : f64) -> f64 {
}
}


/// Build the graph from the dataset
pub fn build_graph(data : &HashMap<String, Dataset>) -> Graph {
let mut g = Graph::new();
let mut linked_datasets = HashSet::new();

for dataset in data.values() {
if !dataset.links.is_empty() {
let v1 = g.add_vertex(&dataset.identifier);
if dataset.links.iter().any(|d| data.contains_key(&d.target)) {
linked_datasets.insert(dataset.identifier.clone());
for link in dataset.links.iter() {
if data.contains_key(&link.target) {
linked_datasets.insert(link.target.clone());
}
}
}
}

for dataset in data.values() {
if linked_datasets.contains(&dataset.identifier) {
let v1 = g.add_vertex(&dataset.identifier);
for link in dataset.links.iter() {
if linked_datasets.contains(&link.target) {
let v2 = g.add_vertex(&link.target);
g.edges.push(Edge::new(v1,v2));
}
Expand Down
167 changes: 167 additions & 0 deletions src/ident.rs
@@ -0,0 +1,167 @@
//! Code used to identify the domains (bubble colours) of the bubbles
use data::Dataset;
use std::collections::HashMap;
use noisy_float::prelude::*;

/// Find the domain by the neighbours of a dataset, the domain will be
/// set to the most frequent among neighbours
pub fn domain_by_most_neighbours(datasets : &mut HashMap<String, Dataset>) {
let mut incoming = HashMap::new();
for (_, dataset) in datasets.iter() {
for link in dataset.links.iter() {
incoming.entry(link.target.clone())
.or_insert_with(|| Vec::new()).push(dataset.identifier.clone());
}
}
let mut ds2domain : HashMap<String, String> = datasets.iter().map(|k| {
(k.0.clone(), k.1.domain.clone())
}).collect();
let mut last_fails = -1;
let mut fails = 0;
while fails != last_fails {
last_fails = fails;
fails = 0;
for (_, dataset) in datasets.iter_mut() {
if dataset.domain == "" {
let mut counts = HashMap::new();
for link in dataset.links.iter() {
match ds2domain.get(&link.target) {
Some(d) if d != "" => {
let c : i32 = *counts.get(d).unwrap_or(&0);
counts.insert(d.clone(), c + 1);
},
_ => {}
}
}

let empty = Vec::new();
for link in incoming.get(&dataset.identifier).unwrap_or_else(|| &empty).iter() {
match ds2domain.get(link) {
Some(d) if d != "" => {
let c : i32 = *counts.get(d).unwrap_or(&0);
counts.insert(d.clone(), c + 1);
},
_ => {}
}
}

let mut best_domain = String::new();
let mut best_count = -1;
for (k, v) in counts.iter() {
if *v > best_count {
best_domain = k.clone();
best_count = *v;
}
}
if best_domain == "" {
fails += 1;
} else {
dataset.domain = best_domain.to_string();
ds2domain.insert(dataset.identifier.to_string(),
best_domain.to_string());
}
}
}
// eprintln!("Fails: {} ({})", fails, last_fails);
}
}

const ALPHA : f64 = 0.0001f64;

/// Find the domain of a dataset by the set of keywords. A naive Bayes classifier
/// is created from the labelled datasets and this is applied to all the
/// unlabelled datasets
pub fn domain_by_keywords(datasets : &mut HashMap<String, Dataset>) {
let mut tag_cat_freq = HashMap::new();
let mut cat_freq = HashMap::new();
let mut tag_freq = HashMap::new();
let mut total = 0;

for (_, dataset) in datasets.iter() {
let cat = dataset.domain.clone();
if cat != "" {
let c = *cat_freq.get(&cat).unwrap_or(&0);
for tag in dataset.keywords.iter() {
let x = tag_cat_freq.entry(tag.clone()).or_insert_with(|| HashMap::new());
let c2 = *x.get(&cat).unwrap_or(&0);
x.insert(cat.clone(), c2 + 1);
let c3 = *tag_freq.get(tag).unwrap_or(&0);
tag_freq.insert(tag.clone(), c3 + 1);
}
cat_freq.insert(cat, c + 1);
total += 1;
}
}

let n_alpha = (cat_freq.len() as f64) * ALPHA;

let tag_cat_prob : HashMap<(String, String), f64> =
tag_freq.keys().flat_map(|_tag| {
let tag = _tag.clone();
let v : Vec<((String, String), f64)> =
cat_freq.keys().map(|_cat| {
let cat = _cat.clone();
let tcf = *tag_cat_freq[&tag].get(&cat).unwrap_or(&0) as f64 + ALPHA;
let tf = *tag_freq.get(&tag).unwrap_or(&0) as f64 + ALPHA;
let cf = *cat_freq.get(&cat).unwrap_or(&0) as f64 + n_alpha;
let p = (tcf / cf).ln() - (tf / (total as f64 + n_alpha)).ln();
((tag.clone(), cat.clone()), p)
}).collect();
v
}).collect();

let cat_prob : HashMap<String, f64> = cat_freq.iter().map(|cf| {
(cf.0.clone(), ((*cf.1 as f64 + ALPHA) / (total as f64 + n_alpha)).ln())
}).collect();

let cats : Vec<String> = cat_prob.keys().map(|x| x.clone()).collect();

for (_, dataset) in datasets.iter_mut() {
if dataset.domain == "" {
if let Some((cat, _p)) = cats.iter().map(|c| {
let mut prob = *cat_prob.get(c).unwrap_or(&(ALPHA / (total as f64 + n_alpha)));
for tag in dataset.keywords.iter() {
prob += *tag_cat_prob.get(&(tag.to_string(), c.to_string())).unwrap_or(&0.0);
}
(c.clone(), r64(prob))
}).max_by_key(|c| c.1) {
dataset.domain = cat;
}
}
}
}

#[cfg(test)]
mod tests{
use std::collections::HashMap;
use data::{Dataset,IntLike};
use ident::*;

fn make_dataset(s : &str, d : &str) -> Dataset {
Dataset {
description: HashMap::new(),
title: None,
links: Vec::new(),
identifier: s.to_string(),
domain: d.to_string(),
triples: IntLike::from(0),
keywords : s.chars().map(|x| x.to_string()).collect()
}
}

#[test]
fn test_domain_by_keywords() {
let mut datasets = HashMap::new();
datasets.insert("foo".to_string(), make_dataset("foo", "a"));
datasets.insert("bar".to_string(), make_dataset("bar", "b"));
datasets.insert("baz".to_string(), make_dataset("baz", "b"));
datasets.insert("bao".to_string(), make_dataset("bao", ""));
datasets.insert("fod".to_string(), make_dataset("fod", ""));

domain_by_keywords(&mut datasets);

assert_eq!(datasets["bao"].domain, "b");
assert_eq!(datasets["fod"].domain, "a");
}
}

24 changes: 23 additions & 1 deletion src/main.rs
Expand Up @@ -5,9 +5,11 @@ extern crate serde_json;
extern crate serde_derive;
extern crate clap;
extern crate htmlescape;
extern crate noisy_float;

mod data;
mod graph;
mod ident;
mod settings;
mod svg;
mod tree;
Expand Down Expand Up @@ -107,6 +109,10 @@ Gradient or lbfgsb = Limited BFGS)")
.help("Apply an n x n blocking method to speed up the algorithm
(default=1, no blocking)")
.takes_value(true))
.arg(Arg::with_name("ident")
.long("ident")
.value_name("none|neighbour|tags")
.help("The algorithm used to identify domain (bubble colours) of unidentified datasets"))
.get_matches();

let mut model : graph::Model = Default::default();
Expand Down Expand Up @@ -151,6 +157,15 @@ Gradient or lbfgsb = Limited BFGS)")
None => "lbfgsb"
};

let ident_algorithm = match args.value_of("ident") {
Some("none") => "none",
Some("tags") => "tags",
Some("neighbour") => "neighbour",
Some("neighbor") => "neighbour", // For Americans
Some(a) => panic!(format!("{} is not a supported identification algorithm", a)),
None => "none"
};

let max_iters = args.value_of("max_iters")
.map(|s| { s.parse::<u32>().expect("Iterations is not an integer") })
.unwrap_or(10000);
Expand All @@ -161,7 +176,14 @@ Gradient or lbfgsb = Limited BFGS)")

let settings = Settings::default();

let data : HashMap<String,Dataset> = serde_json::from_reader(data_file).expect("JSON error");
let mut data : HashMap<String,Dataset> = serde_json::from_reader(data_file).expect("JSON error");

match ident_algorithm {
"none" => {},
"neighbour" => ident::domain_by_most_neighbours(&mut data),
"tags" => ident::domain_by_keywords(&mut data),
_ => panic!("Unreachable")
};

let graph = graph::build_graph(&data);

Expand Down
1 change: 1 addition & 0 deletions src/settings.rs
Expand Up @@ -9,6 +9,7 @@
//! "id": "cross-domain",
//! "color": "#c8a788"
//! },{
//! ...
//!
//! }]
//! }
Expand Down

0 comments on commit be79468

Please sign in to comment.