In [1]:
!pip install lmdb
!pip install transformers
!pip install pytorch-crf
!pip install evaluate
!pip uninstall -y transformers accelerate
!pip install transformers accelerate
!pip install pandas

Collecting lmdb
  Downloading lmdb-1.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (299 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m299.2/299.2 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lmdb
Successfully installed lmdb-1.4.1
[0mCollecting pytorch-crf
  Downloading pytorch_crf-0.7.2-py3-none-any.whl (9.5 kB)
Installing collected packages: pytorch-crf
Successfully installed pytorch-crf-0.7.2
[0mCollecting evaluate
  Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: evaluate
Successfully installed evaluate-0.4.0
[0mFound existing installation: transformers 4.30.1
Uninstalling transformers-4.30.1:
  Successfully uninstalled transformers-4.30.1
Found existing installation: accelerate 0.12.0
Uninstalling accelerate-0.12.0:
  Successfully uninstalled accelerat

In [2]:
import lmdb
from typing import Union, List, Tuple, Sequence, Dict, Any, Optional, Collection
from torch.utils.data import Dataset
from pathlib import Path
import pickle as pkl

class LMDBDataset(Dataset):
    """Creates a dataset from an lmdb file.
    Args:
        data_file (Union[str, Path]): Path to lmdb file.
        in_memory (bool, optional): Whether to load the full dataset into memory.
            Default: False.
    """

    def __init__(self,
                 data_file: Union[str, Path],
                 in_memory: bool = False):

        data_file = Path(data_file)
        if not data_file.exists():
            raise FileNotFoundError(data_file)

        env = lmdb.open(str(data_file), max_readers=1, readonly=True,
                        lock=False, readahead=False, meminit=False)

        with env.begin(write=False) as txn:
            num_examples = pkl.loads(txn.get(b'num_examples'))

        if in_memory:
            cache = [None] * num_examples
            self._cache = cache

        self._env = env
        self._in_memory = in_memory
        self._num_examples = num_examples

    def __len__(self) -> int:
        return self._num_examples

    def __getitem__(self, index: int):
        if not 0 <= index < self._num_examples:
            raise IndexError(index)

        if self._in_memory and self._cache[index] is not None:
            item = self._cache[index]
        else:
            with self._env.begin(write=False) as txn:
                item = pkl.loads(txn.get(str(index).encode()))
                if 'id' not in item:
                    item['id'] = str(index)
                if self._in_memory:
                    self._cache[index] = item
        return item

In [3]:
def dataset_factory(data_file: Union[str, Path], *args, **kwargs) -> Dataset:
    data_file = Path(data_file)
    if not data_file.exists():
        raise FileNotFoundError(data_file)
    if data_file.suffix == '.lmdb':
        return LMDBDataset(data_file, *args, **kwargs)
    else:
        raise ValueError(f"Unrecognized datafile type {data_file.suffix}")

In [4]:
import torch
import numpy as np
class SecondaryStructureDataset(Dataset):

    def __init__(self,
                 data_path: Union[str, Path],
                 split: str,
                 in_memory: bool = False):

        if split not in ('train', 'valid', 'casp12', 'ts115', 'cb513','casp14_32'):
            raise ValueError(f"Unrecognized split: {split}. Must be one of "
                             f"['train', 'valid', 'casp12', "
                             f"'ts115', 'cb513']")

        data_path = Path(data_path)
        data_file = f'secondary_structure/secondary_structure_{split}.lmdb'
        self.data = dataset_factory(data_path / data_file, in_memory)

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, index: int):
        item = self.data[index]
        token_ids = item['primary']
        input_mask = np.ones_like(token_ids)

        # pad with -1s because of cls/sep tokens
        psi = np.asarray(item['psi'], np.int64)
        phi = np.asarray(item['phi'], np.int64)
        ss3 = np.asarray(item['ss3'], np.int64)
        ss8 = np.asarray(item['ss8'], np.int64)
        output = {'input_ids': token_ids,
                'attention_mask': input_mask,
                'labels': [psi,phi,ss3,ss8]}
        return output

In [5]:
class SecondaryStructureDataset_cb513(Dataset):

    def __init__(self,
                 data_path: Union[str, Path],
                 split: str,
                 in_memory: bool = False):

        if split not in ('train', 'valid', 'casp12', 'ts115', 'cb513','casp14_32'):
            raise ValueError(f"Unrecognized split: {split}. Must be one of "
                             f"['train', 'valid', 'casp12', "
                             f"'ts115', 'cb513']")

        data_path = Path(data_path)
        data_file = f'secondary_structure/secondary_structure_{split}.lmdb'
        self.data = dataset_factory(data_path / data_file, in_memory)

    def __len__(self) -> int:
        return 496

    def __getitem__(self, index: int):
        item = self.data[index]
        token_ids = item['primary']
        input_mask = np.ones_like(token_ids)

        # pad with -1s because of cls/sep tokens
        psi = np.asarray(item['psi'], np.int64)
        phi = np.asarray(item['phi'], np.int64)
        ss3 = np.asarray(item['ss3'], np.int64)
        ss8 = np.asarray(item['ss8'], np.int64)
        output = {'input_ids': token_ids,
                'attention_mask': input_mask,
                'labels': [psi,phi,ss3,ss8]}
        return output

In [6]:
data_dir = r'/kaggle/input/dataset'
train_dataset = SecondaryStructureDataset(data_dir,'train')
test_ts115 = SecondaryStructureDataset(data_dir,'ts115')
test_casp12 = SecondaryStructureDataset(data_dir,'casp12')
test_cb513 = SecondaryStructureDataset_cb513(data_dir,'cb513')
test_casp14_32 = SecondaryStructureDataset(data_dir,'casp14_32')

In [7]:
train_sequences=[]
train_labels=[]

test_ts115_sequences=[]
test_ts115_labels=[]

test_casp12_sequences=[]
test_casp12_labels=[]

test_cb513_sequences=[]
test_cb513_labels=[]

test_casp14_32_sequences=[]
test_casp14_32_labels=[]
for seq in range(train_dataset.__len__()):
    train_sequences.append(train_dataset[seq]['input_ids'])
    train_labels.append(train_dataset[seq]['labels'])
for seq in range(test_ts115.__len__()):
    test_ts115_sequences.append(test_ts115[seq]['input_ids'])
    test_ts115_labels.append(test_ts115[seq]['labels'])
for seq in range(test_casp12.__len__()):
    test_casp12_sequences.append(test_casp12[seq]['input_ids'])
    test_casp12_labels.append(test_casp12[seq]['labels'])
for seq in range(test_cb513.__len__()):
    test_cb513_sequences.append(test_cb513[seq]['input_ids'])
    test_cb513_labels.append(test_cb513[seq]['labels'])
for seq in range(test_casp14_32.__len__()):
    test_casp14_32_sequences.append(test_casp14_32[seq]['input_ids'])
    test_casp14_32_labels.append(test_casp14_32[seq]['labels'])

In [8]:
check_point = r'/kaggle/input/pretrain-t33/esm2_t33_650M_UR50D'

In [9]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(check_point)

train_tokenized = tokenizer(train_sequences)
test_ts115_tokenized = tokenizer(test_ts115_sequences)
test_casp12_tokenized = tokenizer(test_casp12_sequences)
test_cb513_tokenized = tokenizer(test_cb513_sequences)
test_casp14_32_tokenized = tokenizer(test_casp14_32_sequences)

In [10]:
from datasets import Dataset

train_dataset = Dataset.from_dict(train_tokenized)
ts115_dataset = Dataset.from_dict(test_ts115_tokenized)
casp12_dataset = Dataset.from_dict(test_casp12_tokenized)
cb513_dataset = Dataset.from_dict(test_cb513_tokenized)
asp14_32_dataset = Dataset.from_dict(test_casp14_32_tokenized)

train_dataset = train_dataset.add_column("labels", train_labels)
ts115_dataset = ts115_dataset.add_column("labels", test_ts115_labels)
casp12_dataset = casp12_dataset.add_column("labels", test_casp12_labels)
cb513_dataset = cb513_dataset.add_column("labels", test_cb513_labels)
asp14_32_dataset = asp14_32_dataset.add_column("labels", test_casp14_32_labels)

In [11]:
from transformers import DataCollatorForTokenClassification
class DataCollatorForTokenClassificationMultiLabel(DataCollatorForTokenClassification):
    def torch_call(self, features):
        # logger.info(f"features:{features}")
        import torch

        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
        # logger.info(f'labels:{labels}')
        no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
        # logger.info(f'no_labels_features:{no_labels_features}')
        batch = self.tokenizer.pad(
            no_labels_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        if labels is None:
            return batch

        sequence_length = batch["input_ids"].shape[1]
        padding_side = self.tokenizer.padding_side

        def to_list(tensor_or_iterable):
            if isinstance(tensor_or_iterable, torch.Tensor):
                return tensor_or_iterable.tolist()
            return list(tensor_or_iterable)
        all_features = []
        if padding_side == "right":
            for label in labels:
                single_features=[]
                for k, onelabel in enumerate(label):
                    if k<2:
                        single_features.append([to_list(onelabel) + [360] * (sequence_length - len(onelabel))])
                    else:
                        single_features.append([to_list(onelabel) + [-1] * (sequence_length - len(onelabel))])
                all_features.append(single_features)
            batch[label_name] = all_features
            # batch[label_name] = [
            #     to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
            # ]
        else:
            batch[label_name] = [
                [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
            ]
        batch[label_name] = torch.tensor(batch[label_name]).squeeze()
        
        return batch
data_collator = DataCollatorForTokenClassificationMultiLabel(tokenizer,label_pad_token_id=360)

caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io_plugins.so: undefined symbol: _ZN3tsl6StatusC1EN10tensorflow5error4CodeESt17basic_string_viewIcSt11char_traitsIcEENS_14SourceLocationE']
caused by: ['/opt/conda/lib/python3.10/site-packages/tensorflow_io/python/ops/libtensorflow_io.so: undefined symbol: _ZTVN10tensorflow13GcsFileSystemE']


In [12]:
def dihedral_to_radians(angle):
    """ Converts angles to radians
    Args:
        angles (1D Tensor): vector with angle values
    """
    return angle*np.pi/180
    
def arctan_dihedral(sin, cos):
    """ Converts sin and cos back to diheral angles
    Args:
        sin (1D Tensor): vector with sin values 
        cos (1D Tensor): vector with cos values
    """
    result = torch.where(cos >= 0, torch.arctan(sin/cos), torch.arctan(sin/cos)+np.pi)
    result = torch.where((sin <= 0) & (cos <= 0), result-np.pi*2, result)
    
    return result*180/np.pi

In [13]:
from transformers import EsmForTokenClassification, TrainingArguments, Trainer, EsmModel
from transformers import TrainingArguments, Trainer, EsmModel
from transformers import AutoModel
from transformers import modeling_outputs
from transformers.modeling_outputs import TokenClassifierOutput
from torch.nn import CrossEntropyLoss
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, PackedSequence
import torch.nn as nn
 
##

class EsmForTokenClassificationCRF(EsmForTokenClassification):
    def __init__(self, config,n_hidden=1024):
        super(EsmForTokenClassificationCRF, self).__init__(config)
        # self.classifer = nn.Linear(config.hidden_size, config.num_labels)
        self.lstm = nn.LSTM(input_size=config.hidden_size+64, hidden_size=n_hidden, batch_first=True, num_layers=2, bidirectional=True, dropout=0.5)
        self.lstm_dropout = nn.Dropout(p=0.5)
        self.batch_norm = nn.BatchNorm1d(config.hidden_size+64, track_running_stats=False)
        self.conv1 = nn.Sequential(*[
            nn.Dropout(p=0.5),
            nn.Conv1d(in_channels=config.hidden_size, out_channels=32, kernel_size=129, padding=64),
            nn.ReLU(),
        ])
        self.conv2 = nn.Sequential(*[
            nn.Dropout(p=0.5),
            nn.Conv1d(in_channels=config.hidden_size, out_channels=32, kernel_size=257, padding=128),
            nn.ReLU(),
        ])
        self.esm = EsmModel(config, add_pooling_layer=False)
        self.classifier_psi = nn.Sequential(*[
            nn.Linear(in_features=n_hidden*2, out_features=2),
            nn.Tanh()
        ])
        self.classifier_phi = nn.Sequential(*[
            nn.Linear(in_features=n_hidden*2, out_features=2),
            nn.Tanh()
        ])
        self.classifier_ss8 = nn.Sequential(*[
            nn.Linear(in_features=n_hidden*2, out_features=8),
            #nn.Softmax(),
        ])
        self.classifier_ss3 = nn.Sequential(*[
            nn.Linear(in_features=n_hidden*2, out_features=3),
            #nn.Softmax(),
        ])   
        for param in self.esm.parameters():
            param.requires_grad = False

    def mse(self, outputs, labels, mask):
        loss = torch.square(outputs - labels) * mask
        return torch.sum(loss) / torch.sum(mask)
    def psi(self, outputs, labels, mask):
        labels = labels.unsqueeze(2)
        outputs = outputs.squeeze(2)
        # logger.info(f'mask:{mask.shape}')
        # logger.info(f"labels:{labels.shape}")
        mask = mask * (labels != 360).squeeze(2).int()
        mask = torch.cat(2*[mask.unsqueeze(2)], dim=2)
        # logger.info(f"outputs.shape:{outputs.shape}")
        # logger.info(f"mask.shape:{mask.shape}")
        # logger.info(f"labels.shape:{labels.shape}")
        loss = self.mse(outputs, torch.cat((torch.sin(dihedral_to_radians(labels)), torch.cos(dihedral_to_radians(labels))), dim=2).squeeze(2), mask)
        return loss
    def phi(self, outputs, labels, mask):
        labels = labels.unsqueeze(2)
        outputs = outputs.squeeze(2)
        # logger.info(f"mask_value:{mask}")
        # mask:[2, 279], labels:[2, 279], outputs:[2, 279, 1]
        mask = mask * (labels != 360).squeeze(2).int()
        mask = torch.cat(2*[mask.unsqueeze(2)], dim=2)
        loss = self.mse(outputs, torch.cat((torch.sin(dihedral_to_radians(labels)), torch.cos(dihedral_to_radians(labels))), dim=2).squeeze(2), mask)
        return loss
    def ss3(self,outputs, labels):
        loss_fct = CrossEntropyLoss(ignore_index=-1)
        labels = labels.to(outputs.device)
        loss = loss_fct(outputs.contiguous().view(-1, 3),labels.contiguous().view(-1))
        # loss = loss_fct(outputs,labels)
        return loss
    def ss8(self,outputs,labels):
        loss_fct = CrossEntropyLoss(ignore_index=-1)
        labels = labels.to(outputs.device)
        loss = loss_fct(outputs.contiguous().view(-1, 8),labels.contiguous().view(-1))
        return loss
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.esm(
            input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # logger.info(f"labels_train:{labels.shape}")
        sequence_output = outputs[0]
        _, length, _ = sequence_output.size()
        lengths = torch.sum(attention_mask, dim=1).cpu().long()
        x = sequence_output.permute(0,2,1)
        r1 = self.conv1(x)
        r2 = self.conv2(x)
        x = torch.cat([x, r1, r2], dim=1)
        x = self.batch_norm(x)
        x = x.permute(0,2,1)
        x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        x, _ = self.lstm(x)
        x, _ = pad_packed_sequence(x, total_length=length, batch_first=True)
        
        x = self.lstm_dropout(x)
        logits_psi = self.classifier_psi(x)
        logits_phi = self.classifier_phi(x)
        logits_ss3 = self.classifier_ss3(x)
        logits_ss8 = self.classifier_ss8(x)
        # logits  = [logits_psi, logits_phi, logits_ss3, logits_ss8]
        logits = torch.cat((logits_psi, logits_phi, logits_ss3, logits_ss8),dim = 2).permute(0,2,1)

        loss = None 
        if labels is not None:
            loss_psi = self.psi(logits_psi, labels[:,0,:], attention_mask)*5
            # logger.info(f"loss_psi:{loss_psi}")
            loss_phi = self.phi(logits_phi, labels[:,1,:], attention_mask)*5
            # logger.info(f"loss_phi:{loss_phi}")
            # logger.info(f"logits_ss3:{logits_ss3}")
            # logger.info(f"labels[:,2,:]:{labels[:,2,:]}")
            loss_ss3 = self.ss3(logits_ss3, labels[:,2,:])*5
            # logger.info(f"loss_ss3:{loss_ss3}")
            loss_ss8 = self.ss8(logits_ss8, labels[:,3,:])*1
            # logger.info(f"loss_ss8:{loss_ss8}")
            loss = torch.stack([loss_phi, loss_psi, loss_ss3, loss_ss8])
            # logger.info(f"loss:{loss}")
        if not return_dict:
            output_psi = (logits_psi,) + outputs[2:]
            output_phi = (logits_phi,) + outputs[2:]
            output = [output_psi,output_phi]
            return ((loss.sum(),) + output) if loss.sum() is not None else output
        return TokenClassifierOutput(
            loss=loss.sum(),
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
model = EsmForTokenClassificationCRF.from_pretrained(check_point)

Some weights of the model checkpoint at /kaggle/input/pretrain-t33/esm2_t33_650M_UR50D were not used when initializing EsmForTokenClassificationCRF: ['lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.bias', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing EsmForTokenClassificationCRF from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForTokenClassificationCRF from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmForTokenClassificationCRF were not initialized from the model checkpoint at /kaggle/input/pretrain-t33/esm2_t33_650M_UR50D and are newly initialized: ['batch_norm.weight', 'classifier_phi.0.weight', 'lstm.weight_hh_l0_reve

In [14]:
def mae(pred, labels):
    """ Mean absolute error
    Args:
        inputs (1D Tensor): vector with predicted numeric values
        labels (1D Tensor): vector with correct numeric values
    """
    err = torch.abs(labels - pred)
    return torch.mean(torch.fmin(err, 360-err)).item()

In [15]:
import time
import math
from typing import Dict,NamedTuple, Optional, Tuple, Union
from torch.utils.data import DataLoader
from torch.utils.data import IterableDataset
from collections.abc import Mapping

In [16]:
def speed_metrics(split, start_time, num_samples=None, num_steps=None):
    """
    Measure and return speed performance metrics.

    This function requires a time snapshot `start_time` before the operation to be measured starts and this function
    should be run immediately after the operation to be measured has completed.

    Args:
    - split: name to prefix metric (like train, eval, test...)
    - start_time: operation start time
    - num_samples: number of samples processed
    """
    runtime = time.time() - start_time
    result = {f"{split}_runtime": round(runtime, 4)}
    if num_samples is not None:
        samples_per_second = num_samples / runtime
        result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
    if num_steps is not None:
        steps_per_second = num_steps / runtime
        result[f"{split}_steps_per_second"] = round(steps_per_second, 3)
    return result

In [17]:
class EvalLoopOutput(NamedTuple):
    predictions: Union[np.ndarray, Tuple[np.ndarray]]
    label_ids: Optional[Union[np.ndarray, Tuple[np.ndarray]]]
    metrics: Optional[Dict[str, float]]
    num_samples: Optional[int]

In [18]:
def has_length(dataset):
    """
    Checks if the dataset implements __len__() and it doesn't raise an error
    """
    try:
        return len(dataset) is not None
    except TypeError:
        # TypeError: len() of unsized object
        return False

In [19]:
def find_batch_size(tensors):
    """
    Find the first dimension of a tensor in a nested list/tuple/dict of tensors.
    """
    if isinstance(tensors, (list, tuple)):
        for t in tensors:
            result = find_batch_size(t)
            if result is not None:
                return result
    elif isinstance(tensors, Mapping):
        for key, value in tensors.items():
            result = find_batch_size(value)
            if result is not None:
                return result
    elif isinstance(tensors, torch.Tensor):
        return tensors.shape[0] if len(tensors.shape) >= 1 else None
    elif isinstance(tensors, np.ndarray):
        return tensors.shape[0] if len(tensors.shape) >= 1 else None

In [20]:
def atleast_1d(tensor_or_array: Union[torch.Tensor, np.ndarray]):
    if isinstance(tensor_or_array, torch.Tensor):
        if hasattr(torch, "atleast_1d"):
            tensor_or_array = torch.atleast_1d(tensor_or_array)
        elif tensor_or_array.ndim < 1:
            tensor_or_array = tensor_or_array[None]
    else:
        tensor_or_array = np.atleast_1d(tensor_or_array)
    return tensor_or_array


In [21]:
def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100):
    padding_index=360
    """Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary."""
    tensor1 = atleast_1d(tensor1)
    tensor2 = atleast_1d(tensor2)
    # logger.info(f"tensor1:{tensor1.shape}")
    # logger.info(f"tensor2:{tensor2.shape}")
    '''2023-06-26 20:38:39,863 - __main__ - INFO - tensor1:torch.Size([16, 4, 1496])
2023-06-26 20:38:39,864 - __main__ - INFO - tensor2:torch.Size([5, 4, 411])'''
    if len(tensor1.shape) == 1 or tensor1.shape[2] == tensor2.shape[2]:
        return torch.cat((tensor1, tensor2), dim=0)

    # Let's figure out the new shape

    new_shape = (tensor1.shape[0] + tensor2.shape[0], tensor1.shape[1], max(tensor1.shape[2], tensor2.shape[2])) + tensor1.shape[3:]

    # Now let's fill the result tensor
    result = tensor1.new_full(new_shape, padding_index)
    result[: tensor1.shape[0], : tensor1.shape[1], : tensor1.shape[2]] = tensor1
    result[tensor1.shape[0] :, : tensor2.shape[1], : tensor2.shape[2]] = tensor2
    return result


def numpy_pad_and_concatenate(array1, array2, padding_index=-100):
    """Concatenates `array1` and `array2` on first axis, applying padding on the second if necessary."""
    array1 = atleast_1d(array1)
    array2 = atleast_1d(array2)

    if len(array1.shape) == 1 or array1.shape[1] == array2.shape[1]:
        return np.concatenate((array1, array2), axis=0)

    # Let's figure out the new shape
    new_shape = (array1.shape[0] + array2.shape[0], max(array1.shape[1], array2.shape[1])) + array1.shape[2:]

    # Now let's fill the result tensor
    result = np.full_like(array1, padding_index, shape=new_shape)
    result[: array1.shape[0], : array1.shape[1]] = array1
    result[array1.shape[0] :, : array2.shape[1]] = array2
    return result

In [22]:
def nested_concat(tensors, new_tensors, padding_index=-100):
    """
    Concat the `new_tensors` to `tensors` on the first dim and pad them on the second if needed. Works for tensors or
    nested list/tuples/dict of tensors.
    """
    assert type(tensors) == type(
        new_tensors
    ), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
    if isinstance(tensors, (list, tuple)):
        return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
    elif isinstance(tensors, torch.Tensor):
        return torch_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
    elif isinstance(tensors, Mapping):
        return type(tensors)(
            {k: nested_concat(t, new_tensors[k], padding_index=padding_index) for k, t in tensors.items()}
        )
    elif isinstance(tensors, np.ndarray):
        return numpy_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
    else:
        raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}")

In [23]:
def nested_numpify(tensors):
    "Numpify `tensors` (even if it's a nested list/tuple/dict of tensors)."
    if isinstance(tensors, (list, tuple)):
        return type(tensors)(nested_numpify(t) for t in tensors)
    if isinstance(tensors, Mapping):
        return type(tensors)({k: nested_numpify(t) for k, t in tensors.items()})

    t = tensors.cpu()
    if t.dtype == torch.bfloat16:
        # As of Numpy 1.21.4, NumPy does not support bfloat16 (see
        # https://github.com/numpy/numpy/blob/a47ecdea856986cd60eabbd53265c2ca5916ad5d/doc/source/user/basics.types.rst ).
        # Until Numpy adds bfloat16, we must convert float32.
        t = t.to(torch.float32)
    return t.numpy()

In [24]:
class IterableDatasetShard(IterableDataset):
    """
    Wraps a PyTorch `IterableDataset` to generate samples for one of the processes only. Instances of this class will
    always yield a number of samples that is a round multiple of the actual batch size (which is `batch_size x
    num_processes`). Depending on the value of the `drop_last` attribute, it will either stop the iteration at the
    first batch that would be too small or loop with indices from the beginning.

    On two processes with an iterable dataset yielding of `[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]` with a batch size of
    2:

    - the shard on process 0 will yield `[0, 1, 4, 5, 8, 9]` so will see batches `[0, 1]`, `[4, 5]`, `[8, 9]`
    - the shard on process 1 will yield `[2, 3, 6, 7, 10, 11]` so will see batches `[2, 3]`, `[6, 7]`, `[10, 11]`

    <Tip warning={true}>

        If your IterableDataset implements some randomization that needs to be applied the same way on all processes
        (for instance, a shuffling), you should use a `torch.Generator` in a `generator` attribute of the `dataset` to
        generate your random numbers and call the [`~trainer_pt_utils.IterableDatasetShard.set_epoch`] method of this
        object. It will set the seed of this `generator` to `seed + epoch` on all processes before starting the
        iteration. Alternatively, you can also implement a `set_epoch()` method in your iterable dataset to deal with
        this.

    </Tip>

    Args:
        dataset (`torch.utils.data.IterableDataset`):
            The batch sampler to split in several shards.
        batch_size (`int`, *optional*, defaults to 1):
            The size of the batches per shard.
        drop_last (`bool`, *optional*, defaults to `False`):
            Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the
            beginning.
        num_processes (`int`, *optional*, defaults to 1):
            The number of processes running concurrently.
        process_index (`int`, *optional*, defaults to 0):
            The index of the current process.
        seed (`int`, *optional*, defaults to 0):
            A random seed that will be used for the random number generation in
            [`~trainer_pt_utils.IterableDatasetShard.set_epoch`].
    """

    def __init__(
        self,
        dataset: IterableDataset,
        batch_size: int = 1,
        drop_last: bool = False,
        num_processes: int = 1,
        process_index: int = 0,
        seed: int = 0,
    ):
        self.dataset = dataset
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.num_processes = num_processes
        self.process_index = process_index
        self.seed = seed
        self.epoch = 0
        self.num_examples = 0

    def set_epoch(self, epoch):
        self.epoch = epoch
        if hasattr(self.dataset, "set_epoch"):
            self.dataset.set_epoch(epoch)

    def __iter__(self):
        self.num_examples = 0
        if (
            not hasattr(self.dataset, "set_epoch")
            and hasattr(self.dataset, "generator")
            and isinstance(self.dataset.generator, torch.Generator)
        ):
            self.dataset.generator.manual_seed(self.seed + self.epoch)
        real_batch_size = self.batch_size * self.num_processes
        process_slice = range(self.process_index * self.batch_size, (self.process_index + 1) * self.batch_size)

        first_batch = None
        current_batch = []
        for element in self.dataset:
            self.num_examples += 1
            current_batch.append(element)
            # Wait to have a full batch before yielding elements.
            if len(current_batch) == real_batch_size:
                for i in process_slice:
                    yield current_batch[i]
                if first_batch is None:
                    first_batch = current_batch.copy()
                current_batch = []

        # Finished if drop_last is True, otherwise complete the last batch with elements from the beginning.
        if not self.drop_last and len(current_batch) > 0:
            if first_batch is None:
                first_batch = current_batch.copy()
            while len(current_batch) < real_batch_size:
                current_batch += first_batch
            for i in process_slice:
                yield current_batch[i]

    def __len__(self):
        # Will raise an error if the underlying dataset is not sized.
        if self.drop_last:
            return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size
        else:
            return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size

In [25]:
def nested_truncate(tensors, limit):
    "Truncate `tensors` at `limit` (even if it's a nested list/tuple/dict of tensors)."
    if isinstance(tensors, (list, tuple)):
        return type(tensors)(nested_truncate(t, limit) for t in tensors)
    if isinstance(tensors, Mapping):
        return type(tensors)({k: nested_truncate(t, limit) for k, t in tensors.items()})

    return tensors[:limit]

In [26]:
class EvalPrediction:
    """
    Evaluation output (always contains labels), to be used to compute metrics.

    Parameters:
        predictions (`np.ndarray`): Predictions of the model.
        label_ids (`np.ndarray`): Targets to be matched.
        inputs (`np.ndarray`, *optional*)
    """

    def __init__(
        self,
        predictions: Union[np.ndarray, Tuple[np.ndarray]],
        label_ids: Union[np.ndarray, Tuple[np.ndarray]],
        inputs: Optional[Union[np.ndarray, Tuple[np.ndarray]]] = None,
    ):
        self.predictions = predictions
        self.label_ids = label_ids
        self.inputs = inputs

    def __iter__(self):
        if self.inputs is not None:
            return iter((self.predictions, self.label_ids, self.inputs))
        else:
            return iter((self.predictions, self.label_ids))

    def __getitem__(self, idx):
        if idx < 0 or idx > 2:
            raise IndexError("tuple index out of range")
        if idx == 2 and self.inputs is None:
            raise IndexError("tuple index out of range")
        if idx == 0:
            return self.predictions
        elif idx == 1:
            return self.label_ids
        elif idx == 2:
            return self.inputs

In [27]:
import importlib.util
_torch_available = importlib.util.find_spec("torch") is not None
def is_torch_available():
    return _torch_available

In [28]:
def denumpify_detensorize(metrics):
    """
    Recursively calls `.item()` on the element of the dictionary passed
    """
    if isinstance(metrics, (list, tuple)):
        return type(metrics)(denumpify_detensorize(m) for m in metrics)
    elif isinstance(metrics, dict):
        return type(metrics)({k: denumpify_detensorize(v) for k, v in metrics.items()})
    elif isinstance(metrics, np.generic):
        return metrics.item()
    elif is_torch_available() and isinstance(metrics, torch.Tensor) and metrics.numel() == 1:
        return metrics.item()
    return metrics

In [29]:
from typing import Dict, List, Optional

from torch.utils.data import Dataset


class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        # forward pass
        outputs = model(**inputs)
        # logger. info(f"outputs:{type(outputs)}")
        logits = outputs.get("logits")
        loss = outputs.get("loss")
        return (loss, outputs) if return_outputs else loss
    def evaluate(
        self,
        eval_dataset: Optional[Dataset] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> Dict[str, float]:
        """
        Run evaluation and returns metrics.

        The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
        (pass it to the init `compute_metrics` argument).

        You can also subclass and override this method to inject custom behavior.

        Args:
            eval_dataset (`Dataset`, *optional*):
                Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
                not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
                method.
            ignore_keys (`List[str]`, *optional*):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.
            metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
                An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
                "eval_bleu" if the prefix is "eval" (default)

        Returns:
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
            dictionary also contains the epoch number which comes from the training state.
        """
        # memory metrics - must set up as early as possible
        self._memory_tracker.start()

        eval_dataloader = self.get_eval_dataloader(eval_dataset)
        start_time = time.time()

        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        output = eval_loop(
            eval_dataloader,
            description="Evaluation",
            # No point gathering the predictions if there are no metrics, otherwise we defer to
            # self.args.prediction_loss_only
            prediction_loss_only=True if self.compute_metrics is None else None,
            ignore_keys=ignore_keys,
            metric_key_prefix=metric_key_prefix,
        )

        total_batch_size = self.args.eval_batch_size * self.args.world_size
        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )

        self.log(output.metrics)

        # if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
        #     # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
        #     xm.master_print(met.metrics_report())

        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)

        self._memory_tracker.stop_and_update_metrics(output.metrics)

        return output.metrics
    def evaluation_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> EvalLoopOutput:
        """
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.

        Works both with or without labels.
        """
        args = self.args

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only

        # if eval is called w/o train init deepspeed here
        # if args.deepspeed and not self.deepspeed:
        #     # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
        #     # from the checkpoint eventually
        #     deepspeed_engine, _, _ = deepspeed_init(
        #         self, num_training_steps=0, resume_from_checkpoint=None, inference=True
        #     )
        #     self.model = deepspeed_engine.module
        #     self.model_wrapped = deepspeed_engine
        #     self.deepspeed = deepspeed_engine

        model = self._wrap_model(self.model, training=False, dataloader=dataloader)

        # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
        # while ``train`` is running, cast it to the right dtype first and then put on device
        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)

        batch_size = self.args.eval_batch_size

        # logger.info(f"***** Running {description} *****")
        # if has_length(dataloader):
        #     logger.info(f"  Num examples = {self.num_examples(dataloader)}")
        # else:
        #     logger.info("  Num examples: Unknown")
        # logger.info(f"  Batch size = {batch_size}")

        model.eval()

        self.callback_handler.eval_dataloader = dataloader
        # Do this before wrapping.
        eval_dataset = getattr(dataloader, "dataset", None)

        # if is_torch_tpu_available():
        #     dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)

        if args.past_index >= 0:
            self._past = None

        # Initialize containers
        # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
        losses_host = None
        preds_host = None
        labels_host = None
        inputs_host = None

        # losses/preds/labels on CPU (final containers)
        all_losses = None
        all_preds = None
        all_labels = None
        all_inputs = None
        # Will be useful when we have an iterable dataset so don't know its length.

        observed_num_examples = 0
        # Main evaluation loop
        for step, inputs in enumerate(dataloader):
            # Update the observed num examples
            observed_batch_size = find_batch_size(inputs)
            # logger.info(f"inputs['labels']:{inputs['labels'].shape}")
            # inputs['labels']:torch.Size([16, 4, 1496])inputs['labels']:torch.Size([5, 4, 411])

            # logger.info(f"observed_batch_size:{observed_batch_size}")
            if observed_batch_size is not None:
                observed_num_examples += observed_batch_size
                # For batch samplers, batch_size is not known by the dataloader in advance.
                if batch_size is None:
                    batch_size = observed_batch_size

            # Prediction step
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)

            inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None

            # if is_torch_tpu_available():
            #     xm.mark_step()

            # Update containers on host
            if loss is not None:
                losses = self._nested_gather(loss.repeat(batch_size))
                losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
            # logger.info(f"labels.shape:{labels.shape}")

            if labels is not None:
                labels = self._pad_across_processes(labels)
                labels = self._nested_gather(labels)
                labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
            # logger.info(f"labels_pad_across_processes:{labels}")
            if inputs_decode is not None:
                inputs_decode = self._pad_across_processes(inputs_decode)
                inputs_decode = self._nested_gather(inputs_decode)
                inputs_host = (
                    inputs_decode
                    if inputs_host is None
                    else nested_concat(inputs_host, inputs_decode, padding_index=-100)
                )
            if logits is not None:
                logits = self._pad_across_processes(logits)
                logits = self._nested_gather(logits)
                if self.preprocess_logits_for_metrics is not None:
                    logits = self.preprocess_logits_for_metrics(logits, labels)
                preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
                # print("logits.shape",logits.shape)
            self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)

            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
            if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
                if losses_host is not None:
                    losses = nested_numpify(losses_host)
                    all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
                if preds_host is not None:
                    logits = nested_numpify(preds_host)
                    all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
                if inputs_host is not None:
                    inputs_decode = nested_numpify(inputs_host)
                    all_inputs = (
                        inputs_decode
                        if all_inputs is None
                        else nested_concat(all_inputs, inputs_decode, padding_index=-100)
                    )
                if labels_host is not None:
                    labels = nested_numpify(labels_host)
                    all_labels = (
                        labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
                    )

                # Set back to None to begin a new accumulation
                losses_host, preds_host, inputs_host, labels_host = None, None, None, None

        if args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        # Gather all remaining tensors and put them back on the CPU
        if losses_host is not None:
            losses = nested_numpify(losses_host)
            all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
        if preds_host is not None:
            logits = nested_numpify(preds_host)
            all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
        if inputs_host is not None:
            inputs_decode = nested_numpify(inputs_host)
            all_inputs = (
                inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
            )

        if labels_host is not None:

            labels = nested_numpify(labels_host)
            all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)

        # Number of samples
        if has_length(eval_dataset):
            num_samples = len(eval_dataset)
        # The instance check is weird and does not actually check for the type, but whether the dataset has the right
        # methods. Therefore we need to make sure it also has the attribute.
        elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
            num_samples = eval_dataset.num_examples
        else:
            if has_length(dataloader):
                num_samples = self.num_examples(dataloader)
            else:  # both len(dataloader.dataset) and len(dataloader) fail
                num_samples = observed_num_examples
        if num_samples == 0 and observed_num_examples > 0:
            num_samples = observed_num_examples

        # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
        # samplers has been rounded to a multiple of batch_size, so we truncate.
        if all_losses is not None:
            all_losses = all_losses[:num_samples]
        if all_preds is not None:
            all_preds = nested_truncate(all_preds, num_samples)
        if all_labels is not None:

            all_labels = nested_truncate(all_labels, num_samples)
        if all_inputs is not None:
            all_inputs = nested_truncate(all_inputs, num_samples)

        # Metrics!
        if self.compute_metrics is not None and all_preds is not None and all_labels is not None:

            if args.include_inputs_for_metrics:
                metrics = self.compute_metrics(
                    EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
                )
            else:
                metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
        else:
            metrics = {}

        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

        if all_losses is not None:
            metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
        if hasattr(self, "jit_compilation_time"):
            metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

        return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)

In [30]:
def accuracy(pred, labels):
    """ Accuracy coefficient
    Args:
        inputs (1D Tensor): vector with predicted integer values
        labels (1D Tensor): vector with correct integer values
    """
    
    return (sum((pred == labels)) / len(labels)).item()
def evaluate_psi(pred, labels):
    mask = torch.ones_like(labels)
    mask = mask * (labels != 360)
    labels = labels[mask == 1]
    outputs = arctan_dihedral(pred[:,:,0], pred[:,:,1])[mask == 1]
    return mae(outputs, labels)
def evaluate_phi(pred, labels):
    mask = torch.ones_like(labels)
    mask = mask * (labels != 360)
    labels = labels[mask == 1]
    outputs = arctan_dihedral(pred[:,:,0], pred[:,:,1])[mask == 1]
    return mae(outputs, labels)

def evaluate_ss3(pred, labels):

    mask = torch.ones_like(labels)
    mask = mask * (labels != 360)
    mask = mask * (labels != -1)
    labels = labels[mask == 1]
    outputs = torch.argmax(pred, dim=2)[mask == 1]
        
    return accuracy(outputs, labels)

def evaluate_ss8(pred, labels):
    mask = torch.ones_like(labels)
    mask = mask * (labels != 360)
    mask = mask * (labels != -1)
    labels = labels[mask == 1]
    outputs = torch.argmax(pred, dim=2)[mask == 1]
        
    return accuracy(outputs, labels)


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = torch.tensor(predictions).permute(0,2,1)
    labels = torch.tensor(labels).permute(0,2,1)

    # print('predictions_untra',predictions.shape)
    # print('labels_untra',labels.shape)
    return {"psi":evaluate_psi(predictions[:,:,0:2],labels[:,:,0]),
            "phi":evaluate_phi(predictions[:,:,2:4],labels[:,:,1]),
            "ss3":evaluate_ss3(predictions[:,:,4:7],labels[:,:,2]),
            "ss8":evaluate_ss8(predictions[:,:,7:15],labels[:,:,3])}

In [31]:
# model_name = model_checkpoint.split("/")[-1]
model_name="esm2_t30_150M_UR50D"
batch_size = 16

args = TrainingArguments(
    f"{model_name}-cnn-lstm-psi-psi-ss3-ss8-230628",
    evaluation_strategy = "epoch",
    save_strategy = "no",
    learning_rate=1e-4,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=2,
    weight_decay=0.001,
    load_best_model_at_end=False,
    # metric_for_best_model="mae",
    push_to_hub=False,
    fp16=False,
    fp16_full_eval=False,
    # use_legacy_prediction_loop=True,
    # report_to=None,
)

In [32]:
trainer = CustomTrainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=asp14_32_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
)

In [None]:
trainer.train()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Epoch,Training Loss,Validation Loss,Psi,Phi,Ss3,Ss8
1,4.4466,4.806728,43.688942,18.260643,0.86017,0.745023


In [None]:
def mae1(pred, labels):
    """ Mean absolute error
    Args:
        inputs (1D Tensor): vector with predicted numeric values
        labels (1D Tensor): vector with correct numeric values
    """
    err = torch.abs(labels - pred)
    # return torch.mean(torch.fmin(err, 360-err)).item()
    return torch.fmin(err, torch.abs(360-err))

In [None]:
def evaluate_psi_sep(pred, labels):

    outputs = arctan_dihedral(pred[:,:,0], pred[:,:,1])
    return outputs,mae1(outputs, labels)
def evaluate_phi_sep(pred, labels):
    outputs = arctan_dihedral(pred[:,:,0], pred[:,:,1])
    return outputs,mae1(outputs, labels)

def evaluate_ss3_sep(pred, labels):

    outputs = torch.argmax(pred, dim=2)
        
    return labels,outputs

def evaluate_ss8_sep(pred, labels):
    outputs = torch.argmax(pred, dim=2)
        
    return labels,outputs

In [None]:
def process_label(label,reference):    
    mask = torch.ones_like(label)
    mask = mask * (reference != 360)
    mask = mask * (reference != -1)
    return label[mask == 1]


In [None]:
trainer.evaluate()

In [None]:
predictions_original, labels_original, _ = trainer.predict(ts115_dataset)

In [None]:
predictions=predictions_original.transpose(0,2,1)
labels=labels_original.transpose(0,2,1)
predictions = torch.tensor(predictions)
labels = torch.tensor(labels)

psi_result,psi_loss = evaluate_psi_sep(predictions[:,:,0:2],labels[:,:,0])
phi_result,phi_loss = evaluate_phi_sep(predictions[:,:,2:4],labels[:,:,1])
ss3_origi_label,ss3_outputs = evaluate_ss3_sep(predictions[:,:,4:7],labels[:,:,2])
ss8_origi_label,ss8_outputs = evaluate_ss8_sep(predictions[:,:,7:15],labels[:,:,3])

psi_loss_view = psi_loss.reshape(-1)
phi_loss_view = phi_loss.reshape(-1)
ss3_view = ss3_origi_label.reshape(-1)
ss8_view = ss8_origi_label.reshape(-1)
ss3_res_view = ss3_outputs.reshape(-1)
ss8_res_view = ss8_outputs.reshape(-1)

ss3_label_view = process_label(ss3_view,ss3_view)
ss8_label_view = process_label(ss8_view,ss8_view)
ss3_res_post = process_label(ss3_res_view,ss3_view)
ss8_res_post = process_label(ss8_res_view,ss8_view)

In [None]:
psi_result.shape

In [None]:
ss3_outputs.shape

In [None]:
# torch.save(psi_result,"./psi_result.pt")

In [None]:
# torch.save(phi_result,"./phi_result.pt")

In [None]:
# torch.save(ss3_origi_label,"./ss3_origi_label.pt")

In [None]:
# torch.save(ss8_origi_label,"./ss8_origi_label.pt")

In [None]:
aa = {
    "ss3_label":ss3_label_view,
    "ss3_result":ss3_res_post,
    "ss8_label":ss8_label_view,
    "ss8_result":ss8_res_post
}

In [None]:
torch.save(aa,"./ts115_ss3_ss8_label_result.pt")

In [None]:
predictions_original, labels_original, _ = trainer.predict(asp14_32_dataset)

In [None]:
predictions=predictions_original.transpose(0,2,1)
labels=labels_original.transpose(0,2,1)
predictions = torch.tensor(predictions)
labels = torch.tensor(labels)

psi_result,psi_loss = evaluate_psi_sep(predictions[:,:,0:2],labels[:,:,0])
phi_result,phi_loss = evaluate_phi_sep(predictions[:,:,2:4],labels[:,:,1])
ss3_origi_label,ss3_outputs = evaluate_ss3_sep(predictions[:,:,4:7],labels[:,:,2])
ss8_origi_label,ss8_outputs = evaluate_ss8_sep(predictions[:,:,7:15],labels[:,:,3])

psi_loss_view = psi_loss.reshape(-1)
phi_loss_view = phi_loss.reshape(-1)
ss3_view = ss3_origi_label.reshape(-1)
ss8_view = ss8_origi_label.reshape(-1)
ss3_res_view = ss3_outputs.reshape(-1)
ss8_res_view = ss8_outputs.reshape(-1)

ss3_label_view = process_label(ss3_view,ss3_view)
ss8_label_view = process_label(ss8_view,ss8_view)
ss3_res_post = process_label(ss3_res_view,ss3_view)
ss8_res_post = process_label(ss8_res_view,ss8_view)

In [None]:
bb = {
    "ss3_label":ss3_label_view,
    "ss3_result":ss3_res_post,
    "ss8_label":ss8_label_view,
    "ss8_result":ss8_res_post
}

In [None]:
torch.save(bb,"./casp14_32_ss3_ss8_label_result.pt")

In [None]:
predictions_original, labels_original, _ = trainer.predict(casp12_dataset)

In [None]:
predictions=predictions_original.transpose(0,2,1)
labels=labels_original.transpose(0,2,1)
predictions = torch.tensor(predictions)
labels = torch.tensor(labels)

psi_result,psi_loss = evaluate_psi_sep(predictions[:,:,0:2],labels[:,:,0])
phi_result,phi_loss = evaluate_phi_sep(predictions[:,:,2:4],labels[:,:,1])
ss3_origi_label,ss3_outputs = evaluate_ss3_sep(predictions[:,:,4:7],labels[:,:,2])
ss8_origi_label,ss8_outputs = evaluate_ss8_sep(predictions[:,:,7:15],labels[:,:,3])

psi_loss_view = psi_loss.reshape(-1)
phi_loss_view = phi_loss.reshape(-1)
ss3_view = ss3_origi_label.reshape(-1)
ss8_view = ss8_origi_label.reshape(-1)
ss3_res_view = ss3_outputs.reshape(-1)
ss8_res_view = ss8_outputs.reshape(-1)

ss3_label_view = process_label(ss3_view,ss3_view)
ss8_label_view = process_label(ss8_view,ss8_view)
ss3_res_post = process_label(ss3_res_view,ss3_view)
ss8_res_post = process_label(ss8_res_view,ss8_view)

In [None]:
cc = {
    "ss3_label":ss3_label_view,
    "ss3_result":ss3_res_post,
    "ss8_label":ss8_label_view,
    "ss8_result":ss8_res_post
}

In [None]:
torch.save(cc,"./casp12_ss3_ss8_label_result.pt")

In [None]:
predictions_original, labels_original, _ = trainer.predict(cb513_dataset)

In [None]:
predictions=predictions_original.transpose(0,2,1)
labels=labels_original.transpose(0,2,1)
predictions = torch.tensor(predictions)
labels = torch.tensor(labels)

psi_result,psi_loss = evaluate_psi_sep(predictions[:,:,0:2],labels[:,:,0])
phi_result,phi_loss = evaluate_phi_sep(predictions[:,:,2:4],labels[:,:,1])
ss3_origi_label,ss3_outputs = evaluate_ss3_sep(predictions[:,:,4:7],labels[:,:,2])
ss8_origi_label,ss8_outputs = evaluate_ss8_sep(predictions[:,:,7:15],labels[:,:,3])

psi_loss_view = psi_loss.reshape(-1)
phi_loss_view = phi_loss.reshape(-1)
ss3_view = ss3_origi_label.reshape(-1)
ss8_view = ss8_origi_label.reshape(-1)
ss3_res_view = ss3_outputs.reshape(-1)
ss8_res_view = ss8_outputs.reshape(-1)

ss3_label_view = process_label(ss3_view,ss3_view)
ss8_label_view = process_label(ss8_view,ss8_view)
ss3_res_post = process_label(ss3_res_view,ss3_view)
ss8_res_post = process_label(ss8_res_view,ss8_view)

In [None]:
dd = {
    "ss3_label":ss3_label_view,
    "ss3_result":ss3_res_post,
    "ss8_label":ss8_label_view,
    "ss8_result":ss8_res_post
}

In [None]:
torch.save(dd,"./cb513_ss3_ss8_label_result.pt")

In [None]:
# psi_ultra = process_label(psi_loss_view,ss3_view)
# phi_ultra = process_label(phi_loss_view,ss3_view)
# ss3_ultra = process_label(ss3_view,ss3_view)
# ss8_ultra = process_label(ss8_view,ss3_view)

In [None]:
# d = {"psi_loss":psi_ultra,
#      "phi_loss":phi_ultra,
#      "ss3_origi_label":ss3_ultra,
#      "ss8_origi_label":ss8_ultra,
# }

In [None]:
# from pandas.core.frame import DataFrame

In [None]:
# data=DataFrame(d)

In [None]:
# df = data.to_csv('./ts115-violinplot.csv')

In [None]:
# trainer = CustomTrainer(
#     model,
#     args,
#     eval_dataset=cb513_dataset,
#     tokenizer=tokenizer,
#     compute_metrics=compute_metrics,
#     data_collator=data_collator,
# )

In [None]:
# trainer.evaluate()

In [None]:
# trainer = CustomTrainer(
#     model,
#     args,
#     eval_dataset=casp12_dataset,
#     tokenizer=tokenizer,
#     compute_metrics=compute_metrics,
#     data_collator=data_collator,
# )

In [None]:
# trainer.evaluate()

In [None]:
# trainer = CustomTrainer(
#     model,
#     args,
#     eval_dataset=asp14_32_dataset,
#     tokenizer=tokenizer,
#     compute_metrics=compute_metrics,
#     data_collator=data_collator,
# )

In [None]:
# trainer.evaluate()