In [None]:
import numpy as np
from pathlib import Path

In [None]:
import gzip

def _iterate_lines(file_handle):
    
    counts = {}
    
    header = True
    for line in file_handle:
        line = line.strip().split(',')

        if header:
            header = False
            peptide_at = line.index('peptide')
            if 'binder' in line:
                binder_at = line.index('binder')
            elif 'hit' in line:
                binder_at = line.index('hit')
            else:
                raise ValueError('could not identify bider/hit column')
            continue

        if line[binder_at] == '1': # i.e. if hit
            length = len(line[peptide_at])
            counts[length] = counts.get(length, 0) + 1
    
    return counts
    

def count_AAs(input_file):
    
    if str(input_file).endswith('.gz') or str(input_file).endswith('.gzip'):
        with gzip.open(input_file, 'rt') as f:
            return _iterate_lines(f)
    else:
        with open(input_file, 'r') as f:
            return _iterate_lines(f)


def combine_counts(*count_dicts):
    
    counts = {}
    
    for counts_i in count_dicts:
        for length, value in counts_i.items():
            counts[length] = counts.get(length, 0) + value
            
    return counts


def to_frequencies(counts):
    
    # add a pseudocount of 1 for each value in the range
    counts = {length: counts.get(length, 0) + 1 
              for length in range(min(counts), max(counts) + 1)}
    
    total_counts = sum(x for x in counts.values())
    
    return {length: x / total_counts for length, x in counts.items()}


def print_formatted(frequencies):
    
    print('length_distribution_MS_data = {')
    for length in sorted(frequencies):
        print(f"     {length}: {frequencies[length]:.8f},")
    print('    }')

## MHC class I

In [None]:
dataset_path = Path('/mnt/bfx/bfx_RD/Instadeep/cloud_backup/biondeep-data/datasets/mhc1/binding/MSDF_20200604/')

counts_train = count_AAs(dataset_path / 'MSDF_20200604_w_fixed_A0211_dedup.train.csv.gz')
print(counts_train)

counts_tune = count_AAs(dataset_path / 'MSDF_20200604_w_fixed_A0211_dedup.tune.csv.gz')
print(counts_tune)

counts_test = count_AAs(dataset_path / 'MSDF_20200604_w_fixed_A0211_dedup.test.csv.gz')
print(counts_test)

counts = combine_counts(counts_tune, counts_train, counts_test)
print(counts)

In [None]:
# relative frequencies in train + tune partition
counts = combine_counts(counts_tune, counts_train)
frequencies = to_frequencies(counts)
print_formatted(frequencies)

## MHC class II

In [None]:
dataset_path = Path('/mnt/bfx/bfx_RD/Instadeep/cloud_backup/biondeep-data/datasets/mhc2/binding/')

counts_train = count_AAs(dataset_path / 'train_drop_1to19.csv')
print(counts_train)

counts_tune = count_AAs(dataset_path / 'tune_1to19.csv')
print(counts_tune)

counts_test = count_AAs(dataset_path / 'test_scored_1to19.csv')
print(counts_test)

counts = combine_counts(counts_tune, counts_train, counts_test)
print(counts)

In [None]:
# relative frequencies in train + tune partition
counts = combine_counts(counts_tune, counts_train)
frequencies = to_frequencies(counts)
print_formatted(frequencies)