In [2]:
from pathlib import Path
from pynvml import *

curdir = Path(os.getcwd())
sys.path.append(str(curdir.parent.absolute()))
os.chdir(str(curdir.parent.absolute()))
curdir = Path(os.getcwd())

from src.utils.data import (
    seed_everything,
    log_gpu_memory_usage
)
from src.utils.main_utils import get_or_generate_vocabularies,  get_or_generate_label_embeddings, get_or_generate_sequence_embeddings, validate_arguments
from src.data.datasets import ProteinDataset, create_multiple_loaders
from src.models.ProTCLTrainer import ProTCLTrainer
from src.models.ProTCL import ProTCL
from src.models.protein_encoders import ProteInfer
from src.utils.evaluation import EvalMetrics
from src.utils.models import count_parameters_by_layer, sigmoid_bias_from_prob,load_checkpoint,    load_model

from src.utils.configs import get_setup
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
import torch.distributed as dist
import torch
import wandb
import os
import argparse
import json
import pandas as pd
from transformers import AutoTokenizer, AutoModel
from src.data.collators import collate_variable_sequence_length
import mlflow
import loralib as lora
import random


  from .autonotebook import tqdm as notebook_tqdm


In [159]:
import random
from itertools import product
import numpy as np
from torch.utils.data import BatchSampler
from torch.utils.data import Sampler, WeightedRandomSampler
from typing import Optional
import math
import torch
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import math
from torch.utils.data import Dataset

class GeneralDistributedSampler(DistributedSampler):

    """
    Class to use distributed sampler with any sampler!
    """

    def __init__(self, sampler: Sampler, num_replicas: Optional[int] = None,
                 rank: Optional[int] = None,
                 seed: int = 0, drop_last: bool = False):
        
        #Same as normal DistributedSampler with shuffle = False
        super().__init__(dataset = sampler,
                         num_replicas=num_replicas,
                         rank=rank,
                         shuffle=False,
                         seed = seed,
                         drop_last=drop_last)
        
        assert len(sampler)>num_replicas, "Total samples must be > num replicas"
        
    def __iter__(self):
        # deterministically shuffle based on epoch
        torch.manual_seed(self.epoch+self.seed)
        indices = list(self.dataset)
        
        if not self.drop_last:
            # add extra samples to make it evenly divisible
            padding_size = self.total_size - len(indices)
            if padding_size <= len(indices):
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
        else:
            # remove tail of data to make it evenly divisible.
            indices = indices[:self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples
        return iter(indices)
    
class DistributedWeightedSampler(Sampler):
    def __init__(self, weights, world_size=None, rank=None, replacement=True):
        # Get the world size and rank if not provided
        if world_size is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            world_size = dist.get_world_size()
        if rank is None:
            rank = dist.get_rank()

        self.weights = weights
        self.world_size = world_size
        self.rank = rank
        self.epoch = 0
        self.replacement = replacement

        # Determine the number of samples for each GPU, rounding down to ensure it is evenly divisible
        self.num_samples = int(math.floor(len(self.weights) * 1.0 / self.world_size))
        
        # Determine the total number of samples
        self.total_size = self.num_samples * self.world_size

    def __iter__(self):
        # Shuffle based on the epoch
        g = torch.Generator()
        g.manual_seed(self.epoch)
        
        # Create a weighted sample for the entire dataset
        if self.replacement:
            indices = torch.multinomial(self.weights, self.total_size, replacement=True, generator=g)
        else:
            assert len(self.weights) > self.total_size, "When sampling without replacement, number of samples to draw must be less than the number of elements in the dataset"
            indices = torch.multinomial(self.weights, self.total_size, replacement=False, generator=g)

        # Subsample for the current process
        indices_for_one_gpu = indices[self.rank:self.total_size:self.world_size]
        
        # Shuffle each epoch
        indices_for_one_gpu = indices_for_one_gpu[torch.randperm(len(indices_for_one_gpu), generator=g)].tolist()
            
        return iter(indices_for_one_gpu)

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch


class GridBatchSampler(BatchSampler):

    def __init__(self,
                 observation_sampler,
                 observations_batch_size,
                 drop_last_observation_batch,
                 num_labels,
                 labels_batch_size,
                 shuffle_grid = True):
        
        self.observation_sampler = observation_sampler
        self.observations_batch_size = observations_batch_size
        self.drop_last_observation_batch = drop_last_observation_batch

        self.num_labels = num_labels
        self.labels_batch_size = labels_batch_size
        self.shuffle_grid = shuffle_grid
        self.labels_idxs = list(range(num_labels))
        self.calculate_num_batches()
        
    def __iter__(self):
        random.Random(self.epoch).shuffle(self.labels_idxs)
        print('Getting label batches...')
        observation_batches = self.get_observation_batches()
        print('Done...')

        print('Getting observation batches...')
        label_batches = self.get_label_batches()
        print('Done...')

        print('Getting combinations...')
        obs_labels_batch_combinations = list(product(observation_batches,label_batches))

        print('Done...')
        
        if self.shuffle_grid:
            print('Shuffling...')
            random.shuffle(obs_labels_batch_combinations)
        print('Done...')
        for observation_batch,label_batch in obs_labels_batch_combinations:
            yield list(product(observation_batch, [label_batch]))#[observation_batch,label_batch]
    
    def calculate_num_batches(self):
        
        num_label_batches = np.ceil(self.num_labels/self.labels_batch_size)
        num_observation_batches = (np.ceil(len(self.observation_sampler)/self.observations_batch_size)
                                   if not self.drop_last_observation_batch
                                   else len(self.observation_sampler)//self.observations_batch_size)
        print('Done...')

        self.total_num_batches = int(num_label_batches*num_observation_batches)
        print(f"num label batches = {num_label_batches}, num observation batches = {num_observation_batches}")
        print(f"total batches = {self.total_num_batches}")

    def __len__(self):
        return self.total_num_batches
    

    def get_label_batches(self):

        #n_chunks = int(np.ceil(self.num_labels/self.labels_batch_size))
        return [self.labels_idxs[i:i+self.labels_batch_size] for i in range(0,self.num_labels,self.labels_batch_size)]
        

    def get_observation_batches(self):

        batches = []

        if self.drop_last_observation_batch:
            observation_sampler_iter = iter(self.observation_sampler)
            while True:
                try:
                    batch = [next(observation_sampler_iter) for _ in range(self.observations_batch_size)]
                    batches.append(batch)
                except StopIteration:
                    break
        else:
            batch = [0] * self.observations_batch_size
            idx_in_batch = 0
            for idx in self.observation_sampler:
                batch[idx_in_batch] = idx
                idx_in_batch += 1
                if idx_in_batch == self.observations_batch_size:
                    batches.append(batch)
                    idx_in_batch = 0
                    batch = [0] * self.observations_batch_size
            if idx_in_batch > 0:
                batches.append(batch[:idx_in_batch])
        return batches
    
    def set_epoch(self, epoch):
        self.epoch = epoch
    
def observation_sampler_factory(
    distribute_labels:bool,
    weighted_sampling:bool,
    dataset: Dataset = None,
    world_size: int = 1,
    rank: int = 0,
    sequence_weights: torch.Tensor = None

    ):

    if distribute_labels and not weighted_sampling:
        print("WARNING: No Sampler used for distribute labels")
        sampler = None
    elif not distribute_labels and world_size == 1 and weighted_sampling:
        # If NOT distributing labels, and not training on multiple GPU's, create a non-distributed weighted sampler with replacement
        assert sequence_weights is not None, "Weighted RandomSampler requires weights"

        sampler = WeightedRandomSampler(
            sequence_weights, 
            len(sequence_weights), 
            replacement=True
        )
    elif not distribute_labels and world_size > 1 and weighted_sampling:
        # If distributing sequences across multiple GPUs with a weighted sampler, create custom DistributedWeightedSampler
        sampler = DistributedWeightedSampler(
            sequence_weights,
            world_size=world_size,
            rank=rank,
            replacement=True,
        )
    elif not distribute_labels and not weighted_sampling:
        # If simply distributing sequences across GPU's without weighted sampling, use a distributed sampler

        assert dataset is not None, "DistributeSampler requires dataset"

        sampler = DistributedSampler(
            dataset,
            num_replicas=world_size,
            rank=rank,
            shuffle=True
        )
    else:
        # Raise error
        raise ValueError("Invalid combination of WEIGHTED_SAMPLING, WORLD_SIZE, and DISTRIBUTE_LABELS parameters")
    
    return sampler

In [148]:
n = 12
b = 5

c=int(np.floor(n/b))
l=np.random.randint(0,100,size=(n,))
l_i = list(range(n))

splits=np.array_split(l_i,b)


In [153]:
np.array_split(l,b)

[array([94, 16, 27]),
 array([73, 71, 50]),
 array([ 0, 35]),
 array([32, 75]),
 array([18, 53])]

In [151]:
[l[list(i)] for i in splits]

[array([94, 16, 27]),
 array([73, 71, 50]),
 array([ 0, 35]),
 array([32, 75]),
 array([18, 53])]

In [131]:
sorted([i for r in range(b) for i in l[r:b:c]])

[0, 1, 2, 3, 3, 4, 4]

In [155]:

class DistributedGridBatchSampler(BatchSampler):

    def __init__(self,
                 observation_sampler,
                 observations_batch_size,
                 drop_last_observation_batch,
                 num_labels,
                 labels_batch_size,
                 world_size,
                 rank,
                 shuffle_grid = True):
        
        self.observation_sampler = observation_sampler
        self.observations_batch_size = observations_batch_size
        self.drop_last_observation_batch = drop_last_observation_batch
        self.world_size = world_size
        self.rank = rank

        self.num_labels = num_labels
        self.labels_batch_size = labels_batch_size
        self.shuffle_grid = shuffle_grid
        self.labels_idxs = list(range(num_labels))
        self.calculate_num_batches()
        
    def __iter__(self):
        random.Random(self.epoch).shuffle(self.labels_idxs)
        print('Getting label batches...')
        observation_batches = self.get_observation_batches()
        print('Done...')

        print('Getting observation batches...')
        label_batches = self.get_label_batches()
        print('Done...')

        print('Getting combinations...')
        obs_labels_batch_combinations = list(product(observation_batches,label_batches))

        print('Done...')
        
        if self.shuffle_grid:
            print('Shuffling...')
            random.shuffle(obs_labels_batch_combinations)
        print('Done...')
        for observation_batch,label_batch in obs_labels_batch_combinations:
            yield list(product(observation_batch, [label_batch]))#[observation_batch,label_batch]
    
    def calculate_num_batches(self):
        
        num_label_batches = np.ceil(self.num_labels/self.labels_batch_size)
        num_observation_batches = (np.ceil(len(self.observation_sampler)/self.observations_batch_size)
                                   if not self.drop_last_observation_batch
                                   else len(self.observation_sampler)//self.observations_batch_size)
        print('Done...')

        self.total_num_batches = int(num_label_batches*num_observation_batches)
        print(f"num label batches = {num_label_batches}, num observation batches = {num_observation_batches}")
        print(f"total batches = {self.total_num_batches}")

    def __len__(self):
        return self.total_num_batches
    

    def get_label_batches(self):

        label_batches_of_gpu = []
        for i in range(0,self.num_labels,self.labels_batch_size):
            labels = self.labels_idxs[i:i+self.labels_batch_size]
            label_batches_of_gpu.append(np.array_split(labels,self.world_size)[self.rank])

    def set_epoch(self, epoch):
        self.epoch = epoch

    def get_observation_batches(self):

        batches = []

        if self.drop_last_observation_batch:
            observation_sampler_iter = iter(self.observation_sampler)
            while True:
                try:
                    batch = [next(observation_sampler_iter) for _ in range(self.observations_batch_size)]
                    batches.append(batch)
                except StopIteration:
                    break
        else:
            batch = [0] * self.observations_batch_size
            idx_in_batch = 0
            for idx in self.observation_sampler:
                batch[idx_in_batch] = idx
                idx_in_batch += 1
                if idx_in_batch == self.observations_batch_size:
                    batches.append(batch)
                    idx_in_batch = 0
                    batch = [0] * self.observations_batch_size
            if idx_in_batch > 0:
                batches.append(batch[:idx_in_batch])
        return batches



In [None]:
if grid_sampler:
    assert label_sample_size is not None,"Provide label_sample_size when using grid sampler"
    batch_sampler=GridBatchSampler(observation_sampler=sequence_sampler,
        observations_batch_size=batch_size_for_type,
        drop_last_observation_batch=True,
        num_labels=len(dataset.label_vocabulary),
        labels_batch_size=label_sample_size,
        shuffle_grid=True
        )
    #When defining a BatchSampler, these paramters are ignored in the Dataloader. Must be set 
    #To these values to avoid pytorch error.
    batch_size_for_type = 1
    sequence_sampler = None
    drop_last = False

In [36]:

weights = torch.randint(0,100,(100,))*1.0
num_samples = 100


sampler = WeightedRandomSampler(weights, num_samples)

num_replicas = 4

dist_samplers2 = [
    DistributedWeightedSampler(
                        weights,
                        world_size=num_replicas,
                        rank=i,
                        replacement=True,
                    ) 
    for i in range(num_replicas)
]


torch.manual_seed(1)
true_indices = list(sampler)

indices_per_rank = []
for s in dist_samplers2:
    s.set_epoch(1)
    indices_per_rank += list(s)

set(indices_per_rank) == set(true_indices)

True

In [37]:
dist_samplers2

[<__main__.DistributedWeightedSampler at 0x7f80ec1aa140>,
 <__main__.DistributedWeightedSampler at 0x7f80f2af9120>,
 <__main__.DistributedWeightedSampler at 0x7f80f2afa230>,
 <__main__.DistributedWeightedSampler at 0x7f80f2af8d30>]

In [35]:
dist_samplers2

[<__main__.DistributedWeightedSampler at 0x7f80ec0d5930>,
 <__main__.DistributedWeightedSampler at 0x7f80f2964280>,
 <__main__.DistributedWeightedSampler at 0x7f80f2af8cd0>,
 <__main__.DistributedWeightedSampler at 0x7f80f2afa1d0>]

In [16]:
(32/4)*(32/16)

16.0

In [167]:
num_samples = 32
num_labels = 32
seq_batch_size = 4
label_batch_size = 16
num_replicas = 4

weights = torch.randint(0,100,(num_samples,))*1.0

labels = torch.arange(num_labels)

sampler = observation_sampler_factory(
    distribute_labels = False,
    weighted_sampling = True,
    dataset = None,
    world_size = 1,
    rank = None,
    sequence_weights=weights)

sampler_base = WeightedRandomSampler(weights, num_samples)

batch_sampler_base=GridBatchSampler(observation_sampler=sampler_base,
    observations_batch_size=seq_batch_size,
    drop_last_observation_batch=True,
    num_labels=num_labels,
    labels_batch_size=label_batch_size,
    shuffle_grid=True
    )



Done...
num label batches = 2.0, num observation batches = 8
total batches = 16


In [168]:
batch_sampler_base.set_epoch(1)
a=list(batch_sampler_base)

Getting label batches...
Done...
Getting observation batches...
Done...
Getting combinations...
Done...
Shuffling...
Done...


In [169]:
a

[[(17, [26, 17, 11, 10, 28, 1, 5, 4, 7, 16, 9, 19, 30, 13, 22, 0]),
  (1, [26, 17, 11, 10, 28, 1, 5, 4, 7, 16, 9, 19, 30, 13, 22, 0]),
  (7, [26, 17, 11, 10, 28, 1, 5, 4, 7, 16, 9, 19, 30, 13, 22, 0]),
  (2, [26, 17, 11, 10, 28, 1, 5, 4, 7, 16, 9, 19, 30, 13, 22, 0])],
 [(20, [26, 17, 11, 10, 28, 1, 5, 4, 7, 16, 9, 19, 30, 13, 22, 0]),
  (17, [26, 17, 11, 10, 28, 1, 5, 4, 7, 16, 9, 19, 30, 13, 22, 0]),
  (27, [26, 17, 11, 10, 28, 1, 5, 4, 7, 16, 9, 19, 30, 13, 22, 0]),
  (29, [26, 17, 11, 10, 28, 1, 5, 4, 7, 16, 9, 19, 30, 13, 22, 0])],
 [(29, [21, 29, 6, 12, 20, 23, 14, 15, 3, 31, 2, 24, 25, 27, 18, 8]),
  (1, [21, 29, 6, 12, 20, 23, 14, 15, 3, 31, 2, 24, 25, 27, 18, 8]),
  (6, [21, 29, 6, 12, 20, 23, 14, 15, 3, 31, 2, 24, 25, 27, 18, 8]),
  (28, [21, 29, 6, 12, 20, 23, 14, 15, 3, 31, 2, 24, 25, 27, 18, 8])],
 [(20, [21, 29, 6, 12, 20, 23, 14, 15, 3, 31, 2, 24, 25, 27, 18, 8]),
  (17, [21, 29, 6, 12, 20, 23, 14, 15, 3, 31, 2, 24, 25, 27, 18, 8]),
  (27, [21, 29, 6, 12, 20, 23, 14, 15,

In [58]:

dist_samplers2 = []
for i in range(num_replicas):
    sampler_ = observation_sampler_factory(
        distribute_labels = False,
        weighted_sampling = True,
        dataset = None,
        world_size = num_replicas,
        rank = i,
        sequence_weights=weights)
    sampler_.set_epoch(1)
    
    batch_sampler=GridBatchSampler(observation_sampler=sampler_,
        observations_batch_size=seq_batch_size,
        drop_last_observation_batch=True,
        num_labels=num_labels,
        labels_batch_size=label_batch_size,
        shuffle_grid=True
        )
    dist_samplers2.append(batch_sampler)

torch.manual_seed(1)
true_indices = list(sampler)

torch.manual_seed(1)
true_indices_base = list(sampler_base)

indices_per_rank = []
for s in dist_samplers2:
    
    indices_per_rank += list(s)

#print(set(indices_per_rank) == set(true_indices))
#print(set(true_indices_base)==set(true_indices))


'''
dist_samplers2 = [
    DistributedWeightedSampler(
                        weights,
                        world_size=num_replicas,
                        rank=i,
                        replacement=True,
                    ) 
    
]
'''


[(5, [2, 6, 31, 21, 11, 15, 22, 10, 14, 13, 27, 26, 4, 17, 1, 20]),
 (8, [2, 6, 31, 21, 11, 15, 22, 10, 14, 13, 27, 26, 4, 17, 1, 20]),
 (11, [2, 6, 31, 21, 11, 15, 22, 10, 14, 13, 27, 26, 4, 17, 1, 20]),
 (13, [2, 6, 31, 21, 11, 15, 22, 10, 14, 13, 27, 26, 4, 17, 1, 20])]

Done...
num label batches = 10.0, num observation batches = 20
total batches = 200


In [41]:
indices_per_rank

[12,
 33,
 80,
 40,
 37,
 24,
 4,
 2,
 33,
 16,
 9,
 6,
 64,
 67,
 29,
 25,
 80,
 1,
 44,
 21,
 22,
 20,
 21,
 24,
 61,
 8,
 59,
 28,
 21,
 25,
 14,
 23,
 68,
 41,
 30,
 55,
 9,
 40,
 19,
 0,
 30,
 98,
 42,
 22,
 65,
 47,
 29,
 47,
 71,
 65,
 21,
 25,
 11,
 26,
 30,
 30,
 27,
 99,
 65,
 30,
 67,
 43,
 9,
 6,
 14,
 57,
 7,
 73,
 42,
 20,
 32,
 14,
 32,
 15,
 33,
 96,
 49,
 13,
 40,
 48,
 59,
 61,
 71,
 40,
 47,
 16,
 21,
 47,
 3,
 64,
 38,
 41,
 30,
 42,
 24,
 0,
 62,
 83,
 25,
 1]

In [2]:

### SETUP ###
torch.cuda.empty_cache()

# Check if master process
is_master = True
config = "configs/base_config.yaml"
name = "Test"
train_path_name = "TRAIN_DATA_PATH"
validation_path_name = "VAL_DATA_PATH"
test_paths_names = ["TEST_DATA_PATH"]
amlt = False
gpu=0
rank=0

# Unpack and process the config file
config = get_setup(
    config_path=config,
    run_name=name,
    overrides=[],
    train_path_name=train_path_name,
    val_path_name=validation_path_name,
    test_paths_names=test_paths_names,
    amlt=amlt,
    is_master=is_master,
)
params, paths, timestamp, logger = config["params"], config[
    "paths"], config["timestamp"], config["logger"]

# Set the GPU device, if using
torch.cuda.set_device(gpu)
device = torch.device('cuda:' + str(gpu)
                        if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")


# Log the params
logger.info(json.dumps(params, indent=4))

# Initialize label tokenizer
label_tokenizer = AutoTokenizer.from_pretrained(
    params['LABEL_ENCODER_CHECKPOINT'],
)

# Initialize label encoder
label_encoder = AutoModel.from_pretrained(
    params['LABEL_ENCODER_CHECKPOINT'],
)
if params["GRADIENT_CHECKPOINTING"]:
    raise NotImplementedError(
        "Gradient checkpointing is not yet implemented.")

if params["LORA"]:
    for layer in label_encoder.layers:
        in_features, out_features = 1024, 1024
        layer.self_attn.q_proj = lora.Linear(
            in_features, out_features, r=params["LORA_RANK"])
        layer.self_attn.v_proj = lora.Linear(
            in_features, out_features, r=params["LORA_RANK"])
        layer.self_attn.k_proj = lora.Linear(
            in_features, out_features, r=params["LORA_RANK"])
        layer.self_attn.out_proj = lora.Linear(
            in_features, out_features, r=params["LORA_RANK"])
    # Mark only the LoRA parameters as trainable
    lora.mark_only_lora_as_trainable(label_encoder)

label_encoder = label_encoder.to(device)

# Load or generate the vocabularies
vocabularies = get_or_generate_vocabularies(
    paths["FULL_DATA_PATH"], paths["VOCABULARIES_DIR"], logger)

# Create datasets
datasets = ProteinDataset.create_multiple_datasets(
    paths_list=config['dataset_paths_list'],
    config=config,
    logger=logger,
    label_tokenizer=label_tokenizer,
    label_encoder=label_encoder,
    vocabularies=vocabularies,
    require_train_label_idxs=params['GRID_SAMPLER'],
    subset_fractions={
        "train": params["TRAIN_SUBSET_FRACTION"],
        "validation": params["VALIDATION_SUBSET_FRACTION"],
        "test": params["TEST_SUBSET_FRACTION"],
    },
    deduplicate=params["DEDUPLICATE"],
)

# Seed everything so we don't go crazy
seed_everything(params["SEED"], device)

# Initialize new run
logger.info(
    f"################## {timestamp} RUNNING main.py ##################")

# Define label sample sizes for train, validation, and test loaders
label_sample_sizes = {
    "train": params["TRAIN_LABEL_SAMPLE_SIZE"],
    "validation": params["VALIDATION_LABEL_SAMPLE_SIZE"],
    "test": None  # No sampling for the test set
}

# Define data loaders
loaders = create_multiple_loaders(
    datasets,
    params,
    label_sample_sizes=label_sample_sizes,
    shuffle_labels=params['SHUFFLE_LABELS'],
    in_batch_sampling=params['IN_BATCH_SAMPLING'],
    grid_sampler=params['GRID_SAMPLER'],
    num_workers=params["NUM_WORKERS"],
    world_size=1,
    rank=rank,
)

if params["LABEL_ENCODER_NUM_TRAINABLE_LAYERS"]==0:
    # Move the label encoder to CPU
    label_encoder = label_encoder.cpu()

# Initialize ProteInfer
sequence_encoder = ProteInfer.from_pretrained(
    weights_path=paths["PROTEINFER_WEIGHTS_PATH"],
    num_labels=config["embed_sequences_params"]["PROTEINFER_NUM_LABELS"],
    input_channels=config["embed_sequences_params"]["INPUT_CHANNELS"],
    output_channels=config["embed_sequences_params"]["OUTPUT_CHANNELS"],
    kernel_size=config["embed_sequences_params"]["KERNEL_SIZE"],
    activation=torch.nn.ReLU,
    dilation_base=config["embed_sequences_params"]["DILATION_BASE"],
    num_resnet_blocks=config["embed_sequences_params"]["NUM_RESNET_BLOCKS"],
    bottleneck_factor=config["embed_sequences_params"]["BOTTLENECK_FACTOR"],
)

# Generate all sequence embeddings upfront, if not training the sequence encoder
sequence_embedding_df = None
if not params["TRAIN_SEQUENCE_ENCODER"]:
    sequence_embedding_df = get_or_generate_sequence_embeddings(
        paths,
        device,
        sequence_encoder,
        datasets,
        params,
        logger,
    )
    sequence_encoder = sequence_encoder.to('cpu')

# Loop through all the datasets and set the sequence embedding df
for dataset in datasets.values():
    for subset in dataset:
        if not params["TRAIN_SEQUENCE_ENCODER"]:
            subset.set_sequence_embedding_df(sequence_embedding_df)




2023-12-18 18:32:47 PST INFO Logging to ./outputs/logs/2023-12-18_18-32-47_Test.log and console...
2023-12-18 18:32:48 PST INFO Using device: cuda:0
2023-12-18 18:32:48 PST INFO {
    "TRAIN_BATCH_SIZE": 4,
    "VALIDATION_BATCH_SIZE": 4,
    "TEST_BATCH_SIZE": 4,
    "GRID_SAMPLER": false,
    "IN_BATCH_SAMPLING": false,
    "TRAIN_LABEL_SAMPLE_SIZE": null,
    "VALIDATION_LABEL_SAMPLE_SIZE": null,
    "LABEL_BATCH_SIZE_LIMIT_NO_GRAD": 1500,
    "SEQUENCE_BATCH_SIZE_LIMIT_NO_GRAD": 128,
    "LEARNING_RATE": 0.0003,
    "OPTIMIZER": "Adam",
    "PROTEIN_EMBEDDING_DIM": 1100,
    "LABEL_EMBEDDING_DIM": 1024,
    "LATENT_EMBEDDING_DIM": 1024,
    "OUTPUT_MLP_HIDDEN_DIM_SCALE_FACTOR": 2,
    "OUTPUT_MLP_NUM_LAYERS": 2,
    "OUTPUT_NEURON_PROBABILITY_BIAS": null,
    "OUTPUT_MLP_BATCHNORM": true,
    "PROJECTION_HEAD_NUM_LAYERS": 2,
    "PROJECTION_HEAD_HIDDEN_DIM_SCALE_FACTOR": 1,
    "FEATURE_FUSION": "concatenation",
    "LABEL_EMBEDDING_POOLING_METHOD": "mean",
    "OPTIMIZATION_METRIC

In [3]:
model = ProTCL(
    # Parameters
    protein_embedding_dim=params["PROTEIN_EMBEDDING_DIM"],
    label_embedding_dim=params["LABEL_EMBEDDING_DIM"],
    latent_dim=params["LATENT_EMBEDDING_DIM"],
    label_embedding_pooling_method=params["LABEL_EMBEDDING_POOLING_METHOD"],

    # Encoders
    label_encoder=label_encoder,
    sequence_encoder=sequence_encoder,

    # Output Layer
    output_mlp_hidden_dim_scale_factor=params["OUTPUT_MLP_HIDDEN_DIM_SCALE_FACTOR"],
    output_mlp_num_layers=params["OUTPUT_MLP_NUM_LAYERS"],
    output_neuron_bias=sigmoid_bias_from_prob(params["OUTPUT_NEURON_PROBABILITY_BIAS"]) if params["OUTPUT_NEURON_PROBABILITY_BIAS"] is not None else None,
    outout_mlp_add_batchnorm=params["OUTPUT_MLP_BATCHNORM"],
    projection_head_num_layers=params["PROJECTION_HEAD_NUM_LAYERS"],
    projection_head_hidden_dim_scale_factor=params["PROJECTION_HEAD_HIDDEN_DIM_SCALE_FACTOR"],

    # Training options
    label_encoder_num_trainable_layers=params["LABEL_ENCODER_NUM_TRAINABLE_LAYERS"],
    train_sequence_encoder=params["TRAIN_SEQUENCE_ENCODER"],

    # Batch size limits
    label_batch_size_limit=params["LABEL_BATCH_SIZE_LIMIT_NO_GRAD"],
    sequence_batch_size_limit=params["SEQUENCE_BATCH_SIZE_LIMIT_NO_GRAD"],

    #
    feature_fusion=config["params"]["FEATURE_FUSION"],
    temperature=config["params"]["SUPCON_TEMP"]
).to(device)



In [3]:
a = [1,2,3,4,5,6,7,8,9,10]
a[1:10:3]

[2, 5, 8]

In [4]:
model = torch.nn.DataParallel(model)

In [5]:
import logging
from src.utils.data import load_gz_json, log_gpu_memory_usage, save_checkpoint, load_model
from src.utils.evaluation import EvalMetrics,metric_collection_to_dict_float,save_evaluation_results
from src.utils.losses import BatchWeightedBCE, FocalLoss, RGDBCE, WeightedBCE,SupCon, CBLoss
from torchmetrics import MetricCollection, Metric
from src.utils.proteinfer import normalize_confidences
import numpy as np
import torch
import wandb
import os
import json
from collections import defaultdict
from torch.cuda.amp import autocast, GradScaler
from torch.nn.utils import clip_grad_norm_
from transformers import BatchEncoding
from src.utils.models import generate_label_embeddings_from_text, biogpt_train_last_n_layers
from tqdm import tqdm
from torcheval.metrics import MultilabelAUPRC, BinaryAUPRC
from torch.utils.tensorboard import SummaryWriter

from collections import OrderedDict


def print_checkpoint(checkpoint):
    
    print("weights_sum",sum([i.sum() for i in checkpoint['model_state_dict'].values()]))
    print('epoch',checkpoint['epoch'])
    print('best_val_metric',checkpoint['best_val_metric'])

    max_step = max(checkpoint['optimizer_state_dict']['state'].keys())
    print('optimizer param groups',checkpoint['optimizer_state_dict']['param_groups'])
    print('optimizer max step',checkpoint['optimizer_state_dict']['state'][max_step])

def load_model(trainer, checkpoint_path, from_checkpoint=False):
    """
    Load the model's state from a given checkpoint.

    This function is designed to handle checkpoints from both Data Distributed Parallel (DDP) wrapped 
    and non-DDP models. If the checkpoint originates from a DDP-wrapped model, the function will adjust 
    the state dictionary keys accordingly before loading.

    Parameters:
    - trainer (object): An instance of the trainer containing the model, optimizer, and other training attributes.
    - checkpoint_path (str): The path to the checkpoint file to be loaded.
    - from_checkpoint (bool, optional): If True, the function will also load the optimizer's state, 
      epoch number, and best validation metric from the checkpoint. Defaults to False.

    Note:
    The function assumes that the model in the trainer object is DDP-wrapped.
    """

    # Load the entire checkpoint
    checkpoint = torch.load(checkpoint_path)

    # Extract the state_dict from the checkpoint
    state_dict = checkpoint['model_state_dict']

    # Check if the state_dict is from a DDP-wrapped model
    if list(state_dict.keys())[0].startswith('module.'):
        # Remove the "module." prefix
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove 'module.' prefix
            new_state_dict[name] = v
        state_dict = new_state_dict

    # Load the state_dict into the model
    trainer.model.module.load_state_dict(state_dict)

    # Load the optimizer state and epoch number if they exist in the checkpoint
    if 'optimizer_state_dict' in checkpoint and from_checkpoint:
        trainer.optimizer.load_state_dict(
            checkpoint['optimizer_state_dict'])
    if 'epoch' in checkpoint and from_checkpoint:
        trainer.epoch = checkpoint['epoch']
    if 'best_val_metric' in checkpoint and from_checkpoint:
        trainer.best_val_metric = checkpoint['best_val_metric']

    print_checkpoint(checkpoint)
    # Delete the checkpoint to save memory
    del checkpoint

def save_checkpoint(model, optimizer, epoch, best_val_metric, model_path):
    """
    Save model and optimizer states as a checkpoint.

    Args:
    - model (torch.nn.Module): The model whose state we want to save.
    - optimizer (torch.optim.Optimizer): The optimizer whose state we want to save.
    - epoch (int): The current training epoch.
    - model_path (str): The path where the checkpoint will be saved.
    """
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_val_metric': best_val_metric,
    }

    print_checkpoint(checkpoint)

    torch.save(checkpoint, model_path)


def load_checkpoint(trainer, checkpoint_path):
    """
    Load the model's state dict, optimizer's state, and epoch number from the checkpoint.

    This function handles both DDP-wrapped and non-DDP checkpoints.

    :param model: The model into which the checkpoint's state dict should be loaded.
    :param trainer: The trainer instance containing the optimizer and epoch attributes.
    :param checkpoint_path: Path to the checkpoint file.
    """
    print_checkpoint(checkpoint)
    # Load the entire checkpoint
    checkpoint = torch.load(checkpoint_path)

    # Extract the state_dict from the checkpoint
    model_state_dict = checkpoint['model_state_dict']

    # Check if the state_dict is from a DDP-wrapped model
    if list(model_state_dict.keys())[0].startswith('module.'):
        # Remove the "module." prefix
        new_model_state_dict = OrderedDict()
        for k, v in model_state_dict.items():
            name = k[7:]  # remove 'module.' prefix
            new_model_state_dict[name] = v
        model_state_dict = new_model_state_dict

    # Load the state_dict into the model
    trainer.model.module.load_state_dict(model_state_dict)

    # Load the optimizer state and epoch number if they exist in the checkpoint
    if 'optimizer_state_dict' in checkpoint:
        trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if 'epoch' in checkpoint:
        trainer.epoch = checkpoint['epoch']
    if 'best_val_metric' in checkpoint:
        trainer.best_val_metric = checkpoint['best_val_metric']


class ProTCLTrainer:
    def __init__(
        self,
        model: torch.nn.Module,
        device: str,
        config: dict,
        vocabularies: dict,
        logger: logging.Logger,
        timestamp: str,
        run_name: str,
        use_wandb: bool = False,
        bce_pos_weight: torch.Tensor = None,
        label_weights: torch.Tensor = None,
        is_master: bool = True,
        starting_epoch: int = 1,
    ):
        """
        Args:
            model (nn.Module): The PyTorch model to train.
            device (str): The device to use for training (e.g., 'cpu' or 'cuda').
            logger (logging.Logger): The logger to use for logging training progress.
            timestamp (str): The timestamp to use for naming log files and checkpoints.
            run_name (str): The name of the current training run.
            use_wandb (bool, optional): Whether to use Weights & Biases for logging. Defaults to False.
            bce_pos_weight (torch.Tensor, optional): The positive weight for binary cross-entropy loss. Defaults to None.
            is_master (bool, optional): Whether the current process is the master process. Defaults to True.
            starting_epoch (int, optional): The starting epoch number. Defaults to 1. Used for resuming training.
        """

        self.model = model
        self.is_master = is_master
        self.device = device
        self.run_name = run_name
        self.logger = logger
        self.timestamp = timestamp
        self.use_wandb = use_wandb
        self.num_epochs = config["params"]["NUM_EPOCHS"]
        self.train_sequence_encoder = config["params"]["TRAIN_SEQUENCE_ENCODER"]
        self.label_encoder_num_trainable_layers = config["params"]["LABEL_ENCODER_NUM_TRAINABLE_LAYERS"]
        self.train_projection_head = config["params"]["TRAIN_PROJECTION_HEAD"]

        self.normalize_probabilities = config["params"]["NORMALIZE_PROBABILITIES"]
        self.validations_per_epoch = config["params"]["VALIDATIONS_PER_EPOCH"]
        self.gradient_accumulation_steps = config["params"]["GRADIENT_ACCUMULATION_STEPS"]
        self.clip_value = config["params"]["CLIP_VALUE"]
        self.vocabularies = vocabularies
        self.label_normalizer = load_gz_json(
            config["paths"]["PARENTHOOD_LIB_PATH"]
        )
        self.output_model_dir = config["paths"]["OUTPUT_MODEL_DIR"]
        self.lora_params = {'rank':config["params"]["LORA_RANK"],
                            'in_features':config["params"]["LABEL_EMBEDDING_DIM"],
                            'out_features':config["params"]["LABEL_EMBEDDING_DIM"],
                            'device':self.device
                            } if config["params"]["LORA"] else None
        
        self._set_optimizer(opt_name = config["params"]["OPTIMIZER"],
                            lr = config["params"]["LEARNING_RATE"])
        
        self.bce_pos_weight = bce_pos_weight
        self.label_weights=label_weights
        self.loss_fn = self._get_loss_fn(config)
        self.model_path = self._get_saved_model_path()
        self.best_val_metric = 0.0
        self.scaler = GradScaler()
        self.starting_epoch = starting_epoch
        self.epoch = starting_epoch
        self.config = config
        self.tb = SummaryWriter(f"runs/{self.run_name}_{self.timestamp}") if self.is_master else None

    def _get_saved_model_path(self):
        # Save model to OUTPUT_MODEL_DIR. Create path if it doesn't exist.
        if not os.path.exists(self.output_model_dir) and self.is_master:
            os.makedirs(self.output_model_dir)

        model_name = (
            self.run_name if self.run_name else "best_ProTCL.pt"
        )
        model_path = os.path.join(
            self.output_model_dir, f"{self.timestamp}_{model_name}.pt"
        )
        return model_path

    def _get_loss_fn(self, config):
        if config["params"]["LOSS_FN"] == "BCE":
            assert self.bce_pos_weight is not None, "bce_pos_weight must be provided for BCE loss"
            return torch.nn.BCEWithLogitsLoss(reduction='mean', pos_weight=self.bce_pos_weight)
        elif (config["params"]["LOSS_FN"] == "WeightedBCE"):
            assert self.label_weights is not None, "label_weights must be provided for WeightedBCE Loss"
            return WeightedBCE(label_weights = self.label_weights)
        elif (config["params"]["LOSS_FN"] == "CBLoss"):
            assert self.label_weights is not None, "label_weights must be provided for CBLoss"
            return CBLoss(label_weights = self.label_weights)
        elif config["params"]["LOSS_FN"] == "BatchWeightedBCE":
            return BatchWeightedBCE()
        elif config["params"]["LOSS_FN"] == "FocalLoss":
            assert (config["params"]["FOCAL_LOSS_GAMMA"] is not None)\
                & (config["params"]["FOCAL_LOSS_ALPHA"] is not None), "gamma and gamma must be provided for FocalLoss"
            return FocalLoss(gamma=config["params"]["FOCAL_LOSS_GAMMA"], alpha=config["params"]["FOCAL_LOSS_ALPHA"])
        elif config["params"]["LOSS_FN"] == "RGDBCE":
            return RGDBCE(temp=config["params"]["RGDBCE_TEMP"])
        elif config["params"]["LOSS_FN"] == "SupCon":
            return SupCon(temperature=config["params"]["SUPCON_TEMP"])
        else:
            raise ValueError(
                f"Unknown loss function {config['params']['LOSS_FN']}")

    def _to_device(self, *args):
        processed_args = []
        for item in args:
            if isinstance(item, torch.Tensor):
                processed_args.append(item.to(self.device))
            elif isinstance(item, BatchEncoding) or isinstance(item, dict):
                processed_dict = {k: v.to(self.device) if isinstance(
                    v, torch.Tensor) else v for k, v in item.items()}
                processed_args.append(processed_dict)
            else:
                processed_args.append(item)
        return processed_args

    def _set_optimizer(self, opt_name, lr):
        trainable_params = []
        trainable_params_names = []

        # Use to unfreeze last n layers. 0 means entire model frozen.
        biogpt_train_last_n_layers(self.model.module.label_encoder,
                                   self.label_encoder_num_trainable_layers,
                                   lora_params=self.lora_params
                                   )
        
        for name, param in self.model.module.named_parameters():
            if name.startswith('sequence_encoder') and (not self.train_sequence_encoder):
                param.requires_grad = False

            if (name.startswith('W_p.weight') or name.startswith('W_l.weight')) and (not self.train_projection_head):
                param.requires_grad = False

            if name.startswith('output_layer') and (not self.train_projection_head):
                param.requires_grad = False

            if param.requires_grad:
                trainable_params.append(param)
                trainable_params_names.append(name)

        self.trainable_params_names = trainable_params_names

        if opt_name == 'Adam':
            opt = torch.optim.Adam
        elif opt_name == 'SGD':
            opt = torch.optim.SGD
        else:
            raise ValueError("Unsupported optimizer name")

        self.optimizer = opt(
            trainable_params, lr=lr
        )

    def evaluation_step(self, batch) -> tuple:
        """Perform a single evaluation step.

        :param batch: _description_
        :type batch: _type_
        :return: batch loss, logits and labels
        :rtype: tuple
        """

        # Unpack the validation or testing batch
        sequence_onehots, sequence_embeddings, sequence_lengths, sequence_ids, label_multihots, tokenized_labels, label_embeddings = (
            batch["sequence_onehots"],
            batch["sequence_embeddings"],
            batch["sequence_lengths"],
            batch["sequence_ids"],
            batch["label_multihots"],
            batch["tokenized_labels"],
            batch["label_embeddings"]
        )

        # Move all unpacked batch elements to GPU, if available
        sequence_onehots, sequence_embeddings, sequence_lengths, label_multihots, tokenized_labels, label_embeddings = self._to_device(
            sequence_onehots, sequence_embeddings, sequence_lengths, label_multihots, tokenized_labels, label_embeddings)

        # Forward pass
        inputs = {
            "sequence_onehots": sequence_onehots,
            "sequence_embeddings": sequence_embeddings,
            "sequence_lengths": sequence_lengths,
            "tokenized_labels": tokenized_labels,
            "label_embeddings": label_embeddings
        }
        with autocast():
            logits = self.model(**inputs)
            # Compute validation loss for the batch
            loss = self.loss_fn(logits, label_multihots.float())

        return loss.item(), logits, label_multihots, sequence_ids

    def validate(self,
                 val_loader: torch.utils.data.DataLoader,
                 eval_metrics: MetricCollection,
                 val_optimization_metric_name: str
                 ):

        self.logger.info("Running validation...")

        prefix = 'validation'

        val_metrics = self.evaluate(data_loader=val_loader,
                                       eval_metrics=eval_metrics,
                                       metrics_prefix=prefix)
        val_optimization_metric_name = f'{prefix}_{val_optimization_metric_name}'

        
        self.logger.info("+-------------------------------- Validation Results --------------------------------+")
        # Print memory consumption
        if self.is_master:
            log_gpu_memory_usage(self.logger, 0)
        self.logger.info(
            f"Validation metrics:\n{json.dumps(val_metrics, indent=4)}")

        if self.use_wandb and self.is_master:
            try:
                if self.use_wandb and self.is_master:
                    wandb.log(val_metrics,
                              step=self.training_step
                              )

            except Exception as e:
                self.logger.warning(
                    f"Failed to log validation metrics to wandb: {e}")

        # Save the model if it has the best validation loss so far (only on master node)
        if self.is_master and val_metrics[val_optimization_metric_name] > self.best_val_metric and self.epoch==2:
            self.logger.info(
                f"New best {val_optimization_metric_name}: {val_metrics[val_optimization_metric_name]}. Saving model..."
            )
            self.best_val_metric = val_metrics[val_optimization_metric_name]

            save_checkpoint(
                model=self.model.module,
                optimizer=self.optimizer,
                epoch=self.epoch,
                best_val_metric=self.best_val_metric,
                model_path=self.model_path
            )
            self.logger.info(f"Saved model to {self.model_path}.")

            if self.use_wandb:
                wandb.save(f"{self.timestamp}_best_ProTCL.pt")
        
        self.logger.info("+------------------------------------------------------------------------------------+") 

        return val_metrics

    def find_optimal_threshold(
        self, data_loader: torch.utils.data.DataLoader, optimization_metric_name: str
    ) -> tuple[float, float]:
        """Find the optimal threshold for the given data loader.

        :param data_loader: _description_
        :type data_loader: torch.utils.data.DataLoader
        :param average: _description_
        :type average: Literal[&#39;micro&#39;, &#39;macro&#39;, &#39;weighted&#39;]
        :param optimization_metric_name: _description_
        :type optimization_metric_name: str
        :return: _description_
        :rtype: tuple[float, float]
        """

        self.logger.info("Finding optimal threshold...")
        self.model.eval()

        best_th = 0.0
        best_score = 0.0

        with torch.no_grad():
            all_probabilities = []
            all_label_multihots = []
            for batch in data_loader:
                _, logits, label_multihots, _ = self.evaluation_step(
                    batch=batch)

                # Apply sigmoid to get the probabilities for multi-label classification
                probabilities = torch.sigmoid(logits)

                if self.normalize_probabilities:
                    probabilities = self._normalize_probabilities(probabilities)

                all_probabilities.append(probabilities)
                all_label_multihots.append(label_multihots)

            all_probabilities = torch.cat(all_probabilities)
            all_label_multihots = torch.cat(all_label_multihots)

        for th in np.arange(0.1, 1, 0.01):
            optimization_metric = EvalMetrics(device=self.device)\
                .get_metric_by_name(name=optimization_metric_name,
                                    threshold=th,
                                    num_labels=label_multihots.shape[-1])

            optimization_metric(all_probabilities, all_label_multihots)
            score = optimization_metric.compute().item()
            if score > best_score:
                best_score = score
                best_th = th
            self.logger.info("TH: {:.3f}, F1: {:.3f}".format(th, score))

        best_score = best_score
        self.logger.info(
            f"Best validation score: {best_score}, Best val threshold: {best_th}"
        )
        self.model.train()
        return best_th, best_score

    def _normalize_probabilities(self,probabilities):
        # TODO: Using original normalize_confidences implemented with numpy,
                    # but this is slow. Should be able to do this with torch tensors.
        return torch.tensor(
                    normalize_confidences(
                        predictions=probabilities.detach().cpu().numpy(),
                        label_vocab=self.vocabularies["GO_label_vocab"],
                        applicable_label_dict=self.label_normalizer,
                    ),
                    device=self.device,
                )

    def evaluate(
        self,
        data_loader: torch.utils.data.DataLoader,
        eval_metrics: MetricCollection = None,
        save_results: bool = False,
        metrics_prefix = None
    ) -> tuple[dict, dict]:
        """Evaluate the model on the given data loader.
        :param data_loader: pytorch data loader
        :type data_loader: torch.utils.data.DataLoader
        :param eval_metrics: an eval metrics class to calculate metrics like F1 score, defaults to None
        :type eval_metrics: EvalMetrics, optional
        :return: dictionary with evaluation metrics. Always return avg_loss and if eval_metrics is not None, it will return the metrics from eval_metrics.compute()
        :rtype: dict
        """
        self.model.eval()

        # Compute all label embeddings upfront, since we're not training
        if data_loader.dataset.label_embedding_matrix is None:
            logging.info(
                "Computing label embeddings for evaluation...")
            with torch.no_grad():
                label_embedding_matrix = generate_label_embeddings_from_text(
                    data_loader.dataset.label_text_list,
                    data_loader.dataset.label_tokenizer,
                    self.model.module.label_encoder,
                    self.config["params"]["LABEL_BATCH_SIZE_LIMIT_NO_GRAD"],
                ).cpu()
            data_loader.dataset.set_label_embedding_matrix(
                label_embedding_matrix)
            logging.info("Done computing label embeddings.")

        total_loss = 0
        test_results = defaultdict(list)

        if eval_metrics is not None:
            eval_metrics.reset()

        mAP_micro = BinaryAUPRC(device='cpu')
        mAP_macro = MultilabelAUPRC(device='cpu',num_labels=len(self.vocabularies["GO_label_vocab"]))

        with torch.no_grad():
            for batch in tqdm(data_loader,total=len(data_loader)):
                loss, logits, labels, sequence_ids = self.evaluation_step(
                    batch=batch)
                if eval_metrics is not None:
                    # Apply sigmoid to get the probabilities for multi-label classification
                    probabilities = torch.sigmoid(logits)

                    if self.normalize_probabilities:
                        probabilities = self._normalize_probabilities()

                    # Update eval metrics
                    eval_metrics(probabilities, labels)

                    mAP_micro.update(probabilities.cpu().flatten(), labels.cpu().flatten())
                    mAP_macro.update(probabilities.cpu(), labels.cpu())

                    # No need to save results everytime. Only need it for final evaluation.
                    if save_results:
                        test_results["sequence_ids"].append(sequence_ids)
                        test_results["logits"].append(logits.cpu())
                        test_results["labels"].append(labels.cpu())

                # Accumulate loss
                total_loss += loss

            if save_results:
                for key in test_results.keys():
                    if key == "sequence_ids":
                        test_results[key] = (
                            np.array(
                                [j for i in test_results["sequence_ids"] for j in i])
                        )
                    else:
                        test_results[key] = (
                            torch.cat(test_results[key]).numpy()
                        )
                
                self.logger.info("Saving validation results...")
                if self.is_master:
                    save_evaluation_results(results=test_results,
                                            label_vocabulary=self.vocabularies["GO_label_vocab"],
                                            run_name=self.run_name,
                                            output_dir=self.config["paths"]["RESULTS_DIR"],
                                            data_split_name=metrics_prefix
                                            )

            # Compute average validation loss
            avg_loss = total_loss / len(data_loader)

            final_metrics = eval_metrics.compute() if eval_metrics is not None else {}
            final_metrics.update({"loss": avg_loss,
                                  "map_micro":mAP_micro.compute(),
                                  "map_macro":mAP_macro.compute()
                                  })

            final_metrics = metric_collection_to_dict_float(
                final_metrics,
                prefix=metrics_prefix)           

        self.model.train()

        return final_metrics

    def train_one_epoch(self,
                        train_loader: torch.utils.data.DataLoader,
                        eval_metrics: MetricCollection
        ):
        
        avg_loss = 0
        avg_probs = 0
        avg_gt = 0
        eval_metrics.reset()
        
        ####### TRAINING LOOP #######
        for batch_idx, batch in enumerate(train_loader):
            
            self.training_step += 1

            # Unpack the training batch
            sequence_onehots, sequence_embeddings, sequence_lengths, label_multihots, tokenized_labels, label_embeddings = (
                batch["sequence_onehots"],
                batch["sequence_embeddings"],
                batch["sequence_lengths"],
                batch["label_multihots"],
                batch["tokenized_labels"],
                batch["label_embeddings"]
            )

            # Move all unpacked batch elements to GPU, if available
            sequence_onehots, sequence_embeddings, sequence_lengths, label_multihots, tokenized_labels, label_embeddings = self._to_device(
                sequence_onehots, sequence_embeddings, sequence_lengths, label_multihots, tokenized_labels, label_embeddings)

            # Forward pass
            inputs = {
                "sequence_onehots": sequence_onehots,
                "sequence_embeddings": sequence_embeddings,
                "sequence_lengths": sequence_lengths,
                "tokenized_labels": tokenized_labels,
                "label_embeddings": label_embeddings
            }

            with autocast():
                logits = self.model(**inputs)

                # Compute loss, normalized by the number of gradient accumulation steps
                loss = self.loss_fn(logits, label_multihots.float()) / \
                    self.gradient_accumulation_steps
        
            # Backward pass with mixed precision
            self.scaler.scale(loss).backward()
        

            # Gradient accumulation every GRADIENT_ACCUMULATION_STEPS
            if (self.training_step % self.gradient_accumulation_steps == 0) or (batch_idx + 1 == len(train_loader)):     
                # Unscales the gradients of optimizer's assigned params in-place
                self.scaler.unscale_(self.optimizer)
                
                # Apply gradient clipping
                if self.clip_value is not None:
                    clip_grad_norm_(self.model.module.parameters(),
                                    max_norm=self.clip_value)
                
                self.scaler.step(self.optimizer)
                self.scaler.update()

                #Log at this point to TB to have weights and gradients after a full epoch
                if (batch_idx + 1 == len(train_loader)) & self.is_master:
                    for name, weight in self.model.module.named_parameters():
                        if weight.requires_grad:
                            self.tb.add_histogram(name,weight, self.epoch)
                            self.tb.add_histogram(f'{name}.grad',weight.grad, self.epoch)

                self.optimizer.zero_grad()
            
            avg_loss+=loss.item()
            avg_probs += torch.mean(torch.sigmoid(logits).detach())
            avg_gt += torch.mean(label_multihots.float().detach())

            eval_metrics(logits, label_multihots)
            
            if self.use_wandb:
                wandb.log({"per_batch_train_loss": loss.item()},
                          step=self.training_step
                          )

            # Print memory consumption after first batch (to get the max memory consumption during training)
            if batch_idx == 1 and self.is_master:
                self.logger.info("+----------------- Train GPU Memory Usage -----------------+")
                log_gpu_memory_usage(self.logger, 0)
                self.logger.info("+----------------------------------------------------------+")

        avg_loss = avg_loss/len(train_loader) if len(train_loader)> 0 else avg_loss
        avg_probs_gt_ration = avg_probs/avg_gt

        train_metrics = eval_metrics.compute() if eval_metrics is not None else {}
        train_metrics.update({"loss": avg_loss,
                              "avg_probabilities_ground_truth_ratio":avg_probs_gt_ration,
                                })
        train_metrics = metric_collection_to_dict_float(train_metrics,prefix='train')
        
        if self.use_wandb:
            wandb.log(train_metrics,
                      step=self.training_step
                      )

        
        return train_metrics
        
        
    def train(
        self,
        train_loader: torch.utils.data.DataLoader,
        val_loader: torch.utils.data.DataLoader,
        train_eval_metrics: MetricCollection,
        val_eval_metrics: MetricCollection,
        val_optimization_metric_name: str
    ):
        """Train model
        :param train_loader: _description_
        :type train_loader: torch.utils.data.DataLoader
        :param val_loader: _description_
        :type val_loader: torch.utils.data.DataLoader
        """
        self.model.train()

        # Watch the model
        if self.use_wandb:
            wandb.watch(self.model)

        # Compute total number of training steps
        self.training_step = 0
        num_training_steps = len(train_loader) * self.num_epochs
        
        self.logger.info(f"{'='*100}")
        self.logger.info(
            f"Starting training. Total number of training steps: {num_training_steps}")
        self.logger.info(f"{'='*100}")

        for epoch in range(self.starting_epoch, self.starting_epoch + self.num_epochs):
            self.logger.info(
                f"Starting epoch {epoch}/{self.starting_epoch + self.num_epochs - 1}...")
            self.epoch = epoch

            # Set distributed loader epoch to shuffle data
            if hasattr(train_loader.sampler, "set_epoch"):
                train_loader.sampler.set_epoch(epoch)

            train_metrics = self.train_one_epoch(train_loader=train_loader,
                                                 eval_metrics=train_eval_metrics)
                

            if (epoch % self.validations_per_epoch == 0):
                ####### VALIDATION LOOP #######
                torch.cuda.empty_cache()

                # Run validation
                self.validate(val_loader=val_loader,
                                            eval_metrics=val_eval_metrics,
                                            val_optimization_metric_name=val_optimization_metric_name
                                            )

                if self.label_encoder_num_trainable_layers>0:
                    # Clear the label embedding matrix
                    val_loader.dataset.set_label_embedding_matrix(None)

                self.logger.info(
                    f"Epoch {epoch}/{self.starting_epoch + self.num_epochs - 1}, Batch {self.training_step}, Training Loss: {train_metrics['train_loss']}"
                )

        if self.is_master:
            self.logger.info("Restoring model to best validation map_micro...")
            load_model(trainer=self,
                            checkpoint_path=self.model_path)

        
        self.tb.close()

In [6]:
'''checkpoint_path =  'data/models/ProTCL/2023-11-27_17-07-08_FL_Experiments_mlp_scale2_2layer_projection_head_gamma2_lr3e4_bs32.pt'
checkpoint = torch.load(checkpoint_path)
state_dict = checkpoint['model_state_dict']
'''

"checkpoint_path =  'data/models/ProTCL/2023-11-27_17-07-08_FL_Experiments_mlp_scale2_2layer_projection_head_gamma2_lr3e4_bs32.pt'\ncheckpoint = torch.load(checkpoint_path)\nstate_dict = checkpoint['model_state_dict']\n"

In [7]:

# Initialize EvalMetrics
eval_metrics = EvalMetrics(device=device)

label_sample_sizes = {k:(v if v is not None else len(vocabularies['GO_label_vocab'])) 
                        for k,v in label_sample_sizes.items()}

# Log sizes of all datasets

In [8]:
# Load the entire checkpoint


Trainer = ProTCLTrainer(
    model=model,
    device=device,
    config=config,
    vocabularies=vocabularies,
    logger=logger,
    timestamp=timestamp,
    run_name='debugging',
    use_wandb=False,
    bce_pos_weight=None,
    label_weights=None,
    is_master=is_master,
)

Trainer.train(train_loader=loaders["train"][0],
    val_loader=loaders["validation"][0],
    train_eval_metrics=eval_metrics.get_metric_collection_with_regex(pattern="f1_m.*",
                                                                        threshold=0.5,
                                                                num_labels=label_sample_sizes["train"] if (params['IN_BATCH_SAMPLING'] or params['GRID_SAMPLER']) is False else None
                                                                ),
    val_eval_metrics=eval_metrics.get_metric_collection_with_regex(pattern="f1_m.*", threshold=0.5,
                                                        num_labels=label_sample_sizes["validation"]
                                                        ),
    val_optimization_metric_name=params["OPTIMIZATION_METRIC_NAME"])

2023-12-18 18:33:35 PST INFO Starting training. Total number of training steps: 270
2023-12-18 18:33:35 PST INFO Starting epoch 1/3...


2023-12-18 18:33:37 PST INFO +----------------- Train GPU Memory Usage -----------------+
2023-12-18 18:33:37 PST INFO GPU memory occupied: 7568 MB (9.34% of total memory 80994 MB). Device 0 [Name: NVIDIA A100 80GB PCIe]
2023-12-18 18:33:37 PST INFO +----------------------------------------------------------+
2023-12-18 18:33:46 PST INFO Running validation...


100%|██████████| 10/10 [00:01<00:00,  9.14it/s]


2023-12-18 18:33:52 PST INFO +-------------------------------- Validation Results --------------------------------+
2023-12-18 18:33:52 PST INFO GPU memory occupied: 5082 MB (6.28% of total memory 80994 MB). Device 0 [Name: NVIDIA A100 80GB PCIe]
2023-12-18 18:33:52 PST INFO Validation metrics:
{
    "validation_f1_macro": 0.0001800328609533608,
    "validation_f1_micro": 0.10717897117137909,
    "validation_loss": 0.0019823297043330967,
    "validation_map_micro": 0.269729882478714,
    "validation_map_macro": 0.007014680188149214
}
2023-12-18 18:33:52 PST INFO +------------------------------------------------------------------------------------+
2023-12-18 18:33:52 PST INFO Epoch 1/3, Batch 90, Training Loss: 0.019239841347249844
2023-12-18 18:33:52 PST INFO Starting epoch 2/3...
2023-12-18 18:33:52 PST INFO +----------------- Train GPU Memory Usage -----------------+
2023-12-18 18:33:52 PST INFO GPU memory occupied: 7592 MB (9.37% of total memory 80994 MB). Device 0 [Name: NVIDIA A1

100%|██████████| 10/10 [00:01<00:00,  9.58it/s]


2023-12-18 18:34:06 PST INFO +-------------------------------- Validation Results --------------------------------+
2023-12-18 18:34:06 PST INFO GPU memory occupied: 5082 MB (6.28% of total memory 80994 MB). Device 0 [Name: NVIDIA A100 80GB PCIe]
2023-12-18 18:34:06 PST INFO Validation metrics:
{
    "validation_f1_macro": 0.00027972637326456606,
    "validation_f1_micro": 0.1364092230796814,
    "validation_loss": 0.0018234694376587867,
    "validation_map_micro": 0.35204529762268066,
    "validation_map_macro": 0.008132383227348328
}
2023-12-18 18:34:06 PST INFO New best validation_map_micro: 0.35204529762268066. Saving model...
weights_sum tensor(6714471.5000, device='cuda:0')
epoch 2
best_val_metric 0.35204529762268066
optimizer param groups [{'lr': 0.0003, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 

100%|██████████| 10/10 [00:02<00:00,  4.36it/s]


2023-12-18 18:34:22 PST INFO +-------------------------------- Validation Results --------------------------------+
2023-12-18 18:34:22 PST INFO GPU memory occupied: 5082 MB (6.28% of total memory 80994 MB). Device 0 [Name: NVIDIA A100 80GB PCIe]
2023-12-18 18:34:22 PST INFO Validation metrics:
{
    "validation_f1_macro": 0.0006659323698841035,
    "validation_f1_micro": 0.3161482810974121,
    "validation_loss": 0.0016556073911488055,
    "validation_map_micro": 0.3641241788864136,
    "validation_map_macro": 0.009125406853854656
}
2023-12-18 18:34:22 PST INFO +------------------------------------------------------------------------------------+
2023-12-18 18:34:22 PST INFO Epoch 3/3, Batch 270, Training Loss: 0.0018972821669497838
2023-12-18 18:34:22 PST INFO Restoring model to best validation map_micro...
weights_sum tensor(6714471.5000, device='cuda:0')
epoch 2
best_val_metric 0.35204529762268066
optimizer param groups [{'lr': 0.0003, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_d

In [11]:
Trainer = ProTCLTrainer(
    model=model,
    device=device,
    config=config,
    vocabularies=vocabularies,
    logger=logger,
    timestamp=timestamp,
    run_name='debugging',
    use_wandb=False,
    bce_pos_weight=None,
    label_weights=None,
    is_master=is_master,
)


load_model(
    trainer=Trainer,
    checkpoint_path=os.path.join(config["DATA_PATH"], '../outputs/checkpoints/2023-12-18_18-32-47_debugging.pt'),
    from_checkpoint=False
)


validation_metrics = Trainer.evaluate(
            data_loader=loaders['validation'][0],
            eval_metrics=eval_metrics.get_metric_collection_with_regex(pattern="f1_m.*",
                                                                    threshold=0.5,
                                                                    num_labels=label_sample_sizes["validation"]
                                                            ),
            save_results=False,
            metrics_prefix='final_validation'
                    )

weights_sum tensor(6714471.5000, device='cuda:0')
epoch 2
best_val_metric 0.35204529762268066
optimizer param groups [{'lr': 0.0003, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]}]
optimizer max step {'step': tensor(179.), 'exp_avg': tensor([-0.0002], device='cuda:0'), 'exp_avg_sq': tensor([1.0297e-05], device='cuda:0')}


100%|██████████| 10/10 [00:01<00:00,  9.55it/s]


In [12]:
validation_metrics

{'final_validation_f1_macro': 0.00027972637326456606,
 'final_validation_f1_micro': 0.1364092230796814,
 'final_validation_loss': 0.0018234694376587867,
 'final_validation_map_micro': 0.35204529762268066,
 'final_validation_map_macro': 0.008132383227348328}

In [None]:
{
    "validation_f1_macro": 0.00027972637326456606,
    "validation_f1_micro": 0.1364092230796814,
    "validation_loss": 0.0018234694376587867,
    "validation_map_micro": 0.35204529762268066,
    "validation_map_macro": 0.008132383227348328
}

In [20]:
eval_metrics = EvalMetrics(device=device)

trainer.evaluate(
            data_loader=loaders["validation"][0],
            eval_metrics=eval_metrics.get_metric_collection_with_regex(pattern="f1_m.*",
                                                                    threshold=0.5,
                                                                    num_labels=label_sample_sizes["validation"]
                                                            ),
            save_results=True,
            metrics_prefix='final_validation'
                    )

100%|██████████| 693/693 [03:56<00:00,  2.92it/s]


2023-12-18 06:07:15 PST INFO Saving validation results...


  if _pandas_api.is_sparse(col):


{'final_validation_loss': 0.0011481819404962755,
 'final_validation_map_micro': 0.8853791952133179,
 'final_validation_map_macro': 0.30902916193008423}

In [14]:
eval_metrics = EvalMetrics(device=device)

trainer.evaluate(
            data_loader=loaders["validation"][0],
            eval_metrics=eval_metrics.get_metric_collection_with_regex(pattern="f1_m.*",
                                                                    threshold=0.5,
                                                                    num_labels=label_sample_sizes["validation"]
                                                            ),
            save_results=True,
            metrics_prefix='final_validation'
                    )

100%|██████████| 693/693 [03:27<00:00,  3.34it/s]


2023-12-18 05:47:21 PST INFO Saving validation results...


  if _pandas_api.is_sparse(col):


{'final_validation_loss': 0.0011481819404962755,
 'final_validation_map_micro': 0.8853791952133179,
 'final_validation_map_macro': 0.30902916193008423}

In [17]:
loaders["validation"][0]

TypeError: 'DataLoader' object is not subscriptable

In [None]:
import os
import logging
from typing import Literal
from torchdata.datapipes.iter import FileLister, FileOpener
import argparse
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio import SeqIO
from tqdm import tqdm
from src.utils.data import read_fasta


def process_sequence_tfrecord(record: dict, annotation_types: list):

    sequence = record['sequence'][0].decode()
    id = record['id'][0].decode()

    labels = set()

    # Some rows have no lavel column
    if 'label' not in record:
        return None

    # Add all labels from desired annotation types
    for l in record['label']:
        label = l.decode()
        label_type = label.split(':')[0]

        if (label_type in annotation_types):
            labels.add(label)

    # Sequence with no annotation from selected types
    if not labels:
        return None

    return id, (sequence, list(labels))


def process_tfrecords(input_dir: str,
                      output_dir: str,
                      annotation_types: list,
                      pattern: str,
                      pattern_name: str
                      ):
    # Load all tfrecords from desired data split
    datapipe1 = FileLister(input_dir, pattern)
    datapipe2 = FileOpener(datapipe1, mode="b")
    tfrecord_loader_dp = datapipe2.load_from_tfrecord()

    records = []
    # Iterate over records, process and write to a fasta file
    for _, record in tqdm(enumerate(tfrecord_loader_dp)):
        processed_sequence = process_sequence_tfrecord(
            record, annotation_types)

        # Skipping sequence with no labels from desired annotations
        if processed_sequence is None:
            continue

        id, (sequence, labels) = processed_sequence

        description = " ".join(labels)
        record = SeqRecord(Seq(sequence), id=f"{id}", description=description)
        records.append(record)

    with open(os.path.join(output_dir, f"{pattern_name}_{'_'.join(annotation_types)}.fasta"), "w") as output_handle:
        SeqIO.write(records, output_handle, "fasta")


dirname = os.path.dirname(__file__)
# TODO: This paths should be in config or something
# input_dir = os.path.join(dirname, 'data/swissprot/proteinfer_splits/random/')
# output_dir = os.path.join(dirname, 'data/swissprot/proteinfer_splits/random/')

patterns = {'train': 'train*.tfrecord',
            'dev': 'dev*.tfrecord',
            'test': 'test*.tfrecord',
            'full': '*.tfrecord'}

for pattern_name, pattern in patterns.items():
    logging.info(f'Processing {pattern_name}')
    process_tfrecords(input_dir=args.input_dir,
                        output_dir=args.output_dir,
                        annotation_types=args.annotation_types,
                        pattern=pattern,
                        pattern_name=pattern_name)


In [111]:
!pip install torchdata --force-reinstall

Collecting torchdata
  Obtaining dependency information for torchdata from https://files.pythonhosted.org/packages/39/18/6f0d33df4b9fe4d44a779c2c7cc7cb042535a336f051bb0e5b5387844ee6/torchdata-0.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata
  Downloading torchdata-0.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting urllib3>=1.25 (from torchdata)
  Obtaining dependency information for urllib3>=1.25 from https://files.pythonhosted.org/packages/96/94/c31f58c7a7f470d5665935262ebd7455c7e4c7782eb525658d3dbf4b9403/urllib3-2.1.0-py3-none-any.whl.metadata
  Downloading urllib3-2.1.0-py3-none-any.whl.metadata (6.4 kB)
Collecting requests (from torchdata)
  Obtaining dependency information for requests from https://files.pythonhosted.org/packages/70/8e/0e2d847013cb52cd35b38c009bb167a1a26b2ce6cd6965bf26b47bc0bf44/requests-2.31.0-py3-none-any.whl.metadata
  Using cached requests-2.31.0-py3-none-any.whl.metadata (4.6 kB)
Collecting to

In [82]:
from torchdata.datapipes.iter import FileLister, FileOpener

datapipe1 = FileLister('data/swissprot/proteinfer_splits/random/', 'dev*.tfrecord')
datapipe2 = FileOpener(datapipe1, mode="b")
tfrecord_loader_dp = datapipe2.load_from_tfrecord()

records = []
# Iterate over records, process and write to a fasta file
for _, record in tqdm(enumerate(tfrecord_loader_dp)):
    processed_sequence = process_sequence_tfrecord(
        record, "GO")
    break


ImportError: cannot import name '_check_lambda_fn' from 'torch.utils.data.datapipes.utils.common' (/anaconda/envs/protein_functions_310/lib/python3.10/site-packages/torch/utils/data/datapipes/utils/common.py)

In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
from src.utils.data import read_fasta, read_json, get_vocab_mappings, read_pickle
from src.utils.models import tokenize_labels, get_label_embeddings
from typing import Dict
import pandas as pd
import logging
from typing import List
from src.data.collators import collate_variable_sequence_length
from collections import defaultdict
from joblib import Parallel, delayed, cpu_count
from functools import partial
from collections import Counter
from torch.utils.data.distributed import DistributedSampler
from src.utils.main_utils import get_or_generate_label_embeddings


class ProteinDataset(Dataset):
    """
    Dataset class for protein sequences with GO annotations.
    """
    def __init__(
        self,
        data_paths: dict,
        config: dict,
        vocabularies: dict,
        label_tokenizer=None,
        label_encoder=None,
        logger=None,
        subset_fraction: float = 1.0,
        deduplicate: bool = False,
        is_master: bool = True,
    ):
        """
        paths (dict): Dictionary containing paths to the data and vocabularies.
            data_path (str): Path to the FASTA file containing the protein sequences and corresponding GO annotations
            dataset_type (str): One of 'train', 'validation', or 'test'
            go_descriptions_path (str): Path to the pickled file containing the GO term descriptions mapped to GO term IDs
        deduplicate (bool): Whether to remove duplicate sequences (default: False)
        """
        # Error handling: check for missing keys and invalid dataset types
        required_keys = ["data_path", "dataset_type"]
        for key in required_keys:
            if key not in data_paths:
                raise ValueError(
                    f"Missing required key in paths dictionary: {key}")

        assert data_paths["dataset_type"] in [
            "train",
            "validation",
            "test",
        ], "dataset_type must be one of 'train', 'val', or 'test'"

        # Set the dataset type and data path
        self.dataset_type = data_paths["dataset_type"]
        self.data_path = data_paths["data_path"]

        # Set and process the vocabularies
        self.amino_acid_vocabulary = vocabularies["amino_acid_vocab"]
        self.label_vocabulary = vocabularies["GO_label_vocab"]
        self.sequence_id_vocabulary = vocabularies["sequence_id_vocab"]
        self._process_vocab()

        # Initialize class variables
        self.data = read_fasta(data_paths["data_path"])
        self.label_embedding_matrix = self.sequence_embedding_df = None

        # Subset the data if subset_fraction is provided
        if subset_fraction < 1.0:
            logging.info(
                f"Subsetting {subset_fraction*100}% of the {self.dataset_type} set..."
            )
            self.data = self.data[:int(subset_fraction * len(self.data))]

        # Deduplicate the data if deduplicate is True
        if deduplicate:
            self._remove_duplicates()

        # Load the map from alphanumeric label id to text label
        self.label_annotation_map = {key: value['label'] for key, value in read_pickle(
            data_paths["go_annotations_path"]).to_dict(orient='index').items()}

        # Create ordered list of labels
        label_text_list = []
        for label_id in self.label_vocabulary:
            label_text_list.append(self.label_annotation_map[label_id])
        self.label_text_list = label_text_list

        # Loop through the label IDs and tokenize the labels if a label tokenizer is provided
        self.tokenized_labels = None
        self.label_tokenizer = None
        if label_tokenizer is not None:
            self.label_tokenizer = label_tokenizer
            self.tokenized_labels = tokenize_labels(
                label_text_list, label_tokenizer)

        # If a label encoder is provided, encode the labels
        # TODO: Move back to main to remove warning
        self.label_embedding_matrix = None
        self.label_encoder = None
        if label_encoder is not None and not config["params"]["TRAIN_LABEL_ENCODER"]:
            self.label_encoder = label_encoder
            label_embedding_matrix = get_or_generate_label_embeddings(
                label_annotations=self.label_text_list,
                label_tokenizer=label_tokenizer,
                label_encoder=label_encoder,
                label_embedding_path=config["paths"]["LABEL_EMBEDDING_PATH"],
                logger=logger,
                batch_size_limit=config["params"]["LABEL_BATCH_SIZE_LIMIT_NO_GRAD"],
                is_master=is_master,
            )
            self.label_embedding_matrix = label_embedding_matrix

    # Helper functions for setting embedding dictionaries
    def set_sequence_embedding_df(self, embedding_df: pd.DataFrame):
        self.sequence_embedding_df = embedding_df

    def set_label_embedding_matrix(self, embedding_matrix: torch.Tensor):
        self.label_embedding_matrix = embedding_matrix

    def _remove_duplicates(self):
        """
        Remove duplicate sequences from self.data, keeping only the first instance of each sequence
        Use pandas to improve performance
        """

        # Convert self.data to a DataFrame
        df = pd.DataFrame(self.data, columns=["sequence", "labels"])

        # Drop duplicate rows based on the 'sequence' column, keeping the first instance
        df = df.drop_duplicates(subset="sequence", keep="first")

        # Log the number of duplicate sequences removed
        num_duplicates = len(self.data) - len(df)
        logging.info(
            f"Removing {num_duplicates} duplicate sequences from {self.data_path}...")

        # Convert the DataFrame back to the list of tuples format
        self.data = list(df.itertuples(index=False, name=None))

    # Helper functions for processing and loading vocabularies
    def _process_vocab(self):
        self._process_amino_acid_vocab()
        self._process_label_vocab()
        self._process_sequence_id_vocab()

    def _process_amino_acid_vocab(self):
        self.aminoacid2int, self.int2aminoacid = get_vocab_mappings(
            self.amino_acid_vocabulary
        )

    def _process_label_vocab(self):
        self.label2int, self.int2label = get_vocab_mappings(
            self.label_vocabulary)

    def _process_sequence_id_vocab(self):
        self.sequence_id2int, self.int2sequence_id = get_vocab_mappings(
            self.sequence_id_vocabulary
        )

    def __len__(self) -> int:
        return len(self.data)

    def process_example(self, sequence: str, labels: list[str]) -> dict:
        sequence_id_alphanumeric, labels = labels[0], labels[1:]

        # Convert the sequence and labels to integers for one-hot encoding
        amino_acid_ints = torch.tensor(
            [self.aminoacid2int[aa] for aa in sequence], dtype=torch.long
        )

        labels_ints = torch.tensor(
            [self.label2int[label] for label in labels], dtype=torch.long
        )

        # Get the length of the sequence
        sequence_length = torch.tensor(len(amino_acid_ints))

        # Get multi-hot encoding of sequence and labels
        sequence_onehots = torch.nn.functional.one_hot(
            amino_acid_ints, num_classes=len(self.amino_acid_vocabulary)
        ).permute(1, 0)
        label_multihots = torch.nn.functional.one_hot(
            labels_ints, num_classes=len(self.label_vocabulary)
        ).sum(dim=0)

        # Set the label embeddings, if provided
        label_embeddings = self.label_embedding_matrix if self.label_embedding_matrix is not None else None

        # Get the sequence embedding, if provided
        sequence_embedding = None
        # TODO: Remove this check
        if self.sequence_embedding_df is not None:
            sequence_embedding = torch.tensor(
                self.sequence_embedding_df.loc[sequence_id_alphanumeric].values)

        # Get the tokenized labels, if provided
        tokenized_labels = self.tokenized_labels if self.tokenized_labels is not None else None

        # Return a dict containing the processed example
        return {
            "sequence_onehots": sequence_onehots,
            "sequence_id": sequence_id_alphanumeric,
            "sequence_embedding": sequence_embedding,
            "sequence_length": sequence_length,
            "label_multihots": label_multihots,
            "tokenized_labels": tokenized_labels,
            "label_embeddings": label_embeddings,
        }

    def __getitem__(self, idx) -> tuple:
        sequence, labels = self.data[idx]
        return self.process_example(sequence, labels)

    @classmethod
    def create_multiple_datasets(
        cls,
        paths_list: List[Dict[str, str]],
        config: dict,
        vocabularies: dict,
        subset_fractions: dict = None,
        label_tokenizer=None,
        label_encoder=None,
        logger=None,
        deduplicate: bool = False,
    ) -> List[Dataset]:
        """
        paths_list (List[Dict[str, str]]): List of dictionaries, each containing paths to the data and vocabularies.
        subset_fractions (dict): Dictionary containing the subset fraction for each dataset type (default: None)
        """
        datasets = defaultdict(list)
        subset_fractions = subset_fractions or {}
        for data_paths in paths_list:
            datasets[data_paths["dataset_type"]].append(
                cls(
                    data_paths,
                    config,
                    vocabularies,
                    label_tokenizer=label_tokenizer,
                    label_encoder=label_encoder,
                    logger=logger,
                    subset_fraction=subset_fractions.get(
                        data_paths["dataset_type"], 1.0),
                    deduplicate=deduplicate
                )
            )
        return datasets

In [7]:

### SETUP ###
torch.cuda.empty_cache()

# Check if master process
is_master = True
config = "configs/base_config.yaml"
name = "Test"
train_path_name = "TRAIN_DATA_PATH"
validation_path_name = "VAL_DATA_PATH"
test_paths_names = ["TEST_DATA_PATH"]
amlt = False
gpu=0
rank=0

# Unpack and process the config file
config = get_setup(
    config_path=config,
    run_name=name,
    overrides=[],
    train_path_name=train_path_name,
    val_path_name=validation_path_name,
    test_paths_names=test_paths_names,
    amlt=amlt,
    is_master=is_master,
)
params, paths, timestamp, logger = config["params"], config[
    "paths"], config["timestamp"], config["logger"]

# Set the GPU device, if using
torch.cuda.set_device(gpu)
device = torch.device('cuda:' + str(gpu)
                        if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")


# Log the params
logger.info(json.dumps(params, indent=4))

# Initialize label tokenizer
label_tokenizer = AutoTokenizer.from_pretrained(
    params['LABEL_ENCODER_CHECKPOINT'],
)

# Initialize label encoder
label_encoder = AutoModel.from_pretrained(
    params['LABEL_ENCODER_CHECKPOINT'],
)
if params["GRADIENT_CHECKPOINTING"]:
    raise NotImplementedError(
        "Gradient checkpointing is not yet implemented.")

if params["LORA"]:
    for layer in label_encoder.layers:
        in_features, out_features = 1024, 1024
        layer.self_attn.q_proj = lora.Linear(
            in_features, out_features, r=params["LORA_RANK"])
        layer.self_attn.v_proj = lora.Linear(
            in_features, out_features, r=params["LORA_RANK"])
        layer.self_attn.k_proj = lora.Linear(
            in_features, out_features, r=params["LORA_RANK"])
        layer.self_attn.out_proj = lora.Linear(
            in_features, out_features, r=params["LORA_RANK"])
    # Mark only the LoRA parameters as trainable
    lora.mark_only_lora_as_trainable(label_encoder)

label_encoder = label_encoder.to(device)

# Load or generate the vocabularies
vocabularies = get_or_generate_vocabularies(
    paths["FULL_DATA_PATH"], paths["VOCABULARIES_DIR"], logger)


2023-12-03 08:41:37 PST INFO Logging to ./outputs/logs/2023-12-03_08-41-37_Test.log and console...
2023-12-03 08:41:37 PST INFO Logging to ./outputs/logs/2023-12-03_08-41-37_Test.log and console...
2023-12-03 08:41:37 PST INFO Logging to ./outputs/logs/2023-12-03_08-41-37_Test.log and console...
2023-12-03 08:41:37 PST INFO Using device: cuda:0
2023-12-03 08:41:37 PST INFO Using device: cuda:0
2023-12-03 08:41:37 PST INFO Using device: cuda:0
2023-12-03 08:41:37 PST INFO {
    "TRAIN_BATCH_SIZE": 32,
    "VALIDATION_BATCH_SIZE": 64,
    "TEST_BATCH_SIZE": 64,
    "IN_BATCH_SAMPLING": false,
    "TRAIN_LABEL_SAMPLE_SIZE": null,
    "VALIDATION_LABEL_SAMPLE_SIZE": null,
    "LABEL_BATCH_SIZE_LIMIT_NO_GRAD": 1500,
    "SEQUENCE_BATCH_SIZE_LIMIT_NO_GRAD": 128,
    "LEARNING_RATE": 0.0003,
    "OPTIMIZER": "Adam",
    "PROTEIN_EMBEDDING_DIM": 1100,
    "LABEL_EMBEDDING_DIM": 1024,
    "LATENT_EMBEDDING_DIM": 1024,
    "OUTPUT_MLP_HIDDEN_DIM_SCALE_FACTOR": 2,
    "OUTPUT_MLP_NUM_LAYERS": 2,


In [9]:
config['dataset_paths_list']

[{'vocabularies_dir': './data/vocabularies/proteinfer',
  'go_annotations_path': './data/annotations/go_annotations_2019_07_01.pkl',
  'data_path': './data/swissprot/proteinfer_splits/random/train_GO.fasta',
  'dataset_type': 'train'},
 {'vocabularies_dir': './data/vocabularies/proteinfer',
  'go_annotations_path': './data/annotations/go_annotations_2019_07_01.pkl',
  'data_path': './data/swissprot/proteinfer_splits/random/dev_GO.fasta',
  'dataset_type': 'validation'},
 {'vocabularies_dir': './data/vocabularies/proteinfer',
  'go_annotations_path': './data/annotations/go_annotations_2019_07_01.pkl',
  'data_path': './data/swissprot/proteinfer_splits/random/test_GO.fasta',
  'dataset_type': 'test'}]

In [None]:
from itertools import product

In [None]:
"sequence_onehots": sequence_onehots,
"sequence_id": sequence_id_alphanumeric,
"sequence_embedding": sequence_embedding,
"sequence_length": sequence_length,
"label_multihots": label_multihots,
"tokenized_labels": tokenized_labels,
"label_embeddings": label_embeddings,

In [77]:


def flatten_single_seq(idx):
    seq = d.data[idx]
    pos_labels = seq[1][1:]
    neg_labels = set(d.label_vocabulary) - set(pos_labels)
    labels = [(i,1) for i in pos_labels] + [(i,0) for i in neg_labels]
    flattened = defaultdict(list)

    for label,y in labels:
        flattened['sequence'].append(seq[0])
        flattened['sequence_id'].append(seq[1][0])
        flattened['label'].append(label)
        flattened['y'].append(y)

    return flattened

In [80]:
flattened = defaultdict(list)
for i in range(len(d.data)):
    single_flattend = flatten_single_seq(i)
    for k,v in single_flattend.items():
        flattened[k].extend(v)

KeyboardInterrupt: 

In [None]:

### SETUP ###
torch.cuda.empty_cache()

# Check if master process
is_master = True
config = "configs/base_config.yaml"
name = "Test"
train_path_name = "TRAIN_DATA_PATH"
validation_path_name = "VAL_DATA_PATH"
test_paths_names = ["TEST_DATA_PATH"]
amlt = False
gpu=0
rank=0

# Unpack and process the config file
config = get_setup(
    config_path=config,
    run_name=name,
    overrides=[],
    train_path_name=train_path_name,
    val_path_name=validation_path_name,
    test_paths_names=test_paths_names,
    amlt=amlt,
    is_master=is_master,
)
params, paths, timestamp, logger = config["params"], config[
    "paths"], config["timestamp"], config["logger"]

# Set the GPU device, if using
torch.cuda.set_device(gpu)
device = torch.device('cuda:' + str(gpu)
                        if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")


# Log the params
logger.info(json.dumps(params, indent=4))

# Initialize label tokenizer
label_tokenizer = AutoTokenizer.from_pretrained(
    params['LABEL_ENCODER_CHECKPOINT'],
)

# Initialize label encoder
label_encoder = AutoModel.from_pretrained(
    params['LABEL_ENCODER_CHECKPOINT'],
)
if params["GRADIENT_CHECKPOINTING"]:
    raise NotImplementedError(
        "Gradient checkpointing is not yet implemented.")

if params["LORA"]:
    for layer in label_encoder.layers:
        in_features, out_features = 1024, 1024
        layer.self_attn.q_proj = lora.Linear(
            in_features, out_features, r=params["LORA_RANK"])
        layer.self_attn.v_proj = lora.Linear(
            in_features, out_features, r=params["LORA_RANK"])
        layer.self_attn.k_proj = lora.Linear(
            in_features, out_features, r=params["LORA_RANK"])
        layer.self_attn.out_proj = lora.Linear(
            in_features, out_features, r=params["LORA_RANK"])
    # Mark only the LoRA parameters as trainable
    lora.mark_only_lora_as_trainable(label_encoder)

label_encoder = label_encoder.to(device)

# Load or generate the vocabularies
vocabularies = get_or_generate_vocabularies(
    paths["FULL_DATA_PATH"], paths["VOCABULARIES_DIR"], logger)

# Create datasets
datasets = ProteinDataset.create_multiple_datasets(
    paths_list=config['dataset_paths_list'],
    config=config,
    logger=logger,
    label_tokenizer=label_tokenizer,
    label_encoder=label_encoder,
    vocabularies=vocabularies,
    subset_fractions={
        "train": params["TRAIN_SUBSET_FRACTION"],
        "validation": params["VALIDATION_SUBSET_FRACTION"],
        "test": params["TEST_SUBSET_FRACTION"],
    },
    deduplicate=params["DEDUPLICATE"],
)

# Seed everything so we don't go crazy
seed_everything(params["SEED"], device)

# Initialize new run
logger.info(
    f"################## {timestamp} RUNNING main.py ##################")

# Define label sample sizes for train, validation, and test loaders
label_sample_sizes = {
    "train": params["TRAIN_LABEL_SAMPLE_SIZE"],
    "validation": params["VALIDATION_LABEL_SAMPLE_SIZE"],
    "test": None  # No sampling for the test set
}

# Define data loaders
loaders = create_multiple_loaders(
    datasets,
    params,
    label_sample_sizes=label_sample_sizes,
    shuffle_labels=params['SHUFFLE_LABELS'],
    in_batch_sampling=params['IN_BATCH_SAMPLING'],
    num_workers=params["NUM_WORKERS"],
    world_size=1,
    rank=rank,
)

if not params["TRAIN_LABEL_ENCODER"]:
    # Move the label encoder to CPU
    label_encoder = label_encoder.cpu()

# Initialize ProteInfer
sequence_encoder = ProteInfer.from_pretrained(
    weights_path=paths["PROTEINFER_WEIGHTS_PATH"],
    num_labels=config["embed_sequences_params"]["PROTEINFER_NUM_LABELS"],
    input_channels=config["embed_sequences_params"]["INPUT_CHANNELS"],
    output_channels=config["embed_sequences_params"]["OUTPUT_CHANNELS"],
    kernel_size=config["embed_sequences_params"]["KERNEL_SIZE"],
    activation=torch.nn.ReLU,
    dilation_base=config["embed_sequences_params"]["DILATION_BASE"],
    num_resnet_blocks=config["embed_sequences_params"]["NUM_RESNET_BLOCKS"],
    bottleneck_factor=config["embed_sequences_params"]["BOTTLENECK_FACTOR"],
)

# Generate all sequence embeddings upfront, if not training the sequence encoder
sequence_embedding_df = None
if not params["TRAIN_SEQUENCE_ENCODER"]:
    sequence_embedding_df = get_or_generate_sequence_embeddings(
        paths,
        device,
        sequence_encoder,
        datasets,
        params,
        logger,
    )
    sequence_encoder = sequence_encoder.to('cpu')

# Loop through all the datasets and set the sequence embedding df
for dataset in datasets.values():
    for subset in dataset:
        if not params["TRAIN_SEQUENCE_ENCODER"]:
            subset.set_sequence_embedding_df(sequence_embedding_df)


loaders["train"][0]



2023-11-25 03:45:01 PST INFO Logging to ./outputs/logs/2023-11-25_03-45-01_Test.log and console...
2023-11-25 03:45:01 PST INFO Using device: cuda:0
2023-11-25 03:45:01 PST INFO {
    "TRAIN_BATCH_SIZE": 64,
    "VALIDATION_BATCH_SIZE": 64,
    "TEST_BATCH_SIZE": 64,
    "IN_BATCH_SAMPLING": false,
    "TRAIN_LABEL_SAMPLE_SIZE": null,
    "VALIDATION_LABEL_SAMPLE_SIZE": null,
    "LABEL_BATCH_SIZE_LIMIT_NO_GRAD": 1500,
    "SEQUENCE_BATCH_SIZE_LIMIT_NO_GRAD": 128,
    "LEARNING_RATE": 0.001,
    "OPTIMIZER": "Adam",
    "PROTEIN_EMBEDDING_DIM": 1100,
    "LABEL_EMBEDDING_DIM": 1024,
    "LATENT_EMBEDDING_DIM": 1024,
    "OUTPUT_MLP_HIDDEN_DIM_SCALE_FACTOR": 1,
    "OUTPUT_MLP_NUM_LAYERS": 2,
    "OUTPUT_NEURON_PROBABILITY_BIAS": null,
    "OUTPUT_MLP_BATCHNORM": true,
    "OPTIMIZATION_METRIC_NAME": "map_micro",
    "DECISION_TH_METRIC_NAME": "f1_micro",
    "NUM_EPOCHS": 15,
    "GRADIENT_ACCUMULATION_STEPS": 1,
    "GRADIENT_CHECKPOINTING": false,
    "LORA": false,
    "LORA_RANK": 

<torch.utils.data.dataloader.DataLoader at 0x7f639c7977c0>

In [11]:
d = ProteinDataset(config=config,
        logger=logger,
        label_tokenizer=label_tokenizer,
        label_encoder=label_encoder,
        vocabularies=vocabularies,
        data_paths=config['dataset_paths_list'][1],
        deduplicate=params["DEDUPLICATE"])


2023-12-03 08:44:55 PST INFO Removing 8479 duplicate sequences from ./data/swissprot/proteinfer_splits/random/dev_GO.fasta...
2023-12-03 08:44:55 PST INFO Removing 8479 duplicate sequences from ./data/swissprot/proteinfer_splits/random/dev_GO.fasta...
2023-12-03 08:44:55 PST INFO Removing 8479 duplicate sequences from ./data/swissprot/proteinfer_splits/random/dev_GO.fasta...
2023-12-03 08:45:08 PST INFO Loaded label embeddings from ./data/None
2023-12-03 08:45:08 PST INFO Loaded label embeddings from ./data/None
2023-12-03 08:45:08 PST INFO Loaded label embeddings from ./data/None


In [None]:
"sequence_onehots": sequence_onehots,
"sequence_id": sequence_id_alphanumeric,
"sequence_embedding": sequence_embedding,
"sequence_length": sequence_length,
"label_multihots": label_multihots,
"tokenized_labels": tokenized_labels,
"label_embeddings": label_embeddings,

In [43]:
tokenized_labels.keys()

dict_keys(['input_ids', 'attention_mask'])

In [109]:
sequence_onehots_list = []
sequence_id_alphanumeric_list = []
sequence_embedding_list = []
sequence_length_list = []
label_multihots_list = []

flattened = defaultdict(list)
for i in range(10):

    sequence,labels = d.data[i]

    sequence_id_alphanumeric, labels = labels[0], labels[1:]

    # Convert the sequence and labels to integers for one-hot encoding
    amino_acid_ints = torch.tensor(
        [d.aminoacid2int[aa] for aa in sequence], dtype=torch.long
    )

    labels_ints = torch.tensor(
        [d.label2int[label] for label in labels], dtype=torch.long
    )

    # Get the length of the sequence
    sequence_length = torch.tensor(len(amino_acid_ints))

    # Get multi-hot encoding of sequence and labels
    sequence_onehots = torch.nn.functional.one_hot(
        amino_acid_ints, num_classes=len(d.amino_acid_vocabulary)
    ).permute(1, 0)
    label_multihots = torch.nn.functional.one_hot(
        labels_ints, num_classes=len(d.label_vocabulary)
    ).sum(dim=0)

    # Get the sequence embedding, if provided
    sequence_embedding = None
    # TODO: Remove this check
    if d.sequence_embedding_df is not None:
        sequence_embedding = torch.tensor(
            d.sequence_embedding_df.loc[sequence_id_alphanumeric].values)
        

    sequence_onehots_list.extend([sequence_onehots.clone() for _ in range(len(d.label_vocabulary))])
    sequence_id_alphanumeric_list.append([sequence_id_alphanumeric for _ in range(len(d.label_vocabulary))])
    sequence_embedding_list.append([sequence_embedding.clone() if sequence_embedding is not None else None for _ in range(len(d.label_vocabulary))])
    sequence_length_list.append(sequence_length)
    label_multihots_list.append(label_multihots)

[tensor([[0, 0, 0,  ..., 0, 0, 1],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 1, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]),
 tensor([[0, 0, 0,  ..., 0, 0, 1],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 1, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]),
 tensor([[0, 0, 0,  ..., 0, 0, 1],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 1, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]),
 tensor([[0, 0, 0,  ..., 0, 0, 1],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 1, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]),
 tensor([[0, 0, 0,  ..., 0, 0, 1],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0

In [104]:
sequence_onehots.shape

torch.Size([20, 155])

In [33]:

# Set the label embeddings, if provided
label_embeddings = d.label_embedding_matrix if d.label_embedding_matrix is not None else None

# Get the tokenized labels, if provided
tokenized_labels = d.tokenized_labels if d.tokenized_labels is not None else None

tensor(25390)

In [31]:
label_embeddings[labels_ints[-1]]

tensor([-1.1939, -0.8254,  0.0120,  ..., -0.0249, -0.1879,  0.3359])

In [32]:
label_embeddings[labels_ints]

tensor([[-1.0434, -1.4430, -0.3326,  ..., -0.6947,  0.2159, -0.7041],
        [-1.3708, -0.5248,  0.2347,  ..., -0.1317,  0.0789,  0.1895],
        [-1.0705, -1.2185, -0.5712,  ..., -0.3016, -0.3092, -0.1650],
        ...,
        [-1.1631, -0.5377,  0.3100,  ..., -0.2499,  0.4125,  0.2718],
        [-0.7390, -0.1916,  1.3233,  ..., -0.3388,  0.4596,  0.3480],
        [-1.1939, -0.8254,  0.0120,  ..., -0.0249, -0.1879,  0.3359]])

In [5]:
def get_batch_weights_v2(label_weights, target):
    """
    Computes the weights for each sample in the batch based on the target labels
    using broadcasting.
    
    Args:
        label_weights: torch.tensor of size [no_of_classes] with the weight of each label.
        target: torch.tensor of size [batch, no_of_classes].

    Returns:
        weights_for_samples: torch.tensor of size [batch, no_of_classes].
    """

    # Ensure label_weights is a float tensor for correct broadcasting and computation
    label_weights = label_weights.float()

    # Multiply weights with target labels using broadcasting
    # This step applies the specific class weights to the corresponding labels in the target.
    weighted_targets = label_weights * target

    # Sum the weighted targets along the class dimension to get a single weight per sample
    weights_for_samples = weighted_targets.sum(dim=1, keepdim=True)

    # Use broadcasting again for expanding weights across the class dimension
    # No need to repeat the tensor explicitly.
    weights_for_samples = weights_for_samples.expand_as(target)

    return weights_for_samples


class CBLoss(torch.nn.Module):
    def __init__(self, label_weights, beta=0.99):
        super().__init__()

        self.label_weights = label_weights
        self.beta=beta

    def forward(self, input,target):
        no_of_classes = len(self.label_weights)
        effective_num = 1.0 - torch.pow(self.beta, self.label_weights)

        # Replace zeros in effective_num with 'inf' (infinity) to avoid division by zero
        effective_num = torch.where(effective_num == 0, torch.tensor(float('inf')), effective_num)

        weights = (1.0 - self.beta) / effective_num
        weights = weights / torch.sum(weights) * no_of_classes

        weights = get_batch_weights_v2(weights,target)

        return weights

In [55]:
import numpy as np
import torch.nn.functional as F

def CB_loss(labels_one_hot, samples_per_cls, no_of_classes,  beta=0.99):
    """Compute the Class Balanced Loss between `logits` and the ground truth `labels`.

    Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits)
    where Loss is one of the standard losses used for Neural Networks.

    Args:
      labels: A int tensor of size [batch].
      logits: A float tensor of size [batch, no_of_classes].
      samples_per_cls: A python list of size [no_of_classes].
      no_of_classes: total number of classes. int
      loss_type: string. One of "sigmoid", "focal", "softmax".
      beta: float. Hyperparameter for Class balanced loss.
      gamma: float. Hyperparameter for Focal loss.

    Returns:
      cb_loss: A float tensor representing class balanced loss
    """
    effective_num = 1.0 - np.power(beta, samples_per_cls)
    weights = (1.0 - beta) / np.array(effective_num)
    weights = weights / np.sum(weights) * no_of_classes


    weights = torch.tensor(weights).float()
    weights = weights.unsqueeze(0)
    weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot
    weights = weights.sum(1)
    weights = weights.unsqueeze(1)
    weights = weights.repeat(1,no_of_classes)
    return weights

In [50]:
labels = (torch.rand(size=(10,100))>0.4)*1.0
preds = torch.rand(size=labels.shape)*10

In [52]:
samples_per_cls = labels.sum(axis=0
           )

In [53]:
samples_per_cls

tensor([7., 6., 6., 6., 4., 7., 6., 6., 6., 7., 6., 6., 6., 2., 7., 6., 4., 5.,
        7., 5., 7., 7., 6., 4., 7., 5., 5., 7., 5., 8., 6., 6., 6., 5., 5., 6.,
        5., 7., 3., 5., 4., 7., 5., 4., 5., 8., 7., 1., 7., 6., 6., 6., 5., 5.,
        6., 5., 6., 9., 5., 6., 9., 6., 4., 6., 7., 6., 6., 7., 4., 5., 5., 5.,
        8., 6., 7., 6., 6., 7., 4., 8., 5., 4., 7., 2., 6., 5., 6., 6., 6., 4.,
        5., 7., 7., 9., 6., 7., 8., 8., 3., 6.])

In [67]:
w_original = CB_loss(labels, samples_per_cls, len(samples_per_cls),  beta=0.9)

In [68]:
w_original.mean(),w_original.sum()

(tensor(53.7565), tensor(53756.4570))

In [69]:
cb=CBLoss(samples_per_cls,beta=0.9)
w_mine=cb(None,labels)

In [70]:
w_mine.mean(),w_mine.sum()

(tensor(53.7565), tensor(53756.4609))

In [None]:
w_mine = CB_loss(labels, samples_per_cls, len(samples_per_cls),  beta=0.99)

In [3]:
bsz=3
features = torch.randint(0,10,(bsz,2,1))
labels = torch.Tensor([1,2,1])

In [12]:
temperature=0.07
contrast_mode='all'
base_temperature=0.07

device = (torch.device('cuda')
            if features.is_cuda
            else torch.device('cpu'))

features = features.view(features.shape[0], features.shape[1], -1)

batch_size = features.shape[0]
labels = labels.contiguous().view(-1, 1)
if labels.shape[0] != batch_size:
    raise ValueError('Num of labels does not match num of features')
mask = torch.eq(labels, labels.T).float().to(device)


contrast_count = features.shape[1]
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
if contrast_mode == 'one':
    anchor_feature = features[:, 0]
    anchor_count = 1
elif contrast_mode == 'all':
    anchor_feature = contrast_feature
    anchor_count = contrast_count
else:
    raise ValueError('Unknown mode: {}'.format(contrast_mode))

# compute logits
anchor_dot_contrast = torch.div(
    torch.matmul(anchor_feature, contrast_feature.T),
    temperature)
# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()

# tile mask
mask = mask.repeat(anchor_count, contrast_count)
# mask-out self-contrast cases
logits_mask = torch.scatter(
    torch.ones_like(mask),
    1,
    torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
    0
)
mask = mask * logits_mask

# compute log_prob
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

# compute mean of log-likelihood over positive
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

# loss
loss = - (temperature / base_temperature) * mean_log_prob_pos

loss = loss.mean()


In [13]:
logits_mask

tensor([[0., 1., 1., 1., 1., 1.],
        [1., 0., 1., 1., 1., 1.],
        [1., 1., 0., 1., 1., 1.],
        [1., 1., 1., 0., 1., 1.],
        [1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 0.]])

In [10]:
(mask * log_prob).sum(0)

tensor([ -573.2203,    -1.7918,  -716.0775,  -716.0775,  -916.0775, -1144.6489])

In [159]:
loss.view(anchor_count, batch_size).mean(),loss.mean()


(tensor(126.7857), tensor(126.7857))

In [152]:
loss.view(anchor_count, batch_size).shape

torch.Size([2, 3])

In [161]:
del anchor_count

In [162]:
temperature=0.07
base_temperature=0.07

# compute logits
anchor_dot_contrast = torch.div(anchor_dot_contrast,temperature)
# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()

# compute log_prob
exp_logits = torch.exp(logits) 
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

# compute mean of log-likelihood over positive
mean_log_prob_pos = (labels_multihot * log_prob).sum(1) / labels_multihot.sum(1)

# loss
loss = - (temperature / base_temperature) * mean_log_prob_pos
loss = loss.mean()


NameError: name 'labels_multihot' is not defined

In [142]:
loss

tensor(88.9891)

In [140]:
torch.logsumexp(logits,dim=1,keepdim=True)

tensor([[1.7918],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000],
        [0.0000]])

In [128]:
log_prob

tensor([[  -1.7918,   -1.7918,   -1.7918,   -1.7918,   -1.7918,   -1.7918],
        [-214.2857,  -85.7143,  -42.8571, -128.5714, -128.5714,    0.0000],
        [-285.7143, -114.2857,  -57.1429, -171.4286, -171.4286,    0.0000],
        [-142.8571,  -57.1429,  -28.5714,  -85.7143,  -85.7143,    0.0000],
        [-142.8571,  -57.1429,  -28.5714,  -85.7143,  -85.7143,    0.0000],
        [-357.1429, -142.8571,  -71.4286, -214.2857, -214.2857,    0.0000]])

In [123]:
anchor_dot_contrast

tensor([[  0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000],
        [  0.0000, 128.5714, 171.4286,  85.7143,  85.7143, 214.2857],
        [  0.0000, 171.4286, 228.5714, 114.2857, 114.2857, 285.7143],
        [  0.0000,  85.7143, 114.2857,  57.1429,  57.1429, 142.8571],
        [  0.0000,  85.7143, 114.2857,  57.1429,  57.1429, 142.8571],
        [  0.0000, 214.2857, 285.7143, 142.8571, 142.8571, 357.1429]])

In [124]:
features

tensor([[[0],
         [2]],

        [[3],
         [2]],

        [[4],
         [5]]])

In [125]:
contrast_feature

tensor([[0],
        [3],
        [4],
        [2],
        [2],
        [5]])

In [58]:
anchor_dot_contrast

tensor([[  71.4286,  157.1429,  342.8571,  128.5714],
        [ 157.1429,  371.4286,  685.7143,  214.2857],
        [ 342.8571,  685.7143, 1828.5714,  800.0000],
        [ 128.5714,  214.2857,  800.0000,  414.2857]])

In [60]:
features.shape

torch.Size([2, 2, 2])

In [40]:
mask

tensor([[0., 0., 1., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.]])

In [20]:
a=MLP(1000,[10,10],bias=False,norm_layer=torch.nn.BatchNorm1d,activation_layer=torch.nn.Identity)

In [26]:
e=torch.nn.Embedding(100,3)

e(torch.arange(10))

torch.Size([100, 3])

Parameter containing:
tensor([[ 4.1412e-01,  2.0729e+00,  3.3877e-01],
        [ 7.4035e-01, -1.0129e+00,  1.0684e+00],
        [-7.6563e-01, -1.6943e-01, -7.2646e-01],
        [ 3.0629e-01, -5.6680e-01,  6.6975e-01],
        [-4.6175e-03, -4.8004e-01,  1.1684e+00],
        [-1.5192e-01,  4.9175e-01, -1.0614e+00],
        [-1.7002e-01,  1.8095e-01,  4.0745e-01],
        [-1.0855e+00,  1.6527e+00,  1.1391e+00],
        [ 7.1451e-01,  2.7505e+00,  5.0293e-01],
        [-7.2259e-01, -6.9784e-01,  6.9926e-01],
        [-8.0408e-01, -1.9509e+00,  1.9277e+00],
        [-1.6251e-01, -1.7948e-01,  6.0711e-01],
        [ 1.4911e-01,  3.4602e-01, -1.4749e+00],
        [-1.1428e-01,  4.2197e-01, -1.1637e+00],
        [-6.9847e-01,  1.1591e+00,  1.7230e-01],
        [-4.1416e-01, -1.2346e+00, -1.1913e+00],
        [-4.8150e-01,  1.1232e+00,  2.1309e+00],
        [ 4.2791e-01,  2.0048e+00,  1.1230e+00],
        [ 2.1412e-01,  9.4107e-01, -3.6250e-01],
        [ 3.0476e-01, -2.9366e-02,  7.1577e-01]

tensor([[ 0.4141,  2.0729,  0.3388],
        [ 0.7403, -1.0129,  1.0684],
        [-0.7656, -0.1694, -0.7265],
        [ 0.3063, -0.5668,  0.6697],
        [-0.0046, -0.4800,  1.1684],
        [-0.1519,  0.4917, -1.0614],
        [-0.1700,  0.1810,  0.4074],
        [-1.0855,  1.6527,  1.1391],
        [ 0.7145,  2.7505,  0.5029],
        [-0.7226, -0.6978,  0.6993]], grad_fn=<EmbeddingBackward0>)

In [4]:

### SETUP ###
torch.cuda.empty_cache()

# Check if master process
is_master = True
config = "configs/base_config.yaml"
name = "Test"
train_path_name = "TRAIN_DATA_PATH"
validation_path_name = "VAL_DATA_PATH"
test_paths_names = ["TEST_DATA_PATH"]
amlt = False
gpu=0
rank=0

# Unpack and process the config file
config = get_setup(
    config_path=config,
    run_name=name,
    overrides=[],
    train_path_name=train_path_name,
    val_path_name=validation_path_name,
    test_paths_names=test_paths_names,
    amlt=amlt,
    is_master=is_master,
)
params, paths, timestamp, logger = config["params"], config[
    "paths"], config["timestamp"], config["logger"]

# Set the GPU device, if using
torch.cuda.set_device(gpu)
device = torch.device('cuda:' + str(gpu)
                        if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")


# Log the params
logger.info(json.dumps(params, indent=4))

# Initialize label tokenizer
label_tokenizer = AutoTokenizer.from_pretrained(
    params['LABEL_ENCODER_CHECKPOINT'],
)

# Initialize label encoder
label_encoder = AutoModel.from_pretrained(
    params['LABEL_ENCODER_CHECKPOINT'],
)
if params["GRADIENT_CHECKPOINTING"]:
    raise NotImplementedError(
        "Gradient checkpointing is not yet implemented.")

if params["LORA"]:
    for layer in label_encoder.layers:
        in_features, out_features = 1024, 1024
        layer.self_attn.q_proj = lora.Linear(
            in_features, out_features, r=params["LORA_RANK"])
        layer.self_attn.v_proj = lora.Linear(
            in_features, out_features, r=params["LORA_RANK"])
        layer.self_attn.k_proj = lora.Linear(
            in_features, out_features, r=params["LORA_RANK"])
        layer.self_attn.out_proj = lora.Linear(
            in_features, out_features, r=params["LORA_RANK"])
    # Mark only the LoRA parameters as trainable
    lora.mark_only_lora_as_trainable(label_encoder)

label_encoder = label_encoder.to(device)

# Load or generate the vocabularies
vocabularies = get_or_generate_vocabularies(
    paths["FULL_DATA_PATH"], paths["VOCABULARIES_DIR"], logger)

# Create datasets
datasets = ProteinDataset.create_multiple_datasets(
    paths_list=config['dataset_paths_list'],
    config=config,
    logger=logger,
    label_tokenizer=label_tokenizer,
    label_encoder=label_encoder,
    vocabularies=vocabularies,
    subset_fractions={
        "train": params["TRAIN_SUBSET_FRACTION"],
        "validation": params["VALIDATION_SUBSET_FRACTION"],
        "test": params["TEST_SUBSET_FRACTION"],
    },
    deduplicate=params["DEDUPLICATE"],
)

# Seed everything so we don't go crazy
seed_everything(params["SEED"], device)

# Initialize new run
logger.info(
    f"################## {timestamp} RUNNING main.py ##################")

# Define label sample sizes for train, validation, and test loaders
label_sample_sizes = {
    "train": params["TRAIN_LABEL_SAMPLE_SIZE"],
    "validation": params["VALIDATION_LABEL_SAMPLE_SIZE"],
    "test": None  # No sampling for the test set
}

# Define data loaders
loaders = create_multiple_loaders(
    datasets,
    params,
    label_sample_sizes=label_sample_sizes,
    shuffle_labels=params['SHUFFLE_LABELS'],
    in_batch_sampling=params['IN_BATCH_SAMPLING'],
    num_workers=params["NUM_WORKERS"],
    world_size=1,
    rank=rank,
)

if not params["TRAIN_LABEL_ENCODER"]:
    # Move the label encoder to CPU
    label_encoder = label_encoder.cpu()

# Initialize ProteInfer
sequence_encoder = ProteInfer.from_pretrained(
    weights_path=paths["PROTEINFER_WEIGHTS_PATH"],
    num_labels=config["embed_sequences_params"]["PROTEINFER_NUM_LABELS"],
    input_channels=config["embed_sequences_params"]["INPUT_CHANNELS"],
    output_channels=config["embed_sequences_params"]["OUTPUT_CHANNELS"],
    kernel_size=config["embed_sequences_params"]["KERNEL_SIZE"],
    activation=torch.nn.ReLU,
    dilation_base=config["embed_sequences_params"]["DILATION_BASE"],
    num_resnet_blocks=config["embed_sequences_params"]["NUM_RESNET_BLOCKS"],
    bottleneck_factor=config["embed_sequences_params"]["BOTTLENECK_FACTOR"],
)

# Generate all sequence embeddings upfront, if not training the sequence encoder
sequence_embedding_df = None
if not params["TRAIN_SEQUENCE_ENCODER"]:
    sequence_embedding_df = get_or_generate_sequence_embeddings(
        paths,
        device,
        sequence_encoder,
        datasets,
        params,
        logger,
    )
    sequence_encoder = sequence_encoder.to('cpu')

# Loop through all the datasets and set the sequence embedding df
for dataset in datasets.values():
    for subset in dataset:
        if not params["TRAIN_SEQUENCE_ENCODER"]:
            subset.set_sequence_embedding_df(sequence_embedding_df)


loaders["train"][0]



2023-11-25 03:45:01 PST INFO Logging to ./outputs/logs/2023-11-25_03-45-01_Test.log and console...
2023-11-25 03:45:01 PST INFO Using device: cuda:0
2023-11-25 03:45:01 PST INFO {
    "TRAIN_BATCH_SIZE": 64,
    "VALIDATION_BATCH_SIZE": 64,
    "TEST_BATCH_SIZE": 64,
    "IN_BATCH_SAMPLING": false,
    "TRAIN_LABEL_SAMPLE_SIZE": null,
    "VALIDATION_LABEL_SAMPLE_SIZE": null,
    "LABEL_BATCH_SIZE_LIMIT_NO_GRAD": 1500,
    "SEQUENCE_BATCH_SIZE_LIMIT_NO_GRAD": 128,
    "LEARNING_RATE": 0.001,
    "OPTIMIZER": "Adam",
    "PROTEIN_EMBEDDING_DIM": 1100,
    "LABEL_EMBEDDING_DIM": 1024,
    "LATENT_EMBEDDING_DIM": 1024,
    "OUTPUT_MLP_HIDDEN_DIM_SCALE_FACTOR": 1,
    "OUTPUT_MLP_NUM_LAYERS": 2,
    "OUTPUT_NEURON_PROBABILITY_BIAS": null,
    "OUTPUT_MLP_BATCHNORM": true,
    "OPTIMIZATION_METRIC_NAME": "map_micro",
    "DECISION_TH_METRIC_NAME": "f1_micro",
    "NUM_EPOCHS": 15,
    "GRADIENT_ACCUMULATION_STEPS": 1,
    "GRADIENT_CHECKPOINTING": false,
    "LORA": false,
    "LORA_RANK": 

<torch.utils.data.dataloader.DataLoader at 0x7f639c7977c0>

In [67]:
from torch.utils.data import DataLoader, TensorDataset
from torch.cuda.amp import autocast
def tokenize_labels(text, tokenizer, max_length=1024):
    """
    Tokenize a list of text strings.

    Args:
        text (list): The list of text strings.
        tokenizer (transformers.PreTrainedTokenizer): The tokenizer.

    Returns:
        dict: A dictionary containing tokenized labels as 'input_ids' and 'attention_mask'.
    """
    return tokenizer(
        text, padding='longest', truncation=True, max_length=max_length, return_tensors="pt"
    )


def compute_mean_hidden_states(last_hidden_states, attention_mask):
    """Compute the mean of the last hidden state for only the relevant tokens."""
    # Compute the number of relevant tokens for each sequence
    num_relevant_tokens = attention_mask.sum(dim=1, keepdim=True)
    # Mask the last_hidden_state tensor and compute the sum
    sum_hidden_states = (last_hidden_states *
                         attention_mask.unsqueeze(-1)).sum(dim=1)
    # Compute the mean of the last hidden state
    return sum_hidden_states / num_relevant_tokens


def get_label_embeddings(tokenized_labels, model, batch_size_limit=1000):
    """
    Get embeddings for a list of tokenized labels.
    Assumes that tokenized_labels and model are on the same device, ideally GPU.
    """
    total_labels = tokenized_labels["input_ids"].shape[0]

    if total_labels <= batch_size_limit:
        with autocast():
            last_hidden_states = model(
                input_ids=tokenized_labels["input_ids"],
                attention_mask=tokenized_labels["attention_mask"]
            ).last_hidden_state
        output = compute_mean_hidden_states(
            last_hidden_states, tokenized_labels["attention_mask"])
        del last_hidden_states
        return output

    else:
        # Convert dictionary values to tensors
        tensors = [tokenized_labels["input_ids"],
                   tokenized_labels["attention_mask"]]
        # Create TensorDataset and DataLoader
        dataset = TensorDataset(*tensors)
        dataloader = DataLoader(dataset, batch_size=batch_size_limit,
                                shuffle=False, pin_memory=False, num_workers=0)

        all_label_embeddings = []
        for batch in dataloader:
            input_ids, attention_mask = batch
            with autocast():
                last_hidden_states = model(
                    input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
            mean_hidden_states = compute_mean_hidden_states(
                last_hidden_states, attention_mask)
            all_label_embeddings.append(mean_hidden_states)
            del last_hidden_states, mean_hidden_states
        # Concatenate all the label embeddings
        return torch.cat(all_label_embeddings, dim=0)


def generate_label_embeddings_from_text(label_annotations, label_tokenizer, label_encoder, batch_size_limit=1000):
    """Tokenize the labels and generate label embeddings."""
    tokenized_labels = tokenize_labels(label_annotations, label_tokenizer)

    # Move to GPU
    tokenized_labels["input_ids"] = tokenized_labels["input_ids"].to(
        label_encoder.device)
    tokenized_labels["attention_mask"] = tokenized_labels["attention_mask"].to(
        label_encoder.device)

    # Generate label embeddings
    return get_label_embeddings(tokenized_labels, label_encoder, batch_size_limit=batch_size_limit)

# Initialize label tokenizer
label_tokenizer = AutoTokenizer.from_pretrained(
    params['LABEL_ENCODER_CHECKPOINT'],
)

# Initialize label encoder
label_encoder = AutoModel.from_pretrained(
    params['LABEL_ENCODER_CHECKPOINT'],
)

In [69]:
from src.utils.data import read_pickle
annot=read_pickle('data/annotations/go_annotations_2019_07_01.pkl')

In [96]:
i=32000
annot.index[i],annot.iloc[i]['label']

('GO:0070327',
 'The directed movement of thyroid hormone into, out of or within a cell, or between cells, by means of some agent such as a transporter or pore.')

In [97]:
datasets["train"][0].label2int[annot.index[i]]

22605

In [98]:
generate_label_embeddings_from_text([annot.iloc[i]['label']],label_tokenizer=label_tokenizer,label_encoder=label_encoder)

tensor([[-0.8438,  0.1259,  0.2046,  ...,  0.4670, -0.1736,  0.8953]],
       grad_fn=<DivBackward0>)

In [5]:
loader_iter = iter(loaders["train"][0])
data_iter = iter(datasets["train"][0])

In [22]:
data_batch = next(data_iter)
loader_batch=next(loader_iter)

In [99]:
loader_batch['label_embeddings'][22605]

tensor([-0.8445,  0.1256,  0.2044,  ...,  0.4676, -0.1743,  0.8949])

In [8]:
datasets["train"][0].label2int['GO:0035639']

13652

In [17]:
sorted([datasets["train"][0].label2int[i] for i in datasets["train"][0].data[0][1][1:]])==torch.where(data_batch['label_multihots']==1)[0].tolist()

True

In [9]:
datasets["train"][0].data[0][1][1:]

['GO:0035639',
 'GO:0032553',
 'GO:0005524',
 'GO:0017076',
 'GO:0005737',
 'GO:1901265',
 'GO:1901363',
 'GO:0043168',
 'GO:0044424',
 'GO:0030554',
 'GO:0005488',
 'GO:0043167',
 'GO:0042026',
 'GO:0032559',
 'GO:0005515',
 'GO:0051082',
 'GO:0032555',
 'GO:0005575',
 'GO:0008144',
 'GO:0009987',
 'GO:0097159',
 'GO:0006457',
 'GO:0000166',
 'GO:0008150',
 'GO:0036094',
 'GO:0003674',
 'GO:0044464',
 'GO:0097367']

In [11]:
data_batch

{'sequence_onehots': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]),
 'sequence_id': 'P60545',
 'sequence_embedding': tensor([-0.0553, -0.3441, -0.2825,  ...,  0.4497, -0.0895, -0.1504]),
 'sequence_length': tensor(538),
 'label_multihots': tensor([0, 0, 0,  ..., 0, 0, 0]),
 'tokenized_labels': {'input_ids': tensor([[   2,   18,  569,  ...,    1,    1,    1],
         [   2,   18, 1900,  ...,    1,    1,    1],
         [   2,   18,  371,  ...,    1,    1,    1],
         ...,
         [   2,   18,  919,  ...,    1,    1,    1],
         [   2,   18,  919,  ...,    1,    1,    1],
         [   2,   18,  919,  ...,    1,    1,    1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         

In [15]:
embeddings = torch.load('data/embeddings/frozen_BioGPT_label_embeddings.pkl')

In [19]:
embeddings

tensor([[-1.3426e+00,  1.9259e-01,  4.5337e-01,  ..., -6.7419e-02,
          1.7350e-01,  8.9762e-01],
        [-5.8517e-01,  2.5346e-03,  9.9431e-01,  ...,  7.3632e-01,
          1.3791e+00,  1.2030e+00],
        [-4.8449e-01, -2.6923e-01,  1.7874e-01,  ..., -3.5807e-01,
          8.9524e-01,  8.7176e-01],
        ...,
        [-2.0514e-01, -1.0103e+00,  1.2279e+00,  ...,  3.6141e-01,
         -3.4265e-01,  5.1903e-01],
        [-8.9557e-01, -5.3069e-01,  9.3757e-01,  ..., -1.8156e-01,
         -2.4020e-02, -9.7481e-04],
        [-7.9217e-01, -9.6587e-01,  1.2481e+00,  ...,  4.1990e-01,
         -3.4655e-01,  1.0383e-02]])

In [18]:
a['label_embeddings']

tensor([[-1.3426e+00,  1.9259e-01,  4.5337e-01,  ..., -6.7419e-02,
          1.7350e-01,  8.9762e-01],
        [-5.8517e-01,  2.5346e-03,  9.9431e-01,  ...,  7.3632e-01,
          1.3791e+00,  1.2030e+00],
        [-4.8449e-01, -2.6923e-01,  1.7874e-01,  ..., -3.5807e-01,
          8.9524e-01,  8.7176e-01],
        ...,
        [-2.0514e-01, -1.0103e+00,  1.2279e+00,  ...,  3.6141e-01,
         -3.4265e-01,  5.1903e-01],
        [-8.9557e-01, -5.3069e-01,  9.3757e-01,  ..., -1.8156e-01,
         -2.4020e-02, -9.7481e-04],
        [-7.9217e-01, -9.6587e-01,  1.2481e+00,  ...,  4.1990e-01,
         -3.4655e-01,  1.0383e-02]])

In [20]:
P_e = a['sequence_embeddings']
L_e = a['label_embeddings']

In [50]:
from tqdm import tqdm
joint = []
for i in tqdm(P_e):
    for j in L_e:
        joint.append(torch.concat([i,j]))

  0%|          | 0/64 [00:00<?, ?it/s]

100%|██████████| 64/64 [00:15<00:00,  4.23it/s]


In [25]:
torch.repe

AttributeError: module 'torch' has no attribute 'repeat'

In [55]:
from tqdm import tqdm
joint = []
joint_matrix = []
for i in tqdm(range(10)):
    joint_rows=[]
    for j in range(11,15):
        i_ = torch.tensor([i]*5)
        j_ = torch.tensor([j]*7)
        concat = torch.concat([i_,j_])
        joint.append(concat)
        joint_rows.append(concat)
    joint_rows = torch.stack(joint_rows)
    joint_matrix.append(joint_rows)

#joint = torch.stack(joint)

100%|██████████| 10/10 [00:00<00:00, 5448.56it/s]


In [57]:
torch.stack(joint_matrix).sum(axis=-1)

tensor([[ 77,  84,  91,  98],
        [ 82,  89,  96, 103],
        [ 87,  94, 101, 108],
        [ 92,  99, 106, 113],
        [ 97, 104, 111, 118],
        [102, 109, 116, 123],
        [107, 114, 121, 128],
        [112, 119, 126, 133],
        [117, 124, 131, 138],
        [122, 129, 136, 143]])

In [64]:
torch.stack(joint_matrix)[1][3].sum()

tensor(103)

In [36]:
joint.sum(axis=1).reshape(10,4)

tensor([[ 77,  84,  91,  98],
        [ 82,  89,  96, 103],
        [ 87,  94, 101, 108],
        [ 92,  99, 106, 113],
        [ 97, 104, 111, 118],
        [102, 109, 116, 123],
        [107, 114, 121, 128],
        [112, 119, 126, 133],
        [117, 124, 131, 138],
        [122, 129, 136, 143]])

In [35]:
joint.shape

torch.Size([40, 12])

In [52]:
joint.sum(axis=0).mean()

tensor(-80840.0156)

In [47]:
num_sequences = P_e.shape[0]
num_labels = L_e.shape[0]
sequence_embedding_dim = P_e.shape[1]
label_embedding_dim = L_e.shape[1]

# Use broadcasting so we don't have to expand the tensor dimensions
joint_embeddings = torch.cat([
    P_e[:, None, :].expand(
        num_sequences, num_labels, sequence_embedding_dim),
    L_e[None, :, :].expand(
        num_sequences, num_labels, label_embedding_dim)
], dim=2).reshape(-1, sequence_embedding_dim + label_embedding_dim)

In [49]:
joint_embeddings.sum(axis=0).mean()

tensor(-80840.0156)

In [54]:
torch.tensor([1,0,1,0,1,1,1,1,0]).reshape(3, 3)

tensor([[1, 0, 1],
        [0, 1, 1],
        [1, 1, 0]])

In [None]:
parser = argparse.ArgumentParser(
    description="Train and/or Test the ProTCL model.")
parser.add_argument("--train-path-name", type=str, default=None,
                    help="Specify the desired train path name to train the model using names from config file. If not provided, model will not be trained. If provided, must also provide --val-path.")

parser.add_argument("--validation-path-name", type=str, default=None,
                    help="Specify the desired val path name to validate the model during training using names from config file. If not provided, model will not be trained. If provided, must also provide --train-path.")

parser.add_argument("--full-path-name", type=str, default=None,
                    help="Specify the desired full path name to define the vocabularies. Defaults to the full path name in the config file.")

parser.add_argument("--test-paths-names", nargs="+", type=str, default=None,
                    help="Specify all the desired test paths names to test the model using names from config file to test. If not provided, model will not be tested.")

parser.add_argument("--use-wandb", action="store_true", default=False,
                    help="Use Weights & Biases for logging. Default is False.")

parser.add_argument("--load-model", type=str, default=None,
                    help="(Relative) path to the model to be loaded. If not provided, a new model will be initialized.")

parser.add_argument('--from-checkpoint', action="store_true", default=False,
                    help="Continue training from a previous model checkpoint (including optimizer state and epoch). Default is False.")

parser.add_argument("--name", type=str, default="ProTCL",
                    help="Name of the W&B run. If not provided, a name will be generated.")

parser.add_argument("--config", type=str, default="configs/base_config.yaml",
                    help="(Relative) path to the configuration file.")

parser.add_argument("--amlt", action="store_true", default=False,
                    help="Run job on Amulet. Default is False.")

parser.add_argument("--override", nargs="*",
                    help="Override config parameters in key-value pairs.")

parser.add_argument("--save-prediction-results", action="store_true", default=False,
                    help="Save predictions and ground truth dataframe for validation and/or test")

parser.add_argument('-n', '--nodes', default=1, type=int,
                    metavar='N', help='Number of nodes (default: 1)')

parser.add_argument('-g', '--gpus', default=1, type=int,
                    help='Number of gpus per node (default: 1)')