In [None]:
%reload_ext autoreload
%autoreload 2

import os
import torch
import torch_geometric
import Bio.PDB as PDB
import kmbio  # fork of biopython PDB with some changes in how the structure, chain, etc. classes are defined.
import numpy as np
import proteinsolver

from proteinsolver.models.model import *
from proteinsolver.datasets import *

# custom stuff
#import proteinsolver_utils
#import proteinsolver_datasets
np.random.seed(1)

In [None]:
#from_dir = "/home/sebastian/masters/data/210916_TCRpMHCmodels/models/"
#to_dir = "/home/sebastian/masters/data/neat_data/tcrpmhc/"
#model_suffix = "model_TCR-pMHC.pdb"
#for subdir in os.listdir(from_dir):
#    subdir_id = subdir.split("_")[0]
#    new_name = f"tcrpmhc_{subdir_id}.pdb"
#    os.system(f"mv {from_dir}/{subdir}/{model_suffix} {to_dir}/{new_name}")
#    
#from_dir = "/home/sebastian/masters/data/embedding_verification/raw_filtered_models"
#to_dir = "/home/sebastian/masters/data/neat_data/pmhc/"
#model_suffix = "model_pMHC.pdb"
#for subdir in os.listdir(from_dir):
#    subdir_id = subdir.split("_")[0]
#    new_name = f"pmhc_{subdir_id}.pdb"
#    os.system(f"mv {from_dir}/{subdir}/{model_suffix} {to_dir}/{new_name}")
#    
#from_dir = "/home/sebastian/masters/data/embedding_verification/raw_filtered_models"
#to_dir = "/home/sebastian/masters/data/neat_data/p/"
#model_suffix = "model_p.pdb"
#for subdir in os.listdir(from_dir):
#    subdir_id = subdir.split("_")[0]
#    new_name = f"p_{subdir_id}.pdb"
#    os.system(f"mv {from_dir}/{subdir}/{model_suffix} {to_dir}/{new_name}")

In [None]:
UNIQUE_ID = "191f05de"
BEST_STATE_FILES = {
    #
    "191f05de": "/home/sebastian/proteinsolver/data/e53-s1952148-d93703104.state"
}
state_file = BEST_STATE_FILES[UNIQUE_ID]


#test_file = "/home/sebastian/proteinsolver/notebooks/protein_demo/inputs/1n5uA03.pdb"
#test_id = "1n5uA03.pdb"

test_file = "/home/sebastian/masters/data/test/3hfm.pdb"
test_id = "3hfm"

def load_model_paths(data_dir, model="model_TCR-pMHC.pdb"):
    model_list = list()
    for subdir in os.listdir(data_dir):
        path = f"{data_dir}/{subdir}/{model}"
        model_list.append(path)
    return np.array(model_list)

infiles = load_model_paths("/home/sebastian/masters/data/210916_TCRpMHCmodels/models/")

In [None]:
structure_all = kmbio.PDB.load(infiles[0])
structure_all = merge_chains(structure_all)
structure = kmbio.PDB.Structure(test_id, structure_all[0].extract('A'))

pdata = proteinsolver.utils.extract_seq_and_adj(structure, 'A', remove_hetatms=True)
data = proteinsolver.datasets.row_to_data(pdata)
data = proteinsolver.datasets.transform_edge_attr(data)
data.to(device)

In [5]:
import torch.nn.functional as F

In [6]:
F.softmax()

TypeError: softmax() missing 1 required positional argument: 'input'

In [None]:
#train_loader = iter(torch_geometric.data.DataLoader(d_train, batch_size=batch_size))
#valid_loader = iter(torch_geometric.data.DataLoader(d_valid, batch_size=batch_size))
#d = next(train_loader)
#x = torch.ones_like(d.x)*d.x.max().item()
#out = gnn(x, d.edge_index, d.edge_attr)
#
#net.eval()
#with torch.no_grad():
#    y = net(out.T.unsqueeze(0))
#    y = F.softmax(y, dim=0)


In [None]:
def gnn_to_fnn(data, hidden_size, gnn_instance):
        data = data.to(device)
        y = data.y.to(device)
        with torch.no_grad():
            out = gnn_instance(data.x, data.edge_index, data.edge_attr)
        
        batches = torch.unique(data.batch)
        sliced_outs = list()
        pool = nn.AdaptiveAvgPool1d(output_size=hidden_size)
        for batch_idx in batches:
            batch_slice = torch.nonzero(data.batch == batch_idx)
            chain_map = data.chain_map[batch_idx]
            out_sliced = out[batch_slice]
            out_sliced = out_sliced[chain_map == "P"]  # get peptide only
            out_sliced = pool(out_sliced.T)
            sliced_outs.append(out_sliced)
        batched_out = torch.cat(sliced_outs, dim=1)
        return out, y
    
def gnn_to_lstm_batch(data, gnn_instance, device, num_classes):
    """function for bridging gnn output to lstm"""
    data = data.to(device)
    y = data.y
    with torch.no_grad():
        out = gnn_instance(data.x, data.edge_index, data.edge_attr)
    
    batches = torch.unique(data.batch)
    sliced_embeddings = list()
    encoded_y = list()
    for batch_idx in batches:
        # split sub graphs into batches
        batch_slice = torch.nonzero(data.batch == batch_idx)
        chain_map = data.chain_map[batch_idx]
        one_batch_peptide_emb = out[batch_slice][chain_map == "P"]  # get peptide only
        sliced_embeddings.append(one_batch_peptide_emb.squeeze(1))
        
        # one hot encode targets
        sliced_y = int(y[batch_idx].item())
        one_hot_y = np.zeros(num_classes)
        one_hot_y[sliced_y] = 1
        encoded_y.append(one_hot_y)
        
    sliced_embeddings.sort(key=lambda x: len(x))
    encoded_y = torch.Tensor(encoded_y)
    
    return sliced_embeddings, encoded_y

In [None]:
class MyFNN(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        
        self.linear = nn.Sequential(
            nn.Linear(hidden_size * num_features, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1)  # output dim should be 3 long
        )

    def forward(self, x):
        x = x.T.flatten()
        x = self.linear(x)
        return x

In [None]:
# init proteinsolver gnn
num_features = 20
adj_input_size = 2
hidden_size = 128
#frac_present = 0.5
#frac_present_valid = frac_present
#info_size= 1024

gnn = Net(
    x_input_size=num_features + 1, 
    adj_input_size=adj_input_size, 
    hidden_size=hidden_size, 
    output_size=num_features
)
gnn.load_state_dict(torch.load(state_file, map_location=device))
gnn.eval()
gnn = gnn.to(device)

In [None]:
#from sklearn.model_selection import KFold
#from sklearn.metrics import *
#from torch import nn, optim
#import torch.nn.functional as F
#
#
#root = Path("/home/sebastian/masters/data/")
#data_root = root / "neat_data"
#metadata_path = data_root / "embedding_dataset.csv"
#processed_dir = data_root / "processed" / "embedding_verification"
#state_file = root / "state_files" / "e53-s1952148-d93703104.state"
#out_dir = root / "state_files" / "embedding_verification" 
#
#device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
#
## load dataset
#raw_files = list()
#targets = list()
#with open(metadata_path, "r") as infile:
#    for line in infile:
#        line = line.strip().split(",")
#        raw_files.append(line[0])
#        targets.append(int(line[1]))
#
#raw_files = np.array(raw_files)
#targets = np.array(targets)
#
#dataset = ProteinDataset(processed_dir, raw_files, targets, overwrite=False)
#
## init proteinsolver gnn
#num_features = 20
#adj_input_size = 2
#hidden_size = 128
#
#gnn = Net(
#    x_input_size=num_features + 1, 
#    adj_input_size=adj_input_size, 
#    hidden_size=hidden_size, 
#    output_size=num_features
#)
#gnn.load_state_dict(torch.load(state_file, map_location=device))
#gnn.eval()
#gnn = gnn.to(device)
#
## init LSTM
#num_classes = 3
#num_layers = 2
#hidden_size = 26
#
#net = MyLSTM(num_classes, num_features, num_layers, hidden_size)
#net = net.to(device)
#
#criterion = nn.BCEWithLogitsLoss()
#optimizer = optim.Adam(net.parameters(), lr=0.0001) 
#
## training params
#epochs = 1
#n_splits = 5
#batch_size = 5
#
## touch files to ensure output
#save_dir = get_non_dupe_dir(out_dir)
#loss_paths = touch_output_files(save_dir, "loss", n_splits)
#state_paths = touch_output_files(save_dir, "state", n_splits)
#pred_paths = touch_output_files(save_dir, "pred", n_splits)
#
#CV = KFold(n_splits=n_splits, shuffle=True)
#i = 0
#for train_idx, valid_idx in CV.split(dataset):
#    
#    train_subset = dataset[torch.LongTensor(train_idx)][0:10]
#    valid_subset = dataset[torch.LongTensor(valid_idx)][0:10]
#    
#    net = MyLSTM(num_classes, num_features, num_layers, hidden_size)
#    net = net.to(device)
#    
#    # partial function - gnn arg is static, x is given later
#    gnn_transform = lambda x: gnn_to_lstm_batch(
#        x, 
#        gnn_instance=gnn, 
#        device=device,
#        num_classes=num_classes
#)
#    
#    net, train_subset_losses, valid_subset_losses = train_model(
#        model=net,
#        epochs=epochs, 
#        criterion=criterion,
#        optimizer=optimizer,
#        train_data=train_subset, 
#        valid_data=valid_subset,
#        batch_size=batch_size,
#        device=device,
#        transform=gnn_transform,
#)
#
#    torch.save({"train": train_subset_losses, "valid": valid_subset_losses}, loss_paths[i])
#    torch.save(net.state_dict(), state_paths[i])
#    
#    # perform test preds
#    y_pred, y_true = predict(
#        model=net, 
#        data=train_subset, 
#        batch_size=batch_size,
#        device=device,
#        transform=gnn_transform,
#)
#
#    torch.save({"y_pred": y_pred, "y_true": y_true,}, pred_paths[i])
#    
#    i += 1


In [None]:
    #annotations = train_subset.dataset.annotations.squeeze(1)
    #class_weights = compute_class_weight(
    #    'balanced',
    #    np.unique(annotations),
    #    annotations.numpy()
    #)
    #class_weights = torch.tensor(class_weights, dtype=torch.float)