In [2]:
import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import lightning as L

## Dataloader

In [3]:
training_path="/pscratch/sd/l/lemonboy/PDB70_training_ver_A/eigenvalue_training/saxs_r/"

In [4]:
test=os.path.join(training_path,'6LN0_A.pdb.pr.csv')

In [5]:
pd.read_csv(test)

Unnamed: 0,r,P(r)
0,0.0,0.000000e+00
1,0.5,0.000000e+00
2,1.0,0.000000e+00
3,1.5,3.420565e-04
4,2.0,1.758611e-04
...,...,...
240,120.0,1.840159e-07
241,120.5,7.102368e-08
242,121.0,3.551184e-08
243,121.5,6.456698e-09


In [6]:
class SAXSDataset(Dataset):
    def __init__(self, csv_list):
        self.csv_list = csv_list
        
    def __len__(self):
        return len(self.csv_list)
    def __getitem__(self, idx):
        data = pd.read_csv(self.csv_list[idx])
        # The first point is always zero so I didn't include it into the dataset
        features = torch.tensor(np.pad(data['P(r)'].values[1:], (0, 512-len(data['P(r)'].values[1:])),constant_values=(0,0)), dtype=torch.float32)
        return features, features


In [7]:
import glob
csv_list = glob.glob(training_path+'*.csv')

In [8]:
#max_length=[]
#for i in csv_list:
    #pd_data=pd.read_csv(i)
    #max_length.append(len(pd_data))
    
#print(max(max_length))
#max_length.index(3179)
#indices_of_largest_10 = sorted(range(len(max_length)), key=lambda i: max_length[i], reverse=True)[:10]
#for i in indices_of_largest_10:
    #print(max_length[i])
#csv_list[42033]

In [9]:
batch_size = 32
shuffle = True
validation_split= 0.2

In [10]:
dataset = SAXSDataset(csv_list)

In [11]:
num_samples = len(dataset)
num_validation_samples = int(validation_split * num_samples)
num_train_samples = num_samples - num_validation_samples

In [12]:
train_dataset, val_dataset = random_split(dataset, [num_train_samples, num_validation_samples])

In [13]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [14]:
train_dataset

<torch.utils.data.dataset.Subset at 0x7efc386fb730>

In [15]:
#for batch_id, (data, target) in enumerate(train_dataloader):
#    print(batch_id)
#    print("datasize is %d" % len(data))
#    print("y size is %d" % len(target))
    

## VAE model

In [16]:
'''
class VAE(nn.Module):
    # For P(r) the latent_size should be between 6-12. Longer sequence should have a larger
    # latent. For testing purpose we will set latent_size as 10
    
    def __init__(self,input_size=245, hidden_size=20, latent_size=10):
        super(VAE, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.encoder_conv1 = nn.Conv1d(in_channels=1, out_channels=hidden_size,
                                      kernel_size=3, stride=1, padding=1)
        # The output should be 244,20
        self.encoder_avgpool = nn.AvgPool1d(kernal_size=4, stride=4)
        # The output shoule be 61,20
        self.encoder_fc_mu = nn.Linear(hidden_size, latent_size)
        self.encoder_fc_var = nn.Linear(hidden_size, latent_size)
        
        self.decoder_fc1 = nn.Linear(latent_size, hidden_size)
        self.decoder_fc2 = nn.Linear(hidden_size, input_size)
        
    def encode(self, x):
        x = F.relu(self.encoder_conv1(x.unsqueeze(1)))
        x = self.encoder_maxpool(x)
        mu = self.encoder_fc_mu(x)
        log_var self.encoder_fc_var(x)
    
    def decode(self, z):
        z = F.relu(self.decoder_fc1(z))
        return torch.sigmoid(self.decoder_fc2(z))
    
    def forward(self, x):

        x_encoded = self.encoder(x)
        mu, log_var = self.fc_mu(x_encoded), self.fc_var(x_encoded)

        std = torch.exp(log_var / 2)
        q = torch.distributions.Normal(mu, std)
        z = q.rsample()
        
        x_hat = self.decoder(z)
'''     

'\nclass VAE(nn.Module):\n    # For P(r) the latent_size should be between 6-12. Longer sequence should have a larger\n    # latent. For testing purpose we will set latent_size as 10\n    \n    def __init__(self,input_size=245, hidden_size=20, latent_size=10):\n        super(VAE, self).__init__()\n        self.input_size = input_size\n        self.hidden_size = hidden_size\n        self.latent_size = latent_size\n        self.encoder_conv1 = nn.Conv1d(in_channels=1, out_channels=hidden_size,\n                                      kernel_size=3, stride=1, padding=1)\n        # The output should be 244,20\n        self.encoder_avgpool = nn.AvgPool1d(kernal_size=4, stride=4)\n        # The output shoule be 61,20\n        self.encoder_fc_mu = nn.Linear(hidden_size, latent_size)\n        self.encoder_fc_var = nn.Linear(hidden_size, latent_size)\n        \n        self.decoder_fc1 = nn.Linear(latent_size, hidden_size)\n        self.decoder_fc2 = nn.Linear(hidden_size, input_size)\n        \n

## Self-Attention

In [17]:
device=torch.device("cuda")

In [18]:
x=np.random.rand(100)

In [19]:
x.shape

(100,)

In [20]:
class SAXS_to_Eigen(nn.Module):
    def __init__(self, hidden_features, output_dim):
        super().__init__()
        self.channels=hidden_features
        #self.upscale = nn.Linear(1,hidden_features)
        self.q = nn.Linear(1, hidden_features)
        self.k = nn.Linear(1, hidden_features)
        self.v = nn.Linear(1, hidden_features)
        self.out_layer = nn.Linear(hidden_features, output_dim)
    def forward(self, x):
        h_ = x[:, :, np.newaxis]
        #print(x.shape, h_.shape)
        #h_ = self.upscale(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)
        w_ = torch.bmm(q,k.permute(0,2,1))
        w_ = w_ * (self.channels**(-0.5))
        w_ = torch.nn.functional.softmax(w_,dim=2)
        h_ = torch.bmm(w_,v)
        h_ = self.out_layer(h_)
        h_ = nn.ReLU()(h_)
        print(h_.shape)
        return h_

In [53]:
class SAXS_to_Eigen_Cov(nn.Module):
    def __init__(self, hidden_features, output_dim):
        super().__init__()
        self.channels=hidden_features
        self.q = nn.Conv1d(in_channels=1, out_channels=hidden_features, kernel_size=1)
        self.k = nn.Conv1d(in_channels=1, out_channels=hidden_features, kernel_size=1)
        self.v = nn.Conv1d(in_channels=1, out_channels=hidden_features, kernel_size=1)
        self.out_layer = nn.Linear(hidden_features, output_dim)
    def forward(self, x):
        h_ = x[:, np.newaxis, :]
        #print(x.shape, h_.shape)
        #h_ = self.upscale(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)
        print(k.shape)
        w_ = torch.bmm(q.permute(0,2,1),k)
        w_ = w_ * (self.channels**(-0.5))
        w_ = torch.nn.functional.softmax(w_,dim=2)
        print(w_.shape)
        print(v.shape)
        h_ = torch.bmm(w_,v.permute(0,2,1))
        
        h_ = self.out_layer(h_)
        h_ = nn.ReLU()(h_)
        print(h_.shape)
        return h_

In [54]:
#cross attention between sequence and p(r)

In [55]:
model = SAXS_to_Eigen_Cov(64,128)

In [56]:
model(data[0:2])

torch.Size([2, 64, 512])
torch.Size([2, 512, 512])
torch.Size([2, 64, 512])
torch.Size([2, 512, 128])


tensor([[[0.1327, 0.3985, 0.5472,  ..., 0.1652, 0.0000, 0.4235],
         [0.1327, 0.3985, 0.5472,  ..., 0.1652, 0.0000, 0.4235],
         [0.1327, 0.3985, 0.5472,  ..., 0.1652, 0.0000, 0.4235],
         ...,
         [0.1327, 0.3985, 0.5472,  ..., 0.1652, 0.0000, 0.4235],
         [0.1327, 0.3985, 0.5472,  ..., 0.1652, 0.0000, 0.4235],
         [0.1327, 0.3985, 0.5472,  ..., 0.1652, 0.0000, 0.4235]],

        [[0.1327, 0.3985, 0.5472,  ..., 0.1652, 0.0000, 0.4235],
         [0.1327, 0.3985, 0.5472,  ..., 0.1652, 0.0000, 0.4235],
         [0.1327, 0.3985, 0.5472,  ..., 0.1652, 0.0000, 0.4235],
         ...,
         [0.1327, 0.3985, 0.5472,  ..., 0.1652, 0.0000, 0.4235],
         [0.1327, 0.3985, 0.5472,  ..., 0.1652, 0.0000, 0.4235],
         [0.1327, 0.3985, 0.5472,  ..., 0.1652, 0.0000, 0.4235]]],
       grad_fn=<ReluBackward0>)

## Training

In [24]:
learning_rate = 1e-3
loss_fn = nn.CrossEntropyLoss()
model = SAXS_to_Eigen(64,128)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [25]:
for batch_id,(data,target) in enumerate(train_dataloader):
    print(batch_id,data,target)
    break

0 tensor([[0.0000e+00, 0.0000e+00, 6.4150e-04,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 3.8886e-04,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 5.0475e-07, 5.8927e-04,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, 7.2609e-04,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 2.7945e-04,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 4.3498e-04,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]]) tensor([[0.0000e+00, 0.0000e+00, 6.4150e-04,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 3.8886e-04,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 5.0475e-07, 5.8927e-04,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, 7.2609e-04,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
 

In [26]:
model(data[0:2])

torch.Size([2, 512, 128])


tensor([[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.3292, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.3292, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.3292, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.3292, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.3292, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.3292, 0.0000]],

        [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.3292, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.3292, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.3292, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.3292, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.3292, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.3292, 0.0000]]],
       grad_fn=<ReluBackward0>)

In [27]:
'''
size = len(train_dataloader.dataset)
model.train()
for batch_id, (data, target) in enumerate(train_dataloader):
    pred = model(data)
    loss = loss_fn(pred, target)#Shift one digit
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    print(f'Training Batch {batch_idx}: Data shape: {data.shape}, Target shape: {target.shape}')
    if batch % 100 == 0:
        loss, current = loss.item(), batch * batch_size + len(data)
        print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]" )
    if batch_id==1:
        break
'''

'\nsize = len(train_dataloader.dataset)\nmodel.train()\nfor batch_id, (data, target) in enumerate(train_dataloader):\n    pred = model(data)\n    loss = loss_fn(pred, target)#Shift one digit\n    \n    loss.backward()\n    optimizer.step()\n    optimizer.zero_grad()\n    \n    print(f\'Training Batch {batch_idx}: Data shape: {data.shape}, Target shape: {target.shape}\')\n    if batch % 100 == 0:\n        loss, current = loss.item(), batch * batch_size + len(data)\n        print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]" )\n    if batch_id==1:\n        break\n'

## Pytorch lightning

In [None]:
class SAXSEncoderLightning(L.LightningModule):
        def __init__(self, hidden_features ,output_dim):
        super().__init__()
        self.channels=hidden_features
        self.q = nn.Linear(1, hidden_features)
        self.k = nn.Linear(1, hidden_features)
        self.v = nn.Linear(1, hidden_features)
        self.out_layer = nn.Linear(hidden_features, output_dim)
    def forward(self, x):
        h_ = x[:, :, np.newaxis]
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)
        w_ = torch.bmm(q,k.permute(0,2,1))
        w_ = w_ * (self.channels**(-0.5))
        w_ = torch.nn.functional.softmax(w_,dim=2)
        h_ = torch.bmm(w_,v)
        h_ = self.out_layer(h_)
        h_ = nn.ReLU()(h_)
        print(h_.shape)
        return h_
    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        loss = loss_fn(pred, target)
        self.log("train_loss",loss)
        return loss
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer