In this notebook I will implementing the new decoder module discussed with _Christoph_. From the encodings given by the __ESM__ encoder we begin by attaching two layers of standard _Multi-Head self-attention_ which we then train by minimising the _pseudo-likelihhod_(find package).
Remember that from here we assume to be working with tensors, not graphs anymore. Our hope is that all of the relevant information coming from the graph structure has already been encoded in the embeddings of the encoder.

In [1]:
import torch 
from torch.nn import TransformerEncoderLayer, Linear
from torch import Tensor
from torch.nn.functional import one_hot
import pickle
import os
from torch.utils.data import Dataset, DataLoader, RandomSampler, Subset
import matplotlib.pyplot as plt
import numpy as np
#from torch_geometric.loader import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
## I think I have to be carefull at the dimensions of the objects
class Potts_Decoder(torch.nn.Module):
    def __init__(self, n_cat:int, n_layers:int, atten_dim:int, embed_dim:int, n_heads:int, dropout=0.0):
        ## We use the init of the superclass Module
        super().__init__()
        #self.system_size = system_size       ##length of the amino-acid
        self.n_cat = n_cat                   ## for proteins this is 21, fix??
        self.n_layers = n_layers
        self.atten_dim = atten_dim           ## this is the input dimension for the attention layer
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.dropout = dropout
        
        self.attentions = torch.nn.ModuleList()
        for _ in range(n_layers):
            self.attentions.append(TransformerEncoderLayer(self.atten_dim, self.n_heads,
                                   dropout=self.dropout))
        ## Maybe add a Linear Layer (usually always done)
        self.Linear = Linear(self.atten_dim, self.n_cat*self.embed_dim) ##21 is the number of amino-acids + skip character
        
        
        
    def forward(self, x, mask):
        ## I have to provide a masking matrix
        ## For the moment suppose a single protein.
        L = x.shape[0]
        ## If we have a batch of proteins, for the transformer, is this a batch of batches??
        for attention in self.attentions:
            x = attention(x).relu()
        
        ## We have to output a matrix, not a vector, we hence do the opposite of CNNs
        x = self.Linear(x).reshape(L, self.n_cat, self.embed_dim)
        #coupling = x.reshape(L, 1, 21, self.embed_dim) @ torch.transpose(x, 1, 2).reshape(1, L, self.embed_dim, 21)
        #coupling = torch.transpose(coupling, 1, 2).reshape(L*21, L*21) ##check whether this is correct, looks yes
        coupling = torch.flatten(x, end_dim=1) @ torch.flatten(torch.transpose(torch.transpose(x, 2, 0), 1, 2), 
                                                               start_dim=1)
        fields = torch.diag(coupling).reshape(L, self.n_cat)
        coupling = coupling * mask     ##element-wise product
        #coupling = 0.5*(coupling + torch.transpose(coupling,0,1))
        
        ##This recovers the overall coupling matrix
        coupling = coupling + torch.transpose(coupling, 0, 1)
        #for i in range(L):
        #    coupling[i*self.n_cat : (i+1)*self.n_cat, i*self.n_cat : (i+1)*self.n_cat] = 0.0    
        
        return fields, coupling 
    
    
    
def Pseudo_Likelihood(model, data:Tensor, fields:Tensor, coupling:Tensor, one_hot_input:bool = False) -> Tensor:
    seq = data[0]
    if not one_hot_input:
        data = one_hot(data, num_classes = model.n_cat).float().view(data.shape[0], -1)
    return torch.mean(torch.logsumexp(energy_diffs(model, data, fields, coupling, one_hot_input=True), dim=-1), axis=-1)    

def Local_Fields(model, data: Tensor, fields:Tensor, coupling:Tensor, one_hot_input: bool = False) -> Tensor:
    ndata: int = data.shape[0]
    system_size = coupling.shape[0]//model.n_cat
    if not one_hot_input:
        data = one_hot(data, num_classes = model.n_cat).float().view(data.shape[0], -1)
    return data.float() @ coupling + fields.T.reshape(1, -1).expand(ndata, system_size*model.n_cat) 

def energy_diffs(model, data: Tensor, fields:Tensor, coupling:Tensor, one_hot_input: bool = False) -> Tensor:

    ndata: int = data.shape[0]
    system_size = coupling.shape[0]//model.n_cat
    if not one_hot_input:
        data = one_hot(data, num_classes = model.n_cat).float().view(data.shape[0], -1)
    ## We extract the local fields vector
    local_fields: Tensor = Local_Fields(model, data, fields, coupling, one_hot_input=True)
    
    ## torch.mul is element-wise product
    local_fields_with_deltas: Tensor = torch.mul(local_fields, data)
    local_fields_with_deltas = local_fields_with_deltas.reshape(ndata, system_size, model.n_cat)
    local_fields_true: Tensor = torch.sum(local_fields_with_deltas, axis=-1)
    local_fields_true = local_fields_true.unsqueeze(-1).expand(ndata, system_size, model.n_cat)
    energy_diffs: Tensor = local_fields_true - local_fields.reshape(ndata, system_size, model.n_cat) 
    return -1*energy_diffs

def energy_diffs_position(model, position: int, data: Tensor, fields:Tensor, coupling:Tensor, one_hot_input: bool = False) -> Tensor:

    if not one_hot_input:
        data = one_hot(data, num_classes = model.n_cat).float().view(data.shape[0], -1)

    energy_diffs: Tensor = energy_diffs(model, data, fields, coupling, one_hot_input=True)
    ##position is the amino acid position 
    return energy_diffs[:, position]

In [3]:
def load_obj(file):
    with open(file, 'rb') as f:
        return pickle.load(f)

In [4]:
letter_to_num = {'C': 4, 'D': 3, 'S': 15, 'Q': 5, 'K': 11, 'I': 9,
                       'P': 14, 'T': 16, 'F': 13, 'A': 0, 'G': 7, 'H': 8,
                       'E': 6, 'L': 10, 'R': 1, 'W': 17, 'V': 19, 
                       'N': 2, 'Y': 18, 'M': 12, 'X':20}
# 'X' I think is -
## got this conversion from GVP github.

In [5]:
class Encoded_Proteins(Dataset):
    def __init__(self, path_dir, transform=None, target_transform=None):
        self.path_dir = path_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(os.listdir(path_dir))

    def __getitem__(self, idx):
        #protein_path = self.path_dir + '/CATH_430_' + str(idx) 
        protein_file = os.path.join(self.path_dir, os.listdir(self.path_dir)[idx])
        d = load_obj(protein_file)
        encoded_protein = d['Encoded_Protein']
        #native_seq = d['Native_Seq']
        num_seq = d['Num_Seq']
        if self.transform:
            encoded_protein = self.transform(encoded_protein)
        if self.target_transform:
            native_seq = self.target_transform(native_seq)
        return encoded_protein, num_seq

In [6]:
def default_collate(batch):
    """
    Fill in
    """
    #item = batch[0]
    #print(item)
    
    data_x = [item[0] for item in batch]
    data_y = [item[1] for item in batch]
    # each element is of size (1, h*, w*). where (h*, w*) changes from mask to another.
    ## data_y is already numerical
    return data_x, data_y


In [10]:
#path_dir = "C:/Users/lucas/Desktop/Encoded_Proteins_Toy"
path_dir = "./Encoded_Proteins_Toy/"
dataset = Encoded_Proteins(path_dir)
#dataloader = DataLoader(dataset, batch_size=4, collate_fn=default_collate, shuffle=True)

In [13]:
dataset_train = Subset(dataset, np.arange(1000))
dataset_test = Subset(dataset, np.arange(start=1000, stop=1300))
dataloader_train = DataLoader(dataset_train, batch_size=4, collate_fn=default_collate, shuffle=True)
dataloader_test = DataLoader(dataset_test, batch_size=4, collate_fn=default_collate, shuffle=True)

In [45]:
#DataLoader(sample_ds, batch_size=4, collate_fn=default_collate, shuffle=True)

In [14]:
#system_size:int = data.shape[0]
n_cat:int = 21
n_layers:int=2
atten_dim:int=512
embed_dim:int=5
n_heads:int=16
batch_size:int=4
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=default_collate, shuffle=True)

#decoder = Potts_Decoder(system_size, n_cat, n_layers, atten_dim, embed_dim, n_heads).to(device)
lr = 0.001

logging = []
#device='cuda'
device='cpu'
#optimizer = torch.optim.Adam(decoder.parameters(), lr=lr)
decoder = Potts_Decoder(n_cat, n_layers, atten_dim, embed_dim, n_heads).to(device)
optimizer = torch.optim.Adam(decoder.parameters(), lr=lr)
loss_f = Pseudo_Likelihood

num_epochs = 10
decoder.train()
## Set the decoder to training mode
for epoch in range(1, num_epochs + 1):
    total_loss = 0
    iterator=0
    for data_xs, data_ys in dataloader_train:
        print(f"iterator:{iterator}", end="\r")
        iterator+=1
        for data_x, data_y in zip(data_xs, data_ys):
            data_x = data_x.to(device)
            system_size=data_x.shape[0]
            mask: Tensor = torch.triu(torch.ones(system_size*decoder.n_cat, system_size*decoder.n_cat, dtype=bool), 1)  
            for i in range(system_size):
                mask[i*n_cat : (i+1)*n_cat, i*n_cat : (i+1)*n_cat] = 0              
            mask: Tensor = torch.nn.Parameter(mask, requires_grad=False).to(device)
            ##We have to change how we save the output... necessarely
            fields, couplings = decoder(data_x, mask)
            #fields, couplings = decoder(data_x)
            #seq_vals=torch.zeros(len(data_y), dtype=int)
            #for char,idx in zip(data_y, range(len(data_y))):
            #    seq_vals[idx] = letter_to_num[char]
            hot = one_hot(data_y, num_classes=decoder.n_cat)
            hot = torch.flatten(hot).unsqueeze(dim=0).to(device)
            #print(system_size)

            loss = loss_f(decoder, hot, fields, couplings, one_hot_input=True)/batch_size
            total_loss += float(loss)
            loss.backward()
        optimizer.step()
    print(f"We are at epoch:{epoch}, loss is:{total_loss}", end="\r")
    optimizer.zero_grad()
    
    
    #################################### TESTING #########################################
    total_loss_test = 0
    model.eval()
    for data_xs, data_ys in dataloader_test:
        for data_x, data_y in zip(data_xs, data_ys):
            data_x = data_x.to(device)
            system_size=data_x.shape[0]
            mask: Tensor = torch.triu(torch.ones(system_size*decoder.n_cat, system_size*decoder.n_cat, dtype=bool), 1)  
            for i in range(system_size):
                mask[i*n_cat : (i+1)*n_cat, i*n_cat : (i+1)*n_cat] = 0              
            mask: Tensor = torch.nn.Parameter(mask, requires_grad=False).to(device)
            ##We have to change how we save the output... necessarely
            fields, couplings = decoder(data_x, mask)
            #fields, couplings = decoder(data_x)
            #seq_vals=torch.zeros(len(data_y), dtype=int)
            #for char,idx in zip(data_y, range(len(data_y))):
            #    seq_vals[idx] = letter_to_num[char]
            hot = one_hot(data_y, num_classes=decoder.n_cat)
            hot = torch.flatten(hot).unsqueeze(dim=0).to(device)
            #print(system_size)

            loss = loss_f(decoder, hot, fields, couplings, one_hot_input=True)/batch_size
            total_loss_test += float(loss)
    
    #train_acc = test(train_loader)
    #test_acc = test(test_loader)
    #logging.append({"Epoch": epoch, "value": total_loss, "kind": "Loss"})
    #logging.append({"Epoch": epoch, "value": train_acc, "kind": "Train_acc"})
    #logging.append({"Epoch": epoch, "value": test_acc, "kind": "Test"})


iterator:2

KeyboardInterrupt: 