In [176]:
import pandas as pd
import os
import torch
from torch.utils.data import Dataset, DataLoader
import math

script_path = '.'
data_dir = os.path.join(script_path, '../data')
df_path = os.path.join(data_dir, 'processed/03_filtered_enzyme_dat.csv')
fasta_dir = os.path.join(data_dir, 'fastas')

df = pd.read_csv(df_path)
df

keep = [f[:-3] for f in os.listdir(os.path.join(data_dir, 'embeddings/proteins')) if f.endswith('.pt')]
df = df[df['protein'].isin(keep)]
keep = [f[3:-3].replace('_','.') for f in os.listdir(os.path.join(data_dir, 'embeddings/descriptions')) if f.endswith('.pt')]
df = df[df['EC'].isin(keep)]

In [177]:
df

Unnamed: 0,protein,EC,DE,AN,CA,CF,CC,PR
12,P06525,1.1.1.1,alcohol dehydrogenase.,aldehyde reductase.,(1) a primary alcohol + NAD(+) = an aldehyde +...,,-!- Acts on primary or secondary alcohols or h...,
33,Q64413,1.1.1.1,alcohol dehydrogenase.,aldehyde reductase.,(1) a primary alcohol + NAD(+) = an aldehyde +...,,-!- Acts on primary or secondary alcohols or h...,
43,Q9P6C8,1.1.1.1,alcohol dehydrogenase.,aldehyde reductase.,(1) a primary alcohol + NAD(+) = an aldehyde +...,,-!- Acts on primary or secondary alcohols or h...,
81,O13309,1.1.1.1,alcohol dehydrogenase.,aldehyde reductase.,(1) a primary alcohol + NAD(+) = an aldehyde +...,,-!- Acts on primary or secondary alcohols or h...,
89,P42328,1.1.1.1,alcohol dehydrogenase.,aldehyde reductase.,(1) a primary alcohol + NAD(+) = an aldehyde +...,,-!- Acts on primary or secondary alcohols or h...,
...,...,...,...,...,...,...,...,...
55543,Q97CM5,3.1.1.96,D-aminoacyl-tRNA deacylase.,D-tyrosyl-tRNA(Tyr) aminoacylhydrolase. D-Tyr-...,(1) a D-aminoacyl-tRNA + H2O = a D-alpha-amino...,,"-!- The enzyme, found in all domains of life, ...",
55566,B5FFD2,3.1.1.96,D-aminoacyl-tRNA deacylase.,D-tyrosyl-tRNA(Tyr) aminoacylhydrolase. D-Tyr-...,(1) a D-aminoacyl-tRNA + H2O = a D-alpha-amino...,,"-!- The enzyme, found in all domains of life, ...",
55579,Q9BTV6,3.1.1.97,methylated diphthine methylhydrolase.,diphthine methylesterase.,diphthine methyl ester-[translation elongation...,,-!- The protein is only present in eukaryotes.,
55620,E1PL40,3.1.1.106,O-acetyl-ADP-ribose deacetylase.,,(1) 3''-O-acetyl-ADP-D-ribose + H2O = acetate ...,,"-!- The enzyme, characterized from the bacteri...",


In [178]:
# Split the data into train/val and test datasets
train_df = df.sample(frac=0.8, random_state=42)
val_df = df.drop(index=train_df.index)

# Reset indices
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)

batch_size = 10

In [179]:
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        n = self.data['protein'].iloc[index]
        ec = self.data['EC'].iloc[index]
        x = torch.load(f'../data/embeddings/proteins/{n}.pt', map_location=self.device).detach()
        y = torch.load(f'../data/embeddings/descriptions/EC_{ec.replace(".","_")}.pt', map_location=self.device).detach()
        return x, y

In [180]:
train_data = CustomDataset(train_df)
val_data = CustomDataset(val_df)

In [181]:
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True)

In [182]:
train_iter = iter(train_loader)

In [183]:
x,y = next(train_iter)
print(x.shape, y.shape)

torch.Size([10, 1280]) torch.Size([10, 1280])


In [186]:
import torch.nn.functional as F
import torch.nn as nn

class CFG:
    debug = False
    protein_path = "C:/Moein/AI/Datasets/Flicker-8k/Images"
    captions_path = "C:/Moein/AI/Datasets/Flicker-8k"
    batch_size = 32
    num_workers = 4
    head_lr = 1e-3
    protein_encoder_lr = 1e-4
    text_encoder_lr = 1e-5
    weight_decay = 1e-3
    patience = 1
    factor = 0.8
    epochs = 100
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_name = 'resnet50'
    protein_embedding = 1280
    text_encoder_model = "ESM 1v"
    text_embedding = 1280
    text_tokenizer = "GPT2-large"
    max_length = 200

    pretrained = True # for both protein encoder and text encoder
    trainable = True # for both protein encoder and text encoder
    temperature = 1.0

    # for projection head; used for both protein and text encoders
    num_projection_layers = 1
    projection_dim = 256 
    dropout = 0.1
    
class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.avg, self.sum, self.count = [0] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count

    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]


class ProjectionHead(nn.Module):
    def __init__(
        self,
        embedding_dim,
        projection_dim=CFG.projection_dim,
        dropout=CFG.dropout
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)
    
    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x


class CLIPModel(nn.Module):
    def __init__(
        self,
        temperature=CFG.temperature,
        protein_embedding=CFG.protein_embedding,
        text_embedding=CFG.text_embedding,
    ):
        super().__init__()
        self.protein_projection = ProjectionHead(embedding_dim=protein_embedding)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding)
        self.temperature = temperature

    def forward(self, protein_features, text_features):
        # Getting Protein and Text Embeddings (with same dimension)
        protein_embeddings = self.protein_projection(protein_features)
        text_embeddings = self.text_projection(text_features)

        # Calculating the Loss
        logits = (text_embeddings @ protein_embeddings.T) / self.temperature
        protein_similarity = protein_embeddings @ protein_embeddings.T
        texts_similarity = text_embeddings @ text_embeddings.T
        targets = F.softmax(
            (protein_similarity + texts_similarity) / 2 * self.temperature, dim=-1
        )
        texts_loss = cross_entropy(logits, targets, reduction='none')
        protein_loss = cross_entropy(logits.T, targets.T, reduction='none')
        loss =  (protein_loss + texts_loss) / 2.0 # shape: (batch_size)
        return loss.mean()


def cross_entropy(preds, targets, reduction='none'):
    log_softmax = nn.LogSoftmax(dim=-1)
    loss = (-targets * log_softmax(preds)).sum(1)
    if reduction == "none":
        return loss
    elif reduction == "mean":
        return loss.mean()

In [187]:
import os
import gc
import numpy as np
import pandas as pd
import itertools
from tqdm.autonotebook import tqdm
import torch
from torch import nn
import torch.nn.functional as F


def train_epoch(model, train_loader, optimizer, lr_scheduler):
    loss_meter = AvgMeter()
    tqdm_object = tqdm(train_loader, total=len(train_loader))
    for batch in tqdm_object:
        x = batch[0]
        y = batch[1]
        loss = model(x,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        count = x.size(0)
        loss_meter.update(loss.item(), count)

        tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
    return loss_meter


def val_epoch(model, val_loader):
    loss_meter = AvgMeter()

    tqdm_object = tqdm(val_loader, total=len(val_loader))
    for batch in tqdm_object:
        x = batch[0]
        y = batch[1]
        loss = model(x,y)

        count = x.size(0)
        loss_meter.update(loss.item(), count)

        tqdm_object.set_postfix(val_loss=loss_meter.avg)
    return loss_meter



model = CLIPModel().to(CFG.device)
params = [
    {"params": itertools.chain(
        model.protein_projection.parameters(), model.text_projection.parameters()
    ), "lr": CFG.head_lr, "weight_decay": CFG.weight_decay}
]
optimizer = torch.optim.AdamW(params, weight_decay=0.)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", patience=CFG.patience, factor=CFG.factor
)

best_loss = float('inf')
for epoch in range(CFG.epochs):
    print(f"Epoch: {epoch + 1}")
    model.train()
    train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler)
    model.eval()
    with torch.no_grad():
        val_loss = val_epoch(model, val_loader)

    if val_loss.avg < best_loss:
        best_loss = val_loss.avg
        torch.save(model.state_dict(), "best.pt")
        print("Saved Best Model!")

    lr_scheduler.step(val_loss.avg)

Epoch: 1


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Saved Best Model!
Epoch: 2


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Saved Best Model!
Epoch: 3


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Saved Best Model!
Epoch: 4


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Saved Best Model!
Epoch: 5


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Saved Best Model!
Epoch: 6


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Saved Best Model!
Epoch: 7


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Saved Best Model!
Epoch: 8


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 9


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 10


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Saved Best Model!
Epoch: 11


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 12


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 13


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 14


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Saved Best Model!
Epoch: 15


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Saved Best Model!
Epoch: 16


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 17


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Saved Best Model!
Epoch: 18


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 19


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 20


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 21


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 22


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Saved Best Model!
Epoch: 23


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 24


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 25


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Saved Best Model!
Epoch: 26


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 27


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 28


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 29


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 30


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 31


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 32


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 33


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 34


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 35


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 36


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 37


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 38


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 39


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 40


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 41


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 42


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 43


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Saved Best Model!
Epoch: 44


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 45


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Saved Best Model!
Epoch: 46


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 47


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 48


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 49


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 50


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 51


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 52


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 53


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 54


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 55


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 56


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 57


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 58


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 59


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 60


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 61


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 62


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 63


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 64


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 65


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 66


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 67


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 68


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 69


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 70


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 71


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 72


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 73


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 74


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 75


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 76


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 77


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 78


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 79


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 80


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 81


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 82


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 83


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 84


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 85


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 86


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 87


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 88


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 89


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 90


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 91


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 92


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 93


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 94


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 95


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 96


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 97


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 98


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 99


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 100


  0%|          | 0/154 [00:00<?, ?it/s]

  0%|          | 0/39 [00:00<?, ?it/s]

In [188]:
import numpy as np

In [211]:
x = np.array([[1,2,2]])
y = np.array([
    [1], 
    [2], 
    [2]
])

In [208]:
x*y

array([[1, 2, 2],
       [2, 4, 4],
       [2, 4, 4]])

In [209]:
x.shape

(1, 3)

In [210]:
y.shape

(3, 1)