In [1]:
from unimol.models import UniMolModel
from unimol.tasks.unimol_pocket import UniMolPocketTask
from unicore.data import Dictionary
from unicore import checkpoint_utils
from unicore import tasks
from unicore.logging import progress_bar

import os
class Namespace:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

args=Namespace(no_progress_bar=False, log_interval=10, log_format='simple', tensorboard_logdir='./save_pocket//tsb',
         wandb_project='', wandb_name='', seed=1, cpu=False, fp16=True, bf16=False, bf16_sr=False, allreduce_fp32_grad=False, 
         fp16_no_flatten_grads=False, fp16_init_scale=4, fp16_scale_window=256, fp16_scale_tolerance=0.0, min_loss_scale=0.0001, 
         threshold_loss_scale=None, user_dir='./unimol', empty_cache_freq=0, all_gather_list_size=16384, suppress_crashes=False, 
         profile=False, ema_decay=-1.0, validate_with_ema=False, loss='unimol', optimizer='adam', lr_scheduler='polynomial_decay', 
         task='unimol_pocket', num_workers=4, skip_invalid_size_inputs_valid_test=False, batch_size=32, required_batch_size_multiple=1, 
         data_buffer_size=10, train_subset='train', valid_subset='val', validate_interval=1, validate_interval_updates=10000, 
         validate_after_updates=0, fixed_validation_seed=None, disable_validation=False, batch_size_valid=32, max_valid_steps=None, 
         curriculum=0, distributed_world_size=8, distributed_rank=1, distributed_backend='nccl', distributed_init_method='env://', 
         distributed_port=-1, device_id=1, distributed_no_spawn=True, ddp_backend='c10d', bucket_cap_mb=25, fix_batches_to_gpus=False, 
         find_unused_parameters=False, fast_stat_sync=False, broadcast_buffers=False, nprocs_per_node=4, arch='unimol_base', max_epoch=0, 
         max_update=1000000, stop_time_hours=0, clip_norm=1.0, per_sample_clip_norm=0, update_freq=[1], lr=[0.0001], stop_min_lr=-1, 
         save_dir='./save_pocket/', tmp_save_dir='./', restore_file='checkpoint_last.pt', finetune_from_model=None, load_from_ema=False, 
         reset_dataloader=False, reset_lr_scheduler=False, reset_meters=False, reset_optimizer=False, optimizer_overrides='{}', save_interval=1, 
         save_interval_updates=10000, keep_interval_updates=10, keep_last_epochs=-1, keep_best_checkpoints=-1, no_save=False, 
         no_epoch_checkpoints=False, no_last_checkpoints=False, no_save_optimizer_state=False, best_checkpoint_metric='loss', 
         maximize_best_checkpoint_metric=False, patience=-1, checkpoint_suffix='', mode='infer', data='/p/scratch/found/unimol_datasets/pockets/', 
         mask_prob=0.15, leave_unmasked_prob=0.05, random_token_prob=0.05, noise_type='uniform', noise=1.0, remove_hydrogen=False, 
         remove_polar_hydrogen=False, max_atoms=256, dict_name='dict_coarse.txt', adam_betas='(0.9, 0.99)', adam_eps=1e-06, weight_decay=0.0001, 
         force_anneal=None, warmup_updates=10000, warmup_ratio=-1.0, end_learning_rate=0.0, power=1.0, total_num_update=1000000, masked_token_loss=1.0, 
         masked_coord_loss=1.0, masked_dist_loss=1.0, x_norm_loss=0.01, delta_pair_repr_norm_loss=0.01, no_seed_provided=False, encoder_layers=15, 
         encoder_embed_dim=512, encoder_ffn_embed_dim=2048, encoder_attention_heads=64, dropout=0.1, emb_dropout=0.1, attention_dropout=0.1, 
         activation_dropout=0.0, pooler_dropout=0.0, max_seq_len=512, activation_fn='gelu', pooler_activation_fn='tanh', post_ln=False)

dictionary = Dictionary.load(os.path.join(args.data, args.dict_name))
        
#args=Namespace(no_progress_bar=False, log_interval=10,log_format='simple', tensorboard_logdir='./save_pocket//tsb', wandb_project='', wandb_name='')

Class adding some layers to Unimol

In [2]:
import re

import torch
import torch.nn as nn
from torch import TensorType
import numpy as np
import collections

#from src.models.components.layers import LearnableLogitScaling, Normalize

class PocketModel(nn.Module):

    output_tokens: torch.jit.Final[bool]

    def __init__(
            self,
            #encoder: torch.nn.Module,
            task,
            args,
            dictionary,
            ckpt_path=None,
            output_dim: int = 512,
            proj: str = None,
            use_logit_scale: str = None,
    ):
        super().__init__()

        task=UniMolPocketTask(args, dictionary)
        task_pocket=task.setup_task(args)
        model=task_pocket.build_model(args)
        
        if ckpt_path is not None:
            state = checkpoint_utils.load_checkpoint_to_cpu(ckpt_path)
            model.load_state_dict(state["model"], strict=False)
        
        self.encoder=model
        self.output_dim = output_dim
        d_model = output_dim

        if (d_model == output_dim) and (proj is None):  # do we always need a proj?
            self.proj = nn.Identity()
        elif proj == 'linear':
            self.proj = nn.Linear(d_model, output_dim, bias=False)
        
        # if use_logit_scale:
        #     self.norm = nn.Sequential(
        #                     Normalize(dim=-1), 
        #                     LearnableLogitScaling(learnable=True)
        #             )
        # else:
        #     self.norm = nn.Sequential(
        #                     Normalize(dim=-1), 
        #                     LearnableLogitScaling(learnable=False)
        #             )

    def forward(self, batch: collections.OrderedDict):
        
        src_tokens,src_distance,src_coord,src_edge_type = batch.values()
        pooled_out = self.encoder(src_tokens,src_coord,src_distance,src_edge_type)[0]
        projected = self.proj(pooled_out)
        #normed = self.norm(projected) 
        #return normed
        return projected

    def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
        if not unlocked_layers:  # full freezing
            for n, p in self.model.named_parameters():
                p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
            return

    def init_parameters(self):
        pass

loading the model to test

In [3]:
model=PocketModel(UniMolPocketTask,args,dictionary,'save_pocket/checkpoint37.pt')


end of init!!!!


loading the dataset

In [87]:
subset=args.valid_subset.split(",")[0]
task=UniMolPocketTask(args, dictionary)
task.load_dataset(subset, combine=False, epoch=1)
dataset = task.dataset(subset)
dataset[0].keys()


odict_keys(['net_input.src_tokens', 'net_input.src_coord', 'net_input.src_distance', 'net_input.src_edge_type', 'target.tokens_target', 'target.distance_target', 'target.coord_target', 'target.pdb_id'])

building the batch iterator

In [88]:
data_parallel_world_size = 1
data_parallel_rank = 0

itr = task.get_batch_iterator(
            dataset=dataset,
            batch_size=args.batch_size,
            ignore_invalid_inputs=True,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=data_parallel_world_size,
            shard_id=data_parallel_rank,
            num_workers=args.num_workers,
            data_buffer_size=args.data_buffer_size,
        ).next_epoch_itr(shuffle=False)

In [9]:
# src_tokens,src_distance,src_coord,src_edge_type= next(itr)['net_input'].values()
# src_tokens.shape, src_distance.shape, src_coord.shape, src_edge_type.shape
# next(itr)['target']['pdb_id'][0]

'P06766'

passing an input to a model

In [89]:
sample = next(itr)
res=model(sample['net_input'])
#sample['net_input']['src_tokens'].shape, sample['net_input']['src_distance'].shape, sample['net_input']['src_coord'].shape, sample['net_input']['src_edge_type'].shape
res

tensor([[[ 0.6868, -0.5002,  0.0629,  ...,  0.2643,  0.6296,  0.4229],
         [ 0.0705, -0.9110, -0.4445,  ..., -0.0536,  0.0799, -0.5051],
         [-1.5933, -0.6398, -0.5973,  ...,  0.5244, -0.7628, -1.6311],
         ...,
         [ 0.6576, -0.9981, -1.3797,  ..., -0.6243, -0.2075, -0.3160],
         [ 1.5873, -1.6611, -1.5267,  ..., -1.6872,  0.3415,  0.9952],
         [ 0.5869, -0.7048, -1.2732,  ..., -1.7372, -0.4952, -1.1116]],

        [[ 0.6649, -1.2847,  0.0309,  ...,  0.3215,  1.3422,  0.2808],
         [-0.8561, -0.6664, -0.6936,  ..., -1.7930,  1.3666, -0.1981],
         [-0.7612, -0.7434, -0.2390,  ..., -0.7302,  0.6408, -0.0522],
         ...,
         [ 1.2032, -1.3002, -0.9362,  ..., -0.8181, -0.4498,  0.3111],
         [ 1.4278, -0.8831, -1.1705,  ..., -1.1583, -0.3076,  0.2316],
         [ 1.1652, -1.1560, -1.5016,  ..., -0.4139,  0.0330,  0.0761]],

        [[ 0.7961, -0.5861,  1.2442,  ...,  0.4519,  1.5402,  0.2177],
         [-0.5975, -1.0323, -0.3382,  ...,  0

In [8]:
#from unicore.data import LMDBDataset
#data1=LMDBDataset('/p/scratch/found/structures/EC_test/'+args.train_subset+'.csv')

Data class and testing it later on

In [77]:
from torch.utils.data import Dataset
from transformers import AutoTokenizer
import h5py
import pandas as pd
from unicore.data import LMDBDataset



class PocketDataset(Dataset):
    
    def __init__(self, args, data_dir='/p/scratch/found/EC_test2/',split='train',seq_tokenizer="facebook/esm2_t33_650M_UR50D"):
        subset=args.valid_subset.split(",")[0]
        task=UniMolPocketTask(args, dictionary)
        task.load_dataset(subset, combine=False, epoch=1)
        self.dataset = task.dataset(subset)
        self.h5_file = f'{data_dir}/EC.h5'
        meta_file = f'{data_dir}{split}.csv'
        print(meta_file,"meta_file!!!!!!!!!")
        self.meta_data = list(pd.read_csv(meta_file)['name'])
        print(len(self.meta_data),"meta_data!!!!!!!!!")
        
        self.seq_tokenizer = AutoTokenizer.from_pretrained(seq_tokenizer)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return idx
    
    def collate_fn(self, data):
        
        sequences = []
        pockets = []

        #seq_ids=data
        #print(len(data)," data in collat!!!!!!!!!")
        
        for i in data:
            #print(i,"seq_ids[i]!!!!!!!!!")
            with h5py.File(self.h5_file, 'r') as file:
                #print("yo open the h5 file!!!!!!!!!")
                for chain in file[f'{self.meta_data[i]}']['structure']['0'].keys():
                    sequence = file[f'{self.meta_data[i]}']['structure']['0'][f'{chain}']['residues']['seq1'][()]
                    #print(sequence,"sequence!!!!!!!!!")
                    sequences.append(str(sequence))
            data=self.dataset[i]
            pocket=dict()
            pocket['src_tokens']=data['net_input.src_tokens']
            pocket['src_distance']=data['net_input.src_distance']
            pocket['src_coord']=data['net_input.src_coord']
            pocket['src_edge_type']=data['net_input.src_edge_type']
            pockets.append(pocket)
        
        #print("after the loop!!!")
        #print(len(sequences),"len(sequences)!!!!!!!!!")
        sequence_input = self.seq_tokenizer(sequences, max_length=1024, padding=True, truncation=True, return_tensors="pt").input_ids
        #print(sequence_input.shape,"sequence_input.shape!!!!!!!!!")
        pocket_input={key: torch.stack([d[key] for d in pockets]) for key in pockets[0].keys()}
        #print(pocket_input.keys(),"pocket_input.keys()!!!!!!!!!")
        
        return sequence_input.long(), pocket_input
        
            




In [78]:
dataset_class=PocketDataset(args)

/p/scratch/found/EC_test2/train.csv meta_file!!!!!!!!!
13082 meta_data!!!!!!!!!


In [79]:
from torch.utils.data import ConcatDataset, DataLoader

dataloader=DataLoader(
        dataset=dataset_class,
        batch_size=32,
        num_workers=2,
        pin_memory=True,
        collate_fn=dataset_class.collate_fn,
        shuffle=True,
        drop_last=True,
)

In [84]:
for i, data in enumerate(dataloader):
    print(i,len(data[0]),data[1].keys())
    break

0 32 dict_keys(['src_tokens', 'src_distance', 'src_coord', 'src_edge_type'])
