In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import random
import numpy as np

In [8]:
#Tanish don't jump me. It is sample data. We need to replace with actual datasets
data = {
    'original': ['ATGGCC', 'ATGCGT', 'ATGAAA'],
    'optimized': ['ATGGCAGC', 'ATGCAG', 'ATGAAAGC']
}
df = pd.DataFrame(data)

In [9]:
#Tokenization? idk ts is lowkai hard
def codonize(seq):
    return [seq[i:i+3] for i in range(0, len(seq), 3) if len(seq[i:i+3]) == 3]
#define the codons???
all_codons = set()
for _, row in df.iterrows():
    all_codons.update(codonize(row['original']))
    all_codons.update(codonize(row['optimized']))

codon2idx = {codon: idx+4 for idx, codon in enumerate(sorted(all_codons))}
codon2idx['<pad>'] = 0
codon2idx['<sos>'] = 1
codon2idx['<eos>'] = 2
codon2idx['<unk>'] = 3
idx2codon = {idx: codon for codon, idx in codon2idx.items()}
vocab_size = len(codon2idx)

#encoder, decoder stuff
def encode(seq):
    return [codon2idx.get(c, codon2idx['<unk>']) for c in codonize(seq)]

def decode(indices):
    return ''.join([idx2codon[i] for i in indices if i not in (0, 1, 2)])


In [10]:
class GeneDataset(Dataset):
    def __init__(self, df):
        self.pairs = []
        for _, row in df.iterrows():
            src = [codon2idx['<sos>']] + encode(row['original']) + [codon2idx['<eos>']]
            trg = [codon2idx['<sos>']] + encode(row['optimized'])
            self.pairs.append((src, trg))

    def __len__(self):
        return len(self.pairs)
    def __getitem__(self, idx):
        return self.pairs[idx]
    
def collate_fn(batch):
    src_batch, trg_batch = zip(*batch)
    max_src = max(len(s) for s in src_batch)
    max_trg = max(len(t) for t in trg_batch)
    src_padded = [s + [codon2idx['<pad>']] * (max_src - len(s)) for s in src_batch]
    trg_padded = [t + [codon2idx['<pad>']] * (max_trg - len(t)) for t in trg_batch]
    return torch.tensor(src_padded), torch.tensor(trg_padded)

dataset = GeneDataset(df)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)
