In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from collections import Counter, defaultdict
from rich import print
from sklearn.metrics import precision_score, recall_score, mean_absolute_error


In [2]:

class BPE_Tokenization():
    def __init__(self, data, vocab_size):
        self.data = data
        self.tokenized = [[],[]]
        self.target = torch.tensor(self.data['mRNA_remaining_pct'], dtype=torch.float)
        self.vocab_size = vocab_size
        self.vocab_itos = ['<pad>']
        self.vocab_stoi = {'<pad>': 0}
        
        """Train BPE tokenizer."""
        # compute the base vocabulary of all characters in the corpus
        for seq in self.data['siRNA_antisense_seq']:
            seq = seq.upper()
            temp_tokens = []
            for letter in seq:
                if letter not in self.vocab_stoi.keys():
                    self.vocab_itos.append(letter)
                    self.vocab_stoi[letter] = len(self.vocab_stoi)
                temp_tokens.append(self.vocab_stoi[letter])
            self.tokenized[0].append(temp_tokens)
        
        for seq in self.data['modified_siRNA_antisense_seq_list']:
            temp_tokens = []
            for token in seq.split():
                if token not in self.vocab_stoi.keys():
                    self.vocab_itos.append(token)
                    self.vocab_stoi[token] = len(self.vocab_stoi)
                temp_tokens.append(self.vocab_stoi[token])
            self.tokenized[1].append(temp_tokens)
        
        # merge the most frequent pair iteratively until the vocabulary size is reached
        while len(self.vocab_stoi) < self.vocab_size:
            # compute the frequency of each pair
            most_common_pair = self.compute_pair_freqs()
            new_token = len(self.vocab_itos)
            most_common_pair_text = self.vocab_itos[most_common_pair[0]] + self.vocab_itos[most_common_pair[1]]
            self.vocab_itos.append( most_common_pair_text )
            self.vocab_stoi[most_common_pair_text] = new_token
            self.merge_pair(most_common_pair, new_token)

        # Padding
        self.max_len = 0
        for k in range(2):
            for i in range(len(self.tokenized[k])):
                self.max_len = max(self.max_len, len(self.tokenized[k][i]))
        for k in range(2):
            for i in range(len(self.tokenized[k])):
                self.tokenized[k][i].extend([0]*(self.max_len-len(self.tokenized[k][i])))
        self.tokenized = torch.tensor(self.tokenized, dtype=torch.int)
    
    def compute_pair_freqs(self):
        """Compute the frequency of each pair."""
        pair_freqs = {}
        for i in range(len(self.tokenized[0])):
            for j in range(len(self.tokenized[0][i])-1):
                temp = (self.tokenized[0][i][j], self.tokenized[0][i][j+1])
                pair_freqs[temp] = pair_freqs.get(temp, 0) + 1
        ########
        most_common_pair = Counter(pair_freqs).most_common(1)[0][0]
        return(most_common_pair)
    
    def merge_pair(self, most_common_pair, new_token):
        """Merge the given pair."""
        for i in range(len(self.tokenized[0])):
            j = 0
            while j < len(self.tokenized[0][i])-1:
                temp = (self.tokenized[0][i][j], self.tokenized[0][i][j+1])
                if temp == most_common_pair:
                    self.tokenized[0][i].pop(j)
                    self.tokenized[0][i][j] = new_token
                j = j+1

    def get_dataframe(self):
        retval = pd.DataFrame(
            {
                'siRNA_antisense_seq_encoded': self.tokenized[0],
                'modified_siRNA_antisense_seq_list_encoded': self.tokenized[1],
                'mRNA_remaining_pct': self.target
            }
        )
        return(retval)


class SiRNADataset(Dataset):
    def __init__(self, df_tokenized_0, df_tokenized_1, df_target):
        self.df_tokenized_0 = df_tokenized_0
        self.df_tokenized_1 = df_tokenized_1
        self.df_target = df_target
    
    def __len__(self):
        return len(self.df_target)
    
    def __getitem__(self, idx):
        seqs = [self.df_tokenized_0[idx], self.df_tokenized_1[idx]]
        return seqs, self.df_target[idx]
    


class SiRNAModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=200, hidden_dim=256, n_layers=3, dropout=0.5):
        super(SiRNAModel, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.gru = nn.GRU(embed_dim, hidden_dim, n_layers, bidirectional=True, batch_first=True, dropout=dropout)
        self.fc = nn.Linear(hidden_dim * 4, 1) # Bi-direactional and two feature columns
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        embedded = [self.embedding(seq) for seq in x]
        outputs = []
        for embed in embedded:
            x, _ = self.gru(embed)
            x = self.dropout(x[:, -1, :])  # Use last hidden state
            outputs.append(x)
        
        x = torch.cat(outputs, dim=1)
        x = self.fc(x)
        return x.squeeze()


def calculate_metrics(y_true, y_pred, threshold=30):
    mae = np.mean(np.abs(y_true - y_pred))

    y_true_binary = (y_true < threshold).astype(int)
    y_pred_binary = (y_pred < threshold).astype(int)

    mask = (y_pred >= 0) & (y_pred <= threshold)
    range_mae = mean_absolute_error(y_true[mask], y_pred[mask]) if mask.sum() > 0 else 100

    precision = precision_score(y_true_binary, y_pred_binary, average='binary')
    recall = recall_score(y_true_binary, y_pred_binary, average='binary')
    f1 = 2 * precision * recall / (precision + recall)
    score = (1 - mae / 100) * 0.5 + (1 - range_mae / 100) * f1 * 0.5
    return score



def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=50, device='cuda'):
    model.to(device)
    best_score = -float('inf')
    best_model = None

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for inputs, targets in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
            inputs = [x.to(device) for x in inputs]
            targets = targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        model.eval()
        val_loss = 0
        val_preds = []
        val_targets = []

        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs = [x.to(device) for x in inputs]
                targets = targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()
                val_preds.extend(outputs.cpu().numpy())
                val_targets.extend(targets.cpu().numpy())
        
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        
        val_preds = np.array(val_preds)
        val_targets = np.array(val_targets)
        score = calculate_metrics(val_targets, val_preds)
        
        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
        print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
        print(f'Validation Score: {score:.4f}')

        if score > best_score:
            best_score = score
            best_model = model.state_dict().copy()
            print(f'New best model found with socre: {best_score:.4f}')

    return best_model

def evaluate_model(model, test_loader, device='cuda'):
    model.eval()
    predictions = []
    targets = []
    
    with torch.no_grad():
        for inputs, target in test_loader:
            inputs = [x.to(device) for x in inputs]
            outputs = model(inputs)
            predictions.extend(outputs.cpu().numpy())
            targets.extend(target.numpy())

    y_pred = np.array(predictions)
    y_test = np.array(targets)
    
    score = calculate_metrics(y_test, y_pred)
    print(f"Test Score: {score:.4f}")




In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Prepare dataset
df = pd.read_csv('train_data.csv')

columns = ['siRNA_antisense_seq', 'modified_siRNA_antisense_seq_list']
df = df[columns + ['mRNA_remaining_pct']]
df.dropna(inplace=True)

df_tokenized = BPE_Tokenization(data=df, vocab_size=100)
if len(df_tokenized.vocab_itos) == len(df_tokenized.vocab_stoi):
    print(len(df_tokenized.vocab_itos))
else:
    print('Error')

tokenized_0_train, tokenized_0_val, tokenized_1_train, tokenized_1_val, target_train, target_val = train_test_split(
    df_tokenized.tokenized[0], df_tokenized.tokenized[1], df_tokenized.target,
    test_size=0.1,
    random_state=42
)

train_dataset = SiRNADataset(tokenized_0_train, tokenized_1_train, target_train)
val_dataset = SiRNADataset(tokenized_0_val, tokenized_1_val, target_val)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# Initialize model
model = SiRNAModel(len(df_tokenized.vocab_itos))
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters())


In [None]:
best_model = train_model(model, train_loader, val_loader, criterion, optimizer, 50, device)


Epoch 1/50: 100%|█████████████████████████████| 726/726 [03:57<00:00,  3.06it/s]


Epoch 2/50: 100%|█████████████████████████████| 726/726 [03:49<00:00,  3.16it/s]


Epoch 3/50: 100%|█████████████████████████████| 726/726 [03:37<00:00,  3.34it/s]


Epoch 4/50: 100%|█████████████████████████████| 726/726 [03:35<00:00,  3.36it/s]


Epoch 5/50: 100%|█████████████████████████████| 726/726 [03:38<00:00,  3.32it/s]


Epoch 6/50: 100%|█████████████████████████████| 726/726 [03:39<00:00,  3.31it/s]


Epoch 7/50: 100%|█████████████████████████████| 726/726 [03:42<00:00,  3.26it/s]


Epoch 8/50: 100%|█████████████████████████████| 726/726 [03:53<00:00,  3.10it/s]


Epoch 9/50: 100%|█████████████████████████████| 726/726 [04:03<00:00,  2.98it/s]


Epoch 10/50: 100%|████████████████████████████| 726/726 [04:13<00:00,  2.86it/s]


Epoch 11/50: 100%|████████████████████████████| 726/726 [03:57<00:00,  3.06it/s]


Epoch 12/50: 100%|████████████████████████████| 726/726 [03:56<00:00,  3.07it/s]


Epoch 13/50: 100%|████████████████████████████| 726/726 [04:02<00:00,  3.00it/s]


Epoch 14/50: 100%|████████████████████████████| 726/726 [03:25<00:00,  3.54it/s]


Epoch 15/50: 100%|████████████████████████████| 726/726 [03:55<00:00,  3.09it/s]


Epoch 16/50: 100%|████████████████████████████| 726/726 [03:48<00:00,  3.18it/s]


Epoch 17/50: 100%|████████████████████████████| 726/726 [03:54<00:00,  3.10it/s]


Epoch 18/50: 100%|████████████████████████████| 726/726 [03:47<00:00,  3.19it/s]


Epoch 19/50: 100%|████████████████████████████| 726/726 [03:28<00:00,  3.49it/s]


Epoch 20/50: 100%|████████████████████████████| 726/726 [03:19<00:00,  3.64it/s]


Epoch 21/50: 100%|████████████████████████████| 726/726 [03:23<00:00,  3.56it/s]


Epoch 22/50: 100%|████████████████████████████| 726/726 [04:01<00:00,  3.00it/s]


Epoch 23/50: 100%|████████████████████████████| 726/726 [03:33<00:00,  3.41it/s]


Epoch 24/50: 100%|████████████████████████████| 726/726 [03:35<00:00,  3.37it/s]


Epoch 25/50: 100%|████████████████████████████| 726/726 [03:42<00:00,  3.26it/s]


Epoch 26/50: 100%|████████████████████████████| 726/726 [03:43<00:00,  3.25it/s]


Epoch 27/50: 100%|████████████████████████████| 726/726 [03:40<00:00,  3.29it/s]


Epoch 28/50: 100%|████████████████████████████| 726/726 [03:40<00:00,  3.29it/s]


Epoch 29/50: 100%|████████████████████████████| 726/726 [03:37<00:00,  3.34it/s]


Epoch 30/50: 100%|████████████████████████████| 726/726 [03:54<00:00,  3.09it/s]


Epoch 31/50: 100%|████████████████████████████| 726/726 [03:51<00:00,  3.13it/s]


Epoch 32/50: 100%|████████████████████████████| 726/726 [04:13<00:00,  2.86it/s]


Epoch 33/50:  51%|██████████████▎             | 370/726 [02:03<02:04,  2.86it/s]

In [None]:
for inputs, targets in train_loader:
    inputs = [x.to(device) for x in inputs]
    targets = targets.to(device)
            
    optimizer.zero_grad()
    outputs = model(inputs)
    break


In [None]:
targets.shape