In [6]:
import json

f = open("DeepTMHMM.partitions.json")
folds = json.load(f)

In [7]:
folds.keys() #ok, so these are the folds


dict_keys(['cv0', 'cv1', 'cv2', 'cv3', 'cv4'])

For prototyping models, I would pick one fixed assignment (e.g. train 0/1/2, val 3, test 4)

-> let's create a training and test partitioning instead of the folds 

Structure is dict -> list of dicts with protein name as key


In [16]:
for key in folds:
    print(key)

cv0
cv1
cv2
cv3
cv4


In [97]:
from collections import defaultdict
import pickle

train = []
val = []
test = []
with open("missing_samples/missing_prots.txt") as f:
    exclude_list = f.readlines()

exclude_list = [x.replace("\n","") for x in exclude_list]

with open("data_quality_control.pkl","rb") as f:
    data = pickle.load(f)

mismatches = data["Different sequence"]
for mismatch in mismatches:
    exclude_list.append(mismatch)

split_wise_removed = {"train":0,"val":0,"test":0}

for fold in list(folds.keys()):
    for sample in folds[fold]:#list 
        #check if alfafold download succeeded; if not, exclude from set. Save number of discarded samples in dict 
        if(sample["id"] in exclude_list):
            if(fold=="cv1" or fold=="cv2" or fold=="cv3"):
                split_wise_removed["train"] += 1
            elif(fold=="cv3"):
                split_wise_removed["val"] += 1
            else:
                split_wise_removed["test"] += 1

        else:
            if(fold!="cv4" and fold!="cv3"):
                train.append(sample)
            if(fold=="cv3"):
                val.append(sample)
            elif(fold=="cv4"):
                test.append(sample)

['Q5VT06', 'P29994', 'Q14315', 'Q9U943', 'Q9VDW6', 'Q14789', 'Q8WXX0', 'P69332', 'P36022', 'P04875', 'Q01484', 'Q05470', 'Q96Q15', 'O83774', 'Q5I6C7', 'Q96T58', 'Q9UKN1', 'Q9SMH5', 'P14217', 'P0DTC2', 'Q3KNY0', 'Q8IZQ1', 'Q9VKA4', 'Q9VC56', 'Q7TMY8', 'Q868Z9', 'Q9P2D1', 'Q6KC79', 'F8VPN2', 'P98161', 'O83276', 'Q61001']


In [101]:
import pickle
with open("splits/prototype/train.pkl", "wb") as f:
    pickle.dump(train,f)

with open("splits/prototype/val.pkl", "wb") as f:
    pickle.dump(val,f)

with open("splits/prototype/test.pkl", "wb") as f:
    pickle.dump(test,f)

with open("splits/prototype/missing_distribution.pkl", "wb") as f:
    pickle.dump(split_wise_removed, f)

In [99]:
print(split_wise_removed)

{'train': 25, 'val': 0, 'test': 7}


In [91]:
print(len(train),len(val),len(test))

2129 704 711


Ensure that there is no overlap in proteins for any of the folds

In [92]:
def assure_difference(train_li,val_li,test_li):
    """Checks that splits are not contaminated in any sort of way
        Returns bool. True = good
    """
    train_prots = [x["id"] for x in train]
    val_prots = [x["id"] for x in val]
    test_prots = [x["id"] for x in test]

    ###check that there are no repetitions
    rep_check = (len(set(train_prots))==len(train_prots) and len(set(val_prots))==len(val_prots) and len(set(test_prots))==len(test_prots))

    #check that no sets have an intersection
    s_train = set(train_prots)
    s_val = set(val_prots)
    s_test = set(test_prots)

    overlap_check = (len(s_train & s_val)==0 and len(s_train & s_test) == 0 and len(s_val & s_test) == 0)
    
    return (overlap_check and rep_check and overlap_check!=False)

print(assure_difference(train,val,test))


True


Good. Let's implement a dataset 

In [93]:
train[0]

{'id': 'P10384',
 'sequence': 'MSQKTLFTKSALAVAVALISTQAWSAGFQLNEFSSSGLGRAYSGEGAIADDAGNVSRNPALITMFDRPTFSAGAVYIDPDVNISGTSPSGRSLKADNIAPTAWVPNMHFVAPINDQFGWGASITSNYGLATEFNDTYAGGSVGGTTDLETMNLNLSGAYRLNNAWSFGLGFNAVYARAKIERFAGDLGQLVAGQIMQSPAGQTQQGQALAATANGIDSNTKIAHLNGNQWGFGWNAGILYELDKNNRYALTYRSEVKIDFKGNYSSDLNRAFNNYGLPIPTATGGATQSGYLTLNLPEMWEVSGYNRVDPQWAIHYSLAYTSWSQFQQLKATSTSGDTLFQKHEGFKDAYRIALGTTYYYDDNWTFRTGIAFDDSPVPAQNRSISIPDQDRFWLSAGTTYAFNKDASVDVGVSYMHGQSVKINEGPYQFESEGKAWLFGTNFNYAF',
 'labels': 'SSSSSSSSSSSSSSSSSSSSSSSSSPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPBBBBBBBOOOOOOOOOOOOOOOOOOOOOOOOOOOOBBBBBBBPPPPPPPPBBBBBBBOOOOOOOOOOOOOOOOOOOOOOOOOBBBBBBBBPPPPPPPPBBBBBBBBOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOBBBBBBBBBPPPPPPBBBBBBBBOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOBBBBBBBBBBPPPPBBBBBBBBBOOOOOOOOOOOOOOOOOOOOOOOOOOOOOOBBBBBBBBBBPPPBBBBBBBBOOOOOOOOOOOOOOOOOOOOOBBBBBBBBBPPPPBBBBBBBBOOOOOOOOOOOOOOOOOOOOOOOOBBBBBBBPP'}

In [103]:
import torch
from torch.utils.data import Dataset
import pdbreader
import pickle 


with open("splits/prototype/train.pkl",'rb') as f:
    train = pickle.load(f)

with open("splits/prototype/val.pkl",'rb') as f:
    val = pickle.load(f)

with open("splits/prototype/test.pkl",'rb') as f:
    test = pickle.load(f)




class transmembraneDataset(Dataset):
    def __init__(self,data_li,path):
        self.protein_names = [x["id"] for x in data_li]
        self.residue_sequnces = [x["sequence"] for x in data_li]
        self.label_sequences = [x["labels"] for x in data_li]        
        self.pdb_path = path
        

    def __len__(self):
        return len(self.label_sequences)

    def __getitem__(self,idx):
        name = self.protein_names[idx]
        pdb_file_path = self.pdb_path + name + ".pdb" #todo: read graph-tensor instead
        pdb_file = pdbreader.read_pdb(pdb_file_path)
        label = self.label_sequences[idx]
        return name, pdb_file_path, label


from torch.utils.data import DataLoader
train_dataloader = DataLoader(transmembraneDataset(train,"data/graphein_downloads/train/"),batch_size=1,shuffle=True)
val_dataloader = DataLoader(transmembraneDataset(val,"data/graphein_downloads/train/"),batch_size=1,shuffle=True)
test_dataloader = DataLoader(transmembraneDataset(test,"data/graphein_downloads/train/"),batch_size=1,shuffle=True)


In [104]:
#ensure that loading pdb works for all loaders (ie. no errors thrown at any point through this block)

for i, data in enumerate(train_dataloader):
    _, pdb_file_path, _ = data
    _ = pdb_file = pdbreader.read_pdb(pdb_file_path[0])


for i, data in enumerate(val_dataloader):
    _, pdb_file_path, _ = data
    _ = pdb_file = pdbreader.read_pdb(pdb_file_path[0])
    
for i, data in enumerate(test_dataloader):
    _, pdb_file_path, _ = data
    _ = pdb_file = pdbreader.read_pdb(pdb_file_path[0])
