In [2]:
%reload_ext autoreload
%autoreload 2
import sys
sys.path.append('/home/sebastian/masters/') # add my repo to python path
import os
import torch
import torch_geometric
import kmbio  # fork of biopython PDB with some changes in how the structure, chain, etc. classes are defined.
import numpy as np
import proteinsolver
import modules
from pathlib import Path
from modules.dataset import *
from modules.utils import *
from modules.model import *

np.random.seed(1)

In [3]:
root = Path("/home/sebastian/masters/data/")
data_root = root / "neat_data"
state_file = root / "state_files" / "e53-s1952148-d93703104.state"

### Load positive examples and generate different structure configurations

In [4]:
class ChainFilter(kmbio.PDB.Select):
    def __init__(self, subset):
        self.subset = subset

    def accept_chain(self, chain):
        if chain.id in self.subset:
            return 1
        else:
            return 0

overwrite = False

raw_files, targets = get_data(
    model_dir=data_root / "raw" / "tcrpmhc",
    metadata=data_root / "metadata.csv",
)
mask = np.ma.masked_array(raw_files, mask=targets)  # only get positives

pmhc_chain_subset = ["M", "P"]
p_chain_subset = ["P"]
annotated_paths = list()

outdir_1 = data_root / "raw" / "pmhc"
outdir_2 = data_root / "raw" / "p"

outdir_1.mkdir(parents=True, exist_ok=True)
outdir_2.mkdir(parents=True, exist_ok=True)
for raw_file in raw_files[mask.mask]:

    model_id = raw_file.name.split("_")[0]
    pmhc_file_name = outdir_1 / f"{model_id}_pmhc.pdb"
    p_file_name =  outdir_2/ f"{model_id}_p.pdb"
    
    if overwrite or (not pmhc_file_name.is_file() or not p_file_name.is_file()):
        structure  = kmbio.PDB.load(raw_file)
    
        io = kmbio.PDB.io.PDBIO()
        io.set_structure(structure)
        io.save(pmhc_file_name, ChainFilter(subset=pmhc_chain_subset))
    
        io = kmbio.PDB.io.PDBIO()
        io.set_structure(structure)
        io.save(p_file_name, ChainFilter(subset=p_chain_subset))

    annotated_paths.append([raw_file, "0"])
    annotated_paths.append([pmhc_file_name, "1"])
    annotated_paths.append([p_file_name, "2"])

metadata_path = data_root / "embedding_dataset.csv"
with open(metadata_path, "w") as metadata_outfile:
    for data in annotated_paths:
        print(data[0], data[1], sep=",", file=metadata_outfile)

### Preprocess data

In [5]:
raw_files = list()
targets = list()
with open(metadata_path, "r") as infile:
    for line in infile:
        line = line.strip().split(",")
        raw_files.append(line[0])
        targets.append(int(line[1]))

raw_files = np.array(raw_files)
targets = np.array(targets)

processed_dir = data_root / "processed" / "embedding_verification"
dataset = ProteinDataset(processed_dir, raw_files, targets, overwrite=False)

### Define classifier

In [32]:
from torch import nn, optim
import torch.nn.functional as F

batch_size = 1
num_features = 20
adj_input_size = 2
hidden_size = 128
#frac_present = 0.5
#frac_present_valid = frac_present
#info_size= 1024

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

gnn = Net(
    x_input_size=num_features + 1, 
    adj_input_size=adj_input_size, 
    hidden_size=hidden_size, 
    output_size=num_features
)
gnn.load_state_dict(torch.load(state_file, map_location=device))
gnn.eval()
gnn = gnn.to(device)

class TestNet(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        
        self.linear = nn.Sequential(
            nn.Linear(hidden_size * num_features, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1)
        )

    def forward(self, x):
        x = x.T.flatten()
        x = self.linear(x)
        return x
    
net = TestNet(50)
net = net.to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)
save_path = "/home/sebastian/masters/data/trained_models/embedding_verification_{}.state"

### Fine tuning classifier

In [7]:
def train_model(
    epochs, 
    model,
    criterion,
    train_loader, 
    valid_loader,
    batch_size,
    device,
    transform=None,
):
    train_losses = list()
    valid_losses = list()
    for i in range(epochs):
        train_len = len(train_loader)
        valid_len = len(valid_loader)

        train_loss = 0
        model.train()
        for j, x in enumerate(train_loader):
            if transform:
                print(x.ptr)
                x, y = transform(x)
            else:
                x, y = x.to(device), y.to(device)
            out = model(x)
            print(out.shape, y.shape)
            print(out, y)

            loss = criterion(out, y)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()

            display_func(j, train_len, i, train_losses, valid_losses)

        valid_loss = 0
        model.eval()
        with torch.no_grad():
            for x in valid_loader:
                x, y = x.to(device), y.to(device)
                out = model(x)
                loss = criterion(out, y)
                valid_loss += loss.item()

        train_losses.append(train_loss / train_len)
        valid_losses.append(valid_loss / valid_len)
    
    return model, train_losses, valid_losses

In [35]:
from sklearn.model_selection import KFold


def gnn_to_fnn(x, hidden_size, gnn_instance):
        x = x.to(device)
        y = x.y.to(device)
        with torch.no_grad():
            out = gnn_instance(x.x, x.edge_index, x.edge_attr)
        
        batches = torch.unique(x.batch)
        sliced_outs = list()
        pool = nn.AdaptiveAvgPool1d(output_size=hidden_size)
        for batch_idx in batches:
            batch_slice = torch.nonzero(x.batch == batch_idx)
            # pool to fixed size https://stackoverflow.com/a/63603993/11398318

            out_sliced = out[batch_slice]
            out_sliced = pool(out_sliced.T)
            sliced_outs.append(out_sliced)
        batched_out = torch.cat(sliced_outs, dim=1)
        return out, y


epochs = 3
n_splits = 5
batch_size = 3

train_losses = list()
valid_losses = list()
CV = KFold(n_splits=n_splits, shuffle=True)
i = 0
for train_idx, valid_idx in CV.split(dataset):
    
    train_subset = dataset[torch.LongTensor(train_idx)]
    valid_subset = dataset[torch.LongTensor(valid_idx)]
    
    train_loader = iter(torch_geometric.loader.DataLoader(train_subset, batch_size=batch_size))
    valid_loader = iter(torch_geometric.loader.DataLoader(valid_subset, batch_size=batch_size))
    
    # run through data with gnn and transform
    # prepare for input to generic train func
    # all is well
    
    # partial function - gnn arg is static, x is given later
    gnn_transform = lambda x: gnn_to_fnn(x, hidden_size=hidden_size, gnn_instance=gnn)
    
    net, train_subset_losses, valid_subset_losses = train_model(
        epochs=epochs, 
        model=net, 
        criterion=criterion,
        train_loader=train_loader, 
        valid_loader=valid_loader,
        batch_size=batch_size,
        device=device,
        transform=gnn_transform,
)
    
    train_losses.append(train_subset_losses)
    valid_losses.append(valid_subset_losses)
    torch.save(net.state_dict(), save_path.format(i))
    i += 1


    


tensor([  0, 188, 197, 608])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x12160 and 1000x50)

In [15]:
x = next(train_loader)
x

Batch(x=[612], edge_index=[2, 31616], edge_attr=[31616, 2], y=[3], batch=[612], ptr=[4])

In [16]:
x.batch

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [17]:
o = gnn(x.x, x.edge_index, x.edge_attr)#.T.unsqueeze(0)
print(o.shape)
o

torch.Size([612, 20])


tensor([[ 1.3244e+00, -2.6383e+00,  1.6032e+01,  ...,  4.5585e-01,
         -4.9444e+00, -1.4957e+01],
        [ 2.2892e+01, -1.9287e+00, -2.8375e+00,  ..., -3.9515e+00,
         -9.6134e-02, -8.8087e-01],
        [-5.1125e+00,  1.9370e+01, -2.6890e-01,  ..., -6.5021e+00,
         -8.5269e+00, -3.6263e+00],
        ...,
        [-1.3708e+00, -3.9037e+00,  1.5272e+01,  ...,  9.3151e-04,
         -2.1439e+00, -1.3754e+01],
        [-6.4504e+00,  6.6553e+00, -2.2171e-01,  ...,  1.1206e+00,
         -7.1116e-01, -1.9762e+00],
        [-4.6803e+00,  2.2308e+01, -2.5976e+00,  ..., -9.4172e+00,
         -9.5308e+00, -5.5607e+00]], grad_fn=<AddmmBackward>)

In [19]:
x.batch

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [20]:
o[torch.nonzero(x.batch == 0)].shape

torch.Size([415, 1, 20])

In [30]:
out = o
batches = torch.unique(x.batch)
sliced_outs = list()
pool = nn.AdaptiveAvgPool1d(output_size=hidden_size)
for batch_idx in batches:
    batch_slice = torch.nonzero(x.batch == batch_idx)
    # pool to fixed size https://stackoverflow.com/a/63603993/11398318
    
    out_sliced = out[batch_slice]
    out_sliced = pool(out_sliced.T)
    sliced_outs.append(out_sliced)
batched_out = torch.cat(sliced_outs, dim=1)

In [31]:
batched_out.shape

torch.Size([20, 3, 128])

In [28]:
pool(out_sliced).shape

torch.Size([9, 1, 128])