In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from dataloader import get_dataloaders, MAX_SEQ_LENGTH, vocab_size
import time

save_model = False

In [None]:
class RNAPairTransformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, feature_dim, num_layers=2, nhead=8, device='cpu'):
        super(RNAPairTransformer, self).__init__()

        self.input_dim = input_dim # input size，equal to vocab_size
        self.hidden_dim = hidden_dim # hidden size
        self.output_dim = output_dim  # output size，vocab_size
        self.feature_dim = feature_dim # feature size
        self.num_layers = num_layers # number of layers in the Transformer
        self.device = device

        # Embedding layer for one-hot encoded input
        self.embedding = nn.Embedding(input_dim, hidden_dim) # Embed the RNA sequence
        self.feature_embedding = nn.Linear(feature_dim, hidden_dim) # Embed the RNA features
        self.concat_projection = nn.Linear(hidden_dim * 2, hidden_dim) # Project the two embedding vector to hidden size
        self.positional_encoding = self._generate_positional_encoding(MAX_SEQ_LENGTH, hidden_dim) # positional encoding

        # Transformer Encoder-Decoder
        self.transformer = nn.Transformer(
            d_model=hidden_dim,
            nhead=nhead, # number of heads for multihead attention
            num_encoder_layers=num_layers, # layer of encoder
            num_decoder_layers=num_layers, # layer of decoder
            dim_feedforward=hidden_dim * 4, # hidden size
            batch_first=True, # size of the input is (batch_size, seq_length, feature_dim)。
            norm_first=True, # normalization
         #   dropout=0.1, # dropout threshold 
        )

        # Output layer
        self.fc = nn.Linear(hidden_dim, output_dim) # project the hidden size to the output size

    def _generate_positional_encoding(self, seq_length, hidden_dim):
        position = torch.arange(0, seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, hidden_dim, 2).float() * -(torch.log(torch.tensor(10000.0)) / hidden_dim)
        )
        positional_encoding = torch.zeros(seq_length, hidden_dim)
        positional_encoding[:, 0::2] = torch.sin(position * div_term)
        positional_encoding[:, 1::2] = torch.cos(position * div_term)
        positional_encoding = positional_encoding.unsqueeze(0)
        return positional_encoding
    
    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, src, tgt, src_features):
        # Generate target mask
        self.tgt_mask = self._generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        # Add positional encoding to embeddings
        src_emb = self.embedding(src) + self.positional_encoding[:, : src.size(1), :].to(self.device)
        tgt_emb = self.embedding(tgt) + self.positional_encoding[:, : tgt.size(1), :].to(self.device)
        src_feat_emb = self.feature_embedding(src_features).unsqueeze(1).expand(-1, src_emb.size(1), -1)
        src_emb_concat = torch.cat([src_emb, src_feat_emb], dim=-1)
        src_emb = self.concat_projection(src_emb_concat)
        # Pass through Transformer
        transformer_output = self.transformer(src_emb, tgt_emb, tgt_mask=self.tgt_mask)

        # Output layer
        output = self.fc(transformer_output)

        return output


def train_model(model, train_loader, criterion, optimizer, num_epochs, device, time_stamp):
    best_model = None
    best_dev_loss = float('inf')
    best_train_loss = 0
    best_epoch = 0

    loss_arr = []

    for epoch in range(num_epochs):
        # print(f'Epoch {epoch+1}/{num_epochs}')
        model.train()
        total_loss = 0
        for seq1, feature1, seq2, _ in train_loader:
            seq1, seq2 = seq1.to(device), seq2.to(device)
            feature1 = torch.stack(feature1, dim=1)
            feature1 = feature1.to(device).float()

            # Shift target sequence for decoder input
            tgt_input = seq2[:, :-1]
            tgt_output = seq2[:, 1:]

            outputs = model(seq1, tgt_input, feature1)
            loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt_output.reshape(-1).long())

            optimizer.zero_grad()
            loss.backward()
            #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) #apply gradient clipping
            optimizer.step()
            total_loss += loss.item()

        total_loss /= len(train_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss:.4f}')
        loss_arr.append(total_loss)


        # eval model
        if (epoch+1) % 10 == 0:
            dev_loss = evaluate_model(model, dev_loader, criterion, device)
            if dev_loss < best_dev_loss:
                best_dev_loss = dev_loss
                best_model = model
                best_train_loss = total_loss
                best_epoch = epoch
    
    # Save best model
    if save_model:
        torch.save(best_model.state_dict(), './model/'+time_stamp+'/transformer_model_best.pth')

        import json
        with open('./model/'+time_stamp+'/loss.json', 'w') as f:
            json.dump({'loss': loss_arr, 
                    'best_epoch': best_epoch, 
                    'best_dev_loss': best_dev_loss, 
                    'best_train_loss': best_train_loss}, f, indent=4)
            
    return best_dev_loss

def evaluate_model(model, dev_loader, criterion, device):
    model.eval()
    with torch.no_grad():
        total_loss = 0
        for seq1, feature1, seq2, _ in dev_loader:
            seq1, seq2 = seq1.to(device), seq2.to(device)
            feature1 = torch.stack(feature1, dim=1)
            feature1 = feature1.to(device).float()

            # Shift target sequence for decoder input
            tgt_input = seq2[:, :-1]
            tgt_output = seq2[:, 1:]

            outputs = model(seq1, tgt_input, feature1)
            loss = criterion(outputs.reshape(-1, outputs.size(-1)), tgt_output.reshape(-1).long())
            total_loss += loss.item()

        print(f'Dev Loss: {total_loss / len(dev_loader):.4f}')
    return total_loss / len(dev_loader)


In [None]:
# Hyperparameters sets
hidden_dims = [64,128,256]
num_layerss = [1,2,3,4]

# keep track of the best dev loss and the corresponding hyperparameters
loss_arrs = []
best_dvl_hidden = []
best_dvl_layer = []

# Train model sequentially

for i in range (len(hidden_dims)):
    for j in range(len(num_layerss)):

        print('Start trial ' + str(i+1))

        input_dim = vocab_size
        hidden_dim = hidden_dims[i]
        feature_dim = 4
        output_dim = vocab_size
        num_layers = num_layerss[j]
        nhead = 8
        num_epochs = 60
        learning_rate = 1e-3
        batch_size = 32
        time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())

        if save_model:
            import os
            if not os.path.exists('./model'):
                os.makedirs('./model')
            os.makedirs('./model/'+time_stamp)

            import json
            hyperparameters = {
                'time_stamp': time_stamp,
                'model': 'transformer',
                'input_dim': input_dim,
                'hidden_dim': hidden_dim,
                'feature_dim': feature_dim,
                'output_dim': output_dim,
                'num_layers': num_layers,
                'nhead': nhead,
                'num_epochs': num_epochs,
                'learning_rate': learning_rate,
                'batch_size': batch_size
            }
            with open('./model/'+time_stamp+'/hyperparameters.json', 'w') as f:
                json.dump(hyperparameters, f, indent=4)

        # Device configuration
        if torch.cuda.is_available():
            device = "cuda"
        elif torch.backends.mps.is_available():
            device = "mps"
        else:
            device = "cpu"
        print('Using ' + device)

        # Load data
        train_loader, dev_loader, test_loader = get_dataloaders(batch_size=batch_size, one_hot_encode=False, start_token=True, get_feature=True)

        # Initialize model, criterion and optimizer
        model = RNAPairTransformer(input_dim, hidden_dim, output_dim, feature_dim, num_layers, nhead, device).to(device)
        weight = torch.tensor([1,1,1,1,2,0.01,1,1], dtype=torch.float32, requires_grad=False).to(device)
        criterion = nn.CrossEntropyLoss(weight=weight)
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 0.9 ** epoch)

        # Train the model
        loss_temp = train_model(model, train_loader, criterion, optimizer, num_epochs, device, time_stamp)
        loss_arrs.append(loss_temp)
        best_dvl_hidden.append(hidden_dim)
        best_dvl_layer.append(num_layers)


In [None]:
# Print Result
print(loss_arrs)