In [1]:
from pathlib import Path
from pynvml import *

curdir = Path(os.getcwd())
sys.path.append(str(curdir.parent.absolute()))
os.chdir(str(curdir.parent.absolute()))
curdir = Path(os.getcwd())

from src.utils.data import (
    load_model,
    seed_everything,
    log_gpu_memory_usage
)
from src.utils.main_utils import get_or_generate_vocabularies,  get_or_generate_label_embeddings, get_or_generate_sequence_embeddings, validate_arguments
from src.data.datasets import ProteinDataset, calculate_pos_weight, create_multiple_loaders, calculate_label_weights
from src.models.ProTCLTrainer import ProTCLTrainer
from src.models.ProTCL import ProTCL
from src.models.protein_encoders import ProteInfer
from src.utils.evaluation import EvalMetrics
from src.utils.models import count_parameters_by_layer, sigmoid_bias_from_prob,load_checkpoint
from src.utils.configs import get_setup
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
import torch.distributed as dist
import torch
import wandb
import os
import argparse
import json
import pandas as pd
from transformers import AutoTokenizer, AutoModel
from src.data.collators import collate_variable_sequence_length
import mlflow
import loralib as lora
import random


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
def get_batch_weights_v2(label_weights, target):
    """
    Computes the weights for each sample in the batch based on the target labels
    using broadcasting.
    
    Args:
        label_weights: torch.tensor of size [no_of_classes] with the weight of each label.
        target: torch.tensor of size [batch, no_of_classes].

    Returns:
        weights_for_samples: torch.tensor of size [batch, no_of_classes].
    """

    # Ensure label_weights is a float tensor for correct broadcasting and computation
    label_weights = label_weights.float()

    # Multiply weights with target labels using broadcasting
    # This step applies the specific class weights to the corresponding labels in the target.
    weighted_targets = label_weights * target

    # Sum the weighted targets along the class dimension to get a single weight per sample
    weights_for_samples = weighted_targets.sum(dim=1, keepdim=True)

    # Use broadcasting again for expanding weights across the class dimension
    # No need to repeat the tensor explicitly.
    weights_for_samples = weights_for_samples.expand_as(target)

    return weights_for_samples


class CBLoss(torch.nn.Module):
    def __init__(self, label_weights, beta=0.99):
        super().__init__()

        self.label_weights = label_weights
        self.beta=beta

    def forward(self, input,target):
        no_of_classes = len(self.label_weights)
        effective_num = 1.0 - torch.pow(self.beta, self.label_weights)

        # Replace zeros in effective_num with 'inf' (infinity) to avoid division by zero
        effective_num = torch.where(effective_num == 0, torch.tensor(float('inf')), effective_num)

        weights = (1.0 - self.beta) / effective_num
        weights = weights / torch.sum(weights) * no_of_classes

        weights = get_batch_weights_v2(weights,target)

        return weights

In [55]:
import numpy as np
import torch.nn.functional as F

def CB_loss(labels_one_hot, samples_per_cls, no_of_classes,  beta=0.99):
    """Compute the Class Balanced Loss between `logits` and the ground truth `labels`.

    Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits)
    where Loss is one of the standard losses used for Neural Networks.

    Args:
      labels: A int tensor of size [batch].
      logits: A float tensor of size [batch, no_of_classes].
      samples_per_cls: A python list of size [no_of_classes].
      no_of_classes: total number of classes. int
      loss_type: string. One of "sigmoid", "focal", "softmax".
      beta: float. Hyperparameter for Class balanced loss.
      gamma: float. Hyperparameter for Focal loss.

    Returns:
      cb_loss: A float tensor representing class balanced loss
    """
    effective_num = 1.0 - np.power(beta, samples_per_cls)
    weights = (1.0 - beta) / np.array(effective_num)
    weights = weights / np.sum(weights) * no_of_classes


    weights = torch.tensor(weights).float()
    weights = weights.unsqueeze(0)
    weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot
    weights = weights.sum(1)
    weights = weights.unsqueeze(1)
    weights = weights.repeat(1,no_of_classes)
    return weights

In [50]:
labels = (torch.rand(size=(10,100))>0.4)*1.0
preds = torch.rand(size=labels.shape)*10

In [52]:
samples_per_cls = labels.sum(axis=0
           )

In [53]:
samples_per_cls

tensor([7., 6., 6., 6., 4., 7., 6., 6., 6., 7., 6., 6., 6., 2., 7., 6., 4., 5.,
        7., 5., 7., 7., 6., 4., 7., 5., 5., 7., 5., 8., 6., 6., 6., 5., 5., 6.,
        5., 7., 3., 5., 4., 7., 5., 4., 5., 8., 7., 1., 7., 6., 6., 6., 5., 5.,
        6., 5., 6., 9., 5., 6., 9., 6., 4., 6., 7., 6., 6., 7., 4., 5., 5., 5.,
        8., 6., 7., 6., 6., 7., 4., 8., 5., 4., 7., 2., 6., 5., 6., 6., 6., 4.,
        5., 7., 7., 9., 6., 7., 8., 8., 3., 6.])

In [67]:
w_original = CB_loss(labels, samples_per_cls, len(samples_per_cls),  beta=0.9)

In [68]:
w_original.mean(),w_original.sum()

(tensor(53.7565), tensor(53756.4570))

In [69]:
cb=CBLoss(samples_per_cls,beta=0.9)
w_mine=cb(None,labels)

In [70]:
w_mine.mean(),w_mine.sum()

(tensor(53.7565), tensor(53756.4609))

In [None]:
w_mine = CB_loss(labels, samples_per_cls, len(samples_per_cls),  beta=0.99)

In [3]:
bsz=3
features = torch.randint(0,10,(bsz,2,1))
labels = torch.Tensor([1,2,1])

In [12]:
temperature=0.07
contrast_mode='all'
base_temperature=0.07

device = (torch.device('cuda')
            if features.is_cuda
            else torch.device('cpu'))

features = features.view(features.shape[0], features.shape[1], -1)

batch_size = features.shape[0]
labels = labels.contiguous().view(-1, 1)
if labels.shape[0] != batch_size:
    raise ValueError('Num of labels does not match num of features')
mask = torch.eq(labels, labels.T).float().to(device)


contrast_count = features.shape[1]
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
if contrast_mode == 'one':
    anchor_feature = features[:, 0]
    anchor_count = 1
elif contrast_mode == 'all':
    anchor_feature = contrast_feature
    anchor_count = contrast_count
else:
    raise ValueError('Unknown mode: {}'.format(contrast_mode))

# compute logits
anchor_dot_contrast = torch.div(
    torch.matmul(anchor_feature, contrast_feature.T),
    temperature)
# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()

# tile mask
mask = mask.repeat(anchor_count, contrast_count)
# mask-out self-contrast cases
logits_mask = torch.scatter(
    torch.ones_like(mask),
    1,
    torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
    0
)
mask = mask * logits_mask

# compute log_prob
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

# compute mean of log-likelihood over positive
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

# loss
loss = - (temperature / base_temperature) * mean_log_prob_pos

loss = loss.mean()


In [13]:
logits_mask

tensor([[0., 1., 1., 1., 1., 1.],
        [1., 0., 1., 1., 1., 1.],
        [1., 1., 0., 1., 1., 1.],
        [1., 1., 1., 0., 1., 1.],
        [1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 0.]])

In [10]:
(mask * log_prob).sum(0)

tensor([ -573.2203,    -1.7918,  -716.0775,  -716.0775,  -916.0775, -1144.6489])

In [159]:
loss.view(anchor_count, batch_size).mean(),loss.mean()


(tensor(126.7857), tensor(126.7857))

In [152]:
loss.view(anchor_count, batch_size).shape

torch.Size([2, 3])

In [161]:
del anchor_count

In [162]:
temperature=0.07
base_temperature=0.07

# compute logits
anchor_dot_contrast = torch.div(anchor_dot_contrast,temperature)
# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()

# compute log_prob
exp_logits = torch.exp(logits) 
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

# compute mean of log-likelihood over positive
mean_log_prob_pos = (labels_multihot * log_prob).sum(1) / labels_multihot.sum(1)

# loss
loss = - (temperature / base_temperature) * mean_log_prob_pos
loss = loss.mean()


NameError: name 'labels_multihot' is not defined

In [142]:
loss

tensor(88.9891)

In [140]:
torch.logsumexp(logits,dim=1,keepdim=True)

tensor([[1.7918],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000]])

In [128]:
log_prob

tensor([[  -1.7918,   -1.7918,   -1.7918,   -1.7918,   -1.7918,   -1.7918],
        [-214.2857,  -85.7143,  -42.8571, -128.5714, -128.5714,    0.0000],
        [-285.7143, -114.2857,  -57.1429, -171.4286, -171.4286,    0.0000],
        [-142.8571,  -57.1429,  -28.5714,  -85.7143,  -85.7143,    0.0000],
        [-142.8571,  -57.1429,  -28.5714,  -85.7143,  -85.7143,    0.0000],
        [-357.1429, -142.8571,  -71.4286, -214.2857, -214.2857,    0.0000]])

In [123]:
anchor_dot_contrast

tensor([[  0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000],
        [  0.0000, 128.5714, 171.4286,  85.7143,  85.7143, 214.2857],
        [  0.0000, 171.4286, 228.5714, 114.2857, 114.2857, 285.7143],
        [  0.0000,  85.7143, 114.2857,  57.1429,  57.1429, 142.8571],
        [  0.0000,  85.7143, 114.2857,  57.1429,  57.1429, 142.8571],
        [  0.0000, 214.2857, 285.7143, 142.8571, 142.8571, 357.1429]])

In [124]:
features

tensor([[[0],
         [2]],

        [[3],
         [2]],

        [[4],
         [5]]])

In [125]:
contrast_feature

tensor([[0],
        [3],
        [4],
        [2],
        [2],
        [5]])

In [58]:
anchor_dot_contrast

tensor([[  71.4286,  157.1429,  342.8571,  128.5714],
        [ 157.1429,  371.4286,  685.7143,  214.2857],
        [ 342.8571,  685.7143, 1828.5714,  800.0000],
        [ 128.5714,  214.2857,  800.0000,  414.2857]])

In [60]:
features.shape

torch.Size([2, 2, 2])

In [40]:
mask

tensor([[0., 0., 1., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.]])

In [20]:
a=MLP(1000,[10,10],bias=False,norm_layer=torch.nn.BatchNorm1d,activation_layer=torch.nn.Identity)

In [26]:
e=torch.nn.Embedding(100,3)

e(torch.arange(10))

torch.Size([100, 3])

Parameter containing:
tensor([[ 4.1412e-01,  2.0729e+00,  3.3877e-01],
        [ 7.4035e-01, -1.0129e+00,  1.0684e+00],
        [-7.6563e-01, -1.6943e-01, -7.2646e-01],
        [ 3.0629e-01, -5.6680e-01,  6.6975e-01],
        [-4.6175e-03, -4.8004e-01,  1.1684e+00],
        [-1.5192e-01,  4.9175e-01, -1.0614e+00],
        [-1.7002e-01,  1.8095e-01,  4.0745e-01],
        [-1.0855e+00,  1.6527e+00,  1.1391e+00],
        [ 7.1451e-01,  2.7505e+00,  5.0293e-01],
        [-7.2259e-01, -6.9784e-01,  6.9926e-01],
        [-8.0408e-01, -1.9509e+00,  1.9277e+00],
        [-1.6251e-01, -1.7948e-01,  6.0711e-01],
        [ 1.4911e-01,  3.4602e-01, -1.4749e+00],
        [-1.1428e-01,  4.2197e-01, -1.1637e+00],
        [-6.9847e-01,  1.1591e+00,  1.7230e-01],
        [-4.1416e-01, -1.2346e+00, -1.1913e+00],
        [-4.8150e-01,  1.1232e+00,  2.1309e+00],
        [ 4.2791e-01,  2.0048e+00,  1.1230e+00],
        [ 2.1412e-01,  9.4107e-01, -3.6250e-01],
        [ 3.0476e-01, -2.9366e-02,  7.1577e-01]

tensor([[ 0.4141,  2.0729,  0.3388],
        [ 0.7403, -1.0129,  1.0684],
        [-0.7656, -0.1694, -0.7265],
        [ 0.3063, -0.5668,  0.6697],
        [-0.0046, -0.4800,  1.1684],
        [-0.1519,  0.4917, -1.0614],
        [-0.1700,  0.1810,  0.4074],
        [-1.0855,  1.6527,  1.1391],
        [ 0.7145,  2.7505,  0.5029],
        [-0.7226, -0.6978,  0.6993]], grad_fn=<EmbeddingBackward0>)

In [4]:

### SETUP ###
torch.cuda.empty_cache()

# Check if master process
is_master = True
config = "configs/base_config.yaml"
name = "Test"
train_path_name = "TRAIN_DATA_PATH"
validation_path_name = "VAL_DATA_PATH"
test_paths_names = ["TEST_DATA_PATH"]
amlt = False
gpu=0
rank=0

# Unpack and process the config file
config = get_setup(
    config_path=config,
    run_name=name,
    overrides=[],
    train_path_name=train_path_name,
    val_path_name=validation_path_name,
    test_paths_names=test_paths_names,
    amlt=amlt,
    is_master=is_master,
)
params, paths, timestamp, logger = config["params"], config[
    "paths"], config["timestamp"], config["logger"]

# Set the GPU device, if using
torch.cuda.set_device(gpu)
device = torch.device('cuda:' + str(gpu)
                        if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")


# Log the params
logger.info(json.dumps(params, indent=4))

# Initialize label tokenizer
label_tokenizer = AutoTokenizer.from_pretrained(
    params['LABEL_ENCODER_CHECKPOINT'],
)

# Initialize label encoder
label_encoder = AutoModel.from_pretrained(
    params['LABEL_ENCODER_CHECKPOINT'],
)
if params["GRADIENT_CHECKPOINTING"]:
    raise NotImplementedError(
        "Gradient checkpointing is not yet implemented.")

if params["LORA"]:
    for layer in label_encoder.layers:
        in_features, out_features = 1024, 1024
        layer.self_attn.q_proj = lora.Linear(
            in_features, out_features, r=params["LORA_RANK"])
        layer.self_attn.v_proj = lora.Linear(
            in_features, out_features, r=params["LORA_RANK"])
        layer.self_attn.k_proj = lora.Linear(
            in_features, out_features, r=params["LORA_RANK"])
        layer.self_attn.out_proj = lora.Linear(
            in_features, out_features, r=params["LORA_RANK"])
    # Mark only the LoRA parameters as trainable
    lora.mark_only_lora_as_trainable(label_encoder)

label_encoder = label_encoder.to(device)

# Load or generate the vocabularies
vocabularies = get_or_generate_vocabularies(
    paths["FULL_DATA_PATH"], paths["VOCABULARIES_DIR"], logger)

# Create datasets
datasets = ProteinDataset.create_multiple_datasets(
    paths_list=config['dataset_paths_list'],
    config=config,
    logger=logger,
    label_tokenizer=label_tokenizer,
    label_encoder=label_encoder,
    vocabularies=vocabularies,
    subset_fractions={
        "train": params["TRAIN_SUBSET_FRACTION"],
        "validation": params["VALIDATION_SUBSET_FRACTION"],
        "test": params["TEST_SUBSET_FRACTION"],
    },
    deduplicate=params["DEDUPLICATE"],
)

# Seed everything so we don't go crazy
seed_everything(params["SEED"], device)

# Initialize new run
logger.info(
    f"################## {timestamp} RUNNING main.py ##################")

# Define label sample sizes for train, validation, and test loaders
label_sample_sizes = {
    "train": params["TRAIN_LABEL_SAMPLE_SIZE"],
    "validation": params["VALIDATION_LABEL_SAMPLE_SIZE"],
    "test": None  # No sampling for the test set
}

# Define data loaders
loaders = create_multiple_loaders(
    datasets,
    params,
    label_sample_sizes=label_sample_sizes,
    shuffle_labels=params['SHUFFLE_LABELS'],
    in_batch_sampling=params['IN_BATCH_SAMPLING'],
    num_workers=params["NUM_WORKERS"],
    world_size=1,
    rank=rank,
)

if not params["TRAIN_LABEL_ENCODER"]:
    # Move the label encoder to CPU
    label_encoder = label_encoder.cpu()

# Initialize ProteInfer
sequence_encoder = ProteInfer.from_pretrained(
    weights_path=paths["PROTEINFER_WEIGHTS_PATH"],
    num_labels=config["embed_sequences_params"]["PROTEINFER_NUM_LABELS"],
    input_channels=config["embed_sequences_params"]["INPUT_CHANNELS"],
    output_channels=config["embed_sequences_params"]["OUTPUT_CHANNELS"],
    kernel_size=config["embed_sequences_params"]["KERNEL_SIZE"],
    activation=torch.nn.ReLU,
    dilation_base=config["embed_sequences_params"]["DILATION_BASE"],
    num_resnet_blocks=config["embed_sequences_params"]["NUM_RESNET_BLOCKS"],
    bottleneck_factor=config["embed_sequences_params"]["BOTTLENECK_FACTOR"],
)

# Generate all sequence embeddings upfront, if not training the sequence encoder
sequence_embedding_df = None
if not params["TRAIN_SEQUENCE_ENCODER"]:
    sequence_embedding_df = get_or_generate_sequence_embeddings(
        paths,
        device,
        sequence_encoder,
        datasets,
        params,
        logger,
    )
    sequence_encoder = sequence_encoder.to('cpu')

# Loop through all the datasets and set the sequence embedding df
for dataset in datasets.values():
    for subset in dataset:
        if not params["TRAIN_SEQUENCE_ENCODER"]:
            subset.set_sequence_embedding_df(sequence_embedding_df)


loaders["train"][0]



2023-11-25 03:45:01 PST INFO Logging to ./outputs/logs/2023-11-25_03-45-01_Test.log and console...
2023-11-25 03:45:01 PST INFO Using device: cuda:0
2023-11-25 03:45:01 PST INFO {
    "TRAIN_BATCH_SIZE": 64,
    "VALIDATION_BATCH_SIZE": 64,
    "TEST_BATCH_SIZE": 64,
    "IN_BATCH_SAMPLING": false,
    "TRAIN_LABEL_SAMPLE_SIZE": null,
    "VALIDATION_LABEL_SAMPLE_SIZE": null,
    "LABEL_BATCH_SIZE_LIMIT_NO_GRAD": 1500,
    "SEQUENCE_BATCH_SIZE_LIMIT_NO_GRAD": 128,
    "LEARNING_RATE": 0.001,
    "OPTIMIZER": "Adam",
    "PROTEIN_EMBEDDING_DIM": 1100,
    "LABEL_EMBEDDING_DIM": 1024,
    "LATENT_EMBEDDING_DIM": 1024,
    "OUTPUT_MLP_HIDDEN_DIM_SCALE_FACTOR": 1,
    "OUTPUT_MLP_NUM_LAYERS": 2,
    "OUTPUT_NEURON_PROBABILITY_BIAS": null,
    "OUTPUT_MLP_BATCHNORM": true,
    "OPTIMIZATION_METRIC_NAME": "map_micro",
    "DECISION_TH_METRIC_NAME": "f1_micro",
    "NUM_EPOCHS": 15,
    "GRADIENT_ACCUMULATION_STEPS": 1,
    "GRADIENT_CHECKPOINTING": false,
    "LORA": false,
    "LORA_RANK": 

<torch.utils.data.dataloader.DataLoader at 0x7f639c7977c0>

In [67]:
from torch.utils.data import DataLoader, TensorDataset
from torch.cuda.amp import autocast
def tokenize_labels(text, tokenizer, max_length=1024):
    """
    Tokenize a list of text strings.

    Args:
        text (list): The list of text strings.
        tokenizer (transformers.PreTrainedTokenizer): The tokenizer.

    Returns:
        dict: A dictionary containing tokenized labels as 'input_ids' and 'attention_mask'.
    """
    return tokenizer(
        text, padding='longest', truncation=True, max_length=max_length, return_tensors="pt"
    )


def compute_mean_hidden_states(last_hidden_states, attention_mask):
    """Compute the mean of the last hidden state for only the relevant tokens."""
    # Compute the number of relevant tokens for each sequence
    num_relevant_tokens = attention_mask.sum(dim=1, keepdim=True)
    # Mask the last_hidden_state tensor and compute the sum
    sum_hidden_states = (last_hidden_states *
                         attention_mask.unsqueeze(-1)).sum(dim=1)
    # Compute the mean of the last hidden state
    return sum_hidden_states / num_relevant_tokens


def get_label_embeddings(tokenized_labels, model, batch_size_limit=1000):
    """
    Get embeddings for a list of tokenized labels.
    Assumes that tokenized_labels and model are on the same device, ideally GPU.
    """
    total_labels = tokenized_labels["input_ids"].shape[0]

    if total_labels <= batch_size_limit:
        with autocast():
            last_hidden_states = model(
                input_ids=tokenized_labels["input_ids"],
                attention_mask=tokenized_labels["attention_mask"]
            ).last_hidden_state
        output = compute_mean_hidden_states(
            last_hidden_states, tokenized_labels["attention_mask"])
        del last_hidden_states
        return output

    else:
        # Convert dictionary values to tensors
        tensors = [tokenized_labels["input_ids"],
                   tokenized_labels["attention_mask"]]
        # Create TensorDataset and DataLoader
        dataset = TensorDataset(*tensors)
        dataloader = DataLoader(dataset, batch_size=batch_size_limit,
                                shuffle=False, pin_memory=False, num_workers=0)

        all_label_embeddings = []
        for batch in dataloader:
            input_ids, attention_mask = batch
            with autocast():
                last_hidden_states = model(
                    input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
            mean_hidden_states = compute_mean_hidden_states(
                last_hidden_states, attention_mask)
            all_label_embeddings.append(mean_hidden_states)
            del last_hidden_states, mean_hidden_states
        # Concatenate all the label embeddings
        return torch.cat(all_label_embeddings, dim=0)


def generate_label_embeddings_from_text(label_annotations, label_tokenizer, label_encoder, batch_size_limit=1000):
    """Tokenize the labels and generate label embeddings."""
    tokenized_labels = tokenize_labels(label_annotations, label_tokenizer)

    # Move to GPU
    tokenized_labels["input_ids"] = tokenized_labels["input_ids"].to(
        label_encoder.device)
    tokenized_labels["attention_mask"] = tokenized_labels["attention_mask"].to(
        label_encoder.device)

    # Generate label embeddings
    return get_label_embeddings(tokenized_labels, label_encoder, batch_size_limit=batch_size_limit)

# Initialize label tokenizer
label_tokenizer = AutoTokenizer.from_pretrained(
    params['LABEL_ENCODER_CHECKPOINT'],
)

# Initialize label encoder
label_encoder = AutoModel.from_pretrained(
    params['LABEL_ENCODER_CHECKPOINT'],
)

In [69]:
from src.utils.data import read_pickle
annot=read_pickle('data/annotations/go_annotations_2019_07_01.pkl')

In [96]:
i=32000
annot.index[i],annot.iloc[i]['label']

('GO:0070327',
 'The directed movement of thyroid hormone into, out of or within a cell, or between cells, by means of some agent such as a transporter or pore.')

In [97]:
datasets["train"][0].label2int[annot.index[i]]

22605

In [98]:
generate_label_embeddings_from_text([annot.iloc[i]['label']],label_tokenizer=label_tokenizer,label_encoder=label_encoder)

tensor([[-0.8438,  0.1259,  0.2046,  ...,  0.4670, -0.1736,  0.8953]],
       grad_fn=<DivBackward0>)

In [5]:
loader_iter = iter(loaders["train"][0])
data_iter = iter(datasets["train"][0])

In [22]:
data_batch = next(data_iter)
loader_batch=next(loader_iter)

In [99]:
loader_batch['label_embeddings'][22605]

tensor([-0.8445,  0.1256,  0.2044,  ...,  0.4676, -0.1743,  0.8949])

In [8]:
datasets["train"][0].label2int['GO:0035639']

13652

In [17]:
sorted([datasets["train"][0].label2int[i] for i in datasets["train"][0].data[0][1][1:]])==torch.where(data_batch['label_multihots']==1)[0].tolist()

True

In [9]:
datasets["train"][0].data[0][1][1:]

['GO:0035639',
 'GO:0032553',
 'GO:0005524',
 'GO:0017076',
 'GO:0005737',
 'GO:1901265',
 'GO:1901363',
 'GO:0043168',
 'GO:0044424',
 'GO:0030554',
 'GO:0005488',
 'GO:0043167',
 'GO:0042026',
 'GO:0032559',
 'GO:0005515',
 'GO:0051082',
 'GO:0032555',
 'GO:0005575',
 'GO:0008144',
 'GO:0009987',
 'GO:0097159',
 'GO:0006457',
 'GO:0000166',
 'GO:0008150',
 'GO:0036094',
 'GO:0003674',
 'GO:0044464',
 'GO:0097367']

In [11]:
data_batch

{'sequence_onehots': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]),
 'sequence_id': 'P60545',
 'sequence_embedding': tensor([-0.0553, -0.3441, -0.2825,  ...,  0.4497, -0.0895, -0.1504]),
 'sequence_length': tensor(538),
 'label_multihots': tensor([0, 0, 0,  ..., 0, 0, 0]),
 'tokenized_labels': {'input_ids': tensor([[   2,   18,  569,  ...,    1,    1,    1],
         [   2,   18, 1900,  ...,    1,    1,    1],
         [   2,   18,  371,  ...,    1,    1,    1],
         ...,
         [   2,   18,  919,  ...,    1,    1,    1],
         [   2,   18,  919,  ...,    1,    1,    1],
         [   2,   18,  919,  ...,    1,    1,    1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         

In [15]:
embeddings = torch.load('data/embeddings/frozen_BioGPT_label_embeddings.pkl')

In [19]:
embeddings

tensor([[-1.3426e+00,  1.9259e-01,  4.5337e-01,  ..., -6.7419e-02,
          1.7350e-01,  8.9762e-01],
        [-5.8517e-01,  2.5346e-03,  9.9431e-01,  ...,  7.3632e-01,
          1.3791e+00,  1.2030e+00],
        [-4.8449e-01, -2.6923e-01,  1.7874e-01,  ..., -3.5807e-01,
          8.9524e-01,  8.7176e-01],
        ...,
        [-2.0514e-01, -1.0103e+00,  1.2279e+00,  ...,  3.6141e-01,
         -3.4265e-01,  5.1903e-01],
        [-8.9557e-01, -5.3069e-01,  9.3757e-01,  ..., -1.8156e-01,
         -2.4020e-02, -9.7481e-04],
        [-7.9217e-01, -9.6587e-01,  1.2481e+00,  ...,  4.1990e-01,
         -3.4655e-01,  1.0383e-02]])

In [18]:
a['label_embeddings']

tensor([[-1.3426e+00,  1.9259e-01,  4.5337e-01,  ..., -6.7419e-02,
          1.7350e-01,  8.9762e-01],
        [-5.8517e-01,  2.5346e-03,  9.9431e-01,  ...,  7.3632e-01,
          1.3791e+00,  1.2030e+00],
        [-4.8449e-01, -2.6923e-01,  1.7874e-01,  ..., -3.5807e-01,
          8.9524e-01,  8.7176e-01],
        ...,
        [-2.0514e-01, -1.0103e+00,  1.2279e+00,  ...,  3.6141e-01,
         -3.4265e-01,  5.1903e-01],
        [-8.9557e-01, -5.3069e-01,  9.3757e-01,  ..., -1.8156e-01,
         -2.4020e-02, -9.7481e-04],
        [-7.9217e-01, -9.6587e-01,  1.2481e+00,  ...,  4.1990e-01,
         -3.4655e-01,  1.0383e-02]])

In [20]:
P_e = a['sequence_embeddings']
L_e = a['label_embeddings']

In [50]:
from tqdm import tqdm
joint = []
for i in tqdm(P_e):
    for j in L_e:
        joint.append(torch.concat([i,j]))

  0%|          | 0/64 [00:00<?, ?it/s]

100%|██████████| 64/64 [00:15<00:00,  4.23it/s]


In [25]:
torch.repe

AttributeError: module 'torch' has no attribute 'repeat'

In [55]:
from tqdm import tqdm
joint = []
joint_matrix = []
for i in tqdm(range(10)):
    joint_rows=[]
    for j in range(11,15):
        i_ = torch.tensor([i]*5)
        j_ = torch.tensor([j]*7)
        concat = torch.concat([i_,j_])
        joint.append(concat)
        joint_rows.append(concat)
    joint_rows = torch.stack(joint_rows)
    joint_matrix.append(joint_rows)

#joint = torch.stack(joint)

100%|██████████| 10/10 [00:00<00:00, 5448.56it/s]


In [57]:
torch.stack(joint_matrix).sum(axis=-1)

tensor([[ 77,  84,  91,  98],
        [ 82,  89,  96, 103],
        [ 87,  94, 101, 108],
        [ 92,  99, 106, 113],
        [ 97, 104, 111, 118],
        [102, 109, 116, 123],
        [107, 114, 121, 128],
        [112, 119, 126, 133],
        [117, 124, 131, 138],
        [122, 129, 136, 143]])

In [64]:
torch.stack(joint_matrix)[1][3].sum()

tensor(103)

In [36]:
joint.sum(axis=1).reshape(10,4)

tensor([[ 77,  84,  91,  98],
        [ 82,  89,  96, 103],
        [ 87,  94, 101, 108],
        [ 92,  99, 106, 113],
        [ 97, 104, 111, 118],
        [102, 109, 116, 123],
        [107, 114, 121, 128],
        [112, 119, 126, 133],
        [117, 124, 131, 138],
        [122, 129, 136, 143]])

In [35]:
joint.shape

torch.Size([40, 12])

In [52]:
joint.sum(axis=0).mean()

tensor(-80840.0156)

In [47]:
num_sequences = P_e.shape[0]
num_labels = L_e.shape[0]
sequence_embedding_dim = P_e.shape[1]
label_embedding_dim = L_e.shape[1]

# Use broadcasting so we don't have to expand the tensor dimensions
joint_embeddings = torch.cat([
    P_e[:, None, :].expand(
        num_sequences, num_labels, sequence_embedding_dim),
    L_e[None, :, :].expand(
        num_sequences, num_labels, label_embedding_dim)
], dim=2).reshape(-1, sequence_embedding_dim + label_embedding_dim)

In [49]:
joint_embeddings.sum(axis=0).mean()

tensor(-80840.0156)

In [54]:
torch.tensor([1,0,1,0,1,1,1,1,0]).reshape(3, 3)

tensor([[1, 0, 1],
        [0, 1, 1],
        [1, 1, 0]])

In [None]:
parser = argparse.ArgumentParser(
    description="Train and/or Test the ProTCL model.")
parser.add_argument("--train-path-name", type=str, default=None,
                    help="Specify the desired train path name to train the model using names from config file. If not provided, model will not be trained. If provided, must also provide --val-path.")

parser.add_argument("--validation-path-name", type=str, default=None,
                    help="Specify the desired val path name to validate the model during training using names from config file. If not provided, model will not be trained. If provided, must also provide --train-path.")

parser.add_argument("--full-path-name", type=str, default=None,
                    help="Specify the desired full path name to define the vocabularies. Defaults to the full path name in the config file.")

parser.add_argument("--test-paths-names", nargs="+", type=str, default=None,
                    help="Specify all the desired test paths names to test the model using names from config file to test. If not provided, model will not be tested.")

parser.add_argument("--use-wandb", action="store_true", default=False,
                    help="Use Weights & Biases for logging. Default is False.")

parser.add_argument("--load-model", type=str, default=None,
                    help="(Relative) path to the model to be loaded. If not provided, a new model will be initialized.")

parser.add_argument('--from-checkpoint', action="store_true", default=False,
                    help="Continue training from a previous model checkpoint (including optimizer state and epoch). Default is False.")

parser.add_argument("--name", type=str, default="ProTCL",
                    help="Name of the W&B run. If not provided, a name will be generated.")

parser.add_argument("--config", type=str, default="configs/base_config.yaml",
                    help="(Relative) path to the configuration file.")

parser.add_argument("--amlt", action="store_true", default=False,
                    help="Run job on Amulet. Default is False.")

parser.add_argument("--override", nargs="*",
                    help="Override config parameters in key-value pairs.")

parser.add_argument("--save-prediction-results", action="store_true", default=False,
                    help="Save predictions and ground truth dataframe for validation and/or test")

parser.add_argument('-n', '--nodes', default=1, type=int,
                    metavar='N', help='Number of nodes (default: 1)')

parser.add_argument('-g', '--gpus', default=1, type=int,
                    help='Number of gpus per node (default: 1)')