## Imports

In [1]:
from model_fct import ProteinClassifier, ProteinDataModule, ProteinSequenceDataset
import os
import torch
from torch import nn
import torch.utils.data
import torch.utils.data.distributed
from torch.utils.data import Dataset, DataLoader, RandomSampler, TensorDataset
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torchmetrics
from pytorch_lightning.accelerators import MPSAccelerator
from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup
from torchmetrics.classification import MulticlassAUROC, MulticlassAccuracy, MultilabelF1Score
from torchmetrics import Recall, Precision

from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning import Trainer, seed_everything
import datetime
from datetime import datetime
#from pytorch_lightning.metrics.sklearns import Accuracy

import torchvision

%load_ext autoreload
%autoreload 2

In [2]:
import platform
platform.processor()

'arm'

In [3]:
train_df = pd.read_pickle('train_df.pkl')
test_df = pd.read_pickle('test_df.pkl')
val_df = pd.read_pickle('val_df.pkl')
blind_df = pd.read_pickle('blind_df.pkl')

## Logger and checkpoint

In [4]:
def setup_testube_logger() -> CSVLogger:
    """ Function that sets the TestTubeLogger to be used. """
    now = datetime.now()
    dt_string = now.strftime("%d-%m-%Y--%H-%M-%S")

    return CSVLogger(
        save_dir="experiments/",
        version=dt_string,
        name="lightning_logs",
    )

logger = setup_testube_logger()

In [5]:
ckpt_path = os.path.join(
    logger.save_dir,
    logger.name,
    f"version_{logger.version}",
    "checkpoints",
)

c = ModelCheckpoint(
    dirpath=ckpt_path + "/" + "tanh_3epochs",
    verbose=True,
    monitor='val_acc',
    mode="max",
)

## Set up experiment

In [6]:
TARGETS = ['cyto', 'mito', 'nucleus','other', 'secreted']
PRE_TRAINED_MODEL_NAME = 'Rostlab/prot_bert_bfd_localization'
#PRE_TRAINED_MODEL_NAME = 'Rostlab/prot_bert_bfd'
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME, do_lower_case=False)

EPOCHS = 1
BATCH_SIZE = 1
MAX_LENGTH = 1500

In [7]:
dm = ProteinDataModule(
    train_df, 
    test_df,
    val_df,
    blind_df,
    tokenizer, 
    target_list=TARGETS,
    batch_size=BATCH_SIZE,
    max_len = MAX_LENGTH
)

model = ProteinClassifier(
    n_classes=5,
    target_list=TARGETS,
    steps_per_epoch=len(train_df)//BATCH_SIZE, 
    n_epochs=EPOCHS
)

Some weights of the model checkpoint at Rostlab/prot_bert_bfd were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [8]:
trainer = pl.Trainer(max_epochs=EPOCHS,
                     logger=logger,
                     accelerator='mps',
                     #callbacks = checkpoint_callback
                     default_root_dir='experiments/lightning_logs'
                    )

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [9]:
trainer.fit(model, dm)


  | Name       | Type             | Params
------------------------------------------------
0 | bert       | BertModel        | 419 M 
1 | classifier | Sequential       | 5.1 K 
2 | criterion  | CrossEntropyLoss | 0     
------------------------------------------------
419 M     Trainable params
0         Non-trainable params
419 M     Total params
1,679.745 Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  input = module(input)


the accuracy is 0.00
the precision is 0.00
the recall is 0.00
the f1 is 0.00
   precision  recall   f1  accuracy  num_samples
0        0.0     0.0  0.0       1.0            0
1        0.0     0.0  0.0       0.0            1
2        0.0     0.0  0.0       0.0            1
[[0 0 0]
 [1 0 0]
 [1 0 0]]


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Training: 0it [00:00, ?it/s]

  input = module(input)


Validation: 0it [00:00, ?it/s]

the accuracy is 0.80
the precision is 0.82
the recall is 0.81
the f1 is 0.82
   precision    recall        f1  accuracy  num_samples
0   0.669078  0.734127  0.700095  0.734127          504
1   0.787565  0.763819  0.775510  0.763819          199
2   0.796000  0.766859  0.781158  0.766859          519
3   0.935484  0.914826  0.925040  0.914826          317
4   0.929752  0.868726  0.898204  0.868726          259
[[370  24  93   9   8]
 [ 38 152   4   1   4]
 [100  11 398   6   4]
 [ 20   5   1 290   1]
 [ 25   1   4   4 225]]


`Trainer.fit` stopped: `max_epochs=1` reached.


In [10]:
trainer.test(dataloaders=dm)

  rank_zero_warn(
Restoring states from the checkpoint path at experiments/lightning_logs/22-02-2023--01-20-59/checkpoints/epoch=0-step=7183.ckpt
Loaded model weights from checkpoint at experiments/lightning_logs/22-02-2023--01-20-59/checkpoints/epoch=0-step=7183.ckpt


Testing: 0it [00:00, ?it/s]

  input = module(input)
  nonzero_finite_vals = torch.masked_select(


tensor([[9.6617e-04, 6.0482e-04, 7.4642e-04, 9.9728e-01, 4.0329e-04]],
       device='mps:0')
tensor([[9.8257e-01, 6.4221e-04, 1.6066e-02, 3.8707e-04, 3.3912e-04]],
       device='mps:0')
tensor([[9.9002e-01, 2.3752e-03, 5.8748e-03, 7.6090e-04, 9.6957e-04]],
       device='mps:0')
tensor([[1.2960e-03, 8.5366e-04, 7.1625e-04, 4.3510e-04, 9.9670e-01]],
       device='mps:0')
tensor([[9.9276e-01, 1.5711e-03, 4.6022e-03, 6.0705e-04, 4.5514e-04]],
       device='mps:0')
tensor([[2.8561e-03, 9.9470e-01, 3.5098e-04, 1.0107e-03, 1.0789e-03]],
       device='mps:0')
tensor([[9.9085e-01, 1.7624e-03, 6.2517e-03, 5.2997e-04, 6.0383e-04]],
       device='mps:0')
tensor([[3.7275e-03, 9.6520e-05, 9.9522e-01, 3.2948e-04, 6.3111e-04]],
       device='mps:0')
tensor([[0.8765, 0.0995, 0.0096, 0.0078, 0.0066]], device='mps:0')
tensor([[7.3900e-04, 5.0271e-04, 8.1267e-04, 9.9757e-01, 3.7761e-04]],
       device='mps:0')
tensor([[2.0120e-03, 1.9968e-03, 8.8040e-04, 9.9476e-01, 3.5046e-04]],
       device='m

tensor([[3.7437e-03, 8.6483e-05, 9.9539e-01, 3.2393e-04, 4.5886e-04]],
       device='mps:0')
tensor([[3.5470e-03, 8.7847e-05, 9.9550e-01, 3.2481e-04, 5.4533e-04]],
       device='mps:0')
tensor([[3.7561e-03, 1.0185e-04, 9.9540e-01, 3.7348e-04, 3.6394e-04]],
       device='mps:0')
tensor([[1.0500e-03, 6.3856e-04, 9.8106e-04, 9.9690e-01, 4.3520e-04]],
       device='mps:0')
tensor([[4.8170e-03, 1.1812e-04, 9.9437e-01, 3.8703e-04, 3.0483e-04]],
       device='mps:0')
tensor([[3.4902e-03, 9.0883e-05, 9.9557e-01, 3.4242e-04, 5.1073e-04]],
       device='mps:0')
tensor([[9.9120e-01, 1.5869e-03, 6.0701e-03, 5.8418e-04, 5.6144e-04]],
       device='mps:0')
tensor([[2.8685e-03, 9.9470e-01, 3.7228e-04, 8.9374e-04, 1.1614e-03]],
       device='mps:0')
tensor([[9.9043e-01, 4.5972e-03, 2.9958e-03, 1.4087e-03, 5.6636e-04]],
       device='mps:0')
tensor([[1.5607e-02, 1.5795e-04, 9.8353e-01, 4.0203e-04, 3.0283e-04]],
       device='mps:0')
tensor([[3.4714e-03, 8.6356e-05, 9.9563e-01, 3.2750e-04, 4.8

tensor([[9.9295e-01, 1.2349e-03, 4.7111e-03, 7.0890e-04, 3.9425e-04]],
       device='mps:0')
tensor([[1.1324e-03, 6.1283e-04, 9.6141e-04, 9.9696e-01, 3.3046e-04]],
       device='mps:0')
tensor([[3.4665e-03, 8.3882e-05, 9.9572e-01, 3.2264e-04, 4.0459e-04]],
       device='mps:0')
tensor([[3.7418e-03, 8.8912e-05, 9.9542e-01, 3.5763e-04, 3.8966e-04]],
       device='mps:0')
tensor([[9.9246e-01, 1.3274e-03, 5.1887e-03, 6.5082e-04, 3.7011e-04]],
       device='mps:0')
tensor([[2.3251e-02, 2.3731e-04, 9.7545e-01, 6.2489e-04, 4.3474e-04]],
       device='mps:0')
tensor([[5.3819e-03, 9.3032e-05, 9.9377e-01, 3.2848e-04, 4.3010e-04]],
       device='mps:0')
tensor([[3.8237e-03, 8.7030e-05, 9.9542e-01, 3.8134e-04, 2.9068e-04]],
       device='mps:0')
tensor([[9.9189e-01, 9.6234e-04, 6.1881e-03, 5.7121e-04, 3.8934e-04]],
       device='mps:0')
tensor([[9.9122e-01, 2.2336e-03, 5.1317e-03, 6.8311e-04, 7.3638e-04]],
       device='mps:0')
tensor([[2.8813e-03, 9.9464e-01, 3.7525e-04, 8.9432e-04, 1.2

tensor([[3.0128e-03, 9.9450e-01, 3.9156e-04, 8.9157e-04, 1.2011e-03]],
       device='mps:0')
tensor([[8.6144e-04, 6.2820e-04, 6.9738e-04, 9.9745e-01, 3.6643e-04]],
       device='mps:0')
tensor([[9.1137e-04, 5.0891e-04, 6.9643e-04, 9.9751e-01, 3.6929e-04]],
       device='mps:0')
tensor([[3.8250e-03, 9.8806e-05, 9.9520e-01, 3.7439e-04, 5.0437e-04]],
       device='mps:0')
tensor([[9.9056e-01, 1.9425e-03, 6.2500e-03, 6.7037e-04, 5.7342e-04]],
       device='mps:0')
tensor([[9.9146e-01, 1.5032e-03, 5.6484e-03, 7.2548e-04, 6.6475e-04]],
       device='mps:0')
tensor([[1.9629e-03, 9.5825e-04, 7.6505e-04, 4.4602e-04, 9.9587e-01]],
       device='mps:0')
tensor([[4.5510e-01, 4.6461e-04, 5.4351e-01, 4.7566e-04, 4.5224e-04]],
       device='mps:0')
tensor([[4.3205e-03, 9.7435e-05, 9.9489e-01, 4.3916e-04, 2.5346e-04]],
       device='mps:0')
tensor([[9.9228e-01, 2.6395e-03, 3.8963e-03, 6.9146e-04, 4.8867e-04]],
       device='mps:0')
tensor([[9.9286e-01, 1.1340e-03, 4.9279e-03, 6.9448e-04, 3.8

tensor([[9.9305e-01, 1.5630e-03, 4.3636e-03, 6.0056e-04, 4.1938e-04]],
       device='mps:0')
tensor([[1.1658e-03, 7.3750e-04, 6.8942e-04, 9.9708e-01, 3.2679e-04]],
       device='mps:0')
tensor([[9.9058e-01, 9.8656e-04, 7.5040e-03, 4.5653e-04, 4.6923e-04]],
       device='mps:0')
tensor([[1.3172e-03, 7.8657e-04, 8.2156e-04, 4.2193e-04, 9.9665e-01]],
       device='mps:0')
tensor([[9.9301e-01, 1.7631e-03, 4.0812e-03, 6.7067e-04, 4.7520e-04]],
       device='mps:0')
tensor([[4.1479e-03, 9.1502e-05, 9.9506e-01, 4.0340e-04, 2.9865e-04]],
       device='mps:0')
tensor([[6.0350e-03, 1.1542e-04, 9.9316e-01, 3.8815e-04, 3.0615e-04]],
       device='mps:0')
tensor([[9.9150e-01, 2.4602e-03, 5.0180e-03, 5.0284e-04, 5.1911e-04]],
       device='mps:0')
tensor([[1.0444e-03, 5.2938e-04, 6.8066e-04, 9.9740e-01, 3.4877e-04]],
       device='mps:0')
tensor([[9.9237e-01, 1.9142e-03, 4.5439e-03, 6.1353e-04, 5.5702e-04]],
       device='mps:0')
tensor([[0.9871, 0.0059, 0.0039, 0.0012, 0.0020]], device='m

tensor([[6.3801e-03, 8.1822e-04, 1.4662e-03, 9.9096e-01, 3.7783e-04]],
       device='mps:0')
tensor([[5.7559e-03, 1.4813e-03, 1.5515e-03, 5.8786e-04, 9.9062e-01]],
       device='mps:0')
tensor([[3.0120e-03, 9.9453e-01, 3.9617e-04, 8.6080e-04, 1.1990e-03]],
       device='mps:0')
tensor([[1.0160e-03, 6.2634e-04, 7.9486e-04, 9.9718e-01, 3.7972e-04]],
       device='mps:0')
tensor([[1.3999e-03, 9.6772e-04, 5.8975e-04, 4.3484e-04, 9.9661e-01]],
       device='mps:0')
tensor([[1.3104e-03, 8.6196e-04, 6.9288e-04, 4.1764e-04, 9.9672e-01]],
       device='mps:0')
tensor([[9.8533e-01, 8.0697e-04, 1.2995e-02, 4.5412e-04, 4.1396e-04]],
       device='mps:0')
tensor([[4.2730e-02, 3.8162e-04, 9.5552e-01, 8.0266e-04, 5.6283e-04]],
       device='mps:0')
tensor([[9.9293e-01, 1.5593e-03, 4.5111e-03, 5.7917e-04, 4.1778e-04]],
       device='mps:0')
tensor([[1.6251e-03, 1.2454e-03, 5.4926e-04, 5.1298e-04, 9.9607e-01]],
       device='mps:0')
tensor([[5.1656e-03, 1.2160e-03, 1.5255e-03, 4.7933e-04, 9.9

tensor([[9.9260e-01, 1.8559e-03, 4.3384e-03, 6.8938e-04, 5.1586e-04]],
       device='mps:0')
tensor([[4.4286e-03, 1.1806e-04, 9.9452e-01, 5.9477e-04, 3.3553e-04]],
       device='mps:0')
tensor([[9.9058e-01, 1.4446e-03, 6.8510e-03, 5.7685e-04, 5.4800e-04]],
       device='mps:0')
tensor([[9.9120e-01, 1.3680e-03, 6.3237e-03, 5.5085e-04, 5.5694e-04]],
       device='mps:0')
tensor([[0.9816, 0.0047, 0.0095, 0.0032, 0.0010]], device='mps:0')
tensor([[3.0945e-03, 9.6615e-04, 1.2535e-03, 4.4988e-04, 9.9424e-01]],
       device='mps:0')
tensor([[5.3887e-03, 9.9166e-01, 3.7821e-04, 1.0099e-03, 1.5646e-03]],
       device='mps:0')
tensor([[9.9310e-01, 1.1881e-03, 4.6951e-03, 6.2361e-04, 3.9596e-04]],
       device='mps:0')
tensor([[3.0286e-03, 9.9432e-01, 3.9125e-04, 9.8933e-04, 1.2702e-03]],
       device='mps:0')
tensor([[1.2402e-03, 1.4003e-03, 6.4769e-04, 9.9631e-01, 4.0488e-04]],
       device='mps:0')
tensor([[3.2692e-03, 1.0339e-03, 1.4922e-03, 9.9381e-01, 3.9702e-04]],
       device='m

tensor([[2.8759e-03, 9.9464e-01, 3.9227e-04, 9.4499e-04, 1.1437e-03]],
       device='mps:0')
tensor([[1.3766e-03, 7.5034e-04, 1.0060e-03, 9.9649e-01, 3.8063e-04]],
       device='mps:0')
tensor([[1.2370e-03, 7.5506e-04, 8.1308e-04, 4.3896e-04, 9.9676e-01]],
       device='mps:0')
tensor([[0.1298, 0.0093, 0.0063, 0.0046, 0.8500]], device='mps:0')
tensor([[9.9282e-01, 1.4146e-03, 4.3133e-03, 9.7070e-04, 4.7966e-04]],
       device='mps:0')
tensor([[9.9249e-01, 1.2032e-03, 5.2813e-03, 6.1660e-04, 4.0728e-04]],
       device='mps:0')
tensor([[9.2617e-04, 4.9715e-04, 8.3709e-04, 9.9736e-01, 3.8209e-04]],
       device='mps:0')
tensor([[1.0973e-03, 8.1926e-04, 6.0835e-04, 9.9707e-01, 4.0154e-04]],
       device='mps:0')
tensor([[9.9284e-01, 2.1074e-03, 3.8097e-03, 7.7293e-04, 4.7456e-04]],
       device='mps:0')
tensor([[9.6063e-04, 6.7352e-04, 6.6293e-04, 9.9737e-01, 3.3480e-04]],
       device='mps:0')
tensor([[1.7429e-03, 6.0007e-04, 1.1400e-03, 9.9622e-01, 2.9567e-04]],
       device='m

tensor([[0.8856, 0.1081, 0.0030, 0.0014, 0.0018]], device='mps:0')
tensor([[3.4427e-03, 8.5730e-05, 9.9578e-01, 3.7823e-04, 3.1343e-04]],
       device='mps:0')
tensor([[9.7313e-01, 5.5273e-04, 2.5624e-02, 4.1405e-04, 2.7780e-04]],
       device='mps:0')
tensor([[1.2353e-03, 9.3863e-04, 6.7558e-04, 4.4894e-04, 9.9670e-01]],
       device='mps:0')
tensor([[9.7652e-01, 5.7065e-04, 2.2252e-02, 3.9527e-04, 2.6646e-04]],
       device='mps:0')
tensor([[2.9957e-03, 9.9427e-01, 4.3574e-04, 9.5917e-04, 1.3420e-03]],
       device='mps:0')
tensor([[3.1226e-03, 9.9410e-01, 3.6050e-04, 1.1093e-03, 1.3077e-03]],
       device='mps:0')
tensor([[1.5111e-03, 8.0775e-04, 7.7113e-04, 4.4711e-04, 9.9646e-01]],
       device='mps:0')
tensor([[9.8969e-01, 2.0603e-03, 6.6040e-03, 7.7681e-04, 8.7305e-04]],
       device='mps:0')
tensor([[9.9259e-01, 1.3392e-03, 5.1021e-03, 5.6054e-04, 4.0855e-04]],
       device='mps:0')
tensor([[8.6782e-04, 1.0922e-03, 6.0884e-04, 9.9705e-01, 3.7874e-04]],
       device='m

tensor([[3.8292e-03, 8.6769e-05, 9.9542e-01, 3.3196e-04, 3.3400e-04]],
       device='mps:0')
tensor([[9.9252e-01, 1.7432e-03, 4.6019e-03, 5.8340e-04, 5.5617e-04]],
       device='mps:0')
tensor([[9.6326e-01, 3.1887e-02, 2.7695e-03, 8.8598e-04, 1.1948e-03]],
       device='mps:0')
tensor([[9.9225e-01, 1.1704e-03, 5.3990e-03, 8.2330e-04, 3.6224e-04]],
       device='mps:0')
tensor([[2.8077e-03, 9.9473e-01, 3.9076e-04, 8.9963e-04, 1.1754e-03]],
       device='mps:0')
tensor([[2.8908e-03, 9.9465e-01, 3.7961e-04, 9.6757e-04, 1.1113e-03]],
       device='mps:0')
tensor([[9.9197e-01, 2.1761e-03, 4.7238e-03, 6.6638e-04, 4.6008e-04]],
       device='mps:0')
tensor([[1.9553e-03, 1.2355e-03, 8.4810e-04, 9.9561e-01, 3.5144e-04]],
       device='mps:0')
tensor([[3.9695e-03, 9.1717e-05, 9.9523e-01, 3.8401e-04, 3.2317e-04]],
       device='mps:0')
tensor([[6.6707e-03, 1.6517e-03, 1.4065e-03, 5.9878e-04, 9.8967e-01]],
       device='mps:0')
tensor([[9.9146e-01, 1.8056e-03, 5.4037e-03, 7.2601e-04, 6.0

tensor([[9.8350e-01, 6.8952e-04, 1.5089e-02, 3.8973e-04, 3.3006e-04]],
       device='mps:0')
tensor([[4.3115e-03, 9.0515e-05, 9.9486e-01, 3.3223e-04, 4.0461e-04]],
       device='mps:0')
tensor([[5.6901e-03, 1.1887e-04, 9.9334e-01, 4.9601e-04, 3.5810e-04]],
       device='mps:0')
tensor([[1.9896e-02, 1.7274e-04, 9.7915e-01, 4.4871e-04, 3.3490e-04]],
       device='mps:0')
tensor([[5.4749e-03, 1.4834e-03, 1.5414e-03, 6.1323e-04, 9.9089e-01]],
       device='mps:0')
tensor([[4.0284e-03, 8.7498e-05, 9.9514e-01, 3.4131e-04, 4.0000e-04]],
       device='mps:0')
tensor([[9.9199e-01, 1.3612e-03, 5.7120e-03, 4.6742e-04, 4.7115e-04]],
       device='mps:0')
tensor([[3.5290e-03, 9.0851e-05, 9.9548e-01, 3.2508e-04, 5.7165e-04]],
       device='mps:0')
tensor([[9.9242e-01, 1.4784e-03, 4.9288e-03, 6.4137e-04, 5.3304e-04]],
       device='mps:0')
tensor([[3.4369e-03, 8.8227e-05, 9.9562e-01, 3.3386e-04, 5.2037e-04]],
       device='mps:0')
tensor([[5.4774e-03, 9.8979e-05, 9.9365e-01, 3.3851e-04, 4.4

tensor([[3.6407e-03, 8.4805e-05, 9.9554e-01, 3.4623e-04, 3.8768e-04]],
       device='mps:0')
tensor([[3.0178e-03, 9.9452e-01, 3.7037e-04, 9.7437e-04, 1.1208e-03]],
       device='mps:0')
tensor([[9.9289e-01, 1.8864e-03, 3.9293e-03, 8.9054e-04, 4.0675e-04]],
       device='mps:0')
tensor([[0.0089, 0.0024, 0.0034, 0.0011, 0.9842]], device='mps:0')
tensor([[9.9289e-01, 1.4491e-03, 4.4633e-03, 7.6486e-04, 4.2826e-04]],
       device='mps:0')
tensor([[4.8173e-03, 1.4237e-03, 9.4464e-04, 9.9227e-01, 5.4806e-04]],
       device='mps:0')
tensor([[3.5857e-03, 8.9333e-05, 9.9542e-01, 3.2279e-04, 5.7773e-04]],
       device='mps:0')
tensor([[9.9251e-01, 1.9233e-03, 4.4101e-03, 5.5556e-04, 5.9884e-04]],
       device='mps:0')
tensor([[1.5830e-03, 1.0196e-03, 6.0932e-04, 4.1900e-04, 9.9637e-01]],
       device='mps:0')
tensor([[6.6052e-03, 1.3940e-03, 1.3659e-03, 5.5574e-04, 9.9008e-01]],
       device='mps:0')
tensor([[9.9244e-01, 1.6018e-03, 4.9002e-03, 5.9460e-04, 4.6309e-04]],
       device='m

tensor([[3.6825e-03, 1.0215e-04, 9.9544e-01, 4.4611e-04, 3.2566e-04]],
       device='mps:0')
tensor([[8.7407e-04, 9.4809e-04, 6.9954e-04, 9.9709e-01, 3.9294e-04]],
       device='mps:0')
tensor([[0.7559, 0.0182, 0.0175, 0.0091, 0.1992]], device='mps:0')
tensor([[9.9184e-01, 3.2883e-03, 3.4429e-03, 9.3214e-04, 4.9780e-04]],
       device='mps:0')
tensor([[9.7204e-01, 7.1996e-04, 2.6507e-02, 3.6636e-04, 3.6619e-04]],
       device='mps:0')
tensor([[4.2468e-03, 9.3741e-04, 9.5691e-04, 9.9346e-01, 4.0252e-04]],
       device='mps:0')
tensor([[9.1797e-04, 6.0562e-04, 6.8732e-04, 9.9744e-01, 3.4486e-04]],
       device='mps:0')
tensor([[4.0073e-03, 9.9330e-01, 3.8572e-04, 9.8492e-04, 1.3247e-03]],
       device='mps:0')
tensor([[9.9268e-01, 1.0645e-03, 5.0871e-03, 8.5791e-04, 3.1298e-04]],
       device='mps:0')
tensor([[2.9292e-03, 9.9466e-01, 3.4416e-04, 8.8842e-04, 1.1830e-03]],
       device='mps:0')
tensor([[1.4247e-01, 3.0741e-04, 8.5649e-01, 3.6721e-04, 3.6128e-04]],
       device='m

tensor([[2.8761e-03, 9.9471e-01, 3.5285e-04, 9.2596e-04, 1.1317e-03]],
       device='mps:0')
tensor([[8.5270e-04, 4.8608e-04, 7.2205e-04, 9.9757e-01, 3.7095e-04]],
       device='mps:0')
tensor([[1.0263e-03, 7.7449e-04, 9.1641e-04, 9.9683e-01, 4.5195e-04]],
       device='mps:0')
tensor([[9.8600e-01, 2.4348e-03, 9.6577e-03, 8.3491e-04, 1.0714e-03]],
       device='mps:0')
tensor([[1.3837e-03, 9.7269e-04, 7.4273e-04, 9.9644e-01, 4.6228e-04]],
       device='mps:0')
tensor([[3.1950e-03, 9.9433e-01, 3.5719e-04, 9.9515e-04, 1.1214e-03]],
       device='mps:0')
tensor([[1.2407e-03, 7.4873e-04, 8.1464e-04, 4.4425e-04, 9.9675e-01]],
       device='mps:0')
tensor([[9.9086e-01, 1.9181e-03, 3.5733e-03, 3.2603e-03, 3.9167e-04]],
       device='mps:0')
tensor([[5.6771e-03, 1.0922e-04, 9.9348e-01, 3.8762e-04, 3.4117e-04]],
       device='mps:0')
tensor([[9.8452e-01, 1.0438e-02, 3.7821e-03, 6.2319e-04, 6.3528e-04]],
       device='mps:0')
tensor([[0.5256, 0.3279, 0.0129, 0.0908, 0.0429]], device='m

tensor([[9.9283e-01, 1.5622e-03, 4.5062e-03, 5.8251e-04, 5.1609e-04]],
       device='mps:0')
tensor([[3.1359e-03, 9.9434e-01, 3.6454e-04, 9.9373e-04, 1.1653e-03]],
       device='mps:0')
tensor([[9.7810e-01, 1.7377e-02, 2.9160e-03, 9.4210e-04, 6.5988e-04]],
       device='mps:0')
tensor([[8.3597e-04, 6.4877e-04, 7.6847e-04, 9.9740e-01, 3.4647e-04]],
       device='mps:0')
tensor([[6.8906e-03, 1.3764e-04, 9.9213e-01, 4.9161e-04, 3.4672e-04]],
       device='mps:0')
tensor([[3.4467e-03, 8.7858e-05, 9.9561e-01, 3.2682e-04, 5.2958e-04]],
       device='mps:0')
tensor([[4.1611e-03, 9.4435e-05, 9.9506e-01, 3.8480e-04, 3.0364e-04]],
       device='mps:0')
tensor([[9.9341e-01, 1.6008e-03, 3.7365e-03, 8.1423e-04, 4.3964e-04]],
       device='mps:0')
tensor([[3.6246e-03, 9.0902e-05, 9.9536e-01, 3.2253e-04, 6.0121e-04]],
       device='mps:0')
tensor([[9.9301e-01, 1.2906e-03, 4.7042e-03, 5.9489e-04, 4.0188e-04]],
       device='mps:0')
tensor([[3.2137e-03, 1.0855e-03, 1.1598e-03, 4.5594e-04, 9.9

tensor([[1.3050e-03, 1.0197e-03, 9.4886e-04, 9.9628e-01, 4.4350e-04]],
       device='mps:0')
tensor([[9.9180e-01, 1.4023e-03, 5.8030e-03, 5.4557e-04, 4.4513e-04]],
       device='mps:0')
tensor([[5.2157e-03, 1.0895e-04, 9.9382e-01, 3.5712e-04, 4.9679e-04]],
       device='mps:0')
tensor([[3.1023e-03, 9.9442e-01, 3.8813e-04, 9.0465e-04, 1.1835e-03]],
       device='mps:0')
tensor([[3.8881e-03, 8.7678e-05, 9.9536e-01, 3.2296e-04, 3.3658e-04]],
       device='mps:0')
tensor([[9.9264e-01, 1.7407e-03, 4.4784e-03, 6.2068e-04, 5.2513e-04]],
       device='mps:0')
tensor([[2.1878e-03, 9.3753e-04, 8.6523e-04, 9.9556e-01, 4.4625e-04]],
       device='mps:0')
tensor([[7.9927e-04, 5.7211e-04, 6.9417e-04, 9.9757e-01, 3.6709e-04]],
       device='mps:0')
tensor([[1.7711e-03, 9.9122e-04, 6.1288e-04, 4.6427e-04, 9.9616e-01]],
       device='mps:0')
tensor([[9.9134e-01, 1.7553e-03, 5.7286e-03, 5.9140e-04, 5.8600e-04]],
       device='mps:0')
tensor([[1.4701e-03, 1.1032e-03, 5.8711e-04, 4.7425e-04, 9.9

tensor([[1.3704e-03, 8.1202e-04, 7.4329e-04, 4.1762e-04, 9.9666e-01]],
       device='mps:0')
tensor([[3.8021e-03, 9.2740e-05, 9.9538e-01, 4.0332e-04, 3.2675e-04]],
       device='mps:0')
tensor([[3.5654e-03, 8.7947e-05, 9.9550e-01, 3.2064e-04, 5.2341e-04]],
       device='mps:0')
tensor([[9.9219e-01, 1.6254e-03, 4.9765e-03, 6.9149e-04, 5.1497e-04]],
       device='mps:0')
tensor([[3.5996e-03, 8.8900e-05, 9.9554e-01, 3.3969e-04, 4.3222e-04]],
       device='mps:0')
tensor([[9.4805e-01, 6.4766e-04, 5.0560e-02, 3.9618e-04, 3.4227e-04]],
       device='mps:0')
tensor([[2.8684e-03, 9.9467e-01, 3.8048e-04, 8.4828e-04, 1.2344e-03]],
       device='mps:0')
tensor([[3.6289e-03, 9.0688e-05, 9.9537e-01, 3.2304e-04, 5.9005e-04]],
       device='mps:0')
tensor([[4.7852e-03, 1.4342e-03, 1.3616e-03, 5.6735e-04, 9.9185e-01]],
       device='mps:0')
tensor([[3.5920e-03, 9.5409e-05, 9.9549e-01, 3.7101e-04, 4.5361e-04]],
       device='mps:0')
tensor([[3.3790e-03, 8.5864e-05, 9.9578e-01, 3.4429e-04, 4.0

tensor([[9.8858e-01, 8.8666e-04, 9.6829e-03, 4.6010e-04, 3.8835e-04]],
       device='mps:0')
tensor([[3.5705e-03, 8.4423e-05, 9.9554e-01, 3.1515e-04, 4.8722e-04]],
       device='mps:0')
tensor([[7.5712e-04, 5.1883e-04, 7.0751e-04, 9.9763e-01, 3.9115e-04]],
       device='mps:0')
tensor([[3.3420e-03, 8.0534e-05, 9.9589e-01, 3.3346e-04, 3.5821e-04]],
       device='mps:0')
tensor([[7.0151e-03, 1.3881e-04, 9.9161e-01, 5.1346e-04, 7.2119e-04]],
       device='mps:0')
tensor([[3.5925e-03, 8.9874e-05, 9.9541e-01, 3.2162e-04, 5.8377e-04]],
       device='mps:0')
tensor([[3.4410e-03, 8.5457e-05, 9.9573e-01, 3.1842e-04, 4.2276e-04]],
       device='mps:0')
tensor([[9.0814e-04, 7.6811e-04, 6.3062e-04, 9.9730e-01, 3.8914e-04]],
       device='mps:0')
tensor([[5.0942e-03, 1.2840e-03, 1.5824e-03, 5.3576e-04, 9.9150e-01]],
       device='mps:0')
tensor([[9.8860e-01, 1.7735e-03, 8.1981e-03, 6.7727e-04, 7.5004e-04]],
       device='mps:0')
tensor([[7.4989e-04, 4.9874e-04, 7.0498e-04, 9.9768e-01, 3.7

tensor([[3.6237e-03, 9.0574e-05, 9.9537e-01, 3.2488e-04, 5.9428e-04]],
       device='mps:0')
tensor([[9.9241e-01, 1.3857e-03, 5.0573e-03, 6.2164e-04, 5.2119e-04]],
       device='mps:0')
tensor([[9.9313e-01, 1.3128e-03, 4.4764e-03, 6.7895e-04, 3.9718e-04]],
       device='mps:0')
tensor([[8.2234e-04, 4.2692e-04, 9.2850e-04, 9.9746e-01, 3.5857e-04]],
       device='mps:0')
tensor([[3.5474e-03, 8.9760e-05, 9.9553e-01, 3.3427e-04, 5.0142e-04]],
       device='mps:0')
tensor([[1.2765e-03, 9.3071e-04, 7.4421e-04, 9.9669e-01, 3.5735e-04]],
       device='mps:0')
tensor([[9.9026e-01, 2.2883e-03, 5.9437e-03, 7.2707e-04, 7.8128e-04]],
       device='mps:0')
tensor([[1.1840e-03, 9.7132e-04, 5.5274e-04, 9.9687e-01, 4.1916e-04]],
       device='mps:0')
tensor([[8.3977e-03, 1.6356e-04, 9.9059e-01, 5.5118e-04, 2.9252e-04]],
       device='mps:0')
tensor([[9.9449e-03, 2.0294e-04, 9.8876e-01, 6.7620e-04, 4.1271e-04]],
       device='mps:0')
tensor([[9.9278e-01, 1.0120e-03, 5.1937e-03, 6.7067e-04, 3.4

tensor([[2.9310e-03, 9.9462e-01, 3.8263e-04, 9.3244e-04, 1.1346e-03]],
       device='mps:0')
tensor([[9.9313e-01, 1.3694e-03, 4.3285e-03, 6.7142e-04, 5.0344e-04]],
       device='mps:0')
tensor([[7.7348e-04, 5.8806e-04, 7.2212e-04, 9.9756e-01, 3.5329e-04]],
       device='mps:0')
tensor([[9.9236e-01, 1.9913e-03, 4.5287e-03, 5.8318e-04, 5.3739e-04]],
       device='mps:0')
tensor([[3.4937e-03, 9.0101e-05, 9.9558e-01, 3.2218e-04, 5.1487e-04]],
       device='mps:0')
tensor([[0.0250, 0.0159, 0.0032, 0.9490, 0.0069]], device='mps:0')
tensor([[5.0268e-03, 9.7150e-05, 9.9421e-01, 3.2089e-04, 3.4045e-04]],
       device='mps:0')
tensor([[9.9164e-01, 1.0619e-03, 6.4027e-03, 5.1712e-04, 3.8191e-04]],
       device='mps:0')
tensor([[9.0449e-04, 5.5717e-04, 7.7345e-04, 9.9742e-01, 3.4271e-04]],
       device='mps:0')
tensor([[1.3381e-03, 6.9110e-04, 9.4729e-04, 9.9658e-01, 4.3945e-04]],
       device='mps:0')
tensor([[4.7249e-03, 1.1996e-03, 1.6026e-03, 5.2390e-04, 9.9195e-01]],
       device='m

tensor([[8.8284e-04, 8.1790e-04, 6.2712e-04, 9.9732e-01, 3.5219e-04]],
       device='mps:0')
tensor([[8.0887e-04, 5.7663e-04, 6.9589e-04, 9.9753e-01, 3.8531e-04]],
       device='mps:0')
tensor([[4.7486e-03, 9.2686e-05, 9.9446e-01, 3.5931e-04, 3.4253e-04]],
       device='mps:0')
tensor([[1.1014e-03, 5.7678e-04, 8.0653e-04, 9.9699e-01, 5.2075e-04]],
       device='mps:0')
tensor([[9.9302e-01, 1.3974e-03, 4.5020e-03, 6.9572e-04, 3.8090e-04]],
       device='mps:0')
tensor([[6.5769e-04, 5.3890e-04, 6.6188e-04, 9.9775e-01, 3.8671e-04]],
       device='mps:0')
tensor([[2.9698e-03, 9.9462e-01, 3.6568e-04, 8.5447e-04, 1.1905e-03]],
       device='mps:0')
tensor([[3.5583e-03, 8.8357e-05, 9.9555e-01, 3.2407e-04, 4.7692e-04]],
       device='mps:0')
tensor([[1.6719e-03, 1.2340e-03, 6.3044e-04, 5.1804e-04, 9.9595e-01]],
       device='mps:0')
tensor([[1.3222e-03, 7.7768e-04, 7.5110e-04, 4.4044e-04, 9.9671e-01]],
       device='mps:0')
tensor([[1.2972e-03, 7.4839e-04, 7.9724e-04, 9.9675e-01, 4.0

tensor([[5.5476e-03, 1.0606e-04, 9.9361e-01, 4.7040e-04, 2.6465e-04]],
       device='mps:0')
tensor([[8.4550e-04, 7.3133e-04, 8.7244e-04, 9.9714e-01, 4.0830e-04]],
       device='mps:0')
tensor([[3.5669e-03, 9.3895e-05, 9.9551e-01, 3.9158e-04, 4.3699e-04]],
       device='mps:0')
tensor([[1.1854e-03, 8.1737e-04, 6.8398e-04, 4.3929e-04, 9.9687e-01]],
       device='mps:0')
tensor([[3.6840e-03, 8.6702e-05, 9.9549e-01, 3.3394e-04, 4.0899e-04]],
       device='mps:0')
tensor([[9.1457e-04, 5.3907e-04, 7.8976e-04, 9.9737e-01, 3.9098e-04]],
       device='mps:0')
tensor([[9.9278e-01, 1.8437e-03, 4.0798e-03, 7.5093e-04, 5.4658e-04]],
       device='mps:0')
tensor([[9.9229e-01, 2.3733e-03, 4.2098e-03, 6.9891e-04, 4.2355e-04]],
       device='mps:0')
tensor([[3.4064e-03, 8.6298e-05, 9.9572e-01, 3.3408e-04, 4.5658e-04]],
       device='mps:0')
tensor([[1.4033e-02, 1.8508e-03, 1.5456e-03, 6.7228e-04, 9.8190e-01]],
       device='mps:0')
tensor([[1.2052e-03, 1.0441e-03, 8.7639e-04, 9.9647e-01, 4.0

tensor([[4.0935e-03, 8.6167e-05, 9.9504e-01, 3.0789e-04, 4.6821e-04]],
       device='mps:0')
tensor([[0.9833, 0.0066, 0.0067, 0.0011, 0.0023]], device='mps:0')
tensor([[3.1722e-03, 9.9367e-01, 5.1223e-04, 1.1146e-03, 1.5315e-03]],
       device='mps:0')
tensor([[1.2860e-03, 9.3319e-04, 6.3145e-04, 5.4351e-04, 9.9661e-01]],
       device='mps:0')
tensor([[3.5632e-03, 8.5283e-05, 9.9554e-01, 3.1587e-04, 5.0004e-04]],
       device='mps:0')
tensor([[9.9266e-01, 1.2353e-03, 5.0530e-03, 6.6207e-04, 3.9050e-04]],
       device='mps:0')
tensor([[9.8941e-01, 2.3942e-03, 6.5336e-03, 7.8524e-04, 8.7352e-04]],
       device='mps:0')
tensor([[9.7421e-04, 1.0296e-03, 7.7433e-04, 9.9678e-01, 4.4161e-04]],
       device='mps:0')
tensor([[9.9073e-01, 2.0696e-03, 5.5435e-03, 7.8789e-04, 8.7130e-04]],
       device='mps:0')
tensor([[7.9601e-04, 5.8496e-04, 6.5628e-04, 9.9759e-01, 3.7264e-04]],
       device='mps:0')
tensor([[9.9260e-01, 2.2710e-03, 3.9242e-03, 8.1176e-04, 3.9738e-04]],
       device='m

tensor([[1.7693e-03, 1.0212e-03, 6.4784e-04, 4.6406e-04, 9.9610e-01]],
       device='mps:0')
tensor([[2.0153e-03, 1.1656e-03, 1.0554e-03, 9.9540e-01, 3.6400e-04]],
       device='mps:0')
tensor([[9.9279e-01, 1.7084e-03, 4.2665e-03, 7.9061e-04, 4.4027e-04]],
       device='mps:0')
tensor([[1.3139e-03, 5.8723e-04, 8.5435e-04, 9.9691e-01, 3.2959e-04]],
       device='mps:0')
tensor([[1.4208e-03, 7.7483e-04, 7.8198e-04, 4.0919e-04, 9.9661e-01]],
       device='mps:0')
tensor([[9.9217e-01, 1.7052e-03, 5.0248e-03, 6.8078e-04, 4.2170e-04]],
       device='mps:0')
tensor([[0.9713, 0.0063, 0.0173, 0.0018, 0.0032]], device='mps:0')
tensor([[3.5538e-03, 9.0888e-05, 9.9545e-01, 3.3240e-04, 5.7013e-04]],
       device='mps:0')
tensor([[3.3246e-03, 8.4027e-05, 9.9591e-01, 3.4269e-04, 3.3876e-04]],
       device='mps:0')
tensor([[1.9572e-03, 8.5055e-04, 8.3317e-04, 4.6757e-04, 9.9589e-01]],
       device='mps:0')
tensor([[5.1103e-03, 1.2397e-04, 9.9396e-01, 4.8151e-04, 3.2548e-04]],
       device='m

tensor([[3.5208e-03, 8.5752e-05, 9.9559e-01, 3.2606e-04, 4.8064e-04]],
       device='mps:0')
tensor([[3.5339e-03, 9.9393e-01, 3.5101e-04, 1.0941e-03, 1.0954e-03]],
       device='mps:0')
tensor([[1.7135e-03, 1.1839e-03, 6.5189e-04, 5.4103e-04, 9.9591e-01]],
       device='mps:0')
tensor([[2.7743e-03, 9.9473e-01, 3.9375e-04, 9.2450e-04, 1.1746e-03]],
       device='mps:0')
tensor([[9.9331e-01, 1.3055e-03, 4.1774e-03, 7.8394e-04, 4.2562e-04]],
       device='mps:0')
tensor([[1.0038e-03, 7.8051e-04, 6.8059e-04, 9.9723e-01, 3.0069e-04]],
       device='mps:0')
tensor([[1.6204e-03, 3.8589e-03, 6.5702e-04, 9.9332e-01, 5.4180e-04]],
       device='mps:0')
tensor([[9.9196e-01, 9.6123e-04, 6.0842e-03, 6.5989e-04, 3.3513e-04]],
       device='mps:0')
tensor([[9.8990e-01, 2.5061e-03, 5.9008e-03, 7.3260e-04, 9.5571e-04]],
       device='mps:0')
tensor([[1.9620e-03, 1.0503e-03, 7.7763e-04, 9.9588e-01, 3.3396e-04]],
       device='mps:0')
tensor([[3.5029e-03, 9.0085e-05, 9.9555e-01, 3.2738e-04, 5.2

tensor([[9.3853e-04, 5.3250e-04, 7.3996e-04, 9.9743e-01, 3.5468e-04]],
       device='mps:0')
tensor([[9.9278e-01, 1.5399e-03, 4.4586e-03, 7.0592e-04, 5.1874e-04]],
       device='mps:0')
tensor([[8.6685e-04, 7.0181e-04, 8.5288e-04, 9.9718e-01, 3.9797e-04]],
       device='mps:0')
tensor([[0.3530, 0.0178, 0.0150, 0.0115, 0.6027]], device='mps:0')
tensor([[8.8947e-03, 2.0974e-03, 8.5759e-04, 5.8263e-04, 9.8757e-01]],
       device='mps:0')
tensor([[1.3372e-03, 8.1161e-04, 7.6902e-04, 4.3427e-04, 9.9665e-01]],
       device='mps:0')
tensor([[8.8196e-04, 6.6911e-04, 6.8344e-04, 9.9739e-01, 3.7339e-04]],
       device='mps:0')
tensor([[6.8625e-02, 3.7236e-04, 9.2985e-01, 7.3140e-04, 4.2495e-04]],
       device='mps:0')
tensor([[9.9204e-01, 1.5974e-03, 5.1650e-03, 5.7867e-04, 6.1459e-04]],
       device='mps:0')
tensor([[9.9203e-01, 1.4075e-03, 5.5403e-03, 5.1757e-04, 5.0695e-04]],
       device='mps:0')
tensor([[4.3752e-03, 9.8498e-05, 9.9487e-01, 3.7697e-04, 2.7551e-04]],
       device='m

[{}]

In [58]:
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support


# example target and output lists
targets = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
outputs = [0, 1, 1, 3, 4, 0, 2, 2, 3, 4]

def classification_metrics(targets, outputs):
    # compute confusion matrix
    cm = confusion_matrix(targets, outputs)
    
    # compute total number of samples for each class
    total_per_class = np.sum(cm, axis=1)
    
    # compute number of correctly classified samples for each class
    correct_per_class = np.diagonal(cm)
    
    # compute precision, recall, and f1 score for each class
    p, r, f1, _ = precision_recall_fscore_support(targets, outputs, average=None)
    
    # compute accuracy for each class
    accuracy_per_class = np.divide(correct_per_class, total_per_class, where=total_per_class!=0)
    
    # create a dataframe to hold the results
    df = pd.DataFrame({
        'precision': p,
        'recall': r,
        'f1': f1,
        'accuracy': accuracy_per_class,
        'num_samples': total_per_class
    })
    
    print(df)
    print(cm)
    #return df

In [59]:
classification_metrics(targets, outputs)

   precision  recall   f1  accuracy  num_samples
0        1.0     1.0  1.0       1.0            2
1        0.5     0.5  0.5       0.5            2
2        0.5     0.5  0.5       0.5            2
3        1.0     1.0  1.0       1.0            2
4        1.0     1.0  1.0       1.0            2
[[2 0 0 0 0]
 [0 1 1 0 0]
 [0 1 1 0 0]
 [0 0 0 2 0]
 [0 0 0 0 2]]


In [26]:
from sklearn.metrics import confusion_matrix, classification_report

# Define your target and output lists
targets = [0, 1, 2, 3, 4]
outputs = [1, 1, 2, 3, 4]

# Create a confusion matrix
cm = confusion_matrix(targets, outputs)

# Print the confusion matrix
print("Confusion Matrix:\n", cm)

# Calculate classification report
report = classification_report(targets, outputs)

# Print classification report
print("Classification Report:\n", report)

Confusion Matrix:
 [[0 1 0 0 0]
 [0 1 0 0 0]
 [0 0 1 0 0]
 [0 0 0 1 0]
 [0 0 0 0 1]]
Classification Report:
               precision    recall  f1-score   support

           0       0.00      0.00      0.00         1
           1       0.50      1.00      0.67         1
           2       1.00      1.00      1.00         1
           3       1.00      1.00      1.00         1
           4       1.00      1.00      1.00         1

    accuracy                           0.80         5
   macro avg       0.70      0.80      0.73         5
weighted avg       0.70      0.80      0.73         5



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [19]:
df.groupby('target')

<pandas.core.groupby.generic.DataFrameGroupBy object at 0x104ac80a0>

## Testing and predicting

In [15]:
#change for best one - manually check which one is the best
best_checkpoint_path = '/Users/pierredemetz/UCL_work/COMP0082-CW/code/experiments/lightning_logs/20-02-2023--21-58-19/checkpoints/epoch=1-step=14366.ckpt'




In [17]:
trainer = Trainer(resume_from_checkpoint=best_checkpoint_path)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [53]:
outputs = trainer.predict(model, dm)
results = []
for item in outputs:
    tensor = item[1]
    max_prob, max_target_idx = torch.max(tensor, dim=1)
    max_target = TARGETS[max_target_idx]
    results.append((max_prob.item(), max_target))

print(results)

Predicting: 20it [00:00, ?it/s]

  input = module(input)


[(0.23777472972869873, 'secreted'), (0.2390023022890091, 'secreted'), (0.23801231384277344, 'secreted'), (0.23591041564941406, 'secreted'), (0.23733216524124146, 'secreted'), (0.23813442885875702, 'secreted'), (0.23768913745880127, 'secreted'), (0.23749719560146332, 'secreted'), (0.23838873207569122, 'secreted'), (0.23754209280014038, 'secreted'), (0.2393353134393692, 'secreted'), (0.23688319325447083, 'secreted'), (0.23795557022094727, 'secreted'), (0.23580197989940643, 'secreted'), (0.23750881850719452, 'secreted'), (0.2364635318517685, 'secreted'), (0.2366548478603363, 'secreted'), (0.2381429672241211, 'secreted'), (0.23695126175880432, 'secreted'), (0.2360747903585434, 'secreted')]


In [54]:
outputs

[(0, tensor([[0.2052, 0.1628, 0.1809, 0.2133, 0.2378]])),
 (0, tensor([[0.2051, 0.1622, 0.1803, 0.2135, 0.2390]])),
 (0, tensor([[0.2053, 0.1629, 0.1811, 0.2128, 0.2380]])),
 (0, tensor([[0.2055, 0.1642, 0.1823, 0.2121, 0.2359]])),
 (0, tensor([[0.2056, 0.1630, 0.1808, 0.2132, 0.2373]])),
 (0, tensor([[0.2053, 0.1627, 0.1808, 0.2131, 0.2381]])),
 (0, tensor([[0.2051, 0.1631, 0.1810, 0.2131, 0.2377]])),
 (0, tensor([[0.2053, 0.1632, 0.1812, 0.2128, 0.2375]])),
 (0, tensor([[0.2051, 0.1627, 0.1810, 0.2128, 0.2384]])),
 (0, tensor([[0.2053, 0.1631, 0.1812, 0.2130, 0.2375]])),
 (0, tensor([[0.2054, 0.1619, 0.1801, 0.2134, 0.2393]])),
 (0, tensor([[0.2050, 0.1639, 0.1820, 0.2122, 0.2369]])),
 (0, tensor([[0.2054, 0.1628, 0.1807, 0.2132, 0.2380]])),
 (0, tensor([[0.2056, 0.1643, 0.1821, 0.2122, 0.2358]])),
 (0, tensor([[0.2053, 0.1630, 0.1810, 0.2132, 0.2375]])),
 (0, tensor([[0.2055, 0.1638, 0.1819, 0.2123, 0.2365]])),
 (0, tensor([[0.2054, 0.1635, 0.1813, 0.2132, 0.2367]])),
 (0, tensor([[

## LEGACY

In [17]:
target_list = ['cyto', 'mito', 'nucleus','other', 'secreted']
n_classes = 5

protein_classifier = ProteinClassifier(n_classes, target_list)
protein_classifier = protein_classifier.load_from_checkpoint(
    checkpoint_path=best_checkpoint_path,
    n_classes=n_classes,
    target_list=target_list
)

protein_classifier.eval()
protein_classifier.freeze()

Some weights of the model checkpoint at Rostlab/prot_bert_bfd were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


FileNotFoundError: [Errno 2] No such file or directory: '/Users/pierredemetz/UCL_work/COMP0082-CW/code/experiments/lightning_logs/18-02-2023--00-13-33/checkpoints/epoch=1-step=4.ckpt'

In [111]:
sample = {
  "seq": "M S T D T G V S L P S Y E E D Q G S K L I R K A K E A P F V P V G I A G F A A I V A Y G L Y K L K S R G N T K M S I H L I H M R V A A Q G F V V G A M T V G M G Y S M Y R E F W A K P K P",
}

predictions = protein_classifier.predict_step(sample, batch_idx=0)

print("Sequence Localization Ground Truth is: {} - prediction is: {}".format('Mitochondrion',predictions['predicted_label']))



TypeError: forward() missing 1 required positional argument: 'attention_mask'

## MISC

In [59]:
import re
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
model = BertModel.from_pretrained("Rostlab/prot_bert")
sequence_Example = "A E T C Z A O"
sequence_Example = re.sub(r"[UZOB]", "X", sequence_Example)
encoded_input = tokenizer(sequence_Example, return_tensors='pt')
output = model(**encoded_input)

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print(x)
else:
    print ("MPS device not found.")

tensor([1.], device='mps:0')


  nonzero_finite_vals = torch.masked_select(


In [60]:
accelerator_registry=torch.device("mps")

In [61]:
accelertorch.backends.mps

NameError: name 'accelertorch' is not defined

In [62]:
MPSAccelerator.register_accelerators(device='mps')

TypeError: register_accelerators() got an unexpected keyword argument 'device'

In [59]:
!pip install tensorflow-metal

Collecting tensorflow-metal
  Downloading tensorflow_metal-0.7.1-cp38-cp38-macosx_12_0_x86_64.whl (1.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Installing collected packages: tensorflow-metal
Successfully installed tensorflow-metal-0.7.1


In [63]:
import transformers

In [66]:
!exit

In [67]:
!arch

i386
