Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
cangermueller committed Apr 14, 2017
2 parents 9e68888 + 24d1754 commit 3ad30ec
Show file tree
Hide file tree
Showing 11 changed files with 1,570 additions and 581 deletions.
179 changes: 179 additions & 0 deletions R/dcpg_eval_perf.Rmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
---
title: "Prediction performance evaluation"
date: "`r format(Sys.time(), '%Y-%m-%d')`"
output:
html_document:
toc: yes
---

```{r, include=F}
library(knitr)
opts_chunk$set(echo=F, fig.width=12, warning=F, message=F)
```

```{r, include=F}
library(ggplot2)
library(dplyr)
library(tidyr)
library(xtable)
library(grid)
```

<style>
img {
max-width: none;
}
</style>

```{r}
options(xtable.type='html')
```

```{r}
# Options
opts <- list()
opts$data_dir <- './'
opts$metrics_file <- Sys.glob(file.path(opts$data_dir, 'metrics.tsv*'))
opts$curves_file <- Sys.glob(file.path(opts$data_dir, 'curves.tsv*'))
opts$anno_global <- 'global'
```

```{r}
# ggplot theme
my_theme <- function() {
p <- theme(
axis.text=element_text(size=rel(1.2), color='black'),
axis.title.y=element_text(size=rel(1.8), margin=margin(0, 10, 0, 0)),
axis.title.x=element_text(size=rel(1.8), margin=margin(10, 0, 0, 0)),
axis.line = element_line(colour="black", size=1),
axis.ticks.length = unit(.3, 'cm'),
axis.ticks.margin = unit(.3, 'cm'),
legend.position='right',
legend.text=element_text(size=rel(1.3)),
legend.title=element_text(size=rel(1.3), face='bold'),
legend.key=element_rect(fill='transparent'),
strip.text=element_text(size=rel(1.3)),
panel.border=element_blank(),
panel.grid.major=element_line(colour="grey60", size=0.1, linetype='solid'),
panel.grid.minor=element_line(colour="grey60", size=0.1, linetype='dotted'),
panel.background=element_rect(fill="transparent", colour = NA),
plot.background=element_rect(fill="transparent", colour = NA)
)
return (p)
}
```

```{r}
format_output <- function(d) {
d <- factor(sub('cpg/', '', d))
return (d)
}
read_metrics <- function(filename) {
d <- read.table(gzfile(filename), sep='\t', head=T) %>% tbl_df %>%
select(anno, metric, output, value) %>%
mutate(metric=toupper(metric), output=format_output(output))
return (d)
}
read_curves <- function(filename) {
d <- read.table(filename, sep='\t', head=T) %>% tbl_df %>%
select(anno, curve, output, x, y, thr) %>%
mutate(curve=toupper(curve), output=format_output(output))
return (d)
}
# Read Data
dat <- list()
dat$metrics <- read_metrics(opts$metrics_file)
dat$curves <- read_curves(opts$curves_file)
```


## Genome-wide performances

```{r results='asis'}
# Performance table
d <- dat$metrics %>% filter(anno == opts$anno_global) %>%
select(-anno) %>%
spread(metric, value) %>%
arrange(desc(AUC))
xtable(d, digits=4)
```

```{r fig.width=10, fig.height=10}
# Bar plots
d <- dat$metrics %>% filter(anno == opts$anno_global, metric != 'N')
tmp <- d %>% filter(metric == 'AUC') %>% arrange(desc(value)) %>%
select(output) %>% unlist %>% as.vector
d <- d %>% mutate(output=factor(output, levels=tmp))
# d <- d %>% filter(metric == 'auc')
p <- ggplot(d, aes(x=output, y=value)) +
geom_bar(aes(fill=metric), stat='identity') +
scale_fill_brewer(palette='Set1') +
facet_wrap(~metric, scale='free', ncol=2) +
my_theme() +
theme(axis.text.x=element_text(size=rel(0.9), angle=30, hjust=1),
axis.title.x=element_blank(),
axis.title.y=element_blank(),
legend.position='top')
print(p)
```

```{r fig.width=10, fig.height=6}
# Performance curves
d <- dat$curves %>%
filter(anno == opts$anno_global, (curve == 'ROC') | (x > 0.5))
p <- ggplot(d, aes(x=x, y=y, color=output)) +
geom_line() +
my_theme() +
theme(legend.position='top') +
facet_wrap(~curve, ncol=2, scale='free') +
xlab('') + ylab('')
print(p)
```

## Context-specific performances

```{r results='asis'}
# Performance table
d <- dat$metrics %>%
group_by(anno, metric) %>%
summarise(value=mean(value)) %>%
ungroup %>%
spread(metric, value) %>%
arrange(desc(AUC))
xtable(d, digits=4)
```

```{r fig.width=10, fig.height=25}
# Boxplots annotations
plot_annos <- function(d, metrics=NULL, points=T) {
annos <- d %>% filter(metric == 'AUC') %>% group_by(anno) %>%
summarise(value=mean(value)) %>%
arrange(desc(value)) %>% select(anno) %>% unlist
d <- d %>% mutate(anno=factor(anno, levels=annos))
if (!is.null(metrics)) {
d <- d %>% filter(metric %in% metrics) %>% droplevels
}
p <- ggplot(d, aes(x=anno, y=value)) +
geom_boxplot(aes(fill=metric), outlier.shape=NA) +
scale_fill_brewer(palette='Set1') +
my_theme() +
theme(
panel.grid.major=element_line(colour="grey60", size=0.1, linetype='solid'),
panel.grid.minor=element_line(colour="grey60", size=0.1, linetype='dotted'),
axis.text.x=element_text(angle=30, hjust=1),
axis.title.x=element_blank(),
legend.position='top') +
facet_wrap(~metric, ncol=1, scale='free')
if (points) {
p <- p + geom_point(size=0.3,
position=position_jitter(width=0.1, height=0))
}
return (p)
}
plot_annos(dat$metrics %>% filter(metric != 'N'))
```
2 changes: 1 addition & 1 deletion deepcpg/data/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def join_overlapping(s, e):


def join_overlapping_frame(d):
d = d.sort(['chromo', 'start', 'end'])
d = d.sort_values(['chromo', 'start', 'end'])
e = []
for chromo in d.chromo.unique():
dc = d.loc[d.chromo == chromo]
Expand Down
48 changes: 48 additions & 0 deletions deepcpg/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,54 @@ def evaluate_outputs(outputs, preds):
return perf


def is_binary_output(output_name):
_output_name = output_name.split(OUTPUT_SEP)
if _output_name[0] == 'cpg':
return True
elif _output_name[-1] in ['diff', 'mode', 'cat2_var']:
return True
else:
return False


def evaluate_curve(outputs, preds, fun=skm.roc_curve, mask=CPG_NAN,
nb_point=None):
curves = []
for output_name in outputs.keys():
if not is_binary_output(output_name):
continue

output = outputs[output_name].round().squeeze()
pred = preds[output_name].squeeze()
idx = output != CPG_NAN
output = output[idx]
pred = pred[idx]

x, y, thr = fun(output, pred)
length = min(len(x), len(y), len(thr))
if nb_point and length > nb_point:
idx = np.linspace(0, length - 1, nb_point).astype(np.int32)
else:
idx = slice(0, length)
x = x[idx]
y = y[idx]
thr = thr[idx]

curve = OrderedDict()
curve['output'] = output_name
curve['x'] = x
curve['y'] = y
curve['thr'] = thr
curve = pd.DataFrame(curve)
curves.append(curve)

if not curves:
return None
else:
curves = pd.concat(curves)
return curves


def unstack_report(report):
index = list(report.columns[~report.columns.isin(['metric', 'value'])])
report = pd.pivot_table(report, index=index, columns='metric',
Expand Down
31 changes: 26 additions & 5 deletions deepcpg/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@ def get_sample_weights(y, class_weights=None):
Parameters
----------
y: :cla:`numpy.ndarray`
y: :class:`numpy.ndarray`
1d numpy array of output labels.
class_weights: dict
Weight of output classes, e.g. methylation states.
Returns
-------
:cla:`numpy.ndarray`
:class:`numpy.ndarray`
Sample weights of size `y`.
"""
y = y[:]
Expand Down Expand Up @@ -486,7 +486,7 @@ class DataReader(object):
"""Read data from `dcpg_data.py` output files.
Generator to read data batches from `dcpg_data.py` output files. Reads data
using :fun:`hdf.reader` and pre-processes data.
using :func:`hdf.reader` and pre-processes data.
Parameters
----------
Expand Down Expand Up @@ -568,9 +568,9 @@ def __call__(self, data_files, class_weights=None, *args, **kwargs):
class_weights: dict
dict of dict with class weights of individual outputs.
*args: list
Unnamed arguments passed to :fun:`hdf.reader`
Unnamed arguments passed to :func:`hdf.reader`
*kwargs: dict
Named arguments passed to :fun:`hdf.reader`
Named arguments passed to :func:`hdf.reader`
Returns
-------
Expand Down Expand Up @@ -631,6 +631,24 @@ def __call__(self, data_files, class_weights=None, *args, **kwargs):


def data_reader_from_model(model, outputs=True, replicate_names=None):
"""Return :class:`DataReader` from `model`.
Builds a :class:`DataReader` for reading data for `model`.
Parameters
----------
model: :class:`Model`.
:class:`Model`.
outputs: bool
If `True`, return output labels.
replicate_names: list
Name of input cells of `model`.
Returns
-------
:class:`DataReader`
Instance of :class:`DataReader`.
"""
use_dna = False
dna_wlen = None
cpg_wlen = None
Expand All @@ -640,6 +658,7 @@ def data_reader_from_model(model, outputs=True, replicate_names=None):
input_shapes = to_list(model.input_shape)
for input_name, input_shape in zip(model.input_names, input_shapes):
if input_name == 'dna':
# Read DNA sequences.
use_dna = True
dna_wlen = input_shape[1]
elif input_name.startswith('cpg/state/'):
Expand All @@ -650,6 +669,7 @@ def data_reader_from_model(model, outputs=True, replicate_names=None):
cpg_wlen = input_shape[2]
encode_replicates = True
elif input_name == 'cpg/state':
# Read neighboring CpG sites.
if not replicate_names:
raise ValueError('Replicate names required!')
if len(replicate_names) != input_shape[1]:
Expand All @@ -661,6 +681,7 @@ def data_reader_from_model(model, outputs=True, replicate_names=None):
cpg_wlen = input_shape[2]

if outputs:
# Return output labels.
output_names = model.output_names

return DataReader(output_names=output_names,
Expand Down

0 comments on commit 3ad30ec

Please sign in to comment.