In [1]:
import torch
import numpy as np
import pickle
from Bio import SeqIO
from transformers import AutoTokenizer
from sklearn.model_selection import train_test_split
from collections import defaultdict
from tqdm.notebook import tqdm

In [2]:
fn = '../DistributionEmbeddings/data/spikeprot0430/spikeprot0430.fasta'
ratios = np.logspace(-3, 6, 10)
noise_levels = 1 / (ratios+1)
print("Noise levels:", noise_levels)
print("("+", ".join([f"{nl:.6f}" for nl in noise_levels])+")")

aas = list("ACDEFGHIKLMNPQRSTVWY")

def parse_data(path, lines_to_read=10**8, max_per_month=1000):
    """Parse data with balanced sampling across months."""
    seqs_by_month = defaultdict(list)
    iterator = SeqIO.parse(path, "fasta")
    
    for _ in tqdm(range(lines_to_read), desc='collecting sequences'):
        try:
            r = next(iterator)
            fields = r.description.split("|")
            date = (fields + ["?"] * 11)[2]
            
            # filter valid dates
            if len(date) == 10 and date[4] == '-' and date[5:7] != '00' and date[-2:] != '00':
                yyyy_mm = date[:7]
                months_since_2020 = (int(yyyy_mm[:4]) - 2020) * 12 + int(yyyy_mm[5:7]) - 1
                
                # Only add if we haven't hit the cap for this month
                if len(seqs_by_month[months_since_2020]) < max_per_month:
                    seqs_by_month[months_since_2020].append(str(r.seq))
        except StopIteration:
            break
        except:
            continue
    
    # Flatten to lists
    seqs, months = [], []
    for month_val, month_seqs in seqs_by_month.items():
        seqs.extend(month_seqs)
        months.extend([month_val] * len(month_seqs))
    
    print(f"Collected {len(seqs)} sequences across {len(seqs_by_month)} unique months")
    print(f"Months: {sorted(seqs_by_month.keys())}")
    
    return seqs, np.array(months).reshape(-1, 1)

def mutate(seq_list, rate):
    """Randomize amino acids based on rate."""
    noised = []
    for s in tqdm(seq_list, desc=f'mutating (rate={rate:.6f})', leave=False):
        s_arr = np.array(list(s))
        mask = np.random.rand(len(s_arr)) < rate
        s_arr[mask] = np.random.choice(aas, size=mask.sum())
        noised.append("".join(s_arr))
    return noised

# Parse and split data
print("Parsing data...")
all_seqs, all_months = parse_data(fn, lines_to_read=10**8, max_per_month=1000)
tr_seqs, te_seqs, _, te_months = train_test_split(all_seqs, all_months, test_size=0.25, random_state=42)

print(f"Train: {len(tr_seqs)}, Test: {len(te_seqs)}")
print(f"Unique months in test set: {np.unique(te_months)}")

# Initialize tokenizer
tok = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")

# Prepare data for each noise level
print("\nTokenizing and saving data for each noise level...")
for nl in tqdm(noise_levels, desc="Noise levels"):
    print(f"\nProcessing noise level: {nl:.6f}")
    
    # Apply noise
    tr_n = mutate(tr_seqs, nl)
    te_n = mutate(te_seqs, nl)
    
    # Tokenize
    print("  Tokenizing train data...")
    train_enc = tok(tr_n, return_tensors="pt", padding=True, truncation=True, max_length=1024)
    print("  Tokenizing test data...")
    test_enc = tok(te_n, return_tensors="pt", padding=True, truncation=True, max_length=1024)
    
    # Save
    data_dict = {
        'train_input_ids': train_enc['input_ids'],
        'train_attention_mask': train_enc['attention_mask'],
        'test_input_ids': test_enc['input_ids'],
        'test_attention_mask': test_enc['attention_mask'],
        'test_seqs': te_n,
        'test_months': te_months,
        'noise_level': nl
    }
    
    save_path = f"seq/data_noise_{nl:.6f}.pt"
    torch.save(data_dict, save_path)
    print(f"  Saved to {save_path}")

print("\data preprocessing complete!")

Noise levels: [9.99000999e-01 9.90099010e-01 9.09090909e-01 5.00000000e-01
 9.09090909e-02 9.90099010e-03 9.99000999e-04 9.99900010e-05
 9.99990000e-06 9.99999000e-07]
(0.999001, 0.990099, 0.909091, 0.500000, 0.090909, 0.009901, 0.000999, 0.000100, 0.000010, 0.000001)
Parsing data...


collecting sequences:   0%|          | 0/100000000 [00:00<?, ?it/s]

Collected 63374 sequences across 71 unique months
Months: [-109, -78, -29, -10, -7, -3, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]
Train: 47530, Test: 15844
Unique months in test set: [-109  -78   -1    0    1    2    3    4    5    6    7    8    9   10
   11   12   13   14   15   16   17   18   19   20   21   22   23   24
   25   26   27   28   29   30   31   32   33   34   35   36   37   38
   39   40   41   42   43   44   45   46   47   48   49   50   51   52
   53   54   55   56   57   58   59   60   61   62   63]

Tokenizing and saving data for each noise level...


Noise levels:   0%|          | 0/10 [00:00<?, ?it/s]


Processing noise level: 0.999001


mutating (rate=0.999001):   0%|          | 0/47530 [00:00<?, ?it/s]

mutating (rate=0.999001):   0%|          | 0/15844 [00:00<?, ?it/s]

  Tokenizing train data...
  Tokenizing test data...
  Saved to seq/data_noise_0.999001.pt

Processing noise level: 0.990099


mutating (rate=0.990099):   0%|          | 0/47530 [00:00<?, ?it/s]

mutating (rate=0.990099):   0%|          | 0/15844 [00:00<?, ?it/s]

  Tokenizing train data...
  Tokenizing test data...
  Saved to seq/data_noise_0.990099.pt

Processing noise level: 0.909091


mutating (rate=0.909091):   0%|          | 0/47530 [00:00<?, ?it/s]

mutating (rate=0.909091):   0%|          | 0/15844 [00:00<?, ?it/s]

  Tokenizing train data...
  Tokenizing test data...
  Saved to seq/data_noise_0.909091.pt

Processing noise level: 0.500000


mutating (rate=0.500000):   0%|          | 0/47530 [00:00<?, ?it/s]

mutating (rate=0.500000):   0%|          | 0/15844 [00:00<?, ?it/s]

  Tokenizing train data...
  Tokenizing test data...
  Saved to seq/data_noise_0.500000.pt

Processing noise level: 0.090909


mutating (rate=0.090909):   0%|          | 0/47530 [00:00<?, ?it/s]

mutating (rate=0.090909):   0%|          | 0/15844 [00:00<?, ?it/s]

  Tokenizing train data...
  Tokenizing test data...
  Saved to seq/data_noise_0.090909.pt

Processing noise level: 0.009901


mutating (rate=0.009901):   0%|          | 0/47530 [00:00<?, ?it/s]

mutating (rate=0.009901):   0%|          | 0/15844 [00:00<?, ?it/s]

  Tokenizing train data...
  Tokenizing test data...
  Saved to seq/data_noise_0.009901.pt

Processing noise level: 0.000999


mutating (rate=0.000999):   0%|          | 0/47530 [00:00<?, ?it/s]

mutating (rate=0.000999):   0%|          | 0/15844 [00:00<?, ?it/s]

  Tokenizing train data...
  Tokenizing test data...
  Saved to seq/data_noise_0.000999.pt

Processing noise level: 0.000100


mutating (rate=0.000100):   0%|          | 0/47530 [00:00<?, ?it/s]

mutating (rate=0.000100):   0%|          | 0/15844 [00:00<?, ?it/s]

  Tokenizing train data...
  Tokenizing test data...
  Saved to seq/data_noise_0.000100.pt

Processing noise level: 0.000010


mutating (rate=0.000010):   0%|          | 0/47530 [00:00<?, ?it/s]

mutating (rate=0.000010):   0%|          | 0/15844 [00:00<?, ?it/s]

  Tokenizing train data...
  Tokenizing test data...
  Saved to seq/data_noise_0.000010.pt

Processing noise level: 0.000001


mutating (rate=0.000001):   0%|          | 0/47530 [00:00<?, ?it/s]

mutating (rate=0.000001):   0%|          | 0/15844 [00:00<?, ?it/s]

  Tokenizing train data...
  Tokenizing test data...
  Saved to seq/data_noise_0.000001.pt
\data preprocessing complete!
