In [33]:
import torch
import torch.nn as nn
import torch.optim as optim
from dataloader import get_dataloaders, MAX_SEQ_LENGTH, vocab_size
import time

class RNAPairTransformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2, nhead=8, device='cpu'):
        super(RNAPairTransformer, self).__init__()

        self.input_dim = input_dim # 输入维度，vocab_size
        self.hidden_dim = hidden_dim # Transformer模型中每一层的特征向量维度
        self.output_dim = output_dim  # 输出维度，vocab_size
        self.num_layers = num_layers # Transformer模型中encoder和decoder的层数
        self.device = device

        # Embedding layer for one-hot encoded input
        self.embedding = nn.Embedding(input_dim, hidden_dim) # 每个碱基都有一个固定的特征向量表示
        self.positional_encoding = self._generate_positional_encoding(MAX_SEQ_LENGTH, hidden_dim) # 位置编码

        # Transformer Encoder-Decoder
        self.transformer = nn.Transformer(
            d_model=hidden_dim,
            nhead=nhead, # 多头注意力机制的头数
            num_encoder_layers=num_layers, # encoder层数
            num_decoder_layers=num_layers, # decoder层数
            dim_feedforward=hidden_dim * 4, # 前馈网络中隐层的维度
            batch_first=True, # 输入数据的形状为(batch_size, seq_length, feature_dim)。
            # dropout=0.1, # dropout概率
        )

        # Output layer
        self.fc = nn.Linear(hidden_dim, output_dim) # 利用一个全连接层将隐藏层的特征向量映射到输出维度

    def _generate_positional_encoding(self, seq_length, hidden_dim):
        position = torch.arange(0, seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, hidden_dim, 2).float() * -(torch.log(torch.tensor(10000.0)) / hidden_dim)
        )
        positional_encoding = torch.zeros(seq_length, hidden_dim)
        positional_encoding[:, 0::2] = torch.sin(position * div_term)
        positional_encoding[:, 1::2] = torch.cos(position * div_term)
        positional_encoding = positional_encoding.unsqueeze(0)
        return positional_encoding
    
    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, src, tgt):
        # Generate target mask
        self.tgt_mask = self._generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        # Add positional encoding to embeddings
        src_emb = self.embedding(src) + self.positional_encoding[:, : src.size(1), :].to(self.device)
        tgt_emb = self.embedding(tgt) + self.positional_encoding[:, : tgt.size(1), :].to(self.device)
        # Pass through Transformer
        transformer_output = self.transformer(src_emb, tgt_emb, tgt_mask=self.tgt_mask)

        # Output layer
        output = self.fc(transformer_output)

        return output


def train_model(model, train_loader, criterion, optimizer, num_epochs=10, device='cpu'):
    model.train()
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        for seq1, seq2 in train_loader:
            seq1, seq2 = seq1.to(device), seq2.to(device)

            # Shift target sequence for decoder input
            tgt_input = seq2[:, :-1]
            tgt_output = seq2[:, 1:]

            outputs = model(seq1, tgt_input)
            loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt_output.reshape(-1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')


def evaluate_model(model, dev_loader, criterion, device):
    model.eval()
    with torch.no_grad():
        total_loss = 0
        for seq1, seq2 in dev_loader:
            seq1, seq2 = seq1.to(device), seq2.to(device)

            # Shift target sequence for decoder input
            tgt_input = seq2[:, :-1]
            tgt_output = seq2[:, 1:]

            outputs = model(seq1, tgt_input)
            loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt_output.reshape(-1))
            total_loss += loss.item()

        print(f'Dev Loss: {total_loss / len(dev_loader):.4f}')


if __name__ == "__main__":
    # Hyperparameters
    input_dim = vocab_size
    hidden_dim = 128
    output_dim = vocab_size
    num_layers = 4
    nhead = 8
    num_epochs = 30
    learning_rate = 1e-2
    batch_size = 32

    # Device configuration
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"
    print('Using ' + device)

    # Load data
    train_loader, dev_loader, test_loader = get_dataloaders(batch_size=batch_size, one_hot_encode=False, start_token=True)

    # Initialize model, criterion and optimizer
    model = RNAPairTransformer(input_dim, hidden_dim, output_dim, num_layers, nhead, device).to(device)
    weight = torch.tensor([1, 1, 1, 1, 1, 0.01, 1, 1], dtype=torch.float32, requires_grad=False).to(device)
    criterion = nn.CrossEntropyLoss(weight=weight)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 0.99 ** epoch)

    # Train the model
    train_model(model, train_loader, criterion, optimizer, num_epochs, device)

    # Save the model
    torch.save(model.state_dict(), 'transformer_model.pth')

    # Evaluate the model
    evaluate_model(model, dev_loader, criterion, device)


Using mps
Epoch 1/30
Epoch [1/30], Loss: 1.5700
Epoch 2/30
Epoch [2/30], Loss: 1.5398
Epoch 3/30
Epoch [3/30], Loss: 1.5799
Epoch 4/30
Epoch [4/30], Loss: 1.5377
Epoch 5/30
Epoch [5/30], Loss: 1.5532
Epoch 6/30
Epoch [6/30], Loss: 1.5833
Epoch 7/30
Epoch [7/30], Loss: 1.5703
Epoch 8/30
Epoch [8/30], Loss: 1.5792
Epoch 9/30
Epoch [9/30], Loss: 1.5813
Epoch 10/30
Epoch [10/30], Loss: 1.6030
Epoch 11/30
Epoch [11/30], Loss: 1.5596
Epoch 12/30
Epoch [12/30], Loss: 1.5980
Epoch 13/30
Epoch [13/30], Loss: 1.6628
Epoch 14/30
Epoch [14/30], Loss: 1.5629
Epoch 15/30
Epoch [15/30], Loss: 1.5529
Epoch 16/30
Epoch [16/30], Loss: 1.5567
Epoch 17/30
Epoch [17/30], Loss: 1.6090
Epoch 18/30
Epoch [18/30], Loss: 1.5629
Epoch 19/30
Epoch [19/30], Loss: 1.5577
Epoch 20/30
Epoch [20/30], Loss: 1.5564
Epoch 21/30
Epoch [21/30], Loss: 1.5484
Epoch 22/30
Epoch [22/30], Loss: 1.5952
Epoch 23/30
Epoch [23/30], Loss: 1.5380
Epoch 24/30
Epoch [24/30], Loss: 1.6017
Epoch 25/30
Epoch [25/30], Loss: 1.5587
Epoch 26

In [None]:
from dataloader import get_dataloaders, MAX_SEQ_LENGTH, vocab_size, vocabulary
import torch.nn as nn
import random
batch_size = 1
train_loader, dev_loader, test_loader = get_dataloaders(batch_size=batch_size, one_hot_encode=False, start_token=True)
# random select 5 training samples
random.seed(0)
train_samples = random.sample(list(train_loader), 10)
# random select 5 dev samples
dev_samples = random.sample(list(dev_loader), 10)
# random select 5 test samples
test_samples = random.sample(list(test_loader), 5)

vocab = list(vocabulary.keys())
def outputs_to_seq(outputs, flag=False):
    if flag:
        print(outputs)
        outputs = outputs.argmax(dim=-1)
    # print(outputs.shape)
    outputs = [vocab[i] for i in outputs]
    if 'P' in outputs:
        outputs = outputs[:outputs.index('P')]
    if 'E' in outputs:
        outputs = outputs[:outputs.index('E')]
    return outputs

model.eval()
# 输出原来的seq1和seq2，还有预测的seq2
for i in range(10):
    seq1, seq2 = train_samples[i]
    seq1 = seq1.to(device)
    seq2 = seq2.to(device)
    tgt_input = seq2[:, :-1]
    tgt_output = seq2[:, 1:]
    outputs = model(seq1, tgt_input)
    criterion = nn.CrossEntropyLoss()
    loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt_output.reshape(-1))
    print("train loss: ", loss.item())
    print("seq1: ", outputs_to_seq(seq1[0][1:]))
    print("seq2: ", outputs_to_seq(seq2[0][1:]))
    print("pred: ", outputs_to_seq(outputs[0], True))

print("dev samples")
for i in range(10):
    seq1, seq2 = dev_samples[i]
    seq1 = seq1.to(device)
    seq2 = seq2.to(device)
    tgt_input = seq2[:, :-1]
    tgt_output = seq2[:, 1:]
    outputs = model(seq1, tgt_input)
    criterion = nn.CrossEntropyLoss()
    loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt_output.reshape(-1))
    print("train loss: ", loss.item())
    print("seq1: ", outputs_to_seq(seq1[0][1:]))
    print("seq2: ", outputs_to_seq(seq2[0][1:]))
    print("pred: ", outputs_to_seq(outputs[0], True))