In [1]:
import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

## Dataloader

In [5]:
training_path="/pscratch/sd/l/lemonboy/PDB70_training_ver_A/pdbs/saxs_r/"

In [6]:
test=os.path.join(training_path,'6LN0_A.pdb.pr.csv')

In [8]:
pd.read_csv(test)

Unnamed: 0,r,P(r)
0,0.0,0.000000e+00
1,0.5,0.000000e+00
2,1.0,0.000000e+00
3,1.5,3.420565e-04
4,2.0,1.758611e-04
...,...,...
240,120.0,1.840159e-07
241,120.5,7.102368e-08
242,121.0,3.551184e-08
243,121.5,6.456698e-09


In [11]:
class SAXSDataset(Dataset):
    def __init__(self, csv_list):
        self.csv_list = csv_list
    def __len__(self):
        return len(self.csv_list)
    def __getitem__(self, idx):
        data = pd.read_csv(self.csv_list[idx])
        
        features = torch.tensor(data['P(r)'].values[1:], dtype=torch.float32)
        return features, features

In [12]:
import glob
csv_list = glob.glob(training_path+'*.csv')

In [14]:
batch_size = 32
shuffle = True
validation_split= 0.2

In [15]:
dataset = SAXSDataset(csv_list)

In [16]:
num_samples = len(dataset)
num_validation_samples = int(validation_split * num_samples)
num_train_samples = num_samples - num_validation_samples

In [17]:
train_dataset, val_dataset = random_split(dataset, [num_train_samples, num_validation_samples])

In [18]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

## VAE model

In [2]:
class VAE(nn.Module):
    # For P(r) the latent_size should be between 6-12. Longer sequence should have a larger
    # latent. For testing purpose we will set latent_size as 10
    
    def __init__(self,input_size=245, hidden_size=20, latent_size=10):
        super(VAE, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        
        self.encoder_conv1 = nn.Conv1d(in_channels=1, out_channels=hidden_size,
                                      kernel_size=3, stride=1, padding=1)
        self.encoder_maxpool = nn.MaxPool1d(kernal_size=2, stride=2)
        self.encoder_fc_mu = nn.Linear(hidden_size, latent_size)
        self.encoder_fc_var = nn.Linear(hidden_size, latent_size)
        
        self.decoder_fc1 = nn.Linear(latent_size, hidden_size)
        self.decoder_fc2 = nn.Linear(hidden_size, input_size)
        
    def encode(self, x):
        x = F.relu(self.encoder_conv1(x.unsqueeze(1)))
        x = self.encoder_maxpool(x)
        mu = self.encoder_fc_mu(x)
        log_var self.encoder_fc_var(x)
    
    def decode(self, z):
        z = F.relu(self.decoder_fc1(z))
        return torch.sigmoid(self.decoder_fc2(z))
    
    def forward(self, x):

        x_encoded = self.encoder(x)
        mu, log_var = self.fc_mu(x_encoded), self.fc_var(x_encoded)

        std = torch.exp(log_var / 2)
        q = torch.distributions.Normal(mu, std)
        z = q.rsample()
        
        x_hat = self.decoder(z)
        

## Training

In [None]:
for batch_id, (data, target) in enumerate(train_dataloader):
    print(f'Training Batch {batch_idx}: Data shape: {data.shape}, Target shape: {target.shape}')