-
Notifications
You must be signed in to change notification settings - Fork 75
/
data.py
214 lines (166 loc) · 6.49 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import polars as pl
import numpy as np
from random import randrange, random
from pathlib import Path
from pyfaidx import Fasta
# helper functions
def exists(val):
return val is not None
def identity(t):
return t
def cast_list(t):
return t if isinstance(t, list) else [t]
def coin_flip():
return random() > 0.5
# genomic function transforms
seq_indices_embed = torch.zeros(256).long()
seq_indices_embed[ord('a')] = 0
seq_indices_embed[ord('c')] = 1
seq_indices_embed[ord('g')] = 2
seq_indices_embed[ord('t')] = 3
seq_indices_embed[ord('n')] = 4
seq_indices_embed[ord('A')] = 0
seq_indices_embed[ord('C')] = 1
seq_indices_embed[ord('G')] = 2
seq_indices_embed[ord('T')] = 3
seq_indices_embed[ord('N')] = 4
seq_indices_embed[ord('.')] = -1
one_hot_embed = torch.zeros(256, 4)
one_hot_embed[ord('a')] = torch.Tensor([1., 0., 0., 0.])
one_hot_embed[ord('c')] = torch.Tensor([0., 1., 0., 0.])
one_hot_embed[ord('g')] = torch.Tensor([0., 0., 1., 0.])
one_hot_embed[ord('t')] = torch.Tensor([0., 0., 0., 1.])
one_hot_embed[ord('n')] = torch.Tensor([0., 0., 0., 0.])
one_hot_embed[ord('A')] = torch.Tensor([1., 0., 0., 0.])
one_hot_embed[ord('C')] = torch.Tensor([0., 1., 0., 0.])
one_hot_embed[ord('G')] = torch.Tensor([0., 0., 1., 0.])
one_hot_embed[ord('T')] = torch.Tensor([0., 0., 0., 1.])
one_hot_embed[ord('N')] = torch.Tensor([0., 0., 0., 0.])
one_hot_embed[ord('.')] = torch.Tensor([0.25, 0.25, 0.25, 0.25])
reverse_complement_map = torch.Tensor([3, 2, 1, 0, 4]).long()
def torch_fromstring(seq_strs):
batched = not isinstance(seq_strs, str)
seq_strs = cast_list(seq_strs)
np_seq_chrs = list(map(lambda t: np.fromstring(t, dtype = np.uint8), seq_strs))
seq_chrs = list(map(torch.from_numpy, np_seq_chrs))
return torch.stack(seq_chrs) if batched else seq_chrs[0]
def str_to_seq_indices(seq_strs):
seq_chrs = torch_fromstring(seq_strs)
return seq_indices_embed[seq_chrs.long()]
def str_to_one_hot(seq_strs):
seq_chrs = torch_fromstring(seq_strs)
return one_hot_embed[seq_chrs.long()]
def seq_indices_to_one_hot(t, padding = -1):
is_padding = t == padding
t = t.clamp(min = 0)
one_hot = F.one_hot(t, num_classes = 5)
out = one_hot[..., :4].float()
out = out.masked_fill(is_padding[..., None], 0.25)
return out
# augmentations
def seq_indices_reverse_complement(seq_indices):
complement = reverse_complement_map[seq_indices.long()]
return torch.flip(complement, dims = (-1,))
def one_hot_reverse_complement(one_hot):
*_, n, d = one_hot.shape
assert d == 4, 'must be one hot encoding with last dimension equal to 4'
return torch.flip(one_hot, (-1, -2))
# processing bed files
class FastaInterval():
def __init__(
self,
*,
fasta_file,
context_length = None,
return_seq_indices = False,
shift_augs = None,
rc_aug = False
):
fasta_file = Path(fasta_file)
assert fasta_file.exists(), 'path to fasta file must exist'
self.seqs = Fasta(str(fasta_file))
self.return_seq_indices = return_seq_indices
self.context_length = context_length
self.shift_augs = shift_augs
self.rc_aug = rc_aug
def __call__(self, chr_name, start, end, return_augs = False):
interval_length = end - start
chromosome = self.seqs[chr_name]
chromosome_length = len(chromosome)
if exists(self.shift_augs):
min_shift, max_shift = self.shift_augs
max_shift += 1
min_shift = max(start + min_shift, 0) - start
max_shift = min(end + max_shift, chromosome_length) - end
rand_shift = randrange(min_shift, max_shift)
start += rand_shift
end += rand_shift
left_padding = right_padding = 0
if exists(self.context_length) and interval_length < self.context_length:
extra_seq = self.context_length - interval_length
extra_left_seq = extra_seq // 2
extra_right_seq = extra_seq - extra_left_seq
start -= extra_left_seq
end += extra_right_seq
if start < 0:
left_padding = -start
start = 0
if end > chromosome_length:
right_padding = end - chromosome_length
end = chromosome_length
seq = ('.' * left_padding) + str(chromosome[start:end]) + ('.' * right_padding)
if self.return_seq_indices:
if self.rc_aug and coin_flip():
seq = seq_indices_reverse_complement(seq)
return str_to_seq_indices(seq)
one_hot = str_to_one_hot(seq)
rc_aug = self.rc_aug and coin_flip()
if rc_aug:
one_hot = one_hot_reverse_complement(one_hot)
if not return_augs:
return one_hot
# returns the shift integer as well as the bool (for whether reverse complement was activated)
# for this particular genomic sequence
rand_shift_tensor = torch.tensor([rand_shift])
rand_aug_bool_tensor = torch.tensor([rc_aug])
return one_hot, rand_shift_tensor, rand_aug_bool_tensor
class GenomeIntervalDataset(Dataset):
def __init__(
self,
bed_file,
fasta_file,
filter_df_fn = identity,
chr_bed_to_fasta_map = dict(),
context_length = None,
return_seq_indices = False,
shift_augs = None,
rc_aug = False,
return_augs = False
):
super().__init__()
bed_path = Path(bed_file)
assert bed_path.exists(), 'path to .bed file must exist'
df = pl.read_csv(str(bed_path), sep = '\t', has_headers = False)
df = filter_df_fn(df)
self.df = df
# if the chromosome name in the bed file is different than the keyname in the fasta
# can remap on the fly
self.chr_bed_to_fasta_map = chr_bed_to_fasta_map
self.fasta = FastaInterval(
fasta_file = fasta_file,
context_length = context_length,
return_seq_indices = return_seq_indices,
shift_augs = shift_augs,
rc_aug = rc_aug
)
self.return_augs = return_augs
def __len__(self):
return len(self.df)
def __getitem__(self, ind):
interval = self.df.row(ind)
chr_name, start, end = (interval[0], interval[1], interval[2])
chr_name = self.chr_bed_to_fasta_map.get(chr_name, chr_name)
return self.fasta(chr_name, start, end, return_augs = self.return_augs)