In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from dataloader import get_dataloaders, MAX_SEQ_LENGTH, vocab_size
import time

save_model = False

In [2]:
class RNAPairTransformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, feature_dim, num_layers=2, nhead=8, device='cpu'):
        super(RNAPairTransformer, self).__init__()

        self.input_dim = input_dim # 输入维度，vocab_size
        self.hidden_dim = hidden_dim # Transformer模型中每一层的特征向量维度
        self.output_dim = output_dim  # 输出维度，vocab_size
        self.feature_dim = feature_dim # feature维度
        self.num_layers = num_layers # Transformer模型中encoder和decoder的层数
        self.device = device

        # Embedding layer for one-hot encoded input
        self.embedding = nn.Embedding(input_dim, hidden_dim) # 每个碱基都有一个固定的特征向量表示
        self.feature_embedding = nn.Linear(feature_dim, hidden_dim) # 每个碱基对都有一个固定的特征向量表示
        self.concat_projection = nn.Linear(hidden_dim * 2, hidden_dim) # 将两个特征向量拼接后映射到hidden_dim维度
        self.positional_encoding = self._generate_positional_encoding(MAX_SEQ_LENGTH, hidden_dim) # 位置编码

        # Transformer Encoder-Decoder
        self.transformer = nn.Transformer(
            d_model=hidden_dim,
            nhead=nhead, # 多头注意力机制的头数
            num_encoder_layers=num_layers, # encoder层数
            num_decoder_layers=num_layers, # decoder层数
            dim_feedforward=hidden_dim * 4, # 前馈网络中隐层的维度
            batch_first=True, # 输入数据的形状为(batch_size, seq_length, feature_dim)。
            norm_first=True, # 加normalization
         #   dropout=0.1, # dropout概率
        )

        # Output layer
        self.fc = nn.Linear(hidden_dim, output_dim) # 利用一个全连接层将隐藏层的特征向量映射到输出维度

    def _generate_positional_encoding(self, seq_length, hidden_dim):
        position = torch.arange(0, seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, hidden_dim, 2).float() * -(torch.log(torch.tensor(10000.0)) / hidden_dim)
        )
        positional_encoding = torch.zeros(seq_length, hidden_dim)
        positional_encoding[:, 0::2] = torch.sin(position * div_term)
        positional_encoding[:, 1::2] = torch.cos(position * div_term)
        positional_encoding = positional_encoding.unsqueeze(0)
        return positional_encoding
    
    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, src, tgt, src_features):
        # Generate target mask
        self.tgt_mask = self._generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        # Add positional encoding to embeddings
        src_emb = self.embedding(src) + self.positional_encoding[:, : src.size(1), :].to(self.device)
        tgt_emb = self.embedding(tgt) + self.positional_encoding[:, : tgt.size(1), :].to(self.device)
        src_feat_emb = self.feature_embedding(src_features).unsqueeze(1).expand(-1, src_emb.size(1), -1)
        src_emb_concat = torch.cat([src_emb, src_feat_emb], dim=-1)
        src_emb = self.concat_projection(src_emb_concat)
        # Pass through Transformer
        transformer_output = self.transformer(src_emb, tgt_emb, tgt_mask=self.tgt_mask)

        # Output layer
        output = self.fc(transformer_output)

        return output


def train_model(model, train_loader, criterion, optimizer, num_epochs, device, time_stamp):
    best_model = None
    best_dev_loss = float('inf')
    best_train_loss = 0
    best_epoch = 0

    loss_arr = []

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        model.train()
        total_loss = 0
        for seq1, feature1, seq2, _ in train_loader:
            seq1, seq2 = seq1.to(device), seq2.to(device)
            feature1 = torch.stack(feature1, dim=1)
            feature1 = feature1.to(device).float()

            # Shift target sequence for decoder input
            tgt_input = seq2[:, :-1]
            tgt_output = seq2[:, 1:]

            outputs = model(seq1, tgt_input, feature1)
            loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt_output.reshape(-1).long())

            optimizer.zero_grad()
            loss.backward()
            #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) #apply gradient clipping
            optimizer.step()
            total_loss += loss.item()

        total_loss /= len(train_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss:.4f}')
        loss_arr.append(total_loss)


        # eval model
        if (epoch+1) % 10 == 0:
            dev_loss = evaluate_model(model, dev_loader, criterion, device)
            if dev_loss < best_dev_loss:
                best_dev_loss = dev_loss
                best_model = model
                best_train_loss = total_loss
                best_epoch = epoch
    
    # Save best model
    if save_model:
        torch.save(best_model.state_dict(), './model/'+time_stamp+'/transformer_model_best.pth')

        import json
        with open('./model/'+time_stamp+'/loss.json', 'w') as f:
            json.dump({'loss': loss_arr, 
                    'best_epoch': best_epoch, 
                    'best_dev_loss': best_dev_loss, 
                    'best_train_loss': best_train_loss}, f, indent=4)
            
    print(f'Best dev loss: {best_dev_loss:.4f}, training loss: {best_train_loss:.4f}')
    print('Training finished')

def evaluate_model(model, dev_loader, criterion, device):
    model.eval()
    with torch.no_grad():
        total_loss = 0
        for seq1, feature1, seq2, _ in dev_loader:
            seq1, seq2 = seq1.to(device), seq2.to(device)
            feature1 = torch.stack(feature1, dim=1)
            feature1 = feature1.to(device).float()

            # Shift target sequence for decoder input
            tgt_input = seq2[:, :-1]
            tgt_output = seq2[:, 1:]

            outputs = model(seq1, tgt_input, feature1)
            loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt_output.reshape(-1).long())
            total_loss += loss.item()

        print(f'Dev Loss: {total_loss / len(dev_loader):.4f}')
    return total_loss / len(dev_loader)


if __name__ == "__main__":
    # Hyperparameters
    input_dim = vocab_size
    hidden_dim = 128
    feature_dim = 4
    output_dim = vocab_size
    num_layers = 4
    nhead = 8
    num_epochs = 20
    learning_rate = 1e-3
    batch_size = 32
    time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
 
    if save_model:
        import os
        if not os.path.exists('./model'):
            os.makedirs('./model')
        os.makedirs('./model/'+time_stamp)

        import json
        hyperparameters = {
            'time_stamp': time_stamp,
            'model': 'transformer',
            'input_dim': input_dim,
            'hidden_dim': hidden_dim,
            'feature_dim': feature_dim,
            'output_dim': output_dim,
            'num_layers': num_layers,
            'nhead': nhead,
            'num_epochs': num_epochs,
            'learning_rate': learning_rate,
            'batch_size': batch_size
        }
        with open('./model/'+time_stamp+'/hyperparameters.json', 'w') as f:
            json.dump(hyperparameters, f, indent=4)

    # Device configuration
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"
    print('Using ' + device)

    # Load data
    train_loader, dev_loader, test_loader = get_dataloaders(batch_size=batch_size, one_hot_encode=False, start_token=True, get_feature=True)

    # Initialize model, criterion and optimizer
    model = RNAPairTransformer(input_dim, hidden_dim, output_dim, feature_dim, num_layers, nhead, device).to(device)
    weight = torch.tensor([1,1,1,1,2,0.01,1,1], dtype=torch.float32, requires_grad=False).to(device)
    criterion = nn.CrossEntropyLoss(weight=weight)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 0.9 ** epoch)

    # Train the model
    train_model(model, train_loader, criterion, optimizer, num_epochs, device, time_stamp)


Using cuda
['ADAR1_seq.txt', 'ADAR2_seq.txt', 'ADAR3_seq.txt', 'Endogenous_ADAR1_seq.txt']




Epoch 1/20
Epoch [1/20], Loss: 1.3836
Epoch 2/20
Epoch [2/20], Loss: 1.3295
Epoch 3/20
Epoch [3/20], Loss: 1.2895
Epoch 4/20
Epoch [4/20], Loss: 1.2657
Epoch 5/20
Epoch [5/20], Loss: 1.2461
Epoch 6/20
Epoch [6/20], Loss: 1.2288
Epoch 7/20
Epoch [7/20], Loss: 1.2139
Epoch 8/20
Epoch [8/20], Loss: 1.2011
Epoch 9/20
Epoch [9/20], Loss: 1.1894
Epoch 10/20
Epoch [10/20], Loss: 1.1784
Dev Loss: 1.1686
Epoch 11/20
Epoch [11/20], Loss: 1.1703
Epoch 12/20
Epoch [12/20], Loss: 1.1621
Epoch 13/20
Epoch [13/20], Loss: 1.1530
Epoch 14/20
Epoch [14/20], Loss: 1.1474
Epoch 15/20
Epoch [15/20], Loss: 1.1387
Epoch 16/20
Epoch [16/20], Loss: 1.1323
Epoch 17/20
Epoch [17/20], Loss: 1.1250
Epoch 18/20
Epoch [18/20], Loss: 1.1175
Epoch 19/20
Epoch [19/20], Loss: 1.1119
Epoch 20/20
Epoch [20/20], Loss: 1.1057
Dev Loss: 1.1028
Best dev loss: 1.1028, training loss: 1.1057
Training finished


In [3]:
# generate sequence with 
def generate_sequence(model, src, feature1, start_token, max_len, device):
    model.eval()
    src = src.to(device)
    batch_size = src.size(0)

    # Initialize
    tgt = torch.full((batch_size, 1), start_token, dtype=torch.long, device=device)

    with torch.no_grad():
        for _ in range(max_len - 1): 
            tgt_emb = model.embedding(tgt) + model.positional_encoding[:, : tgt.size(1), :].to(device)
            src_emb = model.embedding(src) + model.positional_encoding[:, : src.size(1), :].to(device)
            src_feat_emb = model.feature_embedding(feature1).unsqueeze(1).expand(-1, src_emb.size(1), -1)
            src_emb_concat = torch.cat([src_emb, src_feat_emb], dim=-1)
            src_emb = model.concat_projection(src_emb_concat)

            # outputs from transformer
            outputs = model.transformer(src_emb, tgt_emb)
            logits = model.fc(outputs[:, -1, :])

            # 贪心解码
            next_token = torch.argmax(logits, dim=-1).unsqueeze(1)
            tgt = torch.cat([tgt, next_token], dim=1)

    return tgt

def greedy_decode_features(model, src, src_features, start_token, max_length, device):
    model.eval()
    with torch.no_grad():
        batch_size = src.size(0)
        tgt = torch.full((batch_size, 1), start_token, dtype=torch.long, device=device)  # Initialize with start token
        outputs = []

        for _ in range(max_length-1):
            # Get current predictions
            tgt_emb = model.embedding(tgt) + model.positional_encoding[:, :tgt.size(1), :].to(device)
            src_emb = model.embedding(src) + model.positional_encoding[:, :src.size(1), :].to(device)
            src_feat_emb = model.feature_embedding(src_features).unsqueeze(1).expand(-1, src_emb.size(1), -1)
            src_emb_concat = torch.cat([src_emb, src_feat_emb], dim=-1)
            src_emb = model.concat_projection(src_emb_concat)

            tgt_mask = model._generate_square_subsequent_mask(tgt.size(1)).to(device)
            transformer_output = model.transformer(src_emb, tgt_emb, tgt_mask=tgt_mask)
            logits = model.fc(transformer_output)  # Shape: (batch_size, tgt_len, vocab_size)

            # Decode next token
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)  # Get the token with max probability
            tgt = torch.cat([tgt, next_token], dim=1)

        return tgt  # Remove start token from the output

def greedy_decode_and_compute_loss(model, src, src_features, target, start_token, max_length, device):
    model.eval()
    weight = torch.tensor([1,1,1,1,2,0.01,1,1], dtype=torch.float32, requires_grad=False).to(device)
    criterion = nn.CrossEntropyLoss(weight=weight,reduction='none')
    with torch.no_grad():
        batch_size = src.size(0)
        tgt = torch.full((batch_size, 1), start_token, dtype=torch.long, device=device)  # Initialize with start token
        outputs = []
        sequence_losses = torch.zeros(batch_size, device=device)  # To store loss for each sequence

        for i in range(max_length - 1):
            # Get current predictions
            tgt_emb = model.embedding(tgt) + model.positional_encoding[:, :tgt.size(1), :].to(device)
            src_emb = model.embedding(src) + model.positional_encoding[:, :src.size(1), :].to(device)
            src_feat_emb = model.feature_embedding(src_features).unsqueeze(1).expand(-1, src_emb.size(1), -1)
            src_emb_concat = torch.cat([src_emb, src_feat_emb], dim=-1)
            src_emb = model.concat_projection(src_emb_concat)

            tgt_mask = model._generate_square_subsequent_mask(tgt.size(1)).to(device)
            transformer_output = model.transformer(src_emb, tgt_emb, tgt_mask=tgt_mask)
            logits = model.fc(transformer_output)  # Shape: (batch_size, tgt_len, vocab_size)

            # Calculate the cross-entropy loss for each token in the sequence
            loss = criterion(logits[:, i, :].reshape(-1, logits.size(-1)), target[:, i].reshape(-1).long())  # Compare logits for each token with target
            sequence_losses += loss  

            # Decode next token using greedy strategy (argmax)
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
            tgt = torch.cat([tgt, next_token], dim=1)

        return tgt, sequence_losses/max_length  # Return the final decoded sequence and the loss for each sequence


def greedy_decode_with_temperature(model, src, feature1, start_token, max_len,device, temperature=1.0):
    src = src.to(device)
    src = src.unsqueeze(0)  # Add batch dimension
    model.eval()
    with torch.no_grad():
        tgt = torch.tensor([start_token], device=device).unsqueeze(0)
        src_emb = model.embedding(src) + model.positional_encoding[:, :src.size(1), :].to(device)
        src_feat_emb = model.feature_embedding(feature1).unsqueeze(1).expand(-1, src_emb.size(1), -1)
        src_emb_concat = torch.cat([src_emb, src_feat_emb], dim=-1)
        src_emb = model.concat_projection(src_emb_concat)
        encoder_output = model.transformer.encoder(src_emb)
        
        for _ in range(max_len):
            tgt_emb = model.embedding(tgt) + model.positional_encoding[:, :tgt.size(1), :].to(device)
            tgt_mask = model._generate_square_subsequent_mask(tgt.size(1)).to(device)
            decoder_output = model.transformer.decoder(tgt_emb, encoder_output, tgt_mask=tgt_mask)
            logits = model.fc(decoder_output[:, -1, :])  # 最后一时间步
            probs = torch.softmax(logits / temperature, dim=-1)  # 调整概率分布
            
            next_token = torch.multinomial(probs, num_samples=1).item()  # 随机采样
            tgt = torch.cat([tgt, torch.tensor([[next_token]], device=device)], dim=1)
            
        
        return tgt

def top_k_sampling(model, src, feature1, start_token, max_len,device, k=5):
    src = src.to(device)
    model.eval()
    with torch.no_grad():
        tgt = torch.tensor([start_token], device=device).unsqueeze(0)
        src_emb = model.embedding(src) + model.positional_encoding[:, :src.size(1), :].to(device)
        src_feat_emb = model.feature_embedding(feature1).unsqueeze(1).expand(-1, src_emb.size(1), -1)
        src_emb_concat = torch.cat([src_emb, src_feat_emb], dim=-1)
        src_emb = model.concat_projection(src_emb_concat)
        encoder_output = model.transformer.encoder(src_emb)
        
        for _ in range(max_len):
            tgt_emb = model.embedding(tgt) + model.positional_encoding[:, :tgt.size(1), :].to(device)
            tgt_mask = model._generate_square_subsequent_mask(tgt.size(1)).to(device)
            decoder_output = model.transformer.decoder(tgt_emb, encoder_output, tgt_mask=tgt_mask)
            logits = model.fc(decoder_output[:, -1, :])
            probs = torch.softmax(logits, dim=-1)
            
            # 取 top-k 的概率和索引
            top_k_probs, top_k_indices = torch.topk(probs, k)
            top_k_probs = top_k_probs / top_k_probs.sum()  # 归一化
            next_token = torch.multinomial(top_k_probs, 1).item()
            
            tgt = torch.cat([tgt, torch.tensor([[next_token]], device=device)], dim=1)        
        return tgt


In [None]:
from dataloader import get_dataloaders, MAX_SEQ_LENGTH, vocab_size, vocabulary
import torch.nn as nn
import random
batch_size = 1
train_loader, dev_loader, test_loader = get_dataloaders(batch_size=batch_size, one_hot_encode=False, start_token=True, get_feature=True)
random.seed(0)
train_samples = random.sample(list(train_loader), 10)
# random select 5 dev samples
dev_samples = random.sample(list(dev_loader), 10)
# random select 5 test samples
test_samples = random.sample(list(test_loader), 10)

vocab = list(vocabulary.keys())
def outputs_to_seq(outputs, flag=False):
    if flag:
        #print(outputs)
        outputs = outputs.argmax(dim=-1)
    # print(outputs.shape)
    outputs = [vocab[i] for i in outputs]
    if 'S' in outputs:
        outputs = outputs[1:]
    if 'P' in outputs:
        outputs = outputs[:outputs.index('P')]
    if 'E' in outputs:
        outputs = outputs[:outputs.index('E')]
    return outputs

model.eval()
# 输出原来的seq1和seq2，还有预测的seq2
for i in range(10):
    seq1, feature1, seq2, _ = train_samples[i]
    seq1 = seq1.to(device)
    seq2 = seq2.to(device)
    feature1 = torch.stack(feature1, dim=1)
    feature1 = feature1.to(device).float()
    tgt_input = seq2[:, :-1]
    tgt_output = seq2[:, 1:]
    outputs = model(seq1, tgt_input, feature1)
    criterion = nn.CrossEntropyLoss()
    loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt_output.reshape(-1).long())
    print("train loss: ", loss.item())
    print("seq1: ", outputs_to_seq(seq1[0][1:]))
    print("seq2: ", outputs_to_seq(seq2[0][1:]))
    print("pred: ", outputs_to_seq(outputs[0], True))

print("dev samples")
for i in range(10):
    seq1, feature1, seq2, _ = dev_samples[i]
    seq1 = seq1.to(device)
    seq2 = seq2.to(device)
    feature1 = torch.stack(feature1, dim=1)
    feature1 = feature1.to(device).float()
    tgt_input = seq2[:, :-1]
    tgt_output = seq2[:, 1:]
    outputs = model(seq1, tgt_input, feature1)
    criterion = nn.CrossEntropyLoss()
    loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt_output.reshape(-1).long())
    print("dev loss: ", loss.item())
    print("seq1: ", outputs_to_seq(seq1[0][1:]))
    print("seq2: ", outputs_to_seq(seq2[0][1:]))
    print("pred: ", outputs_to_seq(outputs[0], True))

#test model
print("test samples")
generation_method = 'features_loss' # change to 'greedy' or 'greedy_temp' if necessary
for i in range(10):
    seq1, feature1, seq2, _ = test_samples[i]
    model.eval()
    with torch.no_grad():
        seq1, seq2 = seq1.to(device), seq2.to(device)
        feature1 = torch.stack(feature1, dim=1)
        feature1 = feature1.to(device).float()
        # Generate sequences
        if generation_method == 'greedy_temp':
            generated_seq = greedy_decode_with_temperature(model, seq1, feature1, 7, MAX_SEQ_LENGTH, device,1.2)
        if generation_method == 'greedy':
            generated_seq = generate_sequence(model, seq1, feature1, 7, MAX_SEQ_LENGTH, device)
        if generation_method == 'k_sampling':
            generated_seq = top_k_sampling(model, seq1, feature1, 7, MAX_SEQ_LENGTH, device,2)
        if generation_method=='features':
            generated_seq=greedy_decode_features(model, seq1, feature1, 7, MAX_SEQ_LENGTH, device)
        if generation_method=='features_loss':
            generated_seq,loss=greedy_decode_and_compute_loss(model, seq1, feature1, seq2,7, MAX_SEQ_LENGTH, device)
        # Compute loss

        for j, (sequence, tmp_loss,tmp_seq1,tmp_seq2) in enumerate(zip(generated_seq, loss,seq1,seq2)):
            print("test loss:", tmp_loss.item())
            print("seq1: ", outputs_to_seq(tmp_seq1[1:]))
            print("seq2: ", outputs_to_seq(tmp_seq2[1:]))
            print("pred: ", outputs_to_seq(sequence))



['ADAR1_seq.txt', 'ADAR2_seq.txt', 'ADAR3_seq.txt', 'Endogenous_ADAR1_seq.txt']
train loss:  0.9240440130233765
seq1:  ['G', 'G', 'C', 'C', 'C', 'A', 'G', 'A', 'C', 'T', 'G', 'G', 'C', 'A', 'C', 'C', 'T', 'G', 'A', 'A', 'G', 'A', 'T', 'G', 'C', 'C', 'T', 'A', 'A', 'G', 'A', 'T', 'A', 'A', 'A', 'A', 'A', 'T', 'G', 'C', 'C', 'C', 'A', 'A', 'G', 'A', 'T', 'C']
seq2:  ['G', 'A', 'T', 'C', 'T', 'C', 'C', 'A', 'T', 'G', 'C', 'C', 'T', 'G', 'A', 'C', 'A', 'T', 'T', 'G', 'A', 'T', 'C', 'T', 'T', 'A', 'A', 'C', 'C', 'T', 'G', 'A', 'A', 'A', 'G', 'G', 'A', 'C', 'C', 'C', 'A', 'A', 'A', 'G', 'T', 'G', 'A', 'A', 'G', 'G', 'G', 'T', 'G', 'A', 'T', 'A', 'T', 'G', 'G', 'A', 'T', 'G', 'T', 'G', 'T', 'C', 'T', 'C', 'T', 'G', 'C', 'C']
pred:  ['G', 'A', 'T', 'C', 'T', 'G', 'T', 'T', 'G', 'G', 'G', 'T', 'T', 'G', 'C', 'G', 'A', 'G', 'G', 'G', 'G', 'A', 'G', 'T', 'G', 'G', 'A', 'A', 'A', 'T', 'G', 'G', 'A', 'G', 'A', 'A', 'A', 'A', 'A']
train loss:  0.5361067652702332
seq1:  ['C', 'A', 'G', 'G', 'A', 'G',