In [1]:
import pickle
import pandas as pd
import numpy as np
import torch
import os
import glob


%cd ~/openfold

from openfold.model.jk_sidechain_model import AngleTransformer
from openfold.config import config
from openfold.utils.loss import supervised_chi_loss
import pytorch_lightning as pl


/net/pulsar/home/koes/jok120/openfold




# Load the data

In [2]:
def load_pkl(fn):
    with open(fn, "rb") as f:
        _d = pickle.load(f)
    print(_d["current_datapt_number"])
    return _d

def load_multiple_pickles(pattern):
    files = glob.glob(pattern)
    files.sort()
    all_data = (load_pkl(fn) for fn in files)
    updated_data = {}
    starting_idx = 0
    for fn, d in zip(files, all_data):
        print(fn, flush=True)
        n = d["current_datapt_number"]
        del d["current_datapt_number"]
        # Add starting index to all keys in d
        d = {k + starting_idx: v for k, v in d.items()}
        starting_idx += n
        updated_data.update(d)


    return updated_data


BASEPATH = "/net/pulsar/home/koes/jok120/openfold/out/experiments/angletransformer-make-caches-50-TrainSample/"
# %ls -hlt $BASEPATH*.pkl

In [3]:
d = load_multiple_pickles(os.path.join(BASEPATH, '*_val.pkl'))

47
/net/pulsar/home/koes/jok120/openfold/out/experiments/angletransformer-make-caches-50-TrainSample/angle_transformer_intermediates0_val.pkl
47
/net/pulsar/home/koes/jok120/openfold/out/experiments/angletransformer-make-caches-50-TrainSample/angle_transformer_intermediates1_val.pkl
47
/net/pulsar/home/koes/jok120/openfold/out/experiments/angletransformer-make-caches-50-TrainSample/angle_transformer_intermediates2_val.pkl
47
/net/pulsar/home/koes/jok120/openfold/out/experiments/angletransformer-make-caches-50-TrainSample/angle_transformer_intermediates3_val.pkl


In [11]:
class ATModuleLit(pl.LightningModule):
    def __init__(
            self,
            dataset_dict,
            c_s=384,
            c_hidden=256,
            no_blocks=2,
            no_angles=config.model.structure_module.no_angles,  # 7
            epsilon=config.globals.eps,
            dropout=0.1,
            d_ff=2048,
            no_heads=4,
            activation='relu',
            batch_size=1,
            num_workers=0):
        super().__init__()
        self.at = AngleTransformer(c_s=c_s,
                                   c_hidden=c_hidden,
                                   no_blocks=no_blocks,
                                   no_angles=no_angles,
                                   epsilon=epsilon,
                                   dropout=dropout,
                                   d_ff=d_ff,
                                   no_heads=no_heads,
                                   activation=activation)
        self.loss = supervised_chi_loss
        self.dataset_dict = dataset_dict
        self.batch_size = batch_size
        self.num_workers = num_workers

    def forward(self, s, s_initial):
        return self.at(s, s_initial)

    def training_step(self, batch, batch_idx):
        s, s_initial = batch['s'][:, -1, ...].squeeze(1), batch['s_initial'].squeeze(1)
        unnorm_ang, ang = self(s, s_initial)
        loss = self.loss(angles_sin_cos=ang,
                         unnormalized_angles_sin_cos=unnorm_ang,
                         aatype=batch['aatype'],
                         seq_mask=batch['seq_mask'],
                         chi_mask=batch['chi_mask'],
                         chi_angles_sin_cos=batch['chi_angles_sin_cos'],
                         chi_weight=config.loss.supervised_chi.chi_weight,
                         angle_norm_weight=config.loss.supervised_chi.angle_norm_weight,
                         eps=1e-6)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def train_dataloader(self):
        return torch.utils.data.DataLoader(ATDataset(self.dataset_dict),
                                           batch_size=self.batch_size,
                                           shuffle=True,
                                           num_workers=self.num_workers,
                                           collate_fn=collate_fn)


class ATDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_dict):
        self.dataset_dict = dataset_dict

    def __len__(self):
        return len(self.dataset_dict) - 1

    def __getitem__(self, idx):
        return self.dataset_dict[idx]


def collate_fn(batch):
    if len(batch) == 1:
        d = {
            k: batch[0][k][0].unsqueeze(0).float()
            for k in batch[0].keys() if k != 'name'
        }
        d['name'] = [batch[0]['name']]
        d['aatype'] = batch[0]['aatype'].long()
        return d
    else:
        # Needs work for padding to work
        max_len = max([b['s'][0].shape[-2] for b in batch])
        d = {}
        for prot in batch:
            for k, v in prot.items():
                if k not in d:
                    d[k] = []
                v = v[0]
                if k == 's_initial':
                    len_diff = max_len - v.shape[-2]
                    new_value = torch.cat([v, torch.zeros(v.shape[0], len_diff, v.shape[-1]).float()], dim=-2)
                    d[k].append(new_value)
                elif k != 'name':
                    try:
                        len_diff = max_len - v.shape[-2]
                        new_value = torch.cat([v, torch.zeros(v.shape[0], v.shape[1], len_diff, v.shape[-1]).float()], dim=-2)
                        d[k].append(new_value)
                    except Exception as e:
                        print(e)
                        print(k)
                        print(v.shape)
                        print(max_len)
                        print(len_diff)
                        raise e
                else:
                    d[k].append(prot[k][0])
        
        d = {
            k: torch.stack(d[k]).float()
            for k in d.keys() if k != 'name'
        }

        d['name'] = [b['name'] for b in batch]
        return d


In [12]:
at_lit = ATModuleLit(dataset_dict=d, batch_size=1, num_workers=1)

In [13]:
# Train the model

at_lit = ATModuleLit(dataset_dict=d, batch_size=1, num_workers=1)
trainer = pl.Trainer(max_epochs=10, gpus=1)
trainer.fit(at_lit)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name | Type             | Params
------------------------------------------
0 | at   | AngleTransformer | 2.8 M 
------------------------------------------
2.8 M     Trainable params
0         Non-trainable params
2.8 M     Total params
11.323    Total estimated model params size (MB)
  rank_zero_warn(


Epoch 0:   0%|          | 0/187 [00:00<?, ?it/s] 

  rank_zero_deprecation(


Epoch 9: 100%|██████████| 187/187 [00:04<00:00, 44.83it/s, loss=0.358, v_num=34]


In [8]:
dl = torch.utils.data.DataLoader(ATDataset(d),
                                 batch_size=1,
                                 shuffle=True,
                                 num_workers=0,
                                 collate_fn=collate_fn)


In [9]:
b = next(iter(dl))
b.keys()

dict_keys(['s', 's_initial', 'aatype', 'seq_mask', 'chi_mask', 'chi_angles_sin_cos', 'name'])

In [10]:
b['s'][:, -1, ...].cuda().shape

torch.Size([1, 1, 162, 384])

In [None]:
at_lit.cuda()
uang, ang = at_lit.at(b['s'][:, -1, ...].cuda().squeeze(1), b['s_initial'].cuda().squeeze(1))

In [None]:
ang.shape, uang.shape

In [None]:
# from openfold.np import residue_constants
# from openfold.utils.tensor_utils import (
#     tree_map,
#     tensor_tree_map,
#     masked_mean,
#     permute_final_dims,
#     batched_gather,
# )
# def supervised_chi_loss(
#     angles_sin_cos: torch.Tensor,
#     unnormalized_angles_sin_cos: torch.Tensor,
#     aatype: torch.Tensor,
#     seq_mask: torch.Tensor,
#     chi_mask: torch.Tensor,
#     chi_angles_sin_cos: torch.Tensor,
#     chi_weight: float,
#     angle_norm_weight: float,
#     eps=1e-6,
#     **kwargs,
# ) -> torch.Tensor:
#     """
#         Implements Algorithm 27 (torsionAngleLoss)

#         Args:
#             angles_sin_cos:
#                 [*, N, 7, 2] predicted angles
#             unnormalized_angles_sin_cos:
#                 The same angles, but unnormalized
#             aatype:
#                 [*, N] residue indices
#             seq_mask:
#                 [*, N] sequence mask
#             chi_mask:
#                 [*, N, 7] angle mask
#             chi_angles_sin_cos:
#                 [*, N, 7, 2] ground truth angles
#             chi_weight:
#                 Weight for the angle component of the loss
#             angle_norm_weight:
#                 Weight for the normalization component of the loss
#         Returns:
#             [*] loss tensor
#     """
#     pred_angles = angles_sin_cos[..., 3:, :]  # [8, 1, 256, 4, 2]
#     residue_type_one_hot = torch.nn.functional.one_hot(
#         aatype,
#         residue_constants.restype_num + 1,
#     )
#     chi_pi_periodic = torch.einsum(
#         "...ij,jk->ik",
#         residue_type_one_hot.type(angles_sin_cos.dtype),
#         angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic),
#     )

#     true_chi = chi_angles_sin_cos[None]  # [1, 1, 256, 4, 2]

#     shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1)
#     true_chi_shifted = shifted_mask * true_chi
#     sq_chi_error = torch.sum((true_chi - pred_angles) ** 2, dim=-1)
#     sq_chi_error_shifted = torch.sum(
#         (true_chi_shifted - pred_angles) ** 2, dim=-1
#     )
#     sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted)
    
#     # The ol' switcheroo
#     sq_chi_error = sq_chi_error.permute(
#         *range(len(sq_chi_error.shape))[1:-2], 0, -2, -1
#     )

#     sq_chi_loss = masked_mean(
#         chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3)
#     )

#     loss = chi_weight * sq_chi_loss

#     angle_norm = torch.sqrt(
#         torch.sum(unnormalized_angles_sin_cos ** 2, dim=-1) + eps
#     )
#     angle_norm = angle_norm.unsqueeze(0)
#     norm_error = torch.abs(angle_norm - 1.0)
#     norm_error = norm_error.permute(
#         *range(len(norm_error.shape))[1:-2], 0, -2, -1
#     )
#     # norm_error = norm_error.unsqueeze(0)
#     angle_norm_loss = masked_mean(
#         seq_mask[..., None, :, None], norm_error, dim=(-1, -2, -3)
#     )

#     loss = loss + angle_norm_weight * angle_norm_loss

#     # Average over the batch dimension
#     loss = torch.mean(loss)

#     # Compute the MAE so we know exactly how good the angle prediction is in Radians
#     # print(pred_angles.shape)
#     # pred = torch.transpose(pred_angles.clone(), 0, 1)  # [1, 8, 256, 4, 2]
#     # pred = pred[:, -1, :, :, :]  # [1, 1, 256, 4, 2]
#     # pred = pred.reshape(pred.shape[0], pred.shape[-3], pred.shape[-2], pred.shape[-1])  # [1, 256, 4, 2]
#     # true_chi2 = chi_angles_sin_cos.clone()  # [1, 256, 4, 2]
#     # true_chi2 = inverse_trig_transform(true_chi2, 4)  # [1, 256, 4]
#     # pred = inverse_trig_transform(pred, 4)  # [1, 256, 4]
#     # true_chi2 = true_chi2.masked_fill_(~chi_mask.bool(), torch.nan)
#     # pred = pred.masked_fill_(~chi_mask.bool(), torch.nan)
#     # mae = angle_mae(true_chi2, pred)

#     loss_dict = {"loss": loss, "sq_chi_loss": sq_chi_loss, "angle_norm_loss": angle_norm_loss}

#     return loss_dict

In [None]:
b['aatype'][0].cuda()

In [None]:
supervised_chi_loss(angles_sin_cos=ang,
                         unnormalized_angles_sin_cos=uang,
                         aatype=b['aatype'][0].cuda(),
                         seq_mask=b['seq_mask'][0].cuda(),
                         chi_mask=b['chi_mask'][0].cuda(),
                         chi_angles_sin_cos=b['chi_angles_sin_cos'][0].cuda(),
                         chi_weight=config.loss.supervised_chi.chi_weight,
                         angle_norm_weight=config.loss.supervised_chi.angle_norm_weight,
                         eps=1e-6)