In [3]:
import pickle
import os
import torch
import torch.nn as nn
from transformers import RobertaTokenizerFast, Trainer, TrainingArguments
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, log_loss, accuracy_score, matthews_corrcoef
from sklearn.utils import shuffle
from matplotlib import pyplot as plt
from sklearn.metrics import roc_auc_score
import re
import tqdm
import torch.nn.functional as F
from datetime import datetime
from copy import deepcopy
from scipy.special import softmax

In [4]:
MAX_LENGTH = 1024
EPOCHS = 10
LEARNING_RATE = 2e-6
BATCH_SIZE = 1
TOKENIZER_PATH =  "./Models/ST-PRoBERTa/Tokenizer"
PRETRAINED_MODEL = "./Models/ST-PRoBERTa/Checkpoints/checkpoint-3560000"
NUM_CLASSES = 2
SCHEDULER='cosine_with_restarts'

In [5]:
df_train = pickle.load(open('./Datasets/finetuning-IDRs-train.pickle', "rb"))
df_test = pickle.load(open('./Datasets/finetuning-IDRs-test.pickle', "rb"))

In [6]:
df_full = pd.concat([df_train, df_test])

In [7]:
df_full

Unnamed: 0,Sequence,full,disprot_ID
0,MELITNELLYKTYKQKPVGVEEPVYDQAGDPLFGERGAVHPQSTLK...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP00675
1,MESAQAVAEPLDLVRLSLDEIVYVKLRGDRELNGRLHAYDEHLNMV...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP02065
2,MPYLKGAPMNLQEMEKNSAKAVVLLKAMANERRLQILCMLLDNELS...,"[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...",DP02188
3,MAEKPKLHYFNARGRMESTRWLLAAAGVEFEEKFIKSAEDLDKLRN...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP01506
4,MAGVGTPCANGCGPSAPSEAEVLHLCRSLEVGTVMTLFYSKKSQRP...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP01851
...,...,...,...
235,MLALLCSCLLLAAGASDAWTGEDSAEPNSDSAEWIRDMYAKVTEIW...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP02656
236,MGEAEKFHYIYSCDLDINVQLKIGSLEGKREQKSYKAVLEDPMLKF...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP03710
237,MASMRESDTGLWLHNKLGATDELWAPPSIASLLTAAVIDNIRLCFH...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP02712
238,MLSTQFNRDNQYQAITKPSLLAGCIALALLPSAAFAAPATEETVIV...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP03193


In [8]:
df_caid = pickle.load(open('./Datasets/caid.pkl', "rb"))

In [9]:
df_caid

Unnamed: 0,ID,Sequence,full
0,DP00084,MSDNDDIEVESDEEQPRFQSAADKRAHHNALERKRRDHIKDSFHSL...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,DP00182,MAPTKRKGSCPGAAPKKPKEPVQVPKLVIKGGIEVLGVKTGVDSFT...,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
2,DP00206,MKAAQKGFTLIELMIVVAIIGILAAIAIPAYQDYTARAQLSERMTL...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,DP00334,MCNTNMSVPTDGAVTTSQIPASEQETLVRPKPLLLKLLKSVGAQKD...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,DP00359,MMLTKSVVISRPAVRPVSTRRAVVVRASGQPAVDLNKKVQDAVKEA...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...
647,DP02330,MWLPLTVLLLAGIVSADYDHGWHVNNEYIYLVRSRTLVNLNELSDQ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
648,DP02331,MWCPLFLVLLAGAATAEHLQAWKTDTEYQYAVRGRTLSALHDVADQ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
649,DP02332,MAGELADKKDRDASPSKEERKRSRTPDRERDRDRDRKSSPSKDRKR...,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
650,DP02333,MSHIQIPPGLTELLQGYTVEVLRQQPPDLVEFAVEYFTRLREARAP...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [10]:
# df_train = df_full[~df_full['disprot_ID'].isin(df_caid['ID'])]
df_full['intersect'] = [i in list(df_caid['ID']) for i in df_full['disprot_ID']]
df_train = df_full[df_full['intersect'] == False]
df_train

Unnamed: 0,Sequence,full,disprot_ID,intersect
0,MELITNELLYKTYKQKPVGVEEPVYDQAGDPLFGERGAVHPQSTLK...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP00675,False
5,MEDINFASLAPRHGSRPFMGNWQDIGTSNMSGGAFSWGSLWSGIKN...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP00808,False
6,MVSSVLSIPPQTCLLPRLPISDSVNCKSKIVYCLSTSVRGSSVKRQ...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP02919,False
7,MAAKFEVGSVYTGKVTGLQAYGAFVALDEETQGLVHISEVTHGFVK...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP00809,False
9,VVYTDCTESGQNLCLCEGSNVCGQGNKCILGSDGEKNQCVTGEGTP...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP00137,False
...,...,...,...,...
234,MGHNDSVETMDEISNPNNILLPHDGTGLDATGISGSQEPYGMVDVL...,"[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...",DP00036,False
235,MLALLCSCLLLAAGASDAWTGEDSAEPNSDSAEWIRDMYAKVTEIW...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP02656,False
236,MGEAEKFHYIYSCDLDINVQLKIGSLEGKREQKSYKAVLEDPMLKF...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP03710,False
237,MASMRESDTGLWLHNKLGATDELWAPPSIASLLTAAVIDNIRLCFH...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP02712,False


In [11]:
# df_train['full'] = [[int(i) for i in j] for j in df_train['full']]
# df_caid['full'] = [[int(i) for i in j] for j in df_caid['full']]
# print(df_train), print(df_caid)

In [12]:
# df_train['disordered_content'] = [list(i).count(1)/len(i) for i in df_train['full']]
# df_caid['disordered_content'] = [list(i).count(1)/len(i) for i in df_caid['full']]
# df_train.hist('disordered_content')
# df_caid.hist('disordered_content')

In [13]:
# df_train = df_train[df_train['disordered_content'] < 0.9]
# df_train.hist('disordered_content')

In [14]:
df_val = df_train.sample(frac=0.1)
df_train = df_train.drop(df_val.index)

In [15]:
df_val.to_pickle('val.pkl')
df_train.to_pickle('train.pkl')

In [16]:
df_test = df_caid

In [17]:
df_test, df_train, df_val

(          ID                                           Sequence  \
 0    DP00084  MSDNDDIEVESDEEQPRFQSAADKRAHHNALERKRRDHIKDSFHSL...   
 1    DP00182  MAPTKRKGSCPGAAPKKPKEPVQVPKLVIKGGIEVLGVKTGVDSFT...   
 2    DP00206  MKAAQKGFTLIELMIVVAIIGILAAIAIPAYQDYTARAQLSERMTL...   
 3    DP00334  MCNTNMSVPTDGAVTTSQIPASEQETLVRPKPLLLKLLKSVGAQKD...   
 4    DP00359  MMLTKSVVISRPAVRPVSTRRAVVVRASGQPAVDLNKKVQDAVKEA...   
 ..       ...                                                ...   
 647  DP02330  MWLPLTVLLLAGIVSADYDHGWHVNNEYIYLVRSRTLVNLNELSDQ...   
 648  DP02331  MWCPLFLVLLAGAATAEHLQAWKTDTEYQYAVRGRTLSALHDVADQ...   
 649  DP02332  MAGELADKKDRDASPSKEERKRSRTPDRERDRDRDRKSSPSKDRKR...   
 650  DP02333  MSHIQIPPGLTELLQGYTVEVLRQQPPDLVEFAVEYFTRLREARAP...   
 651  DP02334  MAPPGMRLRSGRSTGAPLTRGSCRKRNRSPERCDLGDDLHLQPRRK...   
 
                                                   full  
 0    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...  
 1    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...  
 2    [0

In [18]:
class ProteinDegreeDataset(Dataset):

    def __init__(self, max_length, df, tokenizer, region_type):
        self.region_type = region_type
        self.df = df
        self.seqs, self.labels = self.load_dataset()
        self.tokenizer = tokenizer
        self.max_length = max_length

    def load_dataset(self):
        seq = list(self.df['Sequence'])
        label = list(self.df[self.region_type])
        return seq, label

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        seq = " ".join("".join(self.seqs[idx].split()))
        seq = re.sub(r"[UZOB]", "X", seq)

        seq_ids = self.tokenizer(seq, truncation=True, padding='max_length', max_length=self.max_length)
        sample = {key: torch.tensor(val) for key, val in seq_ids.items()}
        tens = torch.tensor(self.labels[idx], dtype=torch.long)
        sample['labels'] = F.pad(tens, (0, MAX_LENGTH - len(tens)))
        return sample

In [19]:
tokenizer = RobertaTokenizerFast.from_pretrained(TOKENIZER_PATH, do_lower_case=False )

file ./Models/ST-PRoBERTa/Tokenizer/config.json not found
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
file ./Models/ST-PRoBERTa/Tokenizer/config.json not found
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [20]:
train_dataset = ProteinDegreeDataset(MAX_LENGTH, df_train, tokenizer, 'full')
val_dataset = ProteinDegreeDataset(MAX_LENGTH, df_val, tokenizer, 'full')

In [21]:
OUTPUT_DIR = f'./Models/DR-BERT'

In [22]:
if not os.path.isdir(OUTPUT_DIR):
    os.mkdir(OUTPUT_DIR)

In [23]:
def precision_recall_f1_roc_convolve(name, logits, labels, convolution):
    convolved = np.convolve(np.array(logits).flatten(), np.array(convolution / np.sum(convolution)).flatten(), 'same')
    p = [(1 - i, i) for i in convolved]
    roc = [i[1] for i in p]
    roc2 = [i[0] for i in p]
    p = np.argmax(p, axis=-1)
    precision, recall, f1, support = precision_recall_fscore_support(labels, p)
    roc_auc = roc_auc_score(labels, roc)
    mcc = matthews_corrcoef(labels, p)
    return {
        f'precision_{name}':precision[1],
        f'recall_{name}':recall[1],
        f'f1_{name}':f1[1],
        f'roc_auc_{name}':roc_auc,
        f'mcc_{name}': mcc,
    }

def compute_metrics(eval_preds):
    logits, labels = eval_preds
    logits = softmax(logits, axis=2)
    l = []
    for j, i in enumerate(labels):
        l = l + list(i[:len(df_val['Sequence'].iloc[j])])
    lg2 = []
    for k, i in enumerate(logits):
        lg2 = lg2 + [j[1] for j in i[:len(df_val['Sequence'].iloc[k])]]
    
    metrics = {}
    metrics.update(precision_recall_f1_roc_convolve('normal', lg2, l, [1]))
    metrics.update(precision_recall_f1_roc_convolve('wa5', lg2, l, [1,1,1,1,1]))
    metrics.update(precision_recall_f1_roc_convolve('wa9', lg2, l, [1,1,1,1,1,1,1]))
    metrics.update(precision_recall_f1_roc_convolve('wa15', lg2, l, [1]*15))
    metrics.update(precision_recall_f1_roc_convolve('linear5', lg2, l, [1,2,3,2,1]))
    metrics.update(precision_recall_f1_roc_convolve('linear9', lg2, l, [1,2,3,4,5,4,3,2,1]))
    metrics.update(precision_recall_f1_roc_convolve('linear15', lg2, l, [1,2,3,4,5,6,7,8,7,6,5,4,3,2,1]))
    metrics.update(precision_recall_f1_roc_convolve('quad5', lg2, l, [1,3,9,3,1]))
    metrics.update(precision_recall_f1_roc_convolve('quad9', lg2, l, [1,3,9,27,81,27,9,3,1]))
    metrics.update(precision_recall_f1_roc_convolve('quad15', lg2, l, [1,3,9,27,81,243,729,2187,729,243,81,27,9,3,1]))
    
    logits_path = OUTPUT_DIR + '/Logits/'
    if not os.path.isdir(logits_path):
        os.mkdir(logits_path)
    new_df = deepcopy(df_val)
    new_df['Logits'] = [[i[1] for i in x] for x in list(logits)]
    pickle.dump(new_df, open(logits_path + datetime.now().strftime("%H:%M:%S"), 'wb'))
    return metrics

In [24]:
training_args = TrainingArguments(
    output_dir = OUTPUT_DIR + '/Checkpoints',
    num_train_epochs = EPOCHS,
    per_device_train_batch_size = 1,
    per_device_eval_batch_size = 16,
    warmup_steps = 1000,
    learning_rate = LEARNING_RATE,
    logging_dir = OUTPUT_DIR + '/Logs',
    logging_steps = 200,
    lr_scheduler_type=SCHEDULER,
    do_train = True,
    do_eval = True,
    evaluation_strategy = 'epoch',
    gradient_accumulation_steps = BATCH_SIZE,
    fp16 = True,
    fp16_opt_level = '02',
    save_strategy = 'epoch',
    load_best_model_at_end = True
)

In [25]:
from transformers import AutoModelForTokenClassification
def model_init():
    model = AutoModelForTokenClassification.from_pretrained(PRETRAINED_MODEL, num_labels=NUM_CLASSES)
    return model

In [26]:
trainer = Trainer(
    model_init=model_init,
    args = training_args,
    train_dataset = train_dataset,
    eval_dataset = val_dataset,
    compute_metrics = compute_metrics,
)

loading configuration file ./Models/ST-PRoBERTa/Checkpoints/checkpoint-3560000/config.json
Model config RobertaConfig {
  "_name_or_path": "./Models/ST-PRoBERTa/Checkpoints/checkpoint-3560000",
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 1026,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 6,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.16.2",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 35
}

loading weights file ./Models/ST-PRoBERTa/Checkpoints/checkpoint-3560000/pytorch_model.bin
Some weights of the model checkpoint at ./Models/ST-PRoBERTa/Checkpoints/ch

In [27]:
trainer.train()

loading configuration file ./Models/ST-PRoBERTa/Checkpoints/checkpoint-3560000/config.json
Model config RobertaConfig {
  "_name_or_path": "./Models/ST-PRoBERTa/Checkpoints/checkpoint-3560000",
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 1026,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 6,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.16.2",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 35
}

loading weights file ./Models/ST-PRoBERTa/Checkpoints/checkpoint-3560000/pytorch_model.bin
Some weights of the model checkpoint at ./Models/ST-PRoBERTa/Checkpoints/ch

Epoch,Training Loss,Validation Loss,Precision Normal,Recall Normal,F1 Normal,Roc Auc Normal,Mcc Normal,Precision Wa5,Recall Wa5,F1 Wa5,Roc Auc Wa5,Mcc Wa5,Precision Wa9,Recall Wa9,F1 Wa9,Roc Auc Wa9,Mcc Wa9,Precision Wa15,Recall Wa15,F1 Wa15,Roc Auc Wa15,Mcc Wa15,Precision Linear5,Recall Linear5,F1 Linear5,Roc Auc Linear5,Mcc Linear5,Precision Linear9,Recall Linear9,F1 Linear9,Roc Auc Linear9,Mcc Linear9,Precision Linear15,Recall Linear15,F1 Linear15,Roc Auc Linear15,Mcc Linear15,Precision Quad5,Recall Quad5,F1 Quad5,Roc Auc Quad5,Mcc Quad5,Precision Quad9,Recall Quad9,F1 Quad9,Roc Auc Quad9,Mcc Quad9,Precision Quad15,Recall Quad15,F1 Quad15,Roc Auc Quad15,Mcc Quad15
1,0.5456,0.422219,0.855556,0.004721,0.00939,0.776119,0.055004,1.0,0.003311,0.0066,0.781875,0.05183,1.0,0.003311,0.0066,0.782284,0.05183,1.0,0.00282,0.005625,0.782848,0.047835,1.0,0.003372,0.006722,0.781734,0.052308,1.0,0.00325,0.006478,0.782423,0.051348,1.0,0.003004,0.005991,0.782839,0.049371,1.0,0.00374,0.007452,0.780894,0.055089,1.0,0.003617,0.007209,0.781147,0.054178,1.0,0.003617,0.007209,0.78118,0.054178
2,0.4874,0.405703,0.587768,0.285776,0.384571,0.804499,0.323937,0.593356,0.282526,0.382788,0.806676,0.324724,0.595819,0.281361,0.382226,0.806957,0.325221,0.602387,0.272348,0.375106,0.807812,0.322686,0.593674,0.284243,0.384427,0.806633,0.325967,0.596653,0.281974,0.382963,0.807059,0.326027,0.599334,0.275966,0.377918,0.807609,0.323522,0.593093,0.285346,0.385313,0.806361,0.326369,0.593798,0.285285,0.385405,0.806506,0.326686,0.593738,0.284856,0.385001,0.806538,0.326386
3,0.4531,0.429115,0.606127,0.254752,0.358731,0.803214,0.312911,0.608696,0.248927,0.353351,0.804995,0.3102,0.609335,0.246536,0.351041,0.805222,0.308877,0.61364,0.240527,0.345593,0.805768,0.306736,0.609683,0.250153,0.354752,0.804973,0.311486,0.609209,0.246597,0.351082,0.805317,0.308861,0.612749,0.243409,0.348414,0.805725,0.30831,0.608226,0.252054,0.356409,0.804743,0.312088,0.608876,0.251502,0.355968,0.804865,0.31202,0.60909,0.251441,0.355943,0.804895,0.312078
4,0.4204,0.416325,0.579619,0.287002,0.383909,0.804223,0.320522,0.581777,0.282649,0.380457,0.805678,0.318936,0.583302,0.282281,0.380449,0.805875,0.319484,0.585251,0.274433,0.373654,0.806462,0.315549,0.582495,0.283998,0.381832,0.805673,0.32014,0.583164,0.282036,0.380197,0.805964,0.319262,0.583623,0.277928,0.376542,0.806375,0.316928,0.580702,0.285224,0.38255,0.805502,0.319981,0.581404,0.284856,0.382371,0.805604,0.320114,0.581069,0.284549,0.382022,0.805632,0.319753
5,0.4683,0.413421,0.548662,0.32802,0.410575,0.804822,0.327806,0.548491,0.324218,0.407537,0.806166,0.325563,0.549514,0.322195,0.406215,0.806364,0.325001,0.549421,0.314224,0.399797,0.806975,0.320409,0.548888,0.325261,0.40847,0.80617,0.326379,0.549227,0.322195,0.406136,0.806452,0.324837,0.550476,0.318945,0.403882,0.806875,0.323699,0.549092,0.326426,0.409444,0.806013,0.327154,0.549184,0.325874,0.409035,0.806111,0.326895,0.549064,0.325567,0.40876,0.806138,0.326653
6,0.4657,0.424295,0.554089,0.316554,0.402919,0.804076,0.324365,0.553028,0.312998,0.399749,0.805333,0.321727,0.553391,0.311711,0.398792,0.805496,0.321188,0.554626,0.304721,0.393336,0.805905,0.317822,0.553674,0.314654,0.401267,0.805348,0.32304,0.553791,0.31214,0.399247,0.805581,0.321658,0.555433,0.309319,0.397354,0.80591,0.320939,0.553916,0.315267,0.401829,0.805206,0.323528,0.554009,0.314776,0.401454,0.805291,0.323298,0.554223,0.314592,0.401361,0.805318,0.323312
7,0.4408,0.421556,0.546304,0.329859,0.411346,0.803581,0.32748,0.549809,0.327223,0.41027,0.804817,0.328015,0.549405,0.325567,0.408855,0.804974,0.326848,0.551285,0.318332,0.403607,0.805387,0.323806,0.549491,0.327774,0.410615,0.804835,0.328144,0.550181,0.32569,0.409166,0.805064,0.327362,0.55116,0.322011,0.406517,0.805393,0.325832,0.548779,0.329368,0.411663,0.804695,0.328633,0.548853,0.328571,0.411061,0.804781,0.328227,0.548807,0.32851,0.411,0.804808,0.328166
8,0.4251,0.426329,0.560408,0.309994,0.399179,0.802881,0.324075,0.560721,0.306867,0.396656,0.804105,0.322418,0.560449,0.305825,0.395716,0.804246,0.321658,0.561373,0.299755,0.390823,0.804564,0.318587,0.560965,0.308032,0.397689,0.804122,0.323234,0.560764,0.305825,0.395794,0.804331,0.321831,0.56074,0.30282,0.393264,0.804613,0.320052,0.560405,0.308584,0.398007,0.803986,0.32325,0.56158,0.308645,0.398354,0.804066,0.32393,0.561768,0.308645,0.398401,0.804091,0.324033
9,0.4417,0.425902,0.562556,0.308216,0.398241,0.80266,0.324212,0.562267,0.305334,0.395756,0.80389,0.322362,0.561551,0.30374,0.394238,0.804031,0.321034,0.563186,0.298651,0.39032,0.804343,0.31891,0.562521,0.306438,0.396745,0.80391,0.323149,0.562309,0.304047,0.394683,0.804116,0.321626,0.563232,0.30092,0.392263,0.804395,0.320279,0.563469,0.307541,0.397906,0.803773,0.324315,0.563463,0.307296,0.397699,0.803853,0.324167,0.563497,0.307419,0.39781,0.803878,0.324258
10,0.4402,0.425976,0.560827,0.309503,0.398878,0.80273,0.324019,0.561126,0.306744,0.396654,0.803958,0.322567,0.560987,0.305395,0.39549,0.804098,0.3217,0.561918,0.299632,0.390851,0.804406,0.318808,0.561612,0.307664,0.397544,0.803977,0.323372,0.561226,0.305457,0.395601,0.804183,0.321866,0.562151,0.302514,0.393351,0.80446,0.320637,0.562744,0.308768,0.398749,0.803841,0.324639,0.562437,0.308461,0.398416,0.803921,0.324291,0.562514,0.3084,0.398384,0.803946,0.324297


***** Running Evaluation *****
  Num examples = 176
  Batch size = 16
Saving model checkpoint to ./Models/DR-BERT/Checkpoints/checkpoint-1569
Configuration saved in ./Models/DR-BERT/Checkpoints/checkpoint-1569/config.json
Model weights saved in ./Models/DR-BERT/Checkpoints/checkpoint-1569/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 176
  Batch size = 16
Saving model checkpoint to ./Models/DR-BERT/Checkpoints/checkpoint-3138
Configuration saved in ./Models/DR-BERT/Checkpoints/checkpoint-3138/config.json
Model weights saved in ./Models/DR-BERT/Checkpoints/checkpoint-3138/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 176
  Batch size = 16
Saving model checkpoint to ./Models/DR-BERT/Checkpoints/checkpoint-4707
Configuration saved in ./Models/DR-BERT/Checkpoints/checkpoint-4707/config.json
Model weights saved in ./Models/DR-BERT/Checkpoints/checkpoint-4707/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 176
  Batch size = 16
Saving

TrainOutput(global_step=15690, training_loss=0.4565909267155664, metrics={'train_runtime': 1112.4017, 'train_samples_per_second': 14.105, 'train_steps_per_second': 14.105, 'total_flos': 4099894279004160.0, 'train_loss': 0.4565909267155664, 'epoch': 10.0})

In [96]:
x = torch.randn([10, 1, 2, 1024])

In [97]:
x.shape

torch.Size([10, 1, 2, 1024])

In [106]:
n = nn.Conv2d(1, 1, (1,7), padding=(0,3))

In [107]:
n(x).shape

torch.Size([10, 1, 2, 1024])

In [53]:
!conda list torch

# packages in environment at /home/johnmf4/.conda/envs/ProteinTransformers3:
#
# Name                    Version                   Build  Channel
_pytorch_select           2.0                  cuda10.2_1    file:///opt/apps/open-ce-v1.2.0/condabuild
pytorch                   1.7.1           cuda10.2_py37_3    file:///opt/apps/open-ce-v1.2.0/condabuild
pytorch-base              1.7.1           cuda10.2_py37_14    file:///opt/apps/open-ce-v1.2.0/condabuild
torchtext                 0.8.1                    py37_4    file:///opt/apps/open-ce-v1.2.0/condabuild
torchvision-base          0.8.2           cuda10.2_py37_6    file:///opt/apps/open-ce-v1.2.0/condabuild


In [1]:
print(df_test)

NameError: name 'df_test' is not defined