In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from dataloader import get_dataloaders, MAX_SEQ_LENGTH, vocab_size
import time

class RNAPairTransformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, feature_dim, num_layers=2, nhead=8, device='cpu'):
        super(RNAPairTransformer, self).__init__()

        self.input_dim = input_dim # 输入维度，vocab_size
        self.hidden_dim = hidden_dim # Transformer模型中每一层的特征向量维度
        self.output_dim = output_dim  # 输出维度，vocab_size
        self.feature_dim = feature_dim # feature维度
        self.num_layers = num_layers # Transformer模型中encoder和decoder的层数
        self.device = device

        # Embedding layer for one-hot encoded input
        self.embedding = nn.Embedding(input_dim, hidden_dim) # 每个碱基都有一个固定的特征向量表示
        self.feature_embedding = nn.Linear(feature_dim, hidden_dim) # 每个碱基对都有一个固定的特征向量表示
        self.concat_projection = nn.Linear(hidden_dim * 2, hidden_dim) # 将两个特征向量拼接后映射到hidden_dim维度
        self.positional_encoding = self._generate_positional_encoding(MAX_SEQ_LENGTH, hidden_dim) # 位置编码

        # Transformer Encoder-Decoder
        self.transformer = nn.Transformer(
            d_model=hidden_dim,
            nhead=nhead, # 多头注意力机制的头数
            num_encoder_layers=num_layers, # encoder层数
            num_decoder_layers=num_layers, # decoder层数
            dim_feedforward=hidden_dim * 4, # 前馈网络中隐层的维度
            batch_first=True, # 输入数据的形状为(batch_size, seq_length, feature_dim)。
            norm_first=True, # 加normalization
         #   dropout=0.1, # dropout概率
        )

        # Output layer
        self.fc = nn.Linear(hidden_dim, output_dim) # 利用一个全连接层将隐藏层的特征向量映射到输出维度

    def _generate_positional_encoding(self, seq_length, hidden_dim):
        position = torch.arange(0, seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, hidden_dim, 2).float() * -(torch.log(torch.tensor(10000.0)) / hidden_dim)
        )
        positional_encoding = torch.zeros(seq_length, hidden_dim)
        positional_encoding[:, 0::2] = torch.sin(position * div_term)
        positional_encoding[:, 1::2] = torch.cos(position * div_term)
        positional_encoding = positional_encoding.unsqueeze(0)
        return positional_encoding
    
    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, src, tgt, src_features):
        # Generate target mask
        self.tgt_mask = self._generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        # Add positional encoding to embeddings
        src_emb = self.embedding(src) + self.positional_encoding[:, : src.size(1), :].to(self.device)
        tgt_emb = self.embedding(tgt) + self.positional_encoding[:, : tgt.size(1), :].to(self.device)
        src_feat_emb = self.feature_embedding(src_features).unsqueeze(1).expand(-1, src_emb.size(1), -1)
        src_emb_concat = torch.cat([src_emb, src_feat_emb], dim=-1)
        src_emb = self.concat_projection(src_emb_concat)
        # Pass through Transformer
        transformer_output = self.transformer(src_emb, tgt_emb, tgt_mask=self.tgt_mask)

        # Output layer
        output = self.fc(transformer_output)

        return output


def train_model(model, train_loader, criterion, optimizer, num_epochs=10, device='cpu'):
    model.train()
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        for seq1, feature1, seq2, _ in train_loader:
            seq1, seq2 = seq1.to(device), seq2.to(device)
            feature1 = torch.stack(feature1, dim=1)
            feature1 = feature1.to(device).float()

            # Shift target sequence for decoder input
            tgt_input = seq2[:, :-1]
            tgt_output = seq2[:, 1:]

            outputs = model(seq1, tgt_input, feature1)
            loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt_output.reshape(-1).long())

            optimizer.zero_grad()
            loss.backward()
            #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) #apply gradient clipping
            optimizer.step()

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')


def evaluate_model(model, dev_loader, criterion, device):
    model.eval()
    with torch.no_grad():
        total_loss = 0
        for seq1, feature1, seq2, _ in dev_loader:
            seq1, seq2 = seq1.to(device), seq2.to(device)
            feature1 = torch.stack(feature1, dim=1)
            feature1 = feature1.to(device).float()

            # Shift target sequence for decoder input
            tgt_input = seq2[:, :-1]
            tgt_output = seq2[:, 1:]

            outputs = model(seq1, tgt_input, feature1)
            loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt_output.reshape(-1).long())
            total_loss += loss.item()

        print(f'Dev Loss: {total_loss / len(dev_loader):.4f}')


if __name__ == "__main__":
    # Hyperparameters
    input_dim = vocab_size
    hidden_dim = 128
    feature_dim = 4
    output_dim = vocab_size
    num_layers = 4
    nhead = 8
    num_epochs = 60
    learning_rate = 1e-3
    batch_size = 32

    # Device configuration
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"
    print('Using ' + device)

    # Load data
    train_loader, dev_loader, test_loader = get_dataloaders(batch_size=batch_size, one_hot_encode=False, start_token=True, get_feature=True)

    # Initialize model, criterion and optimizer
    model = RNAPairTransformer(input_dim, hidden_dim, output_dim, feature_dim, num_layers, nhead, device).to(device)
    weight = torch.tensor([1, 1, 1, 1, 2, 0.01, 1, 1], dtype=torch.float32, requires_grad=False).to(device)
    criterion = nn.CrossEntropyLoss(weight=weight)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 0.9 ** epoch)

    # Train the model
    train_model(model, train_loader, criterion, optimizer, num_epochs, device)

    # Save the model
    torch.save(model.state_dict(), 'transformer_model.pth')

    # Evaluate the model
    evaluate_model(model, dev_loader, criterion, device)


Using mps




Epoch 1/60
Epoch [1/60], Loss: 1.4015
Epoch 2/60
Epoch [2/60], Loss: 1.3517
Epoch 3/60
Epoch [3/60], Loss: 1.2922
Epoch 4/60
Epoch [4/60], Loss: 1.2701
Epoch 5/60
Epoch [5/60], Loss: 1.2704
Epoch 6/60
Epoch [6/60], Loss: 1.3014
Epoch 7/60
Epoch [7/60], Loss: 1.3074
Epoch 8/60
Epoch [8/60], Loss: 1.2312
Epoch 9/60
Epoch [9/60], Loss: 1.3124
Epoch 10/60
Epoch [10/60], Loss: 1.2520
Epoch 11/60
Epoch [11/60], Loss: 1.2572
Epoch 12/60
Epoch [12/60], Loss: 1.3123
Epoch 13/60
Epoch [13/60], Loss: 1.2519
Epoch 14/60
Epoch [14/60], Loss: 1.2422
Epoch 15/60
Epoch [15/60], Loss: 1.2189
Epoch 16/60
Epoch [16/60], Loss: 1.1445
Epoch 17/60
Epoch [17/60], Loss: 1.2137
Epoch 18/60
Epoch [18/60], Loss: 1.2082
Epoch 19/60
Epoch [19/60], Loss: 1.2223
Epoch 20/60
Epoch [20/60], Loss: 1.1124
Epoch 21/60
Epoch [21/60], Loss: 1.1626
Epoch 22/60
Epoch [22/60], Loss: 1.2254
Epoch 23/60
Epoch [23/60], Loss: 1.1707
Epoch 24/60
Epoch [24/60], Loss: 1.1907
Epoch 25/60
Epoch [25/60], Loss: 1.0925
Epoch 26/60
Epoch 

In [None]:
def evaluate_model(model, dev_loader, criterion, device):
    model.eval()
    with torch.no_grad():
        total_loss = 0
        for seq1, feature1, seq2, _ in dev_loader:
            seq1, seq2 = seq1.to(device), seq2.to(device)
            feature1 = torch.stack(feature1, dim=1)
            feature1 = feature1.to(device).float()

            # Shift target sequence for decoder input
            tgt_input = seq2[:, :-1]
            tgt_output = seq2[:, 1:]

            outputs = model(seq1, tgt_input, feature1)
            loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt_output.reshape(-1).long())
            total_loss += loss.item()

        print(f'Dev Loss: {total_loss / len(dev_loader):.4f}')

evaluate_model(model, dev_loader, criterion, device)

Dev Loss: 0.2897


In [20]:
# generate sequence with 
def generate_sequence(model, src, feature1, start_token, max_len, device):
    model.eval()
    src = src.to(device)
    batch_size = src.size(0)

    # Initialize
    tgt = torch.full((batch_size, 1), start_token, dtype=torch.long, device=device)

    with torch.no_grad():
        for _ in range(max_len - 1): 
            tgt_emb = model.embedding(tgt) + model.positional_encoding[:, : tgt.size(1), :].to(device)
            src_emb = model.embedding(src) + model.positional_encoding[:, : src.size(1), :].to(device)
            src_feat_emb = model.feature_embedding(feature1).unsqueeze(1).expand(-1, src_emb.size(1), -1)
            src_emb_concat = torch.cat([src_emb, src_feat_emb], dim=-1)
            src_emb = model.concat_projection(src_emb_concat)

            # outputs from transformer
            outputs = model.transformer(src_emb, tgt_emb)
            logits = model.fc(outputs[:, -1, :])

            # 贪心解码
            next_token = torch.argmax(logits, dim=-1).unsqueeze(1)
            tgt = torch.cat([tgt, next_token], dim=1)

    return tgt

def greedy_decode_with_temperature(model, src, feature1, start_token, max_len,device, temperature=1.0):
    src = src.to(device)
    model.eval()
    with torch.no_grad():
        tgt = torch.tensor([start_token], device=device).unsqueeze(0)
        src_emb = model.embedding(src) + model.positional_encoding[:, :src.size(1), :].to(device)
        src_feat_emb = model.feature_embedding(feature1).unsqueeze(1).expand(-1, src_emb.size(1), -1)
        src_emb_concat = torch.cat([src_emb, src_feat_emb], dim=-1)
        src_emb = model.concat_projection(src_emb_concat)
        encoder_output = model.transformer.encoder(src_emb)
        
        for _ in range(max_len):
            tgt_emb = model.embedding(tgt) + model.positional_encoding[:, :tgt.size(1), :].to(device)
            tgt_mask = model._generate_square_subsequent_mask(tgt.size(1)).to(device)
            decoder_output = model.transformer.decoder(tgt_emb, encoder_output, tgt_mask=tgt_mask)
            logits = model.fc(decoder_output[:, -1, :])  # 最后一时间步
            probs = torch.softmax(logits / temperature, dim=-1)  # 调整概率分布
            
            next_token = torch.multinomial(probs, num_samples=1).item()  # 随机采样
            tgt = torch.cat([tgt, torch.tensor([[next_token]], device=device)], dim=1)
            
        
        return tgt

def top_k_sampling(model, src, feature1, start_token, max_len,device, k=5):
    src = src.to(device)
    model.eval()
    with torch.no_grad():
        tgt = torch.tensor([start_token], device=device).unsqueeze(0)
        src_emb = model.embedding(src) + model.positional_encoding[:, :src.size(1), :].to(device)
        src_feat_emb = model.feature_embedding(feature1).unsqueeze(1).expand(-1, src_emb.size(1), -1)
        src_emb_concat = torch.cat([src_emb, src_feat_emb], dim=-1)
        src_emb = model.concat_projection(src_emb_concat)
        encoder_output = model.transformer.encoder(src_emb)
        
        for _ in range(max_len):
            tgt_emb = model.embedding(tgt) + model.positional_encoding[:, :tgt.size(1), :].to(device)
            tgt_mask = model._generate_square_subsequent_mask(tgt.size(1)).to(device)
            decoder_output = model.transformer.decoder(tgt_emb, encoder_output, tgt_mask=tgt_mask)
            logits = model.fc(decoder_output[:, -1, :])
            probs = torch.softmax(logits, dim=-1)
            
            # 取 top-k 的概率和索引
            top_k_probs, top_k_indices = torch.topk(probs, k)
            top_k_probs = top_k_probs / top_k_probs.sum()  # 归一化
            next_token = torch.multinomial(top_k_probs, 1).item()
            
            tgt = torch.cat([tgt, torch.tensor([[next_token]], device=device)], dim=1)        
        return tgt


In [None]:
from dataloader import get_dataloaders, MAX_SEQ_LENGTH, vocab_size, vocabulary
import torch.nn as nn
import random
batch_size = 1
train_loader, dev_loader, test_loader = get_dataloaders(batch_size=batch_size, one_hot_encode=False, start_token=True, get_feature=True)
# random select 5 training samples
random.seed(0)
train_samples = random.sample(list(train_loader), 10)
# random select 5 dev samples
dev_samples = random.sample(list(dev_loader), 10)
# random select 5 test samples
test_samples = random.sample(list(test_loader), 10)

vocab = list(vocabulary.keys())
def outputs_to_seq(outputs, flag=False):
    if flag:
        #print(outputs)
        outputs = outputs.argmax(dim=-1)
    # print(outputs.shape)
    outputs = [vocab[i] for i in outputs]
    if 'S' in outputs:
        outputs = outputs[1:]
    if 'P' in outputs:
        outputs = outputs[:outputs.index('P')]
    if 'E' in outputs:
        outputs = outputs[:outputs.index('E')]
    return outputs

model.eval()
# 输出原来的seq1和seq2，还有预测的seq2
for i in range(10):
    seq1, feature1, seq2, _ = train_samples[i]
    seq1 = seq1.to(device)
    seq2 = seq2.to(device)
    feature1 = torch.stack(feature1, dim=1)
    feature1 = feature1.to(device).float()
    tgt_input = seq2[:, :-1]
    tgt_output = seq2[:, 1:]
    outputs = model(seq1, tgt_input, feature1)
    criterion = nn.CrossEntropyLoss()
    loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt_output.reshape(-1).long())
    print("train loss: ", loss.item())
    print("seq1: ", outputs_to_seq(seq1[0][1:]))
    print("seq2: ", outputs_to_seq(seq2[0][1:]))
    print("pred: ", outputs_to_seq(outputs[0], True))

print("dev samples")
for i in range(10):
    seq1, feature1, seq2, _ = dev_samples[i]
    seq1 = seq1.to(device)
    seq2 = seq2.to(device)
    feature1 = torch.stack(feature1, dim=1)
    feature1 = feature1.to(device).float()
    tgt_input = seq2[:, :-1]
    tgt_output = seq2[:, 1:]
    outputs = model(seq1, tgt_input, feature1)
    criterion = nn.CrossEntropyLoss()
    loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt_output.reshape(-1).long())
    print("dev loss: ", loss.item())
    print("seq1: ", outputs_to_seq(seq1[0][1:]))
    print("seq2: ", outputs_to_seq(seq2[0][1:]))
    print("pred: ", outputs_to_seq(outputs[0], True))


print("test samples")
# test model
generation_method = 'greedy_temp' # change to 'greedy' or 'greedy_temp' if necessary
for i in range(10):
    seq1, feature1, seq2, _ = test_samples[i]
    model.eval()
    with torch.no_grad():
        seq1, seq2 = seq1.to(device), seq2.to(device)
        feature1 = torch.stack(feature1, dim=1)
        feature1 = feature1.to(device).float()
        # Generate sequences
        if generation_method == 'greedy_temp':
            generated_seq = greedy_decode_with_temperature(model, seq1, feature1, 7, MAX_SEQ_LENGTH, device,1.2)
        if generation_method == 'greedy':
            generated_seq = generate_sequence(model, seq1, feature1, 7, MAX_SEQ_LENGTH, device)
        if generation_method == 'k_sampling':
            generated_seq = top_k_sampling(model, seq1, feature1, 7, MAX_SEQ_LENGTH, device,2)
        print("test sequence:", i)
        print("seq1: ", outputs_to_seq(seq1[0][1:]))
        print("seq2: ", outputs_to_seq(seq2[0][1:]))
        print("pred: ", outputs_to_seq(generated_seq[0]))



train loss:  0.010388349182903767
seq1:  ['T', 'T']
seq2:  ['A', 'A']
pred:  ['A', 'A', 'A']
train loss:  0.11112502962350845
seq1:  ['C', 'A', 'G', 'G', 'A', 'G', 'G', 'A', 'T', 'C', 'A', 'C', 'T', 'T', 'G']
seq2:  ['C', 'A', 'G', 'G', 'A', 'G', 'G', 'A', 'C', 'C', 'G', 'C', 'T', 'T', 'G']
pred:  ['C', 'A', 'G', 'G', 'G', 'C', 'A', 'A', 'G', 'A', 'G', 'C', 'T', 'T', 'G']
train loss:  0.33561569452285767
seq1:  ['T', 'G', 'A', 'A', 'A', 'T', 'C', 'T', 'C', 'T', 'T', 'G', 'T', 'C', 'A', 'C', 'C', 'C', 'C', 'A', 'T', 'T', 'C', 'T', 'G', 'T', 'C', 'T', 'C', 'A', 'T', 'C', 'T', 'G', 'G', 'G', 'A', 'C', 'A', 'T', 'G', 'A']
seq2:  ['T', 'C', 'A', 'G', 'C', 'C', 'T', 'C', 'T', 'G', 'G', 'T', 'T', 'C', 'T', 'T', 'T', 'T', 'G', 'G', 'C', 'T', 'G', 'C', 'A', 'G', 'G', 'C', 'C', 'T', 'C', 'T', 'G', 'A', 'C', 'A', 'C', 'C', 'T', 'T', 'C', 'T', 'C', 'A']
pred:  ['T', 'C', 'A', 'G', 'C', 'T', 'T', 'G', 'C', 'G', 'G', 'G', 'C', 'C', 'A', 'T', 'T', 'G', 'G', 'G', 'G', 'T', 'G', 'G', 'A', 'G', 'G', 'C'