# Institute for Behavioral Genetics International Statistical Genetics 2021 Workshop 

## Inferring Population Labels

Learning objectives:

1. Plot principal components obtained by running PCA on our dataset.
2. Use principal components to identify population clusters.
3. Reidentify populations for samples with missing populations using population clusters

In [None]:
import os
os.environ['PYSPARK_SUBMIT_ARGS'] = '--driver-memory 6G pyspark-shell'

import hail as hl
hl.plot.output_notebook()
hl.init()

## Read in QC'ed data and PCA scores

First, we'll need to read back in the sample annotations and the PCA scores from the previous practical.

In [None]:
pca_scores = hl.read_table('resources/pca_scores.ht')

In [None]:
sa = hl.import_table('resources/HGDP_sample_data.tsv', 
                     impute=True, 
                     key='sample_id')

Let's randomly throw away some of our population information.

In [None]:
sa = sa.annotate(
    continental_pop = hl.if_else(
        hl.rand_bool(0.9),
        sa.continental_pop,
        hl.missing(hl.tstr)
    )
)
sa.write('output/censored_hgdp_data.ht', overwrite=True)
sa = hl.read_table('output/censored_hgdp_data.ht')

Now, we'll take the first 4 PCs from the PCA table, and add the population information for each sample from our dataset.

In [None]:
ht = pca_scores.select(PC1=pca_scores.scores[0],
                       PC2=pca_scores.scores[1],
                       PC3=pca_scores.scores[2],
                       PC4=pca_scores.scores[3])
ht = ht.annotate(**sa[ht.s])

In [None]:
ht.aggregate(hl.agg.collect_as_set(ht.continental_pop))

The populations present in this dataset are `afr` `amr`, `eas`, `fin`, `mid`, `nfe`, `oth`, `sas`. They are three-letter codes from the 1000 Genomes project denoting the [super population of each sample](https://www.internationalgenome.org/category/population/).

## Visualize!

Let's plot all combinations of the first three principal components (PCs) against one another. Perhaps we can identify clear cluster boundaries.

In [None]:
import bokeh

p12 = hl.plot.scatter(ht.PC1, ht.PC2, xlabel='PC1', ylabel='PC2', label=ht.continental_pop, size=3, width=400, height=400)
p13 = hl.plot.scatter(ht.PC1, ht.PC3, xlabel='PC1', ylabel='PC3', label=ht.continental_pop, size=3, width=400, height=400)

p23 = hl.plot.scatter(ht.PC2, ht.PC3, xlabel='PC2', ylabel='PC3', label=ht.continental_pop, size=3, width=400, height=400)

hl.plot.show(bokeh.layouts.gridplot([[p12],
                                     [p13, p23]]))

## Reidentify samples with missing ancestry based on PCA scores

Now that we can see how the populations are decomposed by the PCs, let's try to reidentify the masked samples.

First, we'll define a grading scheme to check against the true populations of each masked sample. (The `check` function will see how many masked samples you have correctly identified.)

In [None]:
true_labels = hl.import_table('resources/HGDP_sample_data.tsv', key='sample_id').cache()

def check(ht):
    ht = ht.annotate(true_pop = true_labels[ht.s].continental_pop)
    c = ht.aggregate(hl.agg.filter(hl.is_missing(ht.continental_pop), 
                                   hl.agg.counter((ht.unmasked, ht.true_pop))))
    n_correct = sum(count for k, count in c.items() if k[0] == k[1])
    n_wrong = sum(count for k, count in c.items() if k[0] != k[1])
    print(f'Correctly identified {n_correct} / {n_correct + n_wrong} masked samples.')
    print()
    
    for (unm, true), n in c.items():
        if unm != true:
            if unm is not None:
                print(f'Incorrectly assigned {n} {true} samples as {unm}.')
            else:
                print(f'Left {n} {true} samples unassigned.')

## Fill in the below

Your job is to expand the below code to reidentify the population labels. One of the populations has already been provided as an example.

### `case().when()` in Hail

The `case` / `when` / `default` motif you see below is a nice way to write `if` / `else if` / `else`. The returned `unmasked` will be equal to the result of the first `when` whose predicate is `True`.

### A note on `&` and `|`

Python uses `and` and `or` for logical operators. Hail expressions use `&` for 'and' and `|` for or.

This can lead to some confusion, especially since `&` and `|` often don't play nicely with expressions involving `>`, `<`, `==`, or `!=`. If both of these operators appear, you will need to wrap the comparison in parentheses.

Suppose we want to write code that returns true when "PC1 is greater than 0.1 or PC2 is less than 0.2":

**correct**:

```
(ht.PC1 > 0.1) | (ht.PC2 < 0.2)
```

**incorrect**:
```
ht.PC1 > 0.1 or ht.PC2 < 0.2
ht.PC1 > 0.1 | ht.PC2 < 0.2
(ht.PC1 > 0.1) or (ht.PC2 < 0.2)
```

### `hl.all` and `hl.any`

You might also find it easier to use `hl.all` (which is "and") and `hl.any` (which is "or"). For example, this

```
(ht.PC1 > 0.1) | (ht.PC2 < 0.2) | (ht.PC3 >= 0.1)
```

could also be written as

```
hl.any(ht.PC1 > 0.1,
       ht.PC2 < 0.2,
       ht.PC3 >= 0.1)
```

### To think about

Which population is hardest to reidentify? Why?

### Extras

Try plotting the PCs again with the re-identified population labels.

Try plotting the PCs again, highlighting the ones that you missed. (The true population labels are in the table `true_labels`.)

In [None]:
unmasked = ht.annotate(
    unmasked = hl.case()
        .when((ht.PC2 > 0.2) & (ht.PC2 > 0.2), 'eas')
#         .when(..., 'AFR')
#         .when(..., 'AMR')
#         .when(..., 'EUR')
#         .when(..., 'SAS')
        .default(ht.continental_pop)
)
check(unmasked)