## Experiment: single species plasmodium perplexities

__Mail Henrik 14/08/2020__: We realized that _P. vivax_ has a relatively high GC content, compared to _P. falciparum_ and most others.  
We split the test set per species, and evaluate the perplexities of the fine-tuned model.

In [2]:
import sys
sys.path.append('..')
from models.awd_lstm import ProteinAWDLSTMForLM, ProteinAWDLSTMConfig
from tape import TAPETokenizer
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from tape.datasets import pad_sequences

from tqdm import tqdm_notebook as tqdm


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Split the data

In [5]:
test_data = pd.read_csv('plasmodium_test_full.tsv', sep = '\t')
test_data.head()

organisms = test_data['Organism'].value_counts()
organisms = organisms[organisms>10]

### Util functions for dataloading + model evaluation

In [6]:
class FullSeqSeriesDataSet(torch.utils.data.Dataset):
    def __init__(self, pd_series):
        super().__init__
        self.data = pd_series.values
        self.tokenizer = TAPETokenizer()
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        seq = self.data[idx]
        tokenized = self.tokenizer.tokenize(seq) + [self.tokenizer.stop_token]
        token_ids = self.tokenizer.convert_tokens_to_ids(tokenized)
        input_ids = token_ids[:-1]
        target_ids =token_ids[1:]
        assert len(target_ids) == len(input_ids)
        return np.array(input_ids), np.array(target_ids) 
    
    @staticmethod
    def collate_fn(batch):
        data, targets = tuple(zip(*batch))   
        torch_data = torch.from_numpy(pad_sequences(data, 0)) #0 is tokenizer pad token
        torch_targets = torch.from_numpy(pad_sequences(targets, -1)) #pad with -1 to ignore loss

        return torch_data.permute(1,0), torch_targets.permute(1,0)  # type: ignore
    
def run_model_on_data(model, dataset) -> float:
    dl = torch.utils.data.DataLoader(ds, batch_size=100, collate_fn=ds.collate_fn)
    total_loss = 0
    for i, batch in enumerate(dl):
        data, targets = batch
        data = data.to(device).contiguous()
        targets = targets.to(device).contiguous()
        with torch.no_grad():
            loss, _, _ = model(data, targets = targets) #loss, output, hidden states        
        #NOTE this is the mean loss over all dimensions: batch_size, seq_len
        #Here I am using datasets with a low number of sequences, so the last batch will be much less than 100 presumably.
        #To compare, need to make sure I do not ignore this -> no sum(avg_loss)/n_batches mean.
        #For the model comparison this was no issue, as all models were run on the same dataset, so the error would be systematic
        #If it is even an error. No best practice on how to really average perplexity.
        total_loss += loss.item()*len(data)
        
    return total_loss / len(dataset)

### Set up model and evaluate on species-level data

In [7]:
model = ProteinAWDLSTMForLM.from_pretrained('../model_checkpoints/best_models_31072020/best_euk_finetuned_10_epochs/')

In [None]:
results_dict = {}
for org in tqdm(organisms.index):
    print(org)
    data =  test_data.loc[test_data['Organism'] ==org, 'Sequence']
    ds= FullSeqSeriesDataSet(data)
    loss = run_model_on_data(model, ds)
    results_dict[org] = np.exp(loss)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, max=55.0), HTML(value='')))

Plasmodium ovale wallikeri
Plasmodium malariae
Plasmodium gonderi
Plasmodium ovale curtisi
Plasmodium vivax
Plasmodium gallinaceum
Plasmodium relictum
Plasmodium inui San Antonio 1
Plasmodium fragile
Plasmodium coatneyi


In [4]:
testdict = {'Plasmodium ovale wallikeri': 1231, 'Plasmodium ovale': 554.3}
lendict = {'Plasmodium ovale wallikeri': 23, 'Plasmodium ovale': 18}

In [12]:
pd.DataFrame([pd.Series(testdict), pd.Series(lendict)], index = ['perplexity', 'n_sequences']).T

Unnamed: 0,perplexity,n_sequences
Plasmodium ovale wallikeri,1231.0,23.0
Plasmodium ovale,554.3,18.0
