In [1]:
import pandas as pd
import numpy as np
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import lightning as L

In [118]:
class HarmonicPrior(nn.Module):
    def __init__(self, hidden_features, output_dim=256):
        super().__init__()
        self.channels=hidden_features
        self.q = nn.Linear(1, hidden_features)
        self.k = nn.Linear(1, hidden_features)
        self.v = nn.Linear(1, hidden_features)
        self.out_layer = nn.Linear(hidden_features, output_dim)
        self.orthognal_vector = nn.utils.parametrizations.orthogonal(nn.Linear(output_dim,output_dim))
        self.background = Fixed_Prior().fixed_background()
    def forward(self, x):
        h_ = x[:, :, np.newaxis]
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)
        w_ = torch.bmm(q,k.permute(0,2,1))
        w_ = w_ * (self.channels**(-0.5))
        w_ = torch.nn.functional.softmax(w_,dim=2)
        h_ = torch.bmm(w_,v)
        h_ = self.out_layer(h_)
        h_ = torch.matmul(x.unsqueeze(1),h_)
        h_ = h_.squeeze(1)
        h_ = nn.ReLU()(h_)
        h_inv = 1/h_
        h_inv[0] = 0 
        Q = self.orthognal_vector.weight
        return torch.matmul(Q,torch.sqrt(h_inv).T).T + self.background

class Fixed_Prior:
    def __init__(self, N = 256, a =3/(3.8**2)):
        J = torch.zeros(N, N)
        for i, j in zip(np.arange(N-1), np.arange(1, N)):
            J[i,i] += a
            J[j,j] += a
            J[i,j] = J[j,i] = -a
        D, P = torch.linalg.eigh(J)
        D_inv = 1/D
        D_inv[0] = 0
        self.P, self.D_inv = P, D_inv
        self.N = N

    def to(self, device):
        self.P = self.P.to(device)
        self.D_inv = self.D_inv.to(device)
        
    def fixed_background(self):
        return torch.matmul(self.P,torch.sqrt(self.D_inv))

    def sample(self, batch_dims=()):
        return self.P @ (torch.sqrt(self.D_inv)[:,None] * torch.randn(*batch_dims, self.N, 3, device=self.P.device))

In [105]:
training_path="/pscratch/sd/l/lemonboy/PDB70_training_ver_A/eigenvalue_training/saxs_r/"

In [106]:
class SAXSDataset(Dataset):
    def __init__(self, csv_list):
        self.csv_list = csv_list
        
    def __len__(self):
        return len(self.csv_list)
    def __getitem__(self, idx):
        data = pd.read_csv(self.csv_list[idx])
        # The first point is always zero so I didn't include it into the dataset
        features = torch.tensor(np.pad(data['P(r)'].values[1:], (0, 512-len(data['P(r)'].values[1:])),constant_values=(0,0)), dtype=torch.float32)
        return features, features

In [107]:
import glob
csv_list = glob.glob(training_path+'*.csv')
batch_size = 32
shuffle = True
validation_split = 0.2
dataset = SAXSDataset(csv_list)
num_samples = len(dataset)
num_validation_samples = int(validation_split * num_samples)
num_train_samples = num_samples - num_validation_samples
train_dataset, val_dataset = random_split(dataset, [num_train_samples, num_validation_samples])
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
device=torch.device("cuda")
for batch_id,(data,target) in enumerate(train_dataloader):
    print(batch_id,data,target)
    break

0 tensor([[0.0000e+00, 1.4657e-04, 7.2678e-04,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 5.1920e-04,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [2.5383e-05, 1.7276e-04, 1.4588e-03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 8.2172e-06, 2.2053e-03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 6.9484e-04,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 2.9972e-04,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]]) tensor([[0.0000e+00, 1.4657e-04, 7.2678e-04,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 5.1920e-04,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [2.5383e-05, 1.7276e-04, 1.4588e-03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 8.2172e-06, 2.2053e-03,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
 

In [114]:
learning_rate = 1e-3
loss_fn = nn.CrossEntropyLoss()
model = HarmonicPrior(64,256)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [116]:
model(data[0:2]).shape

torch.Size([2, 256])
torch.Size([256, 256])
torch.Size([256])


torch.Size([2, 256])

In [123]:
test=Fixed_Prior()

In [128]:
test.fixed_background

<bound method Fixed_Prior.fixed_background of <__main__.Fixed_Prior object at 0x7fa982b89a90>>

In [127]:
test.sample().shape

torch.Size([256, 3])