In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import random
import numpy as np

In [2]:
#Tanish don't jump me. It is sample data. We need to replace with actual datasets
df = pd.read_csv('optimizedextra.csv')

In [3]:
#Tokenization? idk ts is lowkai hard
def codonize(seq):
    return [seq[i:i+3] for i in range(0, len(seq), 3) if len(seq[i:i+3]) == 3]
#define the codons???
all_codons = set()
for _, row in df.iterrows():
    all_codons.update(codonize(row['original']))
    all_codons.update(codonize(row['optimized']))

codon2idx = {codon: idx+4 for idx, codon in enumerate(sorted(all_codons))}
codon2idx['<pad>'] = 0
codon2idx['<sos>'] = 1
codon2idx['<eos>'] = 2
codon2idx['<unk>'] = 3
idx2codon = {idx: codon for codon, idx in codon2idx.items()}
vocab_size = len(codon2idx)

#encoder, decoder stuff
def encode(seq):
    return [codon2idx.get(c, codon2idx['<unk>']) for c in codonize(seq)]

def decode(indices):
    return ''.join([idx2codon[i] for i in indices if i not in (0, 1, 2)])


In [4]:
class GeneDataset(Dataset):
    def __init__(self, df):
        self.pairs = []
        for _, row in df.iterrows():
            src = [codon2idx['<sos>']] + encode(row['original']) + [codon2idx['<eos>']]
            trg = [codon2idx['<sos>']] + encode(row['optimized'])
            self.pairs.append((src, trg))

    def __len__(self):
        return len(self.pairs)
    def __getitem__(self, idx):
        return self.pairs[idx]
    
def collate_fn(batch):
    src_batch, trg_batch = zip(*batch)
    max_src = max(len(s) for s in src_batch)
    max_trg = max(len(t) for t in trg_batch)
    src_padded = [s + [codon2idx['<pad>']] * (max_src - len(s)) for s in src_batch]
    trg_padded = [t + [codon2idx['<pad>']] * (max_trg - len(t)) for t in trg_batch]
    return torch.tensor(src_padded), torch.tensor(trg_padded)

dataset = GeneDataset(df)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)


In [None]:
#gooning to seq2seq encoder and decoder
class Encoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=codon2idx['<pad>'])
        self.rnn = nn.LSTM(emb_dim, hid_dim, batch_first=True)

    def forward(self, src):
        embedded = self.embedding(src)
        outputs, (hidden, cell) = self.rnn(embedded)
        return hidden, cell

class Decoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=codon2idx['<pad>'])
        self.rnn = nn.LSTM(emb_dim, hid_dim, batch_first=True)
        self.fc_out = nn.Linear(hid_dim, vocab_size)

    def forward(self, input, hidden, cell):
        input = input.unsqueeze(1)
        embedded = self.embedding(input)
        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        return self.fc_out(output.squeeze(1)), hidden, cell

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, src, trg, teacher_forcing_ratio=0.3):
        batch_size = src.size(0)
        trg_len = trg.size(1)
        vocab_size = self.decoder.fc_out.out_features
        outputs = torch.zeros(batch_size, trg_len, vocab_size)
        hidden, cell = self.encoder(src)
        input = trg[:, 0]
        for t in range(1, trg_len):
            output, hidden, cell = self.decoder(input, hidden, cell)
            outputs[:, t] = output
            top1 = output.argmax(1)
            input = trg[:, t] if random.random() < teacher_forcing_ratio else top1
        return outputs


In [6]:
#training za gooner
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
enc = Encoder(vocab_size, 32, 64).to(device)
dec = Decoder(vocab_size, 32, 64).to(device)
model = Seq2Seq(enc, dec).to(device)
optimzer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=codon2idx['<pad>'])

for epoch in range(30):
    model.train()
    total_loss = 0
    for src, trg in dataloader:
        src,trg = src.to(device), trg.to(device)
        optimzer.zero_grad()
        output = model(src, trg)
        output = output[:, 1:].reshape(-1, vocab_size)
        trg = trg[:, 1:].reshape(-1)
        loss = criterion(output, trg)
        loss.backward()
        optimzer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")



Epoch 1, Loss: 4.2137
Epoch 2, Loss: 4.1788
Epoch 3, Loss: 4.1305
Epoch 4, Loss: 4.0532
Epoch 5, Loss: 3.9424
Epoch 6, Loss: 3.7971
Epoch 7, Loss: 3.6879
Epoch 8, Loss: 3.6564
Epoch 9, Loss: 3.6138
Epoch 10, Loss: 3.5936
Epoch 11, Loss: 3.6054
Epoch 12, Loss: 3.5798
Epoch 13, Loss: 3.5908
Epoch 14, Loss: 3.5861
Epoch 15, Loss: 3.5834
Epoch 16, Loss: 3.5780
Epoch 17, Loss: 3.5789
Epoch 18, Loss: 3.5678
Epoch 19, Loss: 3.5789
Epoch 20, Loss: 3.5750
Epoch 21, Loss: 3.5585
Epoch 22, Loss: 3.5742
Epoch 23, Loss: 3.5757
Epoch 24, Loss: 3.5739
Epoch 25, Loss: 3.5641
Epoch 26, Loss: 3.5731
Epoch 27, Loss: 3.5602
Epoch 28, Loss: 3.5740
Epoch 29, Loss: 3.5688
Epoch 30, Loss: 3.5574


In [None]:
#Gooner prediction
def predict(model, seq):
    model.eval()
    with torch.no_grad():
        src = [codon2idx['<sos>']] + encode(seq) + [codon2idx['<eos>']]
        src_tensor = torch.tensor([src]).to(device)
        hidden, cell = model.encoder(src_tensor)
        input = torch.tensor([codon2idx['<sos>']]).to(device)
        output_seq = []
        for _ in range(30):
            output, hidden, cell = model.decoder(input, hidden, cell)
            top1 = output.argmax(1).item()
            if top1 == codon2idx['<eos>']:
                break
            output_seq.append(top1)
            input = torch.tensor([top1]).to(device)
        return decode(output_seq)

In [None]:
#input
user_input = input("Original: ")
print("Original: ", user_input)
print("Synthetic: ", predict(model, user_input))

Original:  ATGC
Synthetic:  ATGCTGCTGCTGCTGCTGCTGCTGCTGCTGCTGCTGCTGCTGCTGCTGCTGCTGCTGCTGCTGCTGCTGCTGCTGCTGCTGCTGCTGCTG


In [11]:
# Step 1: Read the FASTA file
with open("trimmed_file.fasta", "r") as infile:
    lines = [line.strip() for line in infile if line.strip()]

# Step 2: Remove first 1000 sequences (i.e. 2000 lines)
trimmed_lines = lines[:-2883]

# Step 3: Write the remaining sequences to a new FASTA file
with open("trimmed_file.fasta", "w") as outfile:
    for line in trimmed_lines:
        outfile.write(line + "\n")
