In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from dataloader import get_dataloaders, MAX_SEQ_LENGTH, vocab_size
import time

save_model = True

In [None]:
class RNAPairTransformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, feature_dim, num_layers=2, nhead=8, device='cpu'):
        super(RNAPairTransformer, self).__init__()

        self.input_dim = input_dim # input size，equal to vocab_size
        self.hidden_dim = hidden_dim # hidden size
        self.output_dim = output_dim  # output size，vocab_size
        self.feature_dim = feature_dim # feature size
        self.num_layers = num_layers # number of layers in the Transformer
        self.device = device

        # Embedding layer for one-hot encoded input
        self.embedding = nn.Embedding(input_dim, hidden_dim) # Embed the RNA sequence
        self.feature_embedding = nn.Linear(feature_dim, hidden_dim) # Embed the RNA features
        self.concat_projection = nn.Linear(hidden_dim * 2, hidden_dim) # Project the two embedding vector to hidden size
        self.positional_encoding = self._generate_positional_encoding(MAX_SEQ_LENGTH, hidden_dim) # positional encoding

        # Transformer Encoder-Decoder
        self.transformer = nn.Transformer(
            d_model=hidden_dim,
            nhead=nhead, # number of heads for multihead attention
            num_encoder_layers=num_layers, # layer of encoder
            num_decoder_layers=num_layers, # layer of decoder
            dim_feedforward=hidden_dim * 4, # hidden size
            batch_first=True, # size of the input is (batch_size, seq_length, feature_dim)。
            norm_first=True, # normalization
         #   dropout=0.1, # dropout threshold 
        )

        # Output layer
        self.fc = nn.Linear(hidden_dim, output_dim) # project the hidden size to the output size

    def _generate_positional_encoding(self, seq_length, hidden_dim):
        position = torch.arange(0, seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, hidden_dim, 2).float() * -(torch.log(torch.tensor(10000.0)) / hidden_dim)
        )
        positional_encoding = torch.zeros(seq_length, hidden_dim)
        positional_encoding[:, 0::2] = torch.sin(position * div_term)
        positional_encoding[:, 1::2] = torch.cos(position * div_term)
        positional_encoding = positional_encoding.unsqueeze(0)
        return positional_encoding
    
    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, src, tgt, src_features):
        # Generate target mask
        self.tgt_mask = self._generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        # Add positional encoding to embeddings
        src_emb = self.embedding(src) + self.positional_encoding[:, : src.size(1), :].to(self.device)
        tgt_emb = self.embedding(tgt) + self.positional_encoding[:, : tgt.size(1), :].to(self.device)
        src_feat_emb = self.feature_embedding(src_features).unsqueeze(1).expand(-1, src_emb.size(1), -1)
        src_emb_concat = torch.cat([src_emb, src_feat_emb], dim=-1)
        src_emb = self.concat_projection(src_emb_concat)
        # Pass through Transformer
        transformer_output = self.transformer(src_emb, tgt_emb, tgt_mask=self.tgt_mask)

        # Output layer
        output = self.fc(transformer_output)

        return output


def train_model(model, train_loader, criterion, optimizer, num_epochs, device, time_stamp):
    best_model = None
    best_dev_loss = float('inf')
    best_train_loss = 0
    best_epoch = 0

    loss_arr = []

    for epoch in range(num_epochs):
        # print(f'Epoch {epoch+1}/{num_epochs}')
        model.train()
        total_loss = 0
        for seq1, feature1, seq2, _ in train_loader:
            seq1, seq2 = seq1.to(device), seq2.to(device)
            feature1 = torch.stack(feature1, dim=1)
            feature1 = feature1.to(device).float()

            # Shift target sequence for decoder input
            tgt_input = seq2[:, :-1]
            tgt_output = seq2[:, 1:]

            outputs = model(seq1, tgt_input, feature1)
            loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt_output.reshape(-1).long())

            optimizer.zero_grad()
            loss.backward()
            #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) #apply gradient clipping
            optimizer.step()
            total_loss += loss.item()

        total_loss /= len(train_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss:.4f}')
        loss_arr.append(total_loss)


        # eval model
        if (epoch+1) % 10 == 0:
            dev_loss = evaluate_model(model, dev_loader, criterion, device)
            if dev_loss < best_dev_loss:
                best_dev_loss = dev_loss
                best_model = model
                best_train_loss = total_loss
                best_epoch = epoch
    
    # Save best model
    if save_model:
        torch.save(best_model.state_dict(), './model/'+time_stamp+'/transformer_model_best.pth')

        import json
        with open('./model/'+time_stamp+'/loss.json', 'w') as f:
            json.dump({'loss': loss_arr, 
                    'best_epoch': best_epoch, 
                    'best_dev_loss': best_dev_loss, 
                    'best_train_loss': best_train_loss}, f, indent=4)
            
    return best_dev_loss

def evaluate_model(model, dev_loader, criterion, device):
    model.eval()
    with torch.no_grad():
        total_loss = 0
        for seq1, feature1, seq2, _ in dev_loader:
            seq1, seq2 = seq1.to(device), seq2.to(device)
            feature1 = torch.stack(feature1, dim=1)
            feature1 = feature1.to(device).float()

            # Shift target sequence for decoder input
            tgt_input = seq2[:, :-1]
            tgt_output = seq2[:, 1:]

            outputs = model(seq1, tgt_input, feature1)
            loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt_output.reshape(-1).long())
            total_loss += loss.item()

        print(f'Dev Loss: {total_loss / len(dev_loader):.4f}')
    return total_loss / len(dev_loader)


In [None]:
# generate sequence with 
def generate_sequence(model, src, feature1, start_token, max_len, device):
    model.eval()
    src = src.to(device)
    batch_size = src.size(0)

    # Initialize
    tgt = torch.full((batch_size, 1), start_token, dtype=torch.long, device=device)

    with torch.no_grad():
        for _ in range(max_len - 1): 
            tgt_emb = model.embedding(tgt) + model.positional_encoding[:, : tgt.size(1), :].to(device)
            src_emb = model.embedding(src) + model.positional_encoding[:, : src.size(1), :].to(device)
            src_feat_emb = model.feature_embedding(feature1).unsqueeze(1).expand(-1, src_emb.size(1), -1)
            src_emb_concat = torch.cat([src_emb, src_feat_emb], dim=-1)
            src_emb = model.concat_projection(src_emb_concat)

            # outputs from transformer
            outputs = model.transformer(src_emb, tgt_emb)
            logits = model.fc(outputs[:, -1, :])

            # greedy decoding
            next_token = torch.argmax(logits, dim=-1).unsqueeze(1)
            tgt = torch.cat([tgt, next_token], dim=1)

    return tgt

def greedy_decode_with_temperature(model, src, feature1, start_token, max_len,device, temperature=1.0):
    src = src.to(device)
    model.eval()
    with torch.no_grad():
        tgt = torch.tensor([start_token], device=device).unsqueeze(0)
        src_emb = model.embedding(src) + model.positional_encoding[:, :src.size(1), :].to(device)
        src_feat_emb = model.feature_embedding(feature1).unsqueeze(1).expand(-1, src_emb.size(1), -1)
        src_emb_concat = torch.cat([src_emb, src_feat_emb], dim=-1)
        src_emb = model.concat_projection(src_emb_concat)
        encoder_output = model.transformer.encoder(src_emb)
        
        for _ in range(max_len):
            tgt_emb = model.embedding(tgt) + model.positional_encoding[:, :tgt.size(1), :].to(device)
            tgt_mask = model._generate_square_subsequent_mask(tgt.size(1)).to(device)
            decoder_output = model.transformer.decoder(tgt_emb, encoder_output, tgt_mask=tgt_mask)
            logits = model.fc(decoder_output[:, -1, :]) 
            probs = torch.softmax(logits / temperature, dim=-1)  # softmax probability
            
            next_token = torch.multinomial(probs, num_samples=1).item()  # random samping
            tgt = torch.cat([tgt, torch.tensor([[next_token]], device=device)], dim=1)
            
        
        return tgt

def top_k_sampling(model, src, feature1, start_token, max_len,device, k=5):
    src = src.to(device)
    model.eval()
    with torch.no_grad():
        tgt = torch.tensor([start_token], device=device).unsqueeze(0)
        src_emb = model.embedding(src) + model.positional_encoding[:, :src.size(1), :].to(device)
        src_feat_emb = model.feature_embedding(feature1).unsqueeze(1).expand(-1, src_emb.size(1), -1)
        src_emb_concat = torch.cat([src_emb, src_feat_emb], dim=-1)
        src_emb = model.concat_projection(src_emb_concat)
        encoder_output = model.transformer.encoder(src_emb)
        
        for _ in range(max_len):
            tgt_emb = model.embedding(tgt) + model.positional_encoding[:, :tgt.size(1), :].to(device)
            tgt_mask = model._generate_square_subsequent_mask(tgt.size(1)).to(device)
            decoder_output = model.transformer.decoder(tgt_emb, encoder_output, tgt_mask=tgt_mask)
            logits = model.fc(decoder_output[:, -1, :])
            probs = torch.softmax(logits, dim=-1)
            
            # 取 top-k 的概率和索引
            top_k_probs, top_k_indices = torch.topk(probs, k)
            top_k_probs = top_k_probs / top_k_probs.sum()  # normalization
            next_token = torch.multinomial(top_k_probs, 1).item()
            
            tgt = torch.cat([tgt, torch.tensor([[next_token]], device=device)], dim=1)        
        return tgt


In [None]:
from dataloader import get_dataloaders, MAX_SEQ_LENGTH, vocab_size, vocabulary
import torch.nn as nn
import random
batch_size = 1
train_loader, dev_loader, test_loader = get_dataloaders(batch_size=batch_size, one_hot_encode=False, start_token=True, get_feature=True)
# random select 5 training samples
random.seed(0)
train_samples = random.sample(list(train_loader), 10)
# random select 5 dev samples
dev_samples = random.sample(list(dev_loader), 10)
# random select 5 test samples
test_samples = random.sample(list(test_loader), 10)

vocab = list(vocabulary.keys())
def outputs_to_seq(outputs, flag=False):
    if flag:
        #print(outputs)
        outputs = outputs.argmax(dim=-1)
    # print(outputs.shape)
    outputs = [vocab[i] for i in outputs]
    if 'S' in outputs:
        outputs = outputs[1:]
    if 'P' in outputs:
        outputs = outputs[:outputs.index('P')]
    if 'E' in outputs:
        outputs = outputs[:outputs.index('E')]
    return outputs

model.eval()
# Output seq 1 and ground truth seq 2 and the estimated seq 2
for i in range(10):
    seq1, feature1, seq2, _ = train_samples[i]
    seq1 = seq1.to(device)
    seq2 = seq2.to(device)
    feature1 = torch.stack(feature1, dim=1)
    feature1 = feature1.to(device).float()
    tgt_input = seq2[:, :-1]
    tgt_output = seq2[:, 1:]
    outputs = model(seq1, tgt_input, feature1)
    criterion = nn.CrossEntropyLoss()
    loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt_output.reshape(-1).long())
    print("train loss: ", loss.item())
    print("seq1: ", outputs_to_seq(seq1[0][1:]))
    print("seq2: ", outputs_to_seq(seq2[0][1:]))
    print("pred: ", outputs_to_seq(outputs[0], True))

print("dev samples")
for i in range(10):
    seq1, feature1, seq2, _ = dev_samples[i]
    seq1 = seq1.to(device)
    seq2 = seq2.to(device)
    feature1 = torch.stack(feature1, dim=1)
    feature1 = feature1.to(device).float()
    tgt_input = seq2[:, :-1]
    tgt_output = seq2[:, 1:]
    outputs = model(seq1, tgt_input, feature1)
    criterion = nn.CrossEntropyLoss()
    loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt_output.reshape(-1).long())
    print("dev loss: ", loss.item())
    print("seq1: ", outputs_to_seq(seq1[0][1:]))
    print("seq2: ", outputs_to_seq(seq2[0][1:]))
    print("pred: ", outputs_to_seq(outputs[0], True))


print("test samples")
# test model
generation_method = 'greedy_temp' # change to 'greedy' or 'greedy_temp' if necessary
for i in range(10):
    seq1, feature1, seq2, _ = test_samples[i]
    model.eval()
    with torch.no_grad():
        seq1, seq2 = seq1.to(device), seq2.to(device)
        feature1 = torch.stack(feature1, dim=1)
        feature1 = feature1.to(device).float()
        # Generate sequences
        if generation_method == 'greedy_temp':
            generated_seq = greedy_decode_with_temperature(model, seq1, feature1, 7, MAX_SEQ_LENGTH, device,1.2)
        if generation_method == 'greedy':
            generated_seq = generate_sequence(model, seq1, feature1, 7, MAX_SEQ_LENGTH, device)
        if generation_method == 'k_sampling':
            generated_seq = top_k_sampling(model, seq1, feature1, 7, MAX_SEQ_LENGTH, device,2)
        print("test sequence:", i)
        print("seq1: ", outputs_to_seq(seq1[0][1:]))
        print("seq2: ", outputs_to_seq(seq2[0][1:]))
        print("pred: ", outputs_to_seq(generated_seq[0]))