In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup,BertForSequenceClassification
from Bio import SeqIO
from skbio import Sequence
import pandas as pd

2022-11-17 02:57:53.886844: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-17 02:57:54.030166: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-11-17 02:57:54.673341: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-11-17 02:57:54.673400: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or 

In [2]:
fold=1
cuda_id=0
kmer=1

In [3]:
#logging
import logging
import transformers
transformers.logging.set_verbosity_info()

logger = logging.getLogger(__name__)
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO)

In [4]:
# 事前学習モデル
MODEL_NAME = '../prot_bert_bfd'

# カテゴリーのリスト
l1_list = [
    'antibiotic target alteration' ,
    'antibiotic target replacement',
    'antibiotic target protection',
    'antibiotic inactivation',
    'antibiotic efflux',
    'others'
]

# トークナイザのロード
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME,do_lower_case=False)

# 各データの形式を整える
max_length = 1576

device = torch.device("cuda:"+str(cuda_id ))#if torch.cuda.is_available() else "cpu"

cols=['ID','FEATURES','database','target','subclass','mechanism','transferable','sequence']

Didn't find file ../prot_bert_bfd/added_tokens.json. We won't load it.
loading file ../prot_bert_bfd/vocab.txt
loading file None
loading file ../prot_bert_bfd/special_tokens_map.json
loading file ../prot_bert_bfd/tokenizer_config.json
loading configuration file ../prot_bert_bfd/config.json
Model config BertConfig {
  "_name_or_path": "../prot_bert_bfd",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.0,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 40000,
  "model_type": "bert",
  "num_attention_heads": 16,
  "num_hidden_layers": 30,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.19.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30
}



In [5]:
def read_input_l1(input_file):

    n = 0
    df = pd.DataFrame(index=[], columns=cols)
    index = []
    
    with open(input_file) as f:
        for entry in SeqIO.parse(f, 'fasta'):
            seq = ''
            desc = entry.description
            seq_col = desc.split('|')
            for k in Sequence(str(entry.seq)).iter_kmers(kmer, overlap=True):
                seq += str(k)+' '
            seq_col.append(seq)
            
            df.loc[str(n),:]=seq_col
            n += 1
        print('l1_length:',n)  
    
    df=df.replace({'reduced permeability to antibiotic':'others'})
    df=df.replace({'resistance by absence':'others'})
    print(df['mechanism'].unique())
    
    dataset_for_loader=[]
    for i in range(n-1):
        encoding = tokenizer(
                                        df.iat[i,7],
                                        max_length=max_length, 
                                        padding='max_length',
                                        truncation=True
                                          )
        encoding['labels'] = l1_list.index(df.at[str(i),'mechanism'])
        index.append(encoding['input_ids'])
        encoding = { k: torch.tensor(v, device=device) for k, v in encoding.items()}
        dataset_for_loader.append(encoding)
        
    return dataset_for_loader,index,df

In [6]:
dataset_train_l1,l1_train_index,df_l1_train = read_input_l1('../protein_bert-master/protein_bert-master/fold_0_train_0.9.fasta')#../5_fold_data/level_1/fold_'+ str(fold) +'_train.fasta'
dataset_test_l1,l1_test_index,df_l1_test = read_input_l1('../protein_bert-master/protein_bert-master/fold_0_test_0.9.fasta')#../5_fold_data/level_1/fold_'+ str(fold) +'_test.fasta'

l1_length: 7150
['antibiotic inactivation' 'antibiotic efflux'
 'antibiotic target protection' 'antibiotic target alteration'
 'antibiotic target replacement' 'others']
l1_length: 1788
['antibiotic inactivation' 'antibiotic efflux'
 'antibiotic target protection' 'antibiotic target alteration'
 'antibiotic target replacement' 'others']


In [7]:
dataloader_train_l1 = DataLoader(dataset_train_l1, batch_size=8, shuffle=True) 
dataloader_test_l1 = DataLoader(dataset_test_l1, batch_size=1,shuffle=True)

In [8]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torchmetrics import MetricCollection, Accuracy, Precision, Recall

In [9]:
class BertForSequenceClassification_pl(pl.LightningModule):
        
    def __init__(self, model_name, num_labels, lr):
        # model_name: Transformersのモデルの名前
        # num_labels: ラベルの数
        # lr: 学習率

        super().__init__()
        
        # 引数のnum_labelsとlrを保存。
        # 例えば、self.hparams.lrでlrにアクセスできる。
        # チェックポイント作成時にも自動で保存される。
        self.save_hyperparameters()
        self.bert_sc = BertForSequenceClassification.from_pretrained(
            model_name,
            num_labels=num_labels
        )
        
        self.bert_sc = self.bert_sc.cuda(cuda_id)
        
    # 学習データのミニバッチ(`batch`)が与えられた時に損失を出力する関数を書く。
    # batch_idxはミニバッチの番号であるが今回は使わない。
    def training_step(self, batch, batch_idx):
        output = self.bert_sc(**batch)
        train_loss = output.loss
        
        labels = batch.pop('labels') # バッチからラベルを取得
        labels_predicted = output.logits.argmax(-1)
        num_correct = ( labels_predicted == labels ).sum().item()
        accuracy= num_correct/labels.size(0) #精度
       
        performance = {'loss':train_loss,
                'batch_preds': labels_predicted,
                'batch_labels': labels}
        
        return performance
    
    def training_epoch_end(self,outputs,mode="train"):
        # batch毎のlossの平均を計算
        loss = torch.tensor([x['loss'] for x in outputs])
        epoch_loss = torch.div(loss.sum(),torch.tensor(loss.size()))
        self.log(f"{mode}_loss", epoch_loss, logger=True)

        # accuracy計算
        epoch_preds = torch.cat([x['batch_preds'] for x in outputs])
        epoch_labels = torch.cat([x['batch_labels'] for x in outputs])
        num_correct = (epoch_preds == epoch_labels).sum().item()
        epoch_accuracy = num_correct / len(epoch_labels)
        self.log(f"{mode}_accuracy", epoch_accuracy, logger=True,prog_bar = True)
        
        metric_macro = MetricCollection([
            Accuracy(),
            Precision(num_classes=6, average='macro'),
            Recall(num_classes=6, average='macro')
        ]).cuda(cuda_id)#.type_as(batch)
        macro = metric_macro(epoch_preds, epoch_labels)
        self.log_dict(macro) # Accuracy,Precision,Recallのログをとる。
        
    # 検証データのミニバッチが与えられた時に、
    # 検証データを評価する指標を計算する関数を書く。
    def validation_step(self, batch, batch_idx):
        output = self.bert_sc(**batch)
        val_loss = output.loss
        self.log('val_loss', val_loss) # 損失を'val_loss'の名前でログをとる。
        
        labels = batch.pop('labels') # バッチからラベルを取得
        labels_predicted = output.logits.argmax(-1)
        num_correct = ( labels_predicted == labels ).sum().item()
        accuracy= num_correct/labels.size(0) #精度
        self.log('val_accuracy', accuracy) # 精度を'val_accuracy'の名前でログをとる。
    
    # テストデータのミニバッチが与えられた時に、
    # テストデータを評価する指標を計算する関数を書く。
    def test_step(self, batch, batch_idx):
        index = batch['input_ids']
        output = self.bert_sc(**batch)
        test_loss = output.loss
        
        labels = batch.pop('labels') # バッチからラベルを取得
        labels_predicted = output.logits.argmax(-1)
        performance = {'loss':test_loss,
                'batch_preds': labels_predicted,
                'batch_labels': labels,
                'index':index}
        
        return performance
        
        
    def test_epoch_end(self, outputs):
        # accuracy計算
        epoch_preds = torch.cat([x['batch_preds'] for x in outputs])
        epoch_labels = torch.cat([x['batch_labels'] for x in outputs])
        epoch_index = torch.cat([x['index'] for x in outputs])
        pd.DataFrame([epoch_preds.tolist(),epoch_labels.tolist(),epoch_index.tolist()],index=['preds','labels','index']).to_csv('fold_'+str(fold)+'.csv')
        
        return self.training_epoch_end(outputs, "test")
    
    # 学習に用いるオプティマイザを返す関数を書く。
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

In [10]:
# PyTorch Lightningモデルのロード
model = BertForSequenceClassification_pl(
    MODEL_NAME, num_labels=6, lr=1e-6
)

11/17/2022 02:58:48 - INFO - torch.distributed.nn.jit.instantiator -   Created a temporary directory at /tmp/tmpkgp96j87
11/17/2022 02:58:48 - INFO - torch.distributed.nn.jit.instantiator -   Writing /tmp/tmpkgp96j87/_remote_module_non_sriptable.py
loading configuration file ../prot_bert_bfd/config.json
Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.0,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2",
    "3": "LABEL_3",
    "4": "LABEL_4",
    "5": "LABEL_5"
  },
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2,
    "LABEL_3": 3,
    "LABEL_4": 4,
    "LABEL_5": 5
  },
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 40000,
  "model_type": "bert",
  "num_attention_heads": 16,
  "num_hidden_layers": 30,
  "pad

In [11]:
# 学習時にモデルの重みを保存する条件を指定
checkpoint_l1 = pl.callbacks.ModelCheckpoint(
    monitor='train_loss',
    mode='min',
    save_top_k=1,
    save_weights_only=True,
    dirpath='model_l1/',
)

In [12]:
# 学習の方法を指定
trainer_l1 = pl.Trainer(
    #early_stop_callback=True,
    gpus=[cuda_id],
    max_epochs=12,
    callbacks = [checkpoint_l1]
)

11/17/2022 02:58:52 - INFO - pytorch_lightning.utilities.rank_zero -   GPU available: True, used: True
11/17/2022 02:58:52 - INFO - pytorch_lightning.utilities.rank_zero -   TPU available: False, using: 0 TPU cores
11/17/2022 02:58:52 - INFO - pytorch_lightning.utilities.rank_zero -   IPU available: False, using: 0 IPUs
11/17/2022 02:58:52 - INFO - pytorch_lightning.utilities.rank_zero -   HPU available: False, using: 0 HPUs


In [13]:
# ファインチューニングを行う。
trainer_l1.fit(model, train_dataloaders = dataloader_train_l1)

  rank_zero_warn("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.")
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
11/17/2022 02:58:52 - INFO - pytorch_lightning.accelerators.gpu -   LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
11/17/2022 02:58:52 - INFO - pytorch_lightning.callbacks.model_summary -   
  | Name    | Type                          | Params
----------------------------------------------------------
0 | bert_sc | BertForSequenceClassification | 419 M 
----------------------------------------------------------
419 M     Trainable params
0         Non-trainable params
419 M     Total params
1,679.749 Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

In [14]:
best_model_path = checkpoint_l1.best_model_path # ベストモデルのファイル
print('l1ベストモデルのファイル: ', checkpoint_l1.best_model_path)
print('l1ベストモデルの検証データに対する損失: ', checkpoint_l1.best_model_score)

l1ベストモデルのファイル:  /home/__bert/__new/model_l1/epoch=11-step=10728.ckpt
l1ベストモデルの検証データに対する損失:  tensor(0.0315, device='cuda:0')


In [19]:
test = trainer_l1.test(dataloaders=dataloader_test_l1)
print('Level1: ',test)

11/17/2022 09:58:38 - INFO - pytorch_lightning.utilities.rank_zero -   Restoring states from the checkpoint path at /home/__bert/__new/model_l1/epoch=11-step=10728.ckpt
11/17/2022 09:58:39 - INFO - pytorch_lightning.accelerators.gpu -   LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
11/17/2022 09:58:39 - INFO - pytorch_lightning.utilities.rank_zero -   Loaded model weights from checkpoint at /home/__bert/__new/model_l1/epoch=11-step=10728.ckpt


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        Accuracy            0.9462786912918091
        Precision           0.7192845344543457
         Recall             0.6732664108276367
      test_accuracy         0.9462786912918091
        test_loss           0.19166414439678192
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Level1:  [{'test_loss': 0.19166414439678192, 'test_accuracy': 0.9462786912918091, 'Accuracy': 0.9462786912918091, 'Precision': 0.7192845344543457, 'Recall': 0.6732664108276367}]


In [16]:
# PyTorch Lightningモデルのロード
model_l1 = BertForSequenceClassification_pl.load_from_checkpoint(
    best_model_path
) 

loading configuration file ../prot_bert_bfd/config.json
Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.0,
  "classifier_dropout": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2",
    "3": "LABEL_3",
    "4": "LABEL_4",
    "5": "LABEL_5"
  },
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2,
    "LABEL_3": 3,
    "LABEL_4": 4,
    "LABEL_5": 5
  },
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 40000,
  "model_type": "bert",
  "num_attention_heads": 16,
  "num_hidden_layers": 30,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.19.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30
}

loading weights file ../prot_bert_bfd/pytorch_model.bin
Some weights of the model checkpoi

In [17]:
# Transformers対応のモデルを./model_transformersに保存
model_l1.bert_sc.save_pretrained('./model_transformers_l1' )

Configuration saved in ./model_transformers_l1/config.json
Model weights saved in ./model_transformers_l1/pytorch_model.bin


In [20]:
result = pd.read_csv('fold_'+str(fold)+'.csv',index_col=0).T
result

Unnamed: 0,preds,labels,index
0,0,0,"[2, 21, 15, 13, 12, 5, 17, 18, 13, 17, 9, 21, ..."
1,0,0,"[2, 6, 6, 5, 13, 14, 12, 22, 16, 6, 13, 5, 6, ..."
2,1,1,"[2, 21, 11, 8, 10, 5, 5, 8, 6, 21, 14, 9, 12, ..."
3,4,4,"[2, 21, 9, 7, 15, 10, 12, 6, 9, 8, 18, 12, 18,..."
4,0,0,"[2, 21, 12, 18, 20, 5, 14, 5, 21, 12, 22, 8, 5..."
...,...,...,...
1782,1,1,"[2, 21, 11, 11, 10, 6, 11, 8, 6, 21, 10, 18, 1..."
1783,4,4,"[2, 21, 10, 6, 11, 6, 16, 10, 16, 6, 6, 16, 7,..."
1784,1,1,"[2, 21, 12, 17, 5, 5, 8, 6, 20, 14, 18, 17, 22..."
1785,3,3,"[2, 21, 12, 5, 16, 5, 15, 15, 16, 13, 5, 5, 5,..."


In [21]:
import ast

In [22]:
# predとlabelsを数字から該当するメカニズムに置換
def mechanism(n):
    return l1_list[int(n)]

result['preds'] = list(map(mechanism, result['preds']))
result['labels'] = list(map(mechanism, result['labels']))

# str型になっているindex列の要素をlist型に
result['index'] = list(map(lambda s : ast.literal_eval(s), result['index']))
result

Unnamed: 0,preds,labels,index
0,antibiotic target alteration,antibiotic target alteration,"[2, 21, 15, 13, 12, 5, 17, 18, 13, 17, 9, 21, ..."
1,antibiotic target alteration,antibiotic target alteration,"[2, 6, 6, 5, 13, 14, 12, 22, 16, 6, 13, 5, 6, ..."
2,antibiotic target replacement,antibiotic target replacement,"[2, 21, 11, 8, 10, 5, 5, 8, 6, 21, 14, 9, 12, ..."
3,antibiotic efflux,antibiotic efflux,"[2, 21, 9, 7, 15, 10, 12, 6, 9, 8, 18, 12, 18,..."
4,antibiotic target alteration,antibiotic target alteration,"[2, 21, 12, 18, 20, 5, 14, 5, 21, 12, 22, 8, 5..."
...,...,...,...
1782,antibiotic target replacement,antibiotic target replacement,"[2, 21, 11, 11, 10, 6, 11, 8, 6, 21, 10, 18, 1..."
1783,antibiotic efflux,antibiotic efflux,"[2, 21, 10, 6, 11, 6, 16, 10, 16, 6, 6, 16, 7,..."
1784,antibiotic target replacement,antibiotic target replacement,"[2, 21, 12, 17, 5, 5, 8, 6, 20, 14, 18, 17, 22..."
1785,antibiotic inactivation,antibiotic inactivation,"[2, 21, 12, 5, 16, 5, 15, 15, 16, 13, 5, 5, 5,..."


In [23]:
# 該当する配列を元のデータセットから取り出してくる

#該当する配列の元のデータセットにおいてのindexを取得
result['seq_num'] = 0
for i in result.index:
    result.at[str(i),'seq_num'] = l1_test_index.index(result.at[str(i),'index'])

result.index = result['seq_num'].to_list()
del result['index']
del result['seq_num']
result

Unnamed: 0,preds,labels
1243,antibiotic target alteration,antibiotic target alteration
1293,antibiotic target alteration,antibiotic target alteration
1730,antibiotic target replacement,antibiotic target replacement
678,antibiotic efflux,antibiotic efflux
1058,antibiotic target alteration,antibiotic target alteration
...,...,...
1607,antibiotic target replacement,antibiotic target replacement
609,antibiotic efflux,antibiotic efflux
1633,antibiotic target replacement,antibiotic target replacement
481,antibiotic inactivation,antibiotic inactivation


In [24]:
# 元のデータセットから該当する配列についての情報を取得
cols = ['ID','FEATURES','database','target','subclass','mechanism','transferable','sequence']
result[cols] = ''
for i in result.index:
    result.loc[i,cols] = df_l1_test.iloc[i,0:8]
result

Unnamed: 0,preds,labels,ID,FEATURES,database,target,subclass,mechanism,transferable,sequence
1243,antibiotic target alteration,antibiotic target alteration,AIA16137.1,FEATURES,farme,glycopeptide,vanG,antibiotic target alteration,1,M T R K L N Q R N E M M R T T I L F G G T N K ...
1293,antibiotic target alteration,antibiotic target alteration,AIA13354.1,FEATURES,farme,glycopeptide,vanD,antibiotic target alteration,0,A A L R D K H P A R L A L P S G T N E S A P L ...
1730,antibiotic target replacement,antibiotic target replacement,AIA12221.1,FEATURES,farme,trimethoprim,dfrD,antibiotic target replacement,1,M I V S L L V A M D E K R G I G K D G G L P W ...
678,antibiotic efflux,antibiotic efflux,AIA12814.1,FEATURES,farme,multidrug,Mdr,antibiotic efflux,1,M E G T S K A E V Q K Q D T E K P R L E M S S ...
1058,antibiotic target alteration,antibiotic target alteration,AIA14118.1,FEATURES,farme,trimethoprim,dfrE,antibiotic target alteration,0,M K Q Y L D L M K H V L D N G T K K D D R T H ...
...,...,...,...,...,...,...,...,...,...,...
1607,antibiotic target replacement,antibiotic target replacement,AIA17994.1,FEATURES,farme,trimethoprim,dfrK,antibiotic target replacement,1,M I I S A I V A M S Q N R V I G V N N Q L P W ...
609,antibiotic efflux,antibiotic efflux,AIA14437.1,FEATURES,farme,tetracycline,tetB(48),antibiotic efflux,0,M S A I A P S P A A P G R I V P P S E R R L K ...
1633,antibiotic target replacement,antibiotic target replacement,AIA17095.1,FEATURES,farme,trimethoprim,dfrA3,antibiotic target replacement,1,M K N L L V A Y D Q N H G I G A T G D L L W R ...
481,antibiotic inactivation,antibiotic inactivation,AMJ38222.1,FEATURES,farme,aminoglycoside,AAC(6')-Isa,antibiotic inactivation,0,M K L P L T T P R L L L R R F R T E D L P S F ...


In [25]:
result.to_csv('fold_'+str(fold)+'_result.csv')

In [26]:
result

Unnamed: 0,preds,labels,ID,FEATURES,database,target,subclass,mechanism,transferable,sequence
1243,antibiotic target alteration,antibiotic target alteration,AIA16137.1,FEATURES,farme,glycopeptide,vanG,antibiotic target alteration,1,M T R K L N Q R N E M M R T T I L F G G T N K ...
1293,antibiotic target alteration,antibiotic target alteration,AIA13354.1,FEATURES,farme,glycopeptide,vanD,antibiotic target alteration,0,A A L R D K H P A R L A L P S G T N E S A P L ...
1730,antibiotic target replacement,antibiotic target replacement,AIA12221.1,FEATURES,farme,trimethoprim,dfrD,antibiotic target replacement,1,M I V S L L V A M D E K R G I G K D G G L P W ...
678,antibiotic efflux,antibiotic efflux,AIA12814.1,FEATURES,farme,multidrug,Mdr,antibiotic efflux,1,M E G T S K A E V Q K Q D T E K P R L E M S S ...
1058,antibiotic target alteration,antibiotic target alteration,AIA14118.1,FEATURES,farme,trimethoprim,dfrE,antibiotic target alteration,0,M K Q Y L D L M K H V L D N G T K K D D R T H ...
...,...,...,...,...,...,...,...,...,...,...
1607,antibiotic target replacement,antibiotic target replacement,AIA17994.1,FEATURES,farme,trimethoprim,dfrK,antibiotic target replacement,1,M I I S A I V A M S Q N R V I G V N N Q L P W ...
609,antibiotic efflux,antibiotic efflux,AIA14437.1,FEATURES,farme,tetracycline,tetB(48),antibiotic efflux,0,M S A I A P S P A A P G R I V P P S E R R L K ...
1633,antibiotic target replacement,antibiotic target replacement,AIA17095.1,FEATURES,farme,trimethoprim,dfrA3,antibiotic target replacement,1,M K N L L V A Y D Q N H G I G A T G D L L W R ...
481,antibiotic inactivation,antibiotic inactivation,AMJ38222.1,FEATURES,farme,aminoglycoside,AAC(6')-Isa,antibiotic inactivation,0,M K L P L T T P R L L L R R F R T E D L P S F ...


In [32]:
from sklearn.metrics import precision_score
from sklearn.metrics import confusion_matrix

In [31]:
precision_score(result['labels'], result['preds'],average=None)

  _warn_prf(average, modifier, msg_start, len(result))


array([0.87043189, 0.98127341, 0.97807018, 0.6       , 0.88593156,
       0.        ])

In [34]:
confusion_matrix(result['labels'], result['preds'])

array([[262,   1,   0,   0,   0,   0],
       [ 22, 524,   9,   0,   4,   0],
       [  7,   8, 669,   0,  26,   0],
       [ 10,   1,   3,   3,   0,   0],
       [  0,   0,   3,   0, 233,   0],
       [  0,   0,   0,   2,   0,   0]])