In [1]:
import torch
import pickle
import pandas as pd
from torch_geometric.data import Data, DataLoader
from sklearn.model_selection import train_test_split
import networkx as nx

# Cargar los grafos
with open('/Users/medinils/Desktop/IMC_Spatial_predictions/graph/graphs_dic_batched.pkl', 'rb') as f:
    patient_graphs = pickle.load(f)

# Cargar los datos clínicos
clinical_data = pd.read_csv("/Users/medinils/Desktop/IMC_Spatial_predictions/data/raw_data/METABRIC_IMC/metabric_clinical_data.tsv", sep='\t')

# Limpiar los datos eliminando filas con NaN en 'Cellularity'
clinical_data = clinical_data.dropna(subset=['Cellularity'])



In [2]:
patient_graphs

{'MB-0000': <networkx.classes.graph.Graph at 0x12135f150>,
 'MB-0002': <networkx.classes.graph.Graph at 0x129676950>,
 'MB-0005': <networkx.classes.graph.Graph at 0x1303cc050>,
 'MB-0010': <networkx.classes.graph.Graph at 0x1320d4110>,
 'MB-0014': <networkx.classes.graph.Graph at 0x138160490>,
 'MB-0020': <networkx.classes.graph.Graph at 0x13c1b27d0>,
 'MB-0022': <networkx.classes.graph.Graph at 0x13c1e3b10>,
 'MB-0028': <networkx.classes.graph.Graph at 0x13c7a5c50>,
 'MB-0035': <networkx.classes.graph.Graph at 0x13eb2c450>,
 'MB-0045': <networkx.classes.graph.Graph at 0x13fe54710>,
 'MB-0050': <networkx.classes.graph.Graph at 0x1436c10d0>,
 'MB-0060': <networkx.classes.graph.Graph at 0x14541b090>,
 'MB-0064': <networkx.classes.graph.Graph at 0x15e181290>,
 'MB-0081': <networkx.classes.graph.Graph at 0x1670f8a90>,
 'MB-0095': <networkx.classes.graph.Graph at 0x178d81310>,
 'MB-0099': <networkx.classes.graph.Graph at 0x17d0ca410>,
 'MB-0107': <networkx.classes.graph.Graph at 0x17e582250

In [3]:
# Convertir las etiquetas categóricas de cellularity a índices numéricos
clinical_data['Cellularity'], _ = pd.factorize(clinical_data['Cellularity'])

In [4]:
import random

# Seleccionar aleatoriamente 20 IDs de paciente que estén tanto en los grafos como en los datos clínicos
sampled_patient_ids = random.sample(list(set(patient_graphs.keys()).intersection(clinical_data['Patient ID'])), 10)

# Crear un subconjunto de datos clínicos para los pacientes seleccionados
sampled_clinical_data = clinical_data[clinical_data['Patient ID'].isin(sampled_patient_ids)]

In [18]:
# Crear un diccionario de Patient ID a Cellularity
cellularity_labels = clinical_data.set_index('Patient ID')['Cellularity'].to_dict()

# Añadir cellularity como atributo 'label' en cada grafo
for patient_id, graph in patient_graphs.items():
    # Asignar el label de cellularity al grafo
    if patient_id in cellularity_labels:
        graph.graph['label'] = cellularity_labels[patient_id]
    else:
        graph.graph['label'] = None  # O asignar un valor default o manejar de alguna otra forma si no se encuentra el ID

# Ahora cada grafo en patient_graphs tiene un atributo 'label' con su cellularity correspondiente


In [19]:
from torch_geometric.utils import from_networkx
import torch

data_list = []
for patient_id, graph in patient_graphs.items():
    # Convertir el grafo, asegurándonos de incluir las características 'CD68', 'CD3', y 'CD20'
    data = from_networkx(graph, group_node_attrs=['CD68', 'CD3', 'CD20'])

    # Establecer el label del grafo si existe, de lo contrario usar un valor default
    if graph.graph['label'] is not None:
        data.y = torch.tensor([graph.graph['label']], dtype=torch.long)
    else:
        data.y = torch.tensor([0], dtype=torch.long)  # Asumiendo 0 como valor default para cellularity si no está definido
    
    data_list.append(data)

# data_list ahora contiene los objetos Data de PyTorch Geometric con las características y labels apropiados.


In [21]:
#guardar lista
import pickle

# Guardar data_list en un archivo
with open('/Users/medinils/Desktop/IMC_Spatial_predictions/graph/data_list.pkl', 'wb') as f:
    pickle.dump(data_list, f)

print("data_list ha sido guardado exitosamente.")


data_list ha sido guardado exitosamente.


In [20]:
# Obtener el grafo del paciente MB-0893
patient_graph = patient_graphs.get('MB-0893')

# Verificar si el grafo existe y obtener su label
if patient_graph is not None:
    # Asumiendo que el atributo 'label' ha sido previamente asignado como mostramos en el código anterior
    patient_label = patient_graph.graph.get('label', 'Label no definido')
    print(f"El label de cellularity para el paciente MB-0893 es: {patient_label}")
else:
    print("No existe un grafo para el paciente MB-0893")


El label de cellularity para el paciente MB-0893 es: 1


In [22]:
import torch
from torch_geometric.data import Dataset, DataLoader

class SimpleDataset(Dataset):
    def __init__(self, data_list):
        super(SimpleDataset, self).__init__()
        self.data_list = data_list
    
    def len(self):
        return len(self.data_list)
    
    def get(self, idx):
        return self.data_list[idx]

# Crear el dataset utilizando la lista de datos
my_dataset = SimpleDataset(data_list)

In [23]:
# Usar DataLoader para manejar los datos en batches durante el entrenamiento
loader = DataLoader(my_dataset, batch_size=32, shuffle=True)

# Opcionalmente, imprimir detalles del dataset
print(f'Dataset: {my_dataset}:')
print(f'Number of graphs: {len(my_dataset)}')
print(f'Number of features: {my_dataset[0].num_node_features}')
print(f'Number of classes: {len(set([data.y.item() for data in my_dataset]))}')

# Comprobar el primer grafo
data = my_dataset[0]
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')

Dataset: SimpleDataset(404):
Number of graphs: 404
Number of features: 3
Number of classes: 3
Number of nodes: 319
Number of edges: 1256
Average node degree: 3.94
Has isolated nodes: True
Has self-loops: False
Is undirected: True




In [24]:
# Guardar el dataset en un archivo
with open('/Users/medinils/Desktop/IMC_Spatial_predictions/graph/my_dataset.pkl', 'wb') as f:
    pickle.dump(my_dataset, f)



In [25]:
# Cargar el dataset desde un archivo
with open('/Users/medinils/Desktop/IMC_Spatial_predictions/graph/my_dataset.pkl', 'rb') as f:
    loaded_dataset = pickle.load(f)

print("Dataset ha sido cargado exitosamente.")


Dataset ha sido cargado exitosamente.
