In [6]:
from sklearn.preprocessing import LabelEncoder
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from saveAndLoad import *

canonical_mut_embeddings_esm2 = np.load('../aa/canonical_mut_embeddings_esm2.npy')

labeled_data = pd.read_csv('../data_processing/cancer_type_detailed_data.csv')
data = labeled_data['0'].values

label_encoder = LabelEncoder()
string_labels = labeled_data['1']
labels = label_encoder.fit_transform(string_labels)
labels = torch.tensor(labels,dtype=torch.long)

class Dataset_MutationList(Dataset):
    def __init__(self, data, labels, embeddings):
        load_str = lambda x: list(map(int,x.split(',')))
        self.data = [load_str(i) for i in data]
        self.labels = labels
        self.embeddings = embeddings

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        idxs = self.data[idx]
        emb = torch.stack([torch.tensor(self.embeddings[i],dtype=torch.float32) for i in idxs])
        return emb, self.labels[idx]
    
def custom_collate(batch):
    data = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    data = pad_sequence(data, batch_first=True, padding_value=float('-inf'))
    labels = torch.stack(labels)
    return data, labels

# Create dataset
dataset = Dataset_MutationList(data, labels, canonical_mut_embeddings_esm2)

# Create DataLoader
dataloader = DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=custom_collate)