## Imports

In [6]:
from model_fct import ProteinClassifier, ProteinDataModule, ProteinSequenceDataset
import os
import torch
from torch import nn
import torch.utils.data
import torch.utils.data.distributed
from torch.utils.data import Dataset, DataLoader, RandomSampler, TensorDataset
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torchmetrics
from pytorch_lightning.accelerators import MPSAccelerator
from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup
from torchmetrics.classification import MulticlassAUROC, MulticlassAccuracy, MultilabelF1Score
from torchmetrics import Recall, Precision

from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning import Trainer, seed_everything
import datetime
from datetime import datetime
#from pytorch_lightning.metrics.sklearns import Accuracy

import torchvision

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
import platform
platform.processor()

'arm'

In [8]:
train_df = pd.read_pickle('train_df.pkl')
test_df = pd.read_pickle('test_df.pkl')
val_df = pd.read_pickle('val_df.pkl')
blind_df = pd.read_pickle('blind_df.pkl')

## Logger and checkpoint

In [9]:
def setup_testube_logger() -> CSVLogger:
    """ Function that sets the TestTubeLogger to be used. """
    now = datetime.now()
    dt_string = now.strftime("%d-%m-%Y--%H-%M-%S")

    return CSVLogger(
        save_dir="experiments/",
        version=dt_string,
        name="lightning_logs",
    )

logger = setup_testube_logger()

In [10]:
ckpt_path = os.path.join(
    logger.save_dir,
    logger.name,
    f"version_{logger.version}",
    "checkpoints",
)

c = ModelCheckpoint(
    dirpath=ckpt_path + "/" + "tanh_3epochs",
    verbose=True,
    monitor='val_acc',
    mode="max",
)

## Set up experiment

In [11]:
TARGETS = ['cyto', 'mito', 'nucleus','other', 'secreted']
PRE_TRAINED_MODEL_NAME = 'Rostlab/prot_bert_bfd_localization'
#PRE_TRAINED_MODEL_NAME = 'Rostlab/prot_bert_bfd'
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME, do_lower_case=False)

EPOCHS = 3
BATCH_SIZE = 1
MAX_LENGTH = 1500

In [12]:
dm = ProteinDataModule(
    train_df, 
    test_df,
    val_df,
    blind_df,
    tokenizer, 
    target_list=TARGETS,
    batch_size=BATCH_SIZE,
    max_len = MAX_LENGTH
)

model = ProteinClassifier(
    n_classes=5,
    target_list=TARGETS,
    steps_per_epoch=len(train_df)//BATCH_SIZE, 
    n_epochs=EPOCHS
)

Some weights of the model checkpoint at Rostlab/prot_bert_bfd were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [13]:
trainer = pl.Trainer(max_epochs=EPOCHS,
                     logger=logger,
                     accelerator='mps',
                     #callbacks = checkpoint_callback
                     default_root_dir='experiments/lightning_logs'
                    )

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, dm)


  | Name       | Type             | Params
------------------------------------------------
0 | bert       | BertModel        | 419 M 
1 | classifier | Sequential       | 5.1 K 
2 | criterion  | CrossEntropyLoss | 0     
------------------------------------------------
419 M     Trainable params
0         Non-trainable params
419 M     Total params
1,679.745 Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  input = module(input)


the accuracy is 0.00
the precision is 0.00
the recall is 0.00
the f1 is 0.00
   precision  recall   f1  accuracy  num_samples
0        0.0     0.0  0.0       0.0            1
1        0.0     0.0  0.0       1.0            0
2        0.0     0.0  0.0       0.0            1
[[0 1 0]
 [0 0 0]
 [0 1 0]]


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Training: 0it [00:00, ?it/s]

  input = module(input)


In [None]:
trainer.test(dataloaders=dm)

In [None]:
TARGETS

In [None]:
import numpy as np
import matplotlib.pyplot as plt
# Define the given matrix
classes = TARGETS
matrix = np.array([[475, 21, 94, 8, 8],
                   [46, 203, 7, 6, 1],
                   [153, 7, 473, 6, 3],
                   [18, 18, 3, 373, 1],
                   [23, 3, 1, 5, 289]])

# Initialize the confusion matrix with zeros
confusion_matrix = np.zeros((len(classes), len(classes)))

# Fill the confusion matrix with values from the given matrix
for i in range(len(classes)):
    for j in range(len(classes)):
        confusion_matrix[i, j] = matrix[i, j]

# Normalize the confusion matrix if needed
#confusion_matrix = confusion_matrix.astype('float') / confusion_matrix.sum(axis=1)[:, np.newaxis]

# Plot the confusion matrix
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.matshow(confusion_matrix, cmap=plt.cm.Oranges)

# Set labels for x and y axis
ax.set_xticklabels([''] + classes, fontsize=12)
ax.set_yticklabels([''] + classes, fontsize=12)

# Set the title and axis labels
ax.set_title('Confusion Matrix', fontsize=16)
ax.set_xlabel('Predicted Label', fontsize=14)
ax.set_ylabel('True Label', fontsize=14)

# Add colorbar
cbar = ax.figure.colorbar(im, ax=ax)
cbar.ax.tick_params(labelsize=12)

# Add text annotations to the confusion matrix
for i in range(len(classes)):
    for j in range(len(classes)):
        ax.text(j, i, int(confusion_matrix[i, j]), ha='center', va='center', color='white', fontsize=12)

# Show the plot
#plt.show()
plt.savefig('../Confusion_Matrix_Testing_Set.png', dpi=500, bbox_inches='tight')

In [None]:
import matplotlib.pyplot as plt


In [None]:
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support


# example target and output lists
targets = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
outputs = [0, 1, 1, 3, 4, 0, 2, 2, 3, 4]

def classification_metrics(targets, outputs):
    # compute confusion matrix
    cm = confusion_matrix(targets, outputs)
    
    # compute total number of samples for each class
    total_per_class = np.sum(cm, axis=1)
    
    # compute number of correctly classified samples for each class
    correct_per_class = np.diagonal(cm)
    
    # compute precision, recall, and f1 score for each class
    p, r, f1, _ = precision_recall_fscore_support(targets, outputs, average=None)
    
    # compute accuracy for each class
    accuracy_per_class = np.divide(correct_per_class, total_per_class, where=total_per_class!=0)
    
    # create a dataframe to hold the results
    df = pd.DataFrame({
        'precision': p,
        'recall': r,
        'f1': f1,
        'accuracy': accuracy_per_class,
        'num_samples': total_per_class
    })
    
    print(df)
    print(cm)
    #return df

In [None]:
classification_metrics(targets, outputs)

In [None]:
from sklearn.metrics import confusion_matrix, classification_report

# Define your target and output lists
targets = [0, 1, 2, 3, 4]
outputs = [1, 1, 2, 3, 4]

# Create a confusion matrix
cm = confusion_matrix(targets, outputs)

# Print the confusion matrix
print("Confusion Matrix:\n", cm)

# Calculate classification report
report = classification_report(targets, outputs)

# Print classification report
print("Classification Report:\n", report)

In [None]:
df.groupby('target')

## Testing and predicting

In [None]:
#change for best one - manually check which one is the best
best_checkpoint_path = '/Users/pierredemetz/UCL_work/COMP0082-CW/code/experiments/lightning_logs/20-02-2023--21-58-19/checkpoints/epoch=1-step=14366.ckpt'




In [None]:
trainer = Trainer(resume_from_checkpoint=best_checkpoint_path)

In [None]:
outputs = trainer.predict(model, dm)
results = []
for item in outputs:
    tensor = item[1]
    max_prob, max_target_idx = torch.max(tensor, dim=1)
    max_target = TARGETS[max_target_idx]
    results.append((max_prob.item(), max_target))

print(results)

In [None]:
outputs

## LEGACY

In [None]:
target_list = ['cyto', 'mito', 'nucleus','other', 'secreted']
n_classes = 5

protein_classifier = ProteinClassifier(n_classes, target_list)
protein_classifier = protein_classifier.load_from_checkpoint(
    checkpoint_path=best_checkpoint_path,
    n_classes=n_classes,
    target_list=target_list
)

protein_classifier.eval()
protein_classifier.freeze()

In [None]:
sample = {
  "seq": "M S T D T G V S L P S Y E E D Q G S K L I R K A K E A P F V P V G I A G F A A I V A Y G L Y K L K S R G N T K M S I H L I H M R V A A Q G F V V G A M T V G M G Y S M Y R E F W A K P K P",
}

predictions = protein_classifier.predict_step(sample, batch_idx=0)

print("Sequence Localization Ground Truth is: {} - prediction is: {}".format('Mitochondrion',predictions['predicted_label']))



## MISC

In [None]:
import re
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False )
model = BertModel.from_pretrained("Rostlab/prot_bert")
sequence_Example = "A E T C Z A O"
sequence_Example = re.sub(r"[UZOB]", "X", sequence_Example)
encoded_input = tokenizer(sequence_Example, return_tensors='pt')
output = model(**encoded_input)

In [None]:
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print(x)
else:
    print ("MPS device not found.")

In [None]:
accelerator_registry=torch.device("mps")

In [None]:
accelertorch.backends.mps

In [None]:
MPSAccelerator.register_accelerators(device='mps')

In [None]:
!pip install tensorflow-metal

In [None]:
import transformers

In [None]:
!exit

In [None]:
!arch