In [1]:
import pickle
import torch
import numpy as np
import random

In [2]:
#path_seq = "/export/share/krausef99dm/data/data_train/train_9.0k_data.pkl"
#path_seq = "/export/share/krausef99dm/data/data_test/val_9.0k_data.pkl"
path_seq = "/export/share/krausef99dm/data/data_test/test_9.0k_data.pkl"

In [3]:
# load data
with open(path_seq, 'rb') as f:
    rna_data, tissue_ids, targets, targets_bin =  pickle.load(f)

In [4]:
def compute_utr_5_lengths(rna_data):
    seq_meta = [t[:, 1] for t in rna_data]
    
    utr_5_lengths = []
    for t in seq_meta:
        unique_vals, counts = torch.unique(t, return_counts=True)
        result = dict(zip(unique_vals.tolist(), counts.tolist()))
        try:
            utr_5_lengths.append(result[5])
        except KeyError:
            utr_5_lengths.append(0)
    return utr_5_lengths

In [5]:
max_seq_len = 9000
utr_5_lengths = compute_utr_5_lengths(rna_data)
max_utr_5_len = max(utr_5_lengths)

# TODO identify global max utr_5_len
# from train: 4255
# from val: 3447
# from test: 4695
max_utr_5_len = 4695

In [6]:
# Check longest sample
#idx = utr_5_lengths.index(max_utr_5_len)
#print("len of sequence with longest 5' utr", len(rna_data[idx]))
#torch.unique(rna_data[idx][:, 1], return_counts=True)

In [7]:
random.seed(2)
rna_data_sample = random.sample(rna_data, k=4)
utr_5_lengths = compute_utr_5_lengths(rna_data_sample)

In [8]:
front_pads = [max_utr_5_len - l for l in utr_5_lengths]

In [9]:
front_pads

[4595, 4528, 4199, 4531]

In [10]:
def aug_align_sequences(rna_data, max_utr_5_len):
    seq_meta = [t[:, 1] for t in rna_data]
    
    utr_5_lengths = []
    count_seq_without_utr = 0
    for t in seq_meta:
        unique_vals, counts = torch.unique(t, return_counts=True)
        result = dict(zip(unique_vals.tolist(), counts.tolist()))
        try:
            utr_5_lengths.append(result[5])
        except KeyError:
            utr_5_lengths.append(0)
            count_seq_without_utr += 1

    print("# seq without utr:", count_seq_without_utr)
    
    front_pads = [max_utr_5_len - l for l in utr_5_lengths]
    
    padded_rna_data = [
        torch.cat([torch.zeros(pad, t.size(1), dtype=t.dtype), t], dim=0)
        for t, pad in zip(rna_data, front_pads)
    ]
    return padded_rna_data

In [11]:
padded_rna_data = aug_align_sequences(rna_data, max_utr_5_len)

# seq without utr: 32


In [12]:
# TODO Investigate further NEW max seq len
# from train: 11535
# from val: 12164
# from test: 13358

# when using max utr5 length of all datasets
# from train: 13637
# from val: 13412
# from test: 13358

max([len(t) for t in padded_rna_data])

13358

In [16]:
max([13637, 13412, 13358])

13637

In [13]:
idx = 0
print(rna_data[idx].shape)
print(front_pads[idx])
print(padded_rna_data[idx].shape)

torch.Size([2193, 4])
4595
torch.Size([6679, 4])


In [14]:
# Verify it worked as intended
aug_positions = np.array([
    (t[:, 1] == 1).nonzero(as_tuple=True)[0][0].item() if (t[:, 1] == 1).any() else -1
    for t in padded_rna_data
])
print(aug_positions[0])
# from train: 4255
# from val: 3447
# from test: 4695
print("All equal?", np.all(aug_positions == aug_positions[0]))

4695
All equal? True


In [15]:
# PROBLEM: When padding to the front, total max seq length will also increase! BY HOW MUCH?