In [None]:
#Importing the created transformer
import transformer_model

#manipulation
import numpy as np
import random
import pandas as pd
import torch

#trainer
import torch.nn as nn
import torch.optim as optim

#visualization
import matplotlib.pyplot as plt

#importing data
import pickle

#Device set to gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

### Importing data 
Depending on the preference the chosen data can be loaded to the device

In [None]:
#Importing processed PISCES data
pisces_angles = torch.load('data_processed/pisces_training/angles.pt')
pisces_embeddings = torch.load('data_processed/pisces_training/embeddings.pt')
with open('data_processed/pisces_training/sequences.pkl', 'rb') as f:
        pisces_sequences = pickle.load(f)

#Loading data to the device
#pisces_angles = pisces_angles.to(device)
#pisces_embeddings = pisces_embeddings.to(device)

print("Size of the angles: ",pisces_angles.size())
print("Size of the embeddings: ",pisces_embeddings.size())
print("Number of sequences in total: ",len(pisces_sequences)) 


pisces_sequences = pisces_sequences

In [None]:
#Importing processed AlphaFold data
alphafold_angles = torch.load('data_processed/alphafold_training/angles.pt')
alphafold_embeddings = torch.load('data_processed/alphafold_training/embeddings.pt')
with open('data_processed/alphafold_training/sequences.pkl', 'rb') as f:
        alphafold_sequences = pickle.load(f)

#Loading data to the device
alphafold_angles = alphafold_angles[:1000,:,:].to(device)
alphafold_embeddings = alphafold_embeddings[:1000,:,:].to(device)

print("Size of the angles: ",alphafold_angles.size())
print("Size of the embeddings: ",alphafold_embeddings.size())
print("Number of sequences in total: ",len(alphafold_sequences))

#print(alphafold_angles[2,:,:])
alphafold_sequences = alphafold_sequences[:1000]

In [None]:
#Collecting the whole data


In [None]:
#Modify dataframe and add the alphafold lengths as well
df = pd.read_csv('data_processed/alphafold_length.csv')


In [None]:
#Parameters used for the training
feed_forward_dim1 = 512 #FFNN layer 1
feed_forward_dim2 = 256 #FFNN layer 2
num_epochs = 200
dropout_rate = 0.1

#Embedding dimension coming from Prot-bert
D = 1024

#Importing the transformer model and setting the parameters
model = transformer_model.TransformerModel(embed_dim=D, feed_forward_dim1=feed_forward_dim1, feed_forward_dim2= feed_forward_dim2, dropout_rate = dropout_rate)
model = model.to(device)
#Custom loss
criterion = transformer_model.AngularLoss()

#Getting the lengths of the sequences from the chosen dataset
lengths = torch.tensor(df['length'].values)
print('Size of the tensor created using lengths: ',lengths.shape)

#Changing angles ro radians
angles_tensor = (alphafold_angles.T)*(np.pi/180)

#Masking for the padded values
def create_mask(indices):
    mask = torch.arange(129) < lengths[indices]
    mask = mask.to(device) #Loading the mask to the device
    return mask

In [None]:
class TransformerTrainer:
    '''
    Input =>
    model : transformer model used for the training
    criterion : loss function
    num_epochs : number of epochs
    sequence : protein sequences from the chosen dataset
    angles : angles from the chosen dataset
    
    Output =>
    val_loss : calculated validation loss list for the given number of epochs
    '''
    def __init__(self, model: nn.Module, criterion: nn.Module, num_epochs: int, sequence: torch.Tensor, angles: torch.Tensor):
        self.model = model
        self.criterion = criterion
        self.num_epochs = num_epochs
        self.sequence = sequence
        self.angles_tensor = angles
        self.optimizer = optim.Adam(model.parameters(), lr=0.001) #Optimizer for the back-propagation

    def train(self):
        train_loss = [] #training loss list
        val_loss = [] #validation loss list

        #Shuffling tje indices for randomization of the learning
        indices = np.arange(0, len(self.sequence[:,0,0]))
        random.shuffle(indices)
        
        #Validation set indices (10%) ?
        val_indices = indices[int(0.8 * len(indices)):int(0.9 * len(indices))]
        #Training set indices (80%)
        train_indices = indices[:int(0.8*len(self.sequence[:,0,0]))]
        #Test set indices (10%)
        test_indices = indices[int(0.9 * len(indices)):]

        #For the given epoch chosen training set is trained
        for epoch in range(self.num_epochs):
          for idx in train_indices:
            #Creating the mask for the given sequence
            train_mask = create_mask(indices = idx)
            #Setting gradients to 0
            self.optimizer.zero_grad()
            #Making the predictions
            attention_mask = torch.matmul((train_mask).type(torch.float32).reshape(129,1),torch.ones((1,129)).to(device))
            predictions = self.model.forward(self.sequence[idx,:,:], train_mask)
            #print(predictions)
            #Calculating the loss
            loss = self.criterion(idx,predictions.squeeze(), self.angles_tensor,attention_mask)
            #Back-propagation
            loss.backward(retain_graph=True)
            self.optimizer.step()
            
            #Appending losses
            train_loss.append(loss.item())

          #Validation loop
          self.model.eval()
          with torch.no_grad():
             for idx_val in val_indices:
                val_mask = create_mask(idx_val)
                predictions = self.model(self.sequence[idx,:,:],val_mask)
                epoch_loss = self.criterion(idx, predictions.squeeze(), self.angles_tensor,val_mask)
                val_loss.append(epoch_loss.item())

          #Saving model
          torch.save(self.model.state_dict(), 'predictions_alphafold/500d_200ep_01/model_postraining_500d_200ep_01.pt')
          print(f"Epoch {epoch + 1}/{self.num_epochs}, Training Loss: {train_loss[-1]}, Validation Loss: {val_loss[-1]}")

        return val_loss, test_indices

In [None]:
trainer = TransformerTrainer(model, criterion, num_epochs, alphafold_embeddings , angles_tensor)
loss,test_indices = trainer.train()

In [None]:
plt.plot(loss)
plt.xlabel('Total_number of epochs')
plt.ylabel('Validation loss')
plt.title('Validation loss during the training')

In [None]:
n0 = 133
predicted_angles = model.forward(alphafold_embeddings[n0,:,:], create_mask(indices = n0))[:lengths[n0],:]
angles_pred =predicted_angles.T*(180/np.pi)
angles_pred

In [None]:
#test_indices loss plot maybe or what can we do with that
test_indices
with open('predictions_alphafold/500d_200ep_01/test.txt', 'w') as f:
    for line in test_indices:
        f.write("%s\n" % line)

In [None]:
angles_new = torch.zeros(0,2,129)
angles_org = torch.zeros(0,2,129)
i = 0
for n in test_indices:
    predicted_angles = model.forward(pisces_embeddings[n,:,:], create_mask(indices = n))[:lengths[n],:]#embedding, tensor, what attention
    angles_current =predicted_angles.T*(180/np.pi)
    angle_length = min(angles_current.size(1), 129)
    if i >= angles_new.size(0):
        angles_org = torch.cat([angles_org, torch.zeros(1, 2, 129)], dim=0)
        angles_new = torch.cat([angles_new, torch.zeros(1, 2, 129)], dim=0)
    
    angles_new[i, :, :angle_length] = angles_current[:, :]
    angles_org[i, :, :angle_length] = pisces_angles[n,:, :angle_length]
    i += 1
    target_path = '.'
    torch.save(angles_org,"predictions_alphafold/500d_200ep_01/original_angles_test_500d_200ep.pt")
    torch.save(angles_new,"predictions_alphafold/500d_200ep_01/predicted_angles_test_500d_200ep.pt")

In [None]:
def is_tensor(obj):
    return isinstance(obj, torch.Tensor)

def angle_difference(angle1, angle2): # a wrapped difference that ensures the nagles are between -180 to 180
    return torch.atan2(torch.sin((angle1 - angle2) * np.pi / 180), torch.cos((angle1 - angle2) * np.pi / 180)) * 180 / np.pi

def calculate_angle_loss(angles_pred, angles_original): # to calculate both mean squared error and mean angle error
    # Ensure inputs are tensors
    if not is_tensor(angles_pred):
        angles_pred = torch.tensor(angles_pred)
    if not is_tensor(angles_original):
        angles_original = torch.tensor(angles_original)

    phi_pred = angles_pred[ 0, :]
    psi_pred =angles_pred[ 1, :]
    phi_original =  angles_original[0, :]
    psi_original = angles_original[ 1, :]

    phi_diff = angle_difference(phi_pred, phi_original)
    psi_diff = angle_difference(psi_pred, psi_original)

    #to calculate mse
    mse_phi = torch.mean(phi_diff ** 2)
    mse_psi = torch.mean(psi_diff ** 2)
    loss = (mse_phi + mse_psi)**0.5


    # Calculating mean angle error
    mae_phi = torch.mean(torch.abs(phi_diff))
    mae_psi = torch.mean(torch.abs(psi_diff))



   # print(f'Mean absolute phi_diff: {mae_phi.item()}')
   # print(f'Mean absolute psi_diff: {mae_psi.item()}')
   # print(f'Max phi_diff: {torch.max(torch.abs(phi_diff)).item()}')
   # print(f'Max psi_diff: {torch.max(torch.abs(psi_diff)).item()}')

    return loss, mae_phi, mae_psi

# Compute the angle-based loss
loss, mae_phi, mae_psi = calculate_angle_loss(angles_new, angles_org)
print(f'Angle-based loss: {loss.item()}')
print(f'Mean absolute error for phi: {mae_phi.item()} degrees')
print(f'Mean absolute error for psi: {mae_psi.item()} degrees')
with open('predictions_alphafold/500d_200ep_01/accuracy.txt', 'w') as f:
    for line in [loss, mae_phi, mae_psi]:
        f.write("%s\n" % line)

In [None]:
#torch.save(angles_new2 ,'predictions/full/predict_1234_03.pt' )