In [None]:
import os
import datetime
import sys
from pathlib import Path

curdir = Path(os.getcwd())
sys.path.append(str(curdir.parent.absolute()))

import json
import logging
from src.utils.data import read_pickle
from src.utils.losses import contrastive_loss
from src.data.collators import collate_variable_sequence_length
from src.data.datasets import ProteinDataset
from src.models.ProTCL import ProTCL
import torch
import wandb


# Data paths
TRAIN_DATA_PATH = '/home/ncorley/protein/ProteinFunctions/data/swissprot/proteinfer_splits/random/train_GO.fasta'
VAL_DATA_PATH = '/home/ncorley/protein/ProteinFunctions/data/swissprot/proteinfer_splits/random/dev_GO.fasta'
TEST_DATA_PATH = '/home/ncorley/protein/ProteinFunctions/data/swissprot/proteinfer_splits/random/test_GO.fasta'
AMINO_ACID_VOCAB_PATH = '/home/ncorley/protein/ProteinFunctions/data/vocabularies/amino_acid_vocab.json'
GO_LABEL_VOCAB_PATH = '/home/ncorley/protein/ProteinFunctions/data/vocabularies/GO_label_vocab.json'

# Embedding paths
LABEL_EMBEDDING_PATH = "/home/ncorley/protein/ProteinFunctions/data/embeddings/label_embeddings.pk1"
SEQUENCE_EMBEDDING_PATH = "/home/ncorley/protein/ProteinFunctions/data/embeddings/sequence_embeddings.pk1"

# Load datasets
train_dataset, val_dataset, test_dataset = ProteinDataset\
    .create_multiple_datasets(data_paths=[TRAIN_DATA_PATH, VAL_DATA_PATH, TEST_DATA_PATH],
                              sequence_vocabulary_path=AMINO_ACID_VOCAB_PATH)

# Create label voculabary by merging sets from train_loader, val_loader, and test_loader
master_label_vocabulary = list(set(train_dataset.label_vocabulary) | set(val_dataset.label_vocabulary) | set(test_dataset.label_vocabulary))

# Save master_label_vocabulary to JSON file
with open('/home/ncorley/protein/ProteinFunctions/data/vocabularies/GO_label_vocab.json', 'w') as f:
    json.dump(master_label_vocabulary, f)