In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score, precision_recall_curve
from sklearn.metrics import roc_auc_score, f1_score, average_precision_score, precision_score, recall_score, accuracy_score
from copy import deepcopy
torch.manual_seed(1)
from tdc.single_pred import Epitope, Paratope

In [2]:
data_class, name, X = Paratope, 'SAbDab_Liberis', 'Antibody'

In [3]:
data = data_class(name = name)
split = data.get_split()
train_data = split['train']
valid_data = split['valid']
test_data = split['test']
vocab_set = set()


Downloading...
100%|██████████| 150k/150k [00:00<00:00, 619kiB/s] 
Loading...
Done!


In [4]:
print(train_data)

    Antibody_ID                                           Antibody  \
0        2hh0_H  LEQSGAELVKPGASVKLSCTASGFNIEDSYIHWVKQRPEQGLEWIG...   
1        1u8q_B  ITLKESGPPLVKPTQTLTLTCSFSGFSLSDFGVGVGWIRQPPGKAL...   
2        4ydl_H  EVRLVQSGNQVRKPGASVRISCEASGYKFIDHFIHWVRQVPGHGLE...   
3        4ydk_H  EVQLSESGGGFVKPGGSLRLSCEASGFTFNNYAMGWVRQAPGKGLE...   
4        4ydj_H  QVQLVQPGTAMKSLGSSLTITCRVSGDDLGSFHFGTYFMIWVRQAP...   
..          ...                                                ...   
711      3ztn_H  QVQLVESGGGVVQPGRSLRLSCAASGFTFSTYAMHWVRQAPGKGLE...   
712      3ztn_L  DIVMTQSPDSLAVSLGERATINCKSSQSVTFNYKNYLAWYQQKPGQ...   
713      4hf5_H  EVQLVESGGGVVQPGRSLRLSCAASGFMFSSYVMHWVRQPPGKGLE...   
714      4hf5_L  QSVLTQSPSASGTPGQAITISCSGSSSNIGSNPVNWYQQLPGAAPK...   
715      3ab0_B  EVKLVESGGGLVKPGGSLKLSCSASGFTFSSYAMSWVRQTPEKRLE...   

                                                     Y  
0                                [49, 80, 81, 82, 101]  
1    [30, 31, 53, 83, 84, 85, 104, 105, 106, 

In [5]:
from sklearn import metrics
import matplotlib.pyplot as plt

from sklearn.metrics import roc_curve, auc
from sklearn.metrics import roc_auc_score


In [7]:
def data2vocab(data):
	length = len(data)
	vocab_set = set()
	total_length, positive_num = 0, 0
	for i in range(length):
		antigen = train_data[X][i]
		vocab_set = vocab_set.union(set(antigen))
		Y = train_data['Y'][i]
		assert len(antigen) > max(Y)
		total_length += len(antigen)
		positive_num += len(Y)
	return vocab_set, positive_num / total_length

In [8]:
train_vocab, train_positive_ratio = data2vocab(train_data)
valid_vocab, valid_positive_ratio = data2vocab(valid_data)
test_vocab, test_positive_ratio = data2vocab(test_data)

In [10]:
vocab_set = train_vocab.union(valid_vocab)
vocab_set = vocab_set.union(test_vocab)
vocab_lst = list(vocab_set)

In [11]:
vocab_lst 

['L',
 'D',
 'C',
 'K',
 'F',
 'T',
 'Q',
 'Y',
 'M',
 'E',
 'H',
 'W',
 'V',
 'G',
 'R',
 'P',
 'A',
 'S',
 'I',
 'N']

In [12]:
def onehot(idx, length):
	lst = [0 for i in range(length)]
	lst[idx] = 1
	return lst

def zerohot(length):
	return [0 for i in range(length)]

def standardize_data(data, vocab_lst, maxlength = 300):
	length = len(data)
	standard_data = []
	for i in range(length):
		antigen = data[X][i]
		Y = data['Y'][i]
		sequence = [onehot(vocab_lst.index(s), len(vocab_lst)) for s in antigen]
		labels = [0 for i in range(len(antigen))]
		mask = [True for i in range(len(labels))]
		sequence += (maxlength-len(sequence)) * [zerohot(len(vocab_lst))]
		labels += (maxlength-len(labels)) * [0]
		mask += (maxlength-len(mask)) * [False]
		for y in Y:
			labels[y] = 1
		sequence, labels, mask = sequence[:maxlength], labels[:maxlength], mask[:maxlength]
		sequence, labels, mask = torch.FloatTensor(sequence), torch.FloatTensor(labels), torch.BoolTensor(mask)
		# print(sequence.shape, labels.shape, mask.shape)
		standard_data.append((sequence, labels, mask))
	return standard_data

In [13]:
train_data = standardize_data(train_data, vocab_lst)
valid_data = standardize_data(valid_data, vocab_lst)
test_data = standardize_data(test_data, vocab_lst)



tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 