In [4]:
import pickle
import os
import torch
import torch.nn as nn
from transformers import RobertaTokenizerFast, Trainer, TrainingArguments
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, log_loss, accuracy_score, matthews_corrcoef
from sklearn.utils import shuffle
from matplotlib import pyplot as plt
from sklearn.metrics import roc_auc_score
import re
import tqdm
import torch.nn.functional as F
from datetime import datetime
from copy import deepcopy
from scipy.special import softmax

In [5]:
MAX_LENGTH = 1024
EPOCHS = 10
LEARNING_RATE = 2e-6
BATCH_SIZE = 1
TOKENIZER_PATH =  "./Models/ST-PRoBERTa/Tokenizer"
PRETRAINED_MODEL = "./Models/ST-PRoBERTa/Checkpoints/checkpoint-3560000"
NUM_CLASSES = 2
SCHEDULER='cosine_with_restarts'

In [39]:
df_train = pickle.load(open('./Datasets/caid_clustered_train.pkl', "rb"))
df_val = pickle.load(open('./Datasets/finetuning-IDRs-test.pickle', "rb"))

In [43]:
df_train.dropna()
df_train = df_train.reindex()

In [44]:
df_train

Unnamed: 0,Sequence,full,disprot_ID,intersect,remove_clustered
0,MELITNELLYKTYKQKPVGVEEPVYDQAGDPLFGERGAVHPQSTLK...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP00675,False,False
5,MEDINFASLAPRHGSRPFMGNWQDIGTSNMSGGAFSWGSLWSGIKN...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP00808,False,False
6,MVSSVLSIPPQTCLLPRLPISDSVNCKSKIVYCLSTSVRGSSVKRQ...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP02919,False,False
7,MAAKFEVGSVYTGKVTGLQAYGAFVALDEETQGLVHISEVTHGFVK...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP00809,False,False
9,VVYTDCTESGQNLCLCEGSNVCGQGNKCILGSDGEKNQCVTGEGTP...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP00137,False,False
...,...,...,...,...,...
2380,MGHNDSVETMDEISNPNNILLPHDGTGLDATGISGSQEPYGMVDVL...,"[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...",DP00036,False,False
2381,MLALLCSCLLLAAGASDAWTGEDSAEPNSDSAEWIRDMYAKVTEIW...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP02656,False,False
2382,MGEAEKFHYIYSCDLDINVQLKIGSLEGKREQKSYKAVLEDPMLKF...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP03710,False,False
2383,MASMRESDTGLWLHNKLGATDELWAPPSIASLLTAAVIDNIRLCFH...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP02712,False,False


In [24]:
df_caid = pickle.load(open('./Datasets/caid.pkl', "rb"))

In [25]:
df_caid

Unnamed: 0,ID,Sequence,full
0,DP00084,MSDNDDIEVESDEEQPRFQSAADKRAHHNALERKRRDHIKDSFHSL...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,DP00182,MAPTKRKGSCPGAAPKKPKEPVQVPKLVIKGGIEVLGVKTGVDSFT...,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
2,DP00206,MKAAQKGFTLIELMIVVAIIGILAAIAIPAYQDYTARAQLSERMTL...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,DP00334,MCNTNMSVPTDGAVTTSQIPASEQETLVRPKPLLLKLLKSVGAQKD...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,DP00359,MMLTKSVVISRPAVRPVSTRRAVVVRASGQPAVDLNKKVQDAVKEA...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...
647,DP02330,MWLPLTVLLLAGIVSADYDHGWHVNNEYIYLVRSRTLVNLNELSDQ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
648,DP02331,MWCPLFLVLLAGAATAEHLQAWKTDTEYQYAVRGRTLSALHDVADQ...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
649,DP02332,MAGELADKKDRDASPSKEERKRSRTPDRERDRDRDRKSSPSKDRKR...,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
650,DP02333,MSHIQIPPGLTELLQGYTVEVLRQQPPDLVEFAVEYFTRLREARAP...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [8]:
# df_train = df_full[~df_full['disprot_ID'].isin(df_caid['ID'])]
df_full['intersect'] = [i in list(df_caid['ID']) for i in df_full['disprot_ID']]
df_train = df_full[df_full['intersect'] == False]
df_train

Unnamed: 0,Sequence,full,disprot_ID,remove_clustered,intersect
0,MELITNELLYKTYKQKPVGVEEPVYDQAGDPLFGERGAVHPQSTLK...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP00675,False,False
5,MEDINFASLAPRHGSRPFMGNWQDIGTSNMSGGAFSWGSLWSGIKN...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP00808,False,False
6,MVSSVLSIPPQTCLLPRLPISDSVNCKSKIVYCLSTSVRGSSVKRQ...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP02919,False,False
7,MAAKFEVGSVYTGKVTGLQAYGAFVALDEETQGLVHISEVTHGFVK...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP00809,False,False
9,VVYTDCTESGQNLCLCEGSNVCGQGNKCILGSDGEKNQCVTGEGTP...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP00137,False,False
...,...,...,...,...,...
234,MGHNDSVETMDEISNPNNILLPHDGTGLDATGISGSQEPYGMVDVL...,"[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...",DP00036,,False
235,MLALLCSCLLLAAGASDAWTGEDSAEPNSDSAEWIRDMYAKVTEIW...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP02656,,False
236,MGEAEKFHYIYSCDLDINVQLKIGSLEGKREQKSYKAVLEDPMLKF...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP03710,,False
237,MASMRESDTGLWLHNKLGATDELWAPPSIASLLTAAVIDNIRLCFH...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",DP02712,,False


In [11]:
# df_train['full'] = [[int(i) for i in j] for j in df_train['full']]
# df_caid['full'] = [[int(i) for i in j] for j in df_caid['full']]
# print(df_train), print(df_caid)

In [12]:
# df_train['disordered_content'] = [list(i).count(1)/len(i) for i in df_train['full']]
# df_caid['disordered_content'] = [list(i).count(1)/len(i) for i in df_caid['full']]
# df_train.hist('disordered_content')
# df_caid.hist('disordered_content')

In [13]:
# df_train = df_train[df_train['disordered_content'] < 0.9]
# df_train.hist('disordered_content')

In [26]:
df_val = df_train.sample(frac=0.1)
df_train = df_train.drop(df_val.index)

In [27]:
df_val.to_pickle('val_clustered_caid.pkl')
df_train.to_pickle('train_clustered_caid.pkl')

In [28]:
df_test = df_caid

In [17]:
df_test, df_train, df_val

(          ID                                           Sequence  \
 0    DP00084  MSDNDDIEVESDEEQPRFQSAADKRAHHNALERKRRDHIKDSFHSL...   
 1    DP00182  MAPTKRKGSCPGAAPKKPKEPVQVPKLVIKGGIEVLGVKTGVDSFT...   
 2    DP00206  MKAAQKGFTLIELMIVVAIIGILAAIAIPAYQDYTARAQLSERMTL...   
 3    DP00334  MCNTNMSVPTDGAVTTSQIPASEQETLVRPKPLLLKLLKSVGAQKD...   
 4    DP00359  MMLTKSVVISRPAVRPVSTRRAVVVRASGQPAVDLNKKVQDAVKEA...   
 ..       ...                                                ...   
 647  DP02330  MWLPLTVLLLAGIVSADYDHGWHVNNEYIYLVRSRTLVNLNELSDQ...   
 648  DP02331  MWCPLFLVLLAGAATAEHLQAWKTDTEYQYAVRGRTLSALHDVADQ...   
 649  DP02332  MAGELADKKDRDASPSKEERKRSRTPDRERDRDRDRKSSPSKDRKR...   
 650  DP02333  MSHIQIPPGLTELLQGYTVEVLRQQPPDLVEFAVEYFTRLREARAP...   
 651  DP02334  MAPPGMRLRSGRSTGAPLTRGSCRKRNRSPERCDLGDDLHLQPRRK...   
 
                                                   full  
 0    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...  
 1    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...  
 2    [0

In [45]:
class ProteinDegreeDataset(Dataset):

    def __init__(self, max_length, df, tokenizer, region_type):
        self.region_type = region_type
        self.df = df
        self.seqs, self.labels = self.load_dataset()
        self.tokenizer = tokenizer
        self.max_length = max_length

    def load_dataset(self):
        seq = list(self.df['Sequence'])
        label = list(self.df[self.region_type])
        return seq, label

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        seq = " ".join("".join(self.seqs[idx].split()))
        seq = re.sub(r"[UZOB]", "X", seq)

        seq_ids = self.tokenizer(seq, truncation=True, padding='max_length', max_length=self.max_length)
        sample = {key: torch.tensor(val) for key, val in seq_ids.items()}
        tens = torch.tensor(self.labels[idx], dtype=torch.long)
        sample['labels'] = F.pad(tens, (0, MAX_LENGTH - len(tens)))
        return sample

In [46]:
tokenizer = RobertaTokenizerFast.from_pretrained(TOKENIZER_PATH, do_lower_case=False )

Didn't find file ./Models/ST-PRoBERTa/Tokenizer/tokenizer.json. We won't load it.
Didn't find file ./Models/ST-PRoBERTa/Tokenizer/added_tokens.json. We won't load it.
Didn't find file ./Models/ST-PRoBERTa/Tokenizer/tokenizer_config.json. We won't load it.
loading file ./Models/ST-PRoBERTa/Tokenizer/vocab.json
loading file ./Models/ST-PRoBERTa/Tokenizer/merges.txt
loading file None
loading file None
loading file ./Models/ST-PRoBERTa/Tokenizer/special_tokens_map.json
loading file None
file ./Models/ST-PRoBERTa/Tokenizer/config.json not found
Adding [SEP] to the vocabulary
Adding [PAD] to the vocabulary
Adding [CLS] to the vocabulary
Adding [MASK] to the vocabulary
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
file ./Models/ST-PRoBERTa/Tokenizer/config.json not found
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [47]:
train_dataset = ProteinDegreeDataset(MAX_LENGTH, df_train, tokenizer, 'full')
val_dataset = ProteinDegreeDataset(MAX_LENGTH, df_val, tokenizer, 'full')

In [48]:
OUTPUT_DIR = f'./Models/DR-BERT-caid-clustered'

In [49]:
if not os.path.isdir(OUTPUT_DIR):
    os.mkdir(OUTPUT_DIR)

In [50]:
def precision_recall_f1_roc_convolve(name, logits, labels, convolution):
    convolved = np.convolve(np.array(logits).flatten(), np.array(convolution / np.sum(convolution)).flatten(), 'same')
    p = [(1 - i, i) for i in convolved]
    roc = [i[1] for i in p]
    roc2 = [i[0] for i in p]
    p = np.argmax(p, axis=-1)
    precision, recall, f1, support = precision_recall_fscore_support(labels, p)
    roc_auc = roc_auc_score(labels, roc)
    mcc = matthews_corrcoef(labels, p)
    return {
        f'precision_{name}':precision[1],
        f'recall_{name}':recall[1],
        f'f1_{name}':f1[1],
        f'roc_auc_{name}':roc_auc,
        f'mcc_{name}': mcc,
    }

def compute_metrics(eval_preds):
    logits, labels = eval_preds
    logits = softmax(logits, axis=2)
    l = []
    for j, i in enumerate(labels):
        l = l + list(i[:len(df_val['Sequence'].iloc[j])])
    lg2 = []
    for k, i in enumerate(logits):
        lg2 = lg2 + [j[1] for j in i[:len(df_val['Sequence'].iloc[k])]]
    
    metrics = {}
    metrics.update(precision_recall_f1_roc_convolve('normal', lg2, l, [1]))
    metrics.update(precision_recall_f1_roc_convolve('wa5', lg2, l, [1,1,1,1,1]))
    metrics.update(precision_recall_f1_roc_convolve('wa9', lg2, l, [1,1,1,1,1,1,1]))
    metrics.update(precision_recall_f1_roc_convolve('wa15', lg2, l, [1]*15))
    metrics.update(precision_recall_f1_roc_convolve('linear5', lg2, l, [1,2,3,2,1]))
    metrics.update(precision_recall_f1_roc_convolve('linear9', lg2, l, [1,2,3,4,5,4,3,2,1]))
    metrics.update(precision_recall_f1_roc_convolve('linear15', lg2, l, [1,2,3,4,5,6,7,8,7,6,5,4,3,2,1]))
    metrics.update(precision_recall_f1_roc_convolve('quad5', lg2, l, [1,3,9,3,1]))
    metrics.update(precision_recall_f1_roc_convolve('quad9', lg2, l, [1,3,9,27,81,27,9,3,1]))
    metrics.update(precision_recall_f1_roc_convolve('quad15', lg2, l, [1,3,9,27,81,243,729,2187,729,243,81,27,9,3,1]))
    
    logits_path = OUTPUT_DIR + '/Logits/'
    if not os.path.isdir(logits_path):
        os.mkdir(logits_path)
    new_df = deepcopy(df_val)
    new_df['Logits'] = [[i[1] for i in x] for x in list(logits)]
    pickle.dump(new_df, open(logits_path + datetime.now().strftime("%H:%M:%S"), 'wb'))
    return metrics

In [51]:
training_args = TrainingArguments(
    output_dir = OUTPUT_DIR + '/Checkpoints',
    num_train_epochs = EPOCHS,
    per_device_train_batch_size = 1,
    per_device_eval_batch_size = 16,
    warmup_steps = 1000,
    learning_rate = LEARNING_RATE,
    logging_dir = OUTPUT_DIR + '/Logs',
    logging_steps = 200,
    lr_scheduler_type=SCHEDULER,
    do_train = True,
    do_eval = True,
    evaluation_strategy = 'epoch',
    gradient_accumulation_steps = BATCH_SIZE,
    fp16 = True,
    fp16_opt_level = '02',
    save_strategy = 'epoch',
    load_best_model_at_end = True
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [52]:
from transformers import AutoModelForTokenClassification
def model_init():
    model = AutoModelForTokenClassification.from_pretrained(PRETRAINED_MODEL, num_labels=NUM_CLASSES)
    return model

In [53]:
trainer = Trainer(
    model_init=model_init,
    args = training_args,
    train_dataset = train_dataset,
    eval_dataset = val_dataset,
    compute_metrics = compute_metrics,
)

loading configuration file ./Models/ST-PRoBERTa/Checkpoints/checkpoint-3560000/config.json
Model config RobertaConfig {
  "_name_or_path": "./Models/ST-PRoBERTa/Checkpoints/checkpoint-3560000",
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 1026,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 6,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.16.2",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 35
}

loading weights file ./Models/ST-PRoBERTa/Checkpoints/checkpoint-3560000/pytorch_model.bin
Some weights of the model checkpoint at ./Models/ST-PRoBERTa/Checkpoints/ch

In [54]:
trainer.train()

loading configuration file ./Models/ST-PRoBERTa/Checkpoints/checkpoint-3560000/config.json
Model config RobertaConfig {
  "_name_or_path": "./Models/ST-PRoBERTa/Checkpoints/checkpoint-3560000",
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 1026,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 6,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.16.2",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 35
}

loading weights file ./Models/ST-PRoBERTa/Checkpoints/checkpoint-3560000/pytorch_model.bin
Some weights of the model checkpoint at ./Models/ST-PRoBERTa/Checkpoints/ch

Epoch,Training Loss,Validation Loss,Precision Normal,Recall Normal,F1 Normal,Roc Auc Normal,Mcc Normal,Precision Wa5,Recall Wa5,F1 Wa5,Roc Auc Wa5,Mcc Wa5,Precision Wa9,Recall Wa9,F1 Wa9,Roc Auc Wa9,Mcc Wa9,Precision Wa15,Recall Wa15,F1 Wa15,Roc Auc Wa15,Mcc Wa15,Precision Linear5,Recall Linear5,F1 Linear5,Roc Auc Linear5,Mcc Linear5,Precision Linear9,Recall Linear9,F1 Linear9,Roc Auc Linear9,Mcc Linear9,Precision Linear15,Recall Linear15,F1 Linear15,Roc Auc Linear15,Mcc Linear15,Precision Quad5,Recall Quad5,F1 Quad5,Roc Auc Quad5,Mcc Quad5,Precision Quad9,Recall Quad9,F1 Quad9,Roc Auc Quad9,Mcc Quad9,Precision Quad15,Recall Quad15,F1 Quad15,Roc Auc Quad15,Mcc Quad15
1,0.5217,0.421316,0.830161,0.029482,0.056942,0.764981,0.134376,0.785888,0.016795,0.032887,0.76942,0.096888,0.774359,0.015703,0.030782,0.76974,0.092541,0.774373,0.014455,0.02838,0.769922,0.088776,0.793981,0.017835,0.034886,0.769292,0.100694,0.778351,0.015703,0.030785,0.769854,0.092934,0.772118,0.014975,0.02938,0.770097,0.090147,0.815514,0.020227,0.039474,0.768649,0.109611,0.813449,0.019499,0.038085,0.768853,0.107393,0.813043,0.019447,0.037985,0.768877,0.107206
2,0.4682,0.404651,0.5595,0.402402,0.468122,0.792719,0.37411,0.561175,0.400426,0.467365,0.794594,0.374092,0.562156,0.398087,0.466105,0.794823,0.373436,0.568177,0.389351,0.462065,0.795053,0.372373,0.560827,0.40079,0.467492,0.794546,0.374071,0.562151,0.398347,0.466281,0.794916,0.373574,0.565901,0.394707,0.465049,0.795212,0.373915,0.560875,0.401466,0.467968,0.794298,0.374466,0.561368,0.401206,0.467963,0.794426,0.374634,0.561508,0.401102,0.467941,0.794453,0.374666
3,0.4922,0.399452,0.579419,0.436824,0.498117,0.802403,0.405286,0.582615,0.432508,0.496464,0.803912,0.404959,0.583703,0.430584,0.495586,0.804126,0.404591,0.586018,0.423669,0.491791,0.804423,0.402239,0.583019,0.434536,0.497944,0.803884,0.406319,0.583216,0.43074,0.495514,0.804213,0.404371,0.585119,0.427309,0.493915,0.804529,0.403682,0.582193,0.435888,0.498528,0.803691,0.406534,0.582515,0.43516,0.49817,0.80381,0.406341,0.582515,0.43516,0.49817,0.803837,0.406341
4,0.4553,0.394838,0.584456,0.462562,0.516414,0.808546,0.422411,0.586147,0.458507,0.514529,0.809902,0.421317,0.58591,0.457103,0.513553,0.810093,0.420406,0.589456,0.450551,0.510727,0.810241,0.419118,0.586765,0.460119,0.515781,0.809898,0.422585,0.58683,0.457363,0.51407,0.810177,0.421138,0.587867,0.453983,0.512322,0.810426,0.419971,0.58659,0.461262,0.516431,0.809735,0.42309,0.586394,0.460743,0.516029,0.809836,0.422683,0.586522,0.460691,0.516046,0.809861,0.422737
5,0.4042,0.392586,0.619528,0.434172,0.510547,0.810838,0.428572,0.622008,0.4297,0.508272,0.812152,0.427501,0.621769,0.427725,0.506808,0.812353,0.426226,0.62159,0.420653,0.501752,0.812575,0.422055,0.622117,0.430584,0.508927,0.812139,0.428072,0.62262,0.428505,0.507638,0.81243,0.427179,0.6224,0.424761,0.504929,0.812721,0.424898,0.621474,0.431884,0.509617,0.811985,0.428432,0.621488,0.431312,0.509224,0.812084,0.428114,0.621535,0.431312,0.509239,0.81211,0.428141
6,0.4205,0.398888,0.597055,0.46589,0.52338,0.812101,0.43232,0.598935,0.462094,0.521691,0.813289,0.431444,0.598958,0.460275,0.520537,0.813449,0.430462,0.600658,0.455335,0.517997,0.813494,0.428822,0.599219,0.462874,0.522295,0.813292,0.432052,0.59927,0.461106,0.521187,0.813526,0.431116,0.600299,0.458715,0.520042,0.813722,0.430455,0.597988,0.463602,0.522289,0.813163,0.431666,0.598187,0.463134,0.522068,0.813251,0.431537,0.598307,0.463134,0.522114,0.813273,0.431614
7,0.4437,0.396255,0.593124,0.479045,0.530016,0.814013,0.436917,0.595547,0.475614,0.528866,0.815154,0.436634,0.596638,0.474314,0.52849,0.815283,0.436634,0.595573,0.467242,0.52366,0.815241,0.432107,0.595492,0.476653,0.529487,0.815167,0.437162,0.596447,0.474782,0.528705,0.815362,0.436764,0.595334,0.470986,0.525909,0.815507,0.433988,0.595331,0.477329,0.52984,0.815047,0.437423,0.59556,0.477017,0.529738,0.815126,0.437403,0.595532,0.476809,0.529599,0.815146,0.437272
8,0.443,0.396642,0.599552,0.473014,0.528818,0.814391,0.437802,0.601643,0.468282,0.526651,0.815532,0.436554,0.601756,0.466774,0.525739,0.815663,0.4358,0.6022,0.461106,0.522292,0.815631,0.43297,0.601787,0.46927,0.527331,0.815542,0.437187,0.601714,0.467242,0.52602,0.81574,0.43603,0.602496,0.464434,0.524532,0.815891,0.434986,0.601408,0.470882,0.528201,0.81542,0.437826,0.601437,0.469998,0.527655,0.8155,0.437362,0.601571,0.469946,0.527674,0.815521,0.437419
9,0.4091,0.39679,0.59869,0.475354,0.52994,0.81469,0.438521,0.600239,0.470778,0.527684,0.815827,0.437022,0.600985,0.469738,0.527317,0.815955,0.436931,0.601038,0.463654,0.523482,0.815911,0.433634,0.600543,0.471922,0.528519,0.815838,0.437841,0.600798,0.469998,0.527409,0.816033,0.436954,0.600749,0.467242,0.525651,0.816178,0.435415,0.600554,0.473742,0.529663,0.815718,0.438841,0.60029,0.473222,0.529236,0.815795,0.438389,0.60033,0.473066,0.529153,0.815817,0.438329
10,0.4338,0.397069,0.598028,0.47629,0.530261,0.814678,0.438603,0.599868,0.471454,0.527965,0.815812,0.437154,0.600305,0.470518,0.527546,0.81594,0.436923,0.600471,0.464278,0.523664,0.815893,0.433616,0.60004,0.47265,0.52878,0.815823,0.437916,0.600464,0.47083,0.527804,0.816017,0.437195,0.600253,0.467918,0.525888,0.816162,0.435469,0.600092,0.47447,0.529938,0.815705,0.438942,0.599947,0.47395,0.529557,0.815783,0.438565,0.599974,0.473846,0.529502,0.815803,0.438526


***** Running Evaluation *****
  Num examples = 240
  Batch size = 16
Saving model checkpoint to ./Models/DR-BERT-caid-clustered/Checkpoints/checkpoint-1721
Configuration saved in ./Models/DR-BERT-caid-clustered/Checkpoints/checkpoint-1721/config.json
Model weights saved in ./Models/DR-BERT-caid-clustered/Checkpoints/checkpoint-1721/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 240
  Batch size = 16
Saving model checkpoint to ./Models/DR-BERT-caid-clustered/Checkpoints/checkpoint-3442
Configuration saved in ./Models/DR-BERT-caid-clustered/Checkpoints/checkpoint-3442/config.json
Model weights saved in ./Models/DR-BERT-caid-clustered/Checkpoints/checkpoint-3442/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 240
  Batch size = 16
Saving model checkpoint to ./Models/DR-BERT-caid-clustered/Checkpoints/checkpoint-5163
Configuration saved in ./Models/DR-BERT-caid-clustered/Checkpoints/checkpoint-5163/config.json
Model weights saved in ./Models/DR-BERT-ca

TrainOutput(global_step=17210, training_loss=0.45348676651595016, metrics={'train_runtime': 1222.7752, 'train_samples_per_second': 14.075, 'train_steps_per_second': 14.075, 'total_flos': 4497079703101440.0, 'train_loss': 0.45348676651595016, 'epoch': 10.0})

In [96]:
x = torch.randn([10, 1, 2, 1024])

In [97]:
x.shape

torch.Size([10, 1, 2, 1024])

In [106]:
n = nn.Conv2d(1, 1, (1,7), padding=(0,3))

In [107]:
n(x).shape

torch.Size([10, 1, 2, 1024])

In [53]:
!conda list torch

# packages in environment at /home/johnmf4/.conda/envs/ProteinTransformers3:
#
# Name                    Version                   Build  Channel
_pytorch_select           2.0                  cuda10.2_1    file:///opt/apps/open-ce-v1.2.0/condabuild
pytorch                   1.7.1           cuda10.2_py37_3    file:///opt/apps/open-ce-v1.2.0/condabuild
pytorch-base              1.7.1           cuda10.2_py37_14    file:///opt/apps/open-ce-v1.2.0/condabuild
torchtext                 0.8.1                    py37_4    file:///opt/apps/open-ce-v1.2.0/condabuild
torchvision-base          0.8.2           cuda10.2_py37_6    file:///opt/apps/open-ce-v1.2.0/condabuild


In [1]:
print(df_test)

NameError: name 'df_test' is not defined