In [1]:
#Module imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import os
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset_folder = "../google_prot_fns/"
filters = 64
epochs = 60
batch_size = 128
protein_len = 200

In [3]:
toy_dataset = pd.read_csv('../data/2_class.csv')
classes = pd.unique(toy_dataset['family_accession'])

In [4]:
'''
Get 10 random proteins from each class
'''
minimal_dataset = toy_dataset.groupby('family_accession').apply(lambda x: x.sample(10))
minimal_dataset = minimal_dataset.reset_index(drop=True)

In [5]:
len(classes) #Should be 2

2

In [6]:
# from utils.gpu_utils import *

In [7]:
# chosen_gpu = get_free_gpu()
device = 'cuda'

In [8]:
# esm_model, alphabet = torch.hub.load("facebookresearch/esm:main", "esm2_t33_650M_UR50D") #Downloads cache to AFS. Limited space on AFS...
# batch_converter = alphabet.get_batch_converter()
# esm_model.to(device)
# esm_model.eval()
# print('done')

In [9]:
device = 'cuda'

In [10]:
class ESMFnDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, classes, device='cuda', max_len=200):
        self.dataset = dataset
        self.classes = classes
        self.max_len = max_len
        self.device = device
        self.class_to_idx = {classes[i]: i for i in range(len(classes))}
        self.idx_to_class = {i: classes[i] for i in range(len(classes))}
        self.letter_to_num = {'C': 4, 'D': 3, 'S': 15, 'Q': 5, 'K': 11, 'I': 9,
                'P': 14, 'T': 16, 'F': 13, 'A': 0, 'G': 7, 'H': 8,
                'E': 6, 'L': 10, 'R': 1, 'W': 17, 'V': 19, 
                'N': 2, 'Y': 18, 'M': 12}
        self.num_to_letter = {v:k for k, v in self.letter_to_num.items()}
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        row = self.dataset.iloc[idx]
        sequence_name = row['sequence_name']
        sequence_name = sequence_name.replace('/', '-')
        embeddings = torch.load(f'../data/toy_esm_embeddings/{sequence_name}.pt')
        embeddings = torch.tensor(embeddings, device='cuda', dtype=torch.float32)

        '''Pad embeddings to max_len with zero vector'''
        if embeddings.size(1) < self.max_len:
            B, N, h = embeddings.size()
            pad = torch.zeros((B, self.max_len - embeddings.shape[1], h), device=self.device)
            embeddings = torch.cat((embeddings, pad), dim=1)

        class_idx = torch.tensor(self.class_to_idx[row['family_accession']])
        # label = F.one_hot(class_idx, num_classes=len(self.classes))
        embeddings = embeddings.squeeze(0) #NOTE: LS adds its own batch dimension, so we need to remove it here
        return embeddings, class_idx

In [11]:
dataset = ESMFnDataset(minimal_dataset, classes, device)

In [12]:
dataset.__getitem__(0)[1]

tensor(1)

In [13]:
esmdl = torch.utils.data.DataLoader(dataset)

In [14]:
for batch in esmdl:
    print(batch[0][0])
    break

tensor([[ 0.0265,  0.0639,  0.0099,  ..., -0.2129,  0.2000, -0.0129],
        [ 0.1280, -0.0193, -0.0608,  ...,  0.2596, -0.1939, -0.1739],
        [ 0.1197, -0.0697,  0.0128,  ...,  0.0020,  0.1907, -0.2165],
        ...,
        [ 0.0513, -0.0610,  0.1865,  ...,  0.0707,  0.0377, -0.0421],
        [ 0.0319, -0.0693,  0.1435,  ...,  0.0333,  0.0819,  0.0220],
        [ 0.1033, -0.0212,  0.1197,  ...,  0.0796,  0.0763, -0.0332]],
       device='cuda:0')


# Learning to Split

In [15]:
import ls

In [16]:
# python scripts/extract.py esm2_t33_650M_UR50D examples/data/some_proteins.fasta \
#   examples/data/some_proteins_emb_esm2 --repr_layers 33 --include per_tok

In [17]:
from ls.models.build import ModelFactory


In [18]:
@ModelFactory.register("esm_transformer")
class TransformerEncoder(torch.nn.Module):
    
    def __init__(self, input_size=1280, nheads=8, num_layers=6, device='cuda', classes=2, max_len=202, **kwargs):
        super(TransformerEncoder, self).__init__()
        self.input_size = input_size
        self.classes = classes
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=self.input_size, nhead=nheads, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.fc1 = torch.nn.Linear(self.input_size, 256)
        self.fc2 = torch.nn.Linear(256, 32)
        self.fc3 = torch.nn.Linear(32 * max_len, self.classes)
        self.softmax = torch.nn.Softmax(dim=-1)
        self.device = device
        
    def forward(self, embedding):
        if len(embedding.size()) > 3 and embedding.size(0) == 1:
            embedding = embedding.squeeze(0)
            assert len(embedding.size()) == 3, 'Embedding has greater than 4 dimensions'
            
        B, N, h = embedding.shape
        hidden = self.transformer_encoder(embedding)
        hidden = self.fc1(hidden)
        hidden = self.fc2(hidden)
        '''Flatten hidden state'''
        hidden = hidden.view(B, -1)
        hidden = self.fc3(hidden)
        hidden = self.softmax(hidden)
        return hidden

In [29]:
@ModelFactory.register("esm_mlp")
class LinearLayer(torch.nn.Module):
    
    def __init__(self, input_size=1280, device='cuda', classes=2, max_len=202, **kwargs):
        super(LinearLayer, self).__init__()
        self.input_size = input_size
        self.classes = classes
        self.fc1 = torch.nn.Linear(self.input_size, 32)
        self.relu1 = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(32 * max_len, self.classes)
        self.softmax = torch.nn.Softmax(dim=-1)
        self.device = device
        
    def forward(self, embedding):
        if len(embedding.size()) > 3 and embedding.size(0) == 1:
            embedding = embedding.squeeze(0)
            assert len(embedding.size()) == 3, 'Embedding has greater than 4 dimensions'
            
        B, N, h = embedding.shape
        # hidden = self.fc1(hidden)
        # hidden = self.fc2(hidden)
        '''Flatten hidden state'''
        hidden = self.fc1(embedding)
        hidden = self.relu1(hidden)
        hidden = hidden.view(B, -1)
        hidden = self.fc2(hidden)
        # hidden = self.softmax(hidden)
        return hidden

Model esm_mlp already exists. Will replace it


In [20]:
model = TransformerEncoder(input_size=1280, nheads=8, num_layers=6, device=device, classes=len(classes)).to(device)

In [21]:
len(classes)

2

In [22]:
for batch in esmdl:
    embedding, label = batch
    pred = model(embedding)
    print(pred.shape) # shape is [batch_size, number of classes]
    break

torch.Size([1, 2])


In [36]:
data = ls.datasets.Tox21()

# Learning to split the Tox21 dataset.
# Here we use a simple mlp as our model backbone and use roc_auc as the evaluation metric.
train_data, test_data, train_indices, test_indices, splitter = ls.learning_to_split(data, model={'name': 'mlp'}, metric='roc_auc', return_order=['train_data', 'test_data', 'train_indices', 'test_indices', 'splitter'], num_outer_loop=1)

KeyboardInterrupt: 

In [24]:
# data.__getitem__(0)

In [25]:
# type(data)

In [26]:
dataset.__getitem__(0)[0].shape

torch.Size([202, 1280])

In [35]:
train_data, test_data, train_indices, test_indices, splitter = ls.learning_to_split(dataset, model={'name': 'esm_mlp', 'args': {}}, 
                                                                                    metric='roc_auc', num_workers=0,
                                                                                    return_order=['train_data', 'test_data', 'train_indices', 'test_indices', 'splitter'],
                                                                                    batch_size=20, patience=1, num_outer_loop=2)

KeyboardInterrupt: 