In [54]:
#!wget https://services.healthtech.dtu.dk/suppl/immunology/NAR_NetMHCpan_NetMHCIIpan/NetMHCpan_train.tar.gz 
#!mv NetMHCpan_train.tar.gz ../data/
#!tar -xvzf ../data/NetMHCpan_train.tar.gz -C ../data/
#!pip install epitopepredict

Collecting epitopepredict
  Downloading epitopepredict-0.5.0.tar.gz (11.0 MB)
[K     |████████████████████████████████| 11.0 MB 30.9 MB/s eta 0:00:01
Collecting future
  Downloading future-1.0.0-py3-none-any.whl (491 kB)
[K     |████████████████████████████████| 491 kB 84.3 MB/s eta 0:00:01
Building wheels for collected packages: epitopepredict
  Building wheel for epitopepredict (setup.py) ... [?25ldone
[?25h  Created wheel for epitopepredict: filename=epitopepredict-0.5.0-py3-none-any.whl size=5973754 sha256=e3119146fc513b650782e84cb897fe7886c596d85bffb8c36332beed33ba5efb
  Stored in directory: /home/jrouhana/.cache/pip/wheels/33/0d/a5/5b2802337ae05b248638603e88da786b68227295df42f3da31
Successfully built epitopepredict
Installing collected packages: future, epitopepredict
Successfully installed epitopepredict-0.5.0 future-1.0.0


In [1]:
import pandas as pd
import os

In [20]:
base_dir = "../data/NetMHCpan_train/"

el_files = []

for file in os.listdir(base_dir):
    if file.endswith('el'):
        full_file_path = os.path.join(base_dir, file)
        el_files.append(full_file_path)
        
print(el_files)

['../data/NetMHCpan_train/c002_el', '../data/NetMHCpan_train/c001_el', '../data/NetMHCpan_train/c003_el', '../data/NetMHCpan_train/c004_el', '../data/NetMHCpan_train/c000_el']


In [23]:
dfs = []
for file in el_files:
    new_df = pd.read_csv(file, sep="\s+", header=None)
    dfs.append(new_df)

concat_df = pd.concat(dfs)
concat_df

Unnamed: 0,0,1,2
0,KTAVIDHHNY,1,HLA-B57:01
1,VYIDQTMVL,1,Mel-16
2,ATVDIVQKK,1,RA957
3,GKKHGITEL,1,A20-A20
4,FMFDEKLVTV,1,HLA-A02:07
...,...,...,...
2574787,YVTQAIGKWHMGE,0,pat-ST
2574788,YVVEHQFTHIE,0,pat-ST
2574789,YYALYVHPV,0,pat-ST
2574790,YYKNIDKTHY,0,pat-ST


In [65]:
#Subset by positives, drop duplicates, and populate cell-lines
concat_df.columns = ['epitope', 'label', 'allele']
cleaned_df = concat_df.query('label == 1').copy()
cleaned_df = cleaned_df.drop_duplicates().copy()

allele_mapper = pd.read_csv('../data/NetMHCpan_train/allelelist', sep='\s+', header=None)
allele_mapper_s = pd.Series(list(allele_mapper[1]), index=allele_mapper[0])
allele_mapper_s.drop_duplicates(inplace=True)

cleaned_df['MHC_expanded'] = cleaned_df['allele'].map(allele_mapper_s)

#Drop anything with any NAs. Don't wanna deal
cleaned_df = cleaned_df.dropna().copy()

#Subset by obviously human
cleaned_df = cleaned_df.loc[cleaned_df['MHC_expanded'].apply(lambda x: x.startswith('HLA'))].copy()
cleaned_df['MHC_expanded'] = cleaned_df['MHC_expanded'].apply(lambda x: x.split(','))
cleaned_df['peptide_length'] = cleaned_df.epitope.apply(lambda x: len(x))
cleaned_df.reset_index(drop=True, inplace=True)
cleaned_df['key_binder'] = cleaned_df.index 
cleaned_df = cleaned_df.explode('MHC_expanded').copy()
cleaned_df['allele'] = cleaned_df['MHC_expanded']
cleaned_df.drop('allele', axis=1, inplace=True)
cleaned_df.drop_duplicates(inplace=True)
cleaned_df.reset_index(drop=True, inplace=True)
cleaned_df['MHC_expanded'] = cleaned_df['MHC_expanded'].apply(lambda x: x.replace(
    'HLA-', '').replace(':', ''))
cleaned_df

Unnamed: 0,epitope,label,MHC_expanded,peptide_length,key_binder
0,KTAVIDHHNY,1,B5701,10,0
1,VYIDQTMVL,1,A0101,9,1
2,VYIDQTMVL,1,A2402,9,1
3,VYIDQTMVL,1,B0702,9,1
4,VYIDQTMVL,1,B0801,9,1
...,...,...,...,...,...
2010897,GQYENFRVQY,1,A2301,10,543926
2010898,GQYENFRVQY,1,B0702,10,543926
2010899,GQYENFRVQY,1,B1501,10,543926
2010900,GQYENFRVQY,1,C1203,10,543926


In [72]:
from Bio import SeqIO
from tqdm import tqdm

#Do an HLA-mapping to get the format into what our model knows
#Get fasta files, parse for alleles of interest
#For simplicity of demonstration, will not include HLA-G
hla_a_file="https://raw.githubusercontent.com/ANHIG/IMGTHLA/Latest/fasta/A_prot.fasta"
hla_b_file="https://raw.githubusercontent.com/ANHIG/IMGTHLA/Latest/fasta/B_prot.fasta"
hla_c_file="https://raw.githubusercontent.com/ANHIG/IMGTHLA/Latest/fasta/C_prot.fasta"
#hla_g_file="https://raw.githubusercontent.com/ANHIG/IMGTHLA/Latest/fasta/G_prot.fasta"

allele_mapper = {}

hla_files = [hla_a_file, hla_b_file, hla_c_file]#, hla_g_file]
#Loop through fasta files, get record of each HLA and add it to mapper if not present
for i, hla_file in enumerate(tqdm(hla_files)):
    output_file = hla_file.split('/')[-1]
    write_dir = f"../data/{output_file}"
    if not os.path.isfile(write_dir):
        urllib.request.urlretrieve(hla_file, write_dir)
    fasta_sequences = SeqIO.parse(open(write_dir),'fasta')
    for fasta in fasta_sequences:
        name, sequence = fasta.id, str(fasta.seq)
        description = fasta.description
        allele = description.split(" ")[1]

        allele_name = ''.join(allele.split(':')[:2]).replace('*', '')
        allele_suffix = ''.join(allele.split(':')[2:])

#        if ((allele_suffix in ['0101', '01', '']) and (allele_name not in allele_mapper)):
        #Naively assume that first entry is best entry. I'm sure this'll come back to bite me
        if ((allele_name in list(cleaned_df['MHC_expanded'].drop_duplicates())) and
            (allele_name not in allele_mapper)): 
            allele_mapper[allele_name] = sequence
        else:
            continue

assert(len(allele_mapper) == len(cleaned_df.MHC_expanded.drop_duplicates())) 

100%|████████████████████████████████████████████████████████████████████████████████████| 3/3 [32:21<00:00, 647.20s/it]


In [75]:
import pickle

with open("../data/allele_mapper.pkl", 'wb') as f:
    pickle.dump(allele_mapper,f)

In [79]:
cleaned_df['hla_sequence'] = cleaned_df['MHC_expanded'].map(allele_mapper)
cleaned_df['predict_on'] = cleaned_df['hla_sequence'] + cleaned_df['epitope'] 

cleaned_df.to_csv('../data/validation_data_ma.csv', index=False)

In [1]:
import pickle
import pandas as pd
cleaned_df = pd.read_csv('../data/validation_data_ma.csv')
cleaned_df

Unnamed: 0,epitope,label,MHC_expanded,peptide_length,key_binder,hla_sequence,predict_on
0,KTAVIDHHNY,1,B5701,10,0,MRVTAPRTVLLLLWGAVALTETWAGSHSMRYFYTAMSRPGRGEPRF...,MRVTAPRTVLLLLWGAVALTETWAGSHSMRYFYTAMSRPGRGEPRF...
1,VYIDQTMVL,1,A0101,9,1,MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...
2,VYIDQTMVL,1,A2402,9,1,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFSTSVSRPGRGEPRF...,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFSTSVSRPGRGEPRF...
3,VYIDQTMVL,1,B0702,9,1,MLVMAPRTVLLLLSAALALTETWAGSHSMRYFYTSVSRPGRGEPRF...,MLVMAPRTVLLLLSAALALTETWAGSHSMRYFYTSVSRPGRGEPRF...
4,VYIDQTMVL,1,B0801,9,1,MLVMAPRTVLLLLSAALALTETWAGSHSMRYFDTAMSRPGRGEPRF...,MLVMAPRTVLLLLSAALALTETWAGSHSMRYFDTAMSRPGRGEPRF...
...,...,...,...,...,...,...,...
2010897,GQYENFRVQY,1,A2301,10,543926,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFSTSVSRPGRGEPRF...,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFSTSVSRPGRGEPRF...
2010898,GQYENFRVQY,1,B0702,10,543926,MLVMAPRTVLLLLSAALALTETWAGSHSMRYFYTSVSRPGRGEPRF...,MLVMAPRTVLLLLSAALALTETWAGSHSMRYFYTSVSRPGRGEPRF...
2010899,GQYENFRVQY,1,B1501,10,543926,MRVTAPRTVLLLLSGALALTETWAGSHSMRYFYTAMSRPGRGEPRF...,MRVTAPRTVLLLLSGALALTETWAGSHSMRYFYTAMSRPGRGEPRF...
2010900,GQYENFRVQY,1,C1203,10,543926,MRVMAPRTLILLLSGALALTETWACSHSMRYFYTAVSRPGRGEPRF...,MRVMAPRTLILLLSGALALTETWACSHSMRYFYTAVSRPGRGEPRF...


In [4]:
from transformers import AutoTokenizer
from tqdm import tqdm

model_checkpoint = "esm2_t12_35M_UR50D_MHCI_classification/checkpoint-310935/"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

for i, key in enumerate(tqdm(cleaned_df.key_binder.drop_duplicates())):
    sub_df = cleaned_df.query('key_binder == @key').copy()
    test_sequences = sub_df['predict_on'].tolist()
    test_labels = sub_df['label'].tolist()


    #print("Tokenizing validation data...")
    test_tokenized = tokenizer(test_sequences)
    with open (f"../data/MA_dsets/validation_positives_need_deconvolve_{key}.pkl", 'wb') as f:
        pickle.dump([sub_df, test_tokenized, test_labels], f)

100%|██████████████████████████████████████████████████████████████████████████| 543927/543927 [52:21<00:00, 173.14it/s]


In [92]:
import pickle
from datasets import Dataset
import pandas as pd
import numpy as np
import pickle
import transformers
from tqdm import tqdm
import torch
from evaluate import load
import os

from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
import tensorflow as tf

model_checkpoint = "esm2_t12_35M_UR50D_MHCI_classification/checkpoint-310935/"
device = torch.device('cuda')
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2).to(device)


model.eval(); #enable for prediction
#final_predictions = []
#final_logits = []
base_dir = "../data/MA_dsets/"

predicted_dfs = []

with torch.no_grad():

    for i, file in tqdm(enumerate(os.listdir(base_dir))):
        full_file_path = os.path.join(base_dir, file)
        with open(full_file_path, 'rb') as f:
            sub_df, test_tokenized, test_labels = pickle.load(f)

        sub_df.reset_index(drop=True, inplace=True)

        sub_df['logit_1'] = np.nan
        sub_df['logit_2'] = np.nan
        sub_df['final_predictions'] = np.nan

        test_dset = Dataset.from_dict(test_tokenized)
        test_dset = test_dset.add_column("labels", test_labels)
        test_dset = test_dset.add_column("hla", sub_df['MHC_expanded'].tolist())
        test_dset = test_dset.add_column("epitope", sub_df['epitope'].to_list())

        for index in range(0, len(test_dset['input_ids'])):
            item = test_dset[index]
            sequence = [item['input_ids']]
            label = [item['labels']]
            attention_mask = [item['attention_mask']]
            # Generate prediction
            predictions = model(torch.IntTensor(sequence).to(device), torch.IntTensor(attention_mask).to(device))
            #final_logits = final_logits + [x for x in predictions['logits'].cpu()]
            ret = [x[1].numpy() for x in tf.nn.sigmoid(predictions['logits'].cpu())]
            #final_predictions = final_predictions + ret  

           # print(predictions['logits'].cpu())

            logits = [x for x in predictions.logits.cpu().numpy()[0]]
            sub_df.at[ index, 'logit_1'  ] = logits[0]
            sub_df.at[ index, 'logit_2' ] = logits[1]
            sub_df.at[ index, 'final_predictions' ] = ret[0]

        key = str(sub_df['key_binder'].values[0])

       # predicted_dfs.append(sub_df)
        sub_df.to_csv(f"../data/MA_dsets_with_predictions/validation_positives_deconvolved_{key}.csv", index=False)


#        print(sub_df)
#        break
                

        # Predicted class value using argmax
        # predicted_class = np.argmax(prediction)

543927it [9:45:55, 15.47it/s]


In [107]:
import numpy as np
import pandas 

base_dir = "../data/MA_dsets_with_predictions/"

import_dfs = []
for i, file in tqdm(enumerate(os.listdir(base_dir))):
    full_path = os.path.join(base_dir, file)
    tiny_df = pd.read_csv(full_path)
    if len(tiny_df) == 1:
        import_dfs.append(tiny_df)
    else:
        tiny_df = tiny_df.iloc[tiny_df['final_predictions'].idxmax()].to_frame().T.copy()
        import_dfs.append(tiny_df)

543927it [11:47, 769.19it/s]


In [108]:
deconv_df = pd.concat(import_dfs)
deconv_df.reset_index(drop=True, inplace=True)
deconv_df

Unnamed: 0,epitope,label,MHC_expanded,peptide_length,key_binder,hla_sequence,predict_on,logit_1,logit_2,final_predictions
0,LSTVGPRL,1,C0304,8,537649,MRVMAPRTLILLLSGALALTETWAGSHSMRYFYTAVSRPGRGEPHF...,MRVMAPRTLILLLSGALALTETWAGSHSMRYFYTAVSRPGRGEPHF...,-1.635385,1.839033,0.862834
1,AETTTLFQF,1,B4403,9,317702,MRVTAPRTLLLLLWGAVALTETWAGSHSMRYFYTAMSRPGRGEPRF...,MRVTAPRTLLLLLWGAVALTETWAGSHSMRYFYTAMSRPGRGEPRF...,-3.080864,3.342883,0.965871
2,LPMKVRALGL,1,B0702,10,518233,MLVMAPRTVLLLLSAALALTETWAGSHSMRYFYTSVSRPGRGEPRF...,MLVMAPRTVLLLLSAALALTETWAGSHSMRYFYTSVSRPGRGEPRF...,-2.684186,2.932315,0.949421
3,DIQSSGRAK,1,A0301,9,12400,MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,3.40641,-3.716567,0.02374
4,TMGHHTVGLK,1,A0301,10,308039,MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,1.666246,-1.93759,0.125913
...,...,...,...,...,...,...,...,...,...,...
543922,FPAPGKPGNYQY,1,B3501,12,159676,MRVTAPRTVLLLLWGAVALTETWAGSHSMRYFYTAMSRPGRGEPRF...,MRVTAPRTVLLLLWGAVALTETWAGSHSMRYFYTAMSRPGRGEPRF...,-2.815482,3.057637,0.955111
543923,RLMKHDVNL,1,A0201,9,519647,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,-1.801216,2.017691,0.882642
543924,ESLFVSNHAY,1,B1501,10,460875,MRVTAPRTVLLLLSGALALTETWAGSHSMRYFYTAMSRPGRGEPRF...,MRVTAPRTVLLLLSGALALTETWAGSHSMRYFYTAMSRPGRGEPRF...,-1.179181,1.330194,0.790873
543925,SEEFAQVY,1,B1801,8,226026,MRVTAPRTLLLLLWGAVALTETWAGSHSMRYFHTSVSRPGRGEPRF...,MRVTAPRTLLLLLWGAVALTETWAGSHSMRYFHTSVSRPGRGEPRF...,-3.024541,3.275008,0.963561


In [109]:
deconv_df.to_csv("../data/full_deconvolved_positive_validation_set.csv", index=False)

In [112]:
#Remove peptides that overlap with training/testing set from validation
filter_cell_line_peptides = "../data/NIHMS1541512-supplement-Sup_Tab1.xlsx" #table e
filter_patient_peptides = "../data/NIHMS1541512-supplement-Sup_Tab6.xlsx" #table a

filtration_dict = {}

filter_in_df = pd.read_excel(filter_cell_line_peptides, engine="openpyxl", sheet_name="e. Filtered peptide list")
alleles = list(filter_in_df['Allele'].drop_duplicates())

#populate sets
for allele in alleles:
    filtration_dict[allele] = set(filter_in_df.query("Allele == @allele")['Peptide'])


df2 = pd.read_excel(filter_patient_peptides, engine="openpyxl", sheet_name="a. Filtered peptide list") 

#filter patient peptides ubiquitously 
filtration_dict["patient_data"] = set(df2['Peptide'])
filter_in_df = filter_in_df.loc[~filter_in_df.Allele.apply(lambda x: x.startswith('G'))].copy()
filter_in_df

Unnamed: 0,Allele,Length,Peptide
0,A0101,8,ADMGHLKY
1,A0101,8,ELDDTLKY
2,A0101,8,FSDNIEFY
3,A0101,8,FTELAILY
4,A0101,8,GLDEPLLK
...,...,...,...
184178,C1701,35,NPARSFGPAVIMGNWENHWIYWVGPIIGAVLAGGL
184179,C1701,35,TLEGTPMNVNNLENPLDAPQNFKCMQGLSVEKPYK
184180,C1701,36,GILGLGIAAPNLAFSQMVTTFGLAGIVGYHTVWGVT
184181,C1701,38,NIAEHLGGIGGDLYSAAIVWEEQAYSNCSGPGFSIHSG


In [115]:
deconv_df['filter_out'] = deconv_df['MHC_expanded']+ deconv_df['epitope']
filter_in_df['filter_out'] = filter_in_df['Allele'] + filter_in_df['Peptide']

deconv_df_filtered = deconv_df.loc[~(deconv_df.filter_out.isin(set(filter_in_df.filter_out)))].copy()
deconv_df_filtered.drop('filter_out', axis=1, inplace=True)
deconv_df_filtered.to_csv("../data/filtered_deconvolved_positive_validation_set.csv", index=False)
deconv_df_filtered

Unnamed: 0,epitope,label,MHC_expanded,peptide_length,key_binder,hla_sequence,predict_on,logit_1,logit_2,final_predictions
0,LSTVGPRL,1,C0304,8,537649,MRVMAPRTLILLLSGALALTETWAGSHSMRYFYTAVSRPGRGEPHF...,MRVMAPRTLILLLSGALALTETWAGSHSMRYFYTAVSRPGRGEPHF...,-1.635385,1.839033,0.862834
1,AETTTLFQF,1,B4403,9,317702,MRVTAPRTLLLLLWGAVALTETWAGSHSMRYFYTAMSRPGRGEPRF...,MRVTAPRTLLLLLWGAVALTETWAGSHSMRYFYTAMSRPGRGEPRF...,-3.080864,3.342883,0.965871
2,LPMKVRALGL,1,B0702,10,518233,MLVMAPRTVLLLLSAALALTETWAGSHSMRYFYTSVSRPGRGEPRF...,MLVMAPRTVLLLLSAALALTETWAGSHSMRYFYTSVSRPGRGEPRF...,-2.684186,2.932315,0.949421
3,DIQSSGRAK,1,A0301,9,12400,MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,3.40641,-3.716567,0.02374
4,TMGHHTVGLK,1,A0301,10,308039,MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,MAVMAPRTLLLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,1.666246,-1.93759,0.125913
...,...,...,...,...,...,...,...,...,...,...
543921,HRFWKTELF,1,B2703,9,110624,MRVTAPRTLLLLLWGAVALTETWAGSHSMRYFHTSVSRPGRGEPRF...,MRVTAPRTLLLLLWGAVALTETWAGSHSMRYFHTSVSRPGRGEPRF...,-2.83538,3.089154,0.956443
543922,FPAPGKPGNYQY,1,B3501,12,159676,MRVTAPRTVLLLLWGAVALTETWAGSHSMRYFYTAMSRPGRGEPRF...,MRVTAPRTVLLLLWGAVALTETWAGSHSMRYFYTAMSRPGRGEPRF...,-2.815482,3.057637,0.955111
543923,RLMKHDVNL,1,A0201,9,519647,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,MAVMAPRTLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRF...,-1.801216,2.017691,0.882642
543924,ESLFVSNHAY,1,B1501,10,460875,MRVTAPRTVLLLLSGALALTETWAGSHSMRYFYTAMSRPGRGEPRF...,MRVTAPRTVLLLLSGALALTETWAGSHSMRYFYTAMSRPGRGEPRF...,-1.179181,1.330194,0.790873


In [None]:
# test_dset = Dataset.from_dict(test_tokenized)
# test_dset = test_dset.add_column("labels", test_labels)
# test_dset = test_dset.add_column("hla", cleaned_df['MHC_expanded'].tolist())
# test_dset = test_dset.add_column("epitope", cleaned_df['epitope'].to_list())

# with open("../data/validation_positives_need_deconvolve_dset.pkl", 'wb') as f:
#     pickle.dump(test_dset, f)

In [None]:
# import pickle
# from datasets import Dataset
# import pandas as pd
# import numpy as np
# import pickle
# import transformers
# from tqdm import tqdm
# import torch
# from evaluate import load

# from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
# import tensorflow as tf

# model_checkpoint = "esm2_t12_35M_UR50D_MHCI_classification/checkpoint-310935/"
# device = torch.device('cuda')
# model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2).to(device)


# model.eval(); #enable for prediction
# final_predictions = []
# final_logits = []

# with torch.no_grad():
#     iterator = 64 #Number of predictions to get with each loop
#     for i, n in enumerate(tqdm(range(0, train_dset.num_rows, iterator))):
#         try:
#             # Retrieve item
#             item = test_dset[n:n+iterator]
#             sequence = item['input_ids']
#             label = item['labels']
#             attention_mask = item['attention_mask']
    
#             # Generate prediction
#             predictions = model(torch.IntTensor(sequence).to(device), torch.IntTensor(attention_mask).to(device))
#             final_logits = final_logits + [x for x in predictions['logits'].cpu()]
#             ret = [x[1].numpy() for x in tf.nn.sigmoid(predictions['logits'].cpu())]
#             final_predictions = final_predictions + ret
#         except: #Different lengths occurred... Drat
#             for index in range(n, n+iterator):
#                 item = train_dset[index]
#                 sequence = [item['input_ids']]
#                 label = [item['labels']]
#                 attention_mask = [item['attention_mask']]

#                 # Generate prediction
#                 predictions = model(torch.IntTensor(sequence).to(device), torch.IntTensor(attention_mask).to(device))
#                 final_logits = final_logits + [x for x in predictions['logits'].cpu()]
#                 ret = [x[1].numpy() for x in tf.nn.sigmoid(predictions['logits'].cpu())]
#                 final_predictions = final_predictions + ret  
                

#         # Predicted class value using argmax
#         # predicted_class = np.argmax(prediction)