# Employ transfer learning with new LMs for IDR prediction
Dataset from [Disprot](https://www.disprot.org/download) (actually [older version with annotation](https://idpcentral.org/caid/data/1/reference/disprot-disorder.txt)). Methods used from [ProtTrans](https://github.com/agemagician/ProtTrans).

Based on [PytorchLightning implementation](https://github.com/agemagician/ProtTrans/blob/master/Fine-Tuning/ProtBert-BFD-FineTuning-PyTorchLightning-MS.ipynb).

In [1]:
!nvidia-smi

Mon May 16 19:07:00 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  On   | 00000000:0A:00.0 Off |                  N/A |
| 33%   49C    P8    11W / 250W |      6MiB / 11178MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  On   | 00000000:0B:00.0 Off |                  N/A |
| 27%   49C    P8    12W / 250W |      6MiB / 11178MiB |      0%      Defaul

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, RandomSampler
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, pad_sequence

import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning import Trainer, seed_everything
from torchmetrics import Accuracy

from transformers import T5EncoderModel, T5Tokenizer
from transformers import BertModel, BertTokenizer
from transformers import XLNetModel, XLNetTokenizer
from transformers import AlbertModel, AlbertTokenizer

from torchnlp.encoders import LabelEncoder
from torchnlp.datasets.dataset import Dataset
from torchnlp.utils import collate_tensors

from test_tube import HyperOptArgumentParser
import os
import re
import gc
from datetime import datetime
import logging as log
import glob

In [3]:
torch.cuda.is_available()

True

In [None]:
torch.cuda.device_count()

In [5]:
# Select the model
model_name = "Rostlab/prot_t5_xl_uniref50"

## Main Training

Use `disorder/train.py`.

## Predict new sequence

Use `disorder/predict.py`.

### Ensemble model
Using three different training instances of the same model as an ensemble, I got the following metrics:

| Validation | BAC   | F1    | MCC   |
|------------|-------|-------|-------|
| Ensemble   | 0.753 | 0.633 | 0.524 |

### Test results
#### 30. June
I ran one of the optimized versions on the test set and got the following results:

|                | BAC       | F1        | MCC       |
|----------------|-----------|-----------|-----------|
| Ours           | 0.715     | **0.616** | **0.394** |
| fIDPnn         | **0.720** | 0.483     | 0.370     |
| SPOT-Disorder2 | 0.725     | 0.469     | 0.349     |