In [None]:
from pathflowai.utils import load_sql_df
import torch
import os 
import sys, os
os.environ['CUDA_VISIBLE_DEVICES']="0"
import umap, numba
from sklearn.preprocessing import LabelEncoder
from torch_cluster import knn_graph
from torch_geometric.data import Data 
import numpy as np
from torch_geometric.utils import train_test_split_edges
import os
import argparse
from torch_geometric.utils.convert import to_networkx
from torch_geometric.data import InMemoryDataset,DataLoader
import os,glob, pandas as pd
from sklearn.utils.class_weight import compute_class_weight
import pickle
import fire
import torch_geometric
import torch
import scipy.sparse as sps
from torch_cluster import radius_graph
from torch_geometric.utils import subgraph

In [None]:
class MyOwnDataset(InMemoryDataset):
    def __init__(self, root=None, transform=None, pre_transform=None):
        super(MyOwnDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = None,None#torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        pass

    @property
    def processed_file_names(self):
        pass

    def download(self):
        # Download to `self.raw_dir`.
        pass

    def process(self):
        # Read data into huge `Data` list.
        data_list = extract_graphs()

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])


        
def get_graph_datasets(embedding_dir,k=8,radius=0,build_connected_components=False):
    embeddings={os.path.basename(f).split('.')[0]: torch.load(f) for f in glob.glob("{}/*.pkl".format(embedding_dir))}
    embeddings=dict(embeddings=np.vstack([embeddings[k]['embeddings'] for k in embeddings]),
               patch_info=pd.concat([embeddings[k]['patch_info'] for k in embeddings]))
    df=embeddings['patch_info'].iloc[:,2:].reset_index()
    z=pd.DataFrame(embeddings['embeddings']).loc[df.index]
    embeddings['patch_info']=df
    le=LabelEncoder()
    cols=df['annotation'].value_counts().index.tolist()
    cols=np.array(cols)
    le=le.fit(cols)
    df['y_true']=le.transform(cols[df[cols].values.argmax(1)])
    weights=compute_class_weight('balanced',sorted(df['y_true'].unique()),df['y_true'].values)

    def get_dataset(slide,k=8,radius=0,build_connected_components=False):
        xy=embeddings['patch_info'][embeddings['patch_info']['ID']==slide][['x','y']]
        xy=torch.tensor(xy.values).float().cuda()
        X=z[embeddings['patch_info']['ID'].values==slide]
        X=torch.tensor(X.values)
        y=torch.tensor(df.loc[embeddings['patch_info']['ID'].values==slide,'y_true'].values)
        if not radius:
            G=knn_graph(xy,k=k)
        else:
            G=radius_graph(xy, r=radius*np.sqrt(2), batch=None, loop=True)
        G=G.detach().cpu()
        G=torch_geometric.utils.add_remaining_self_loops(G)[0]
        xy=xy.detach().cpu()
        datasets=[]
        if build_connected_components:
            edges=G.detach().cpu().numpy().astype(int)
            n_components,components=list(sps.csgraph.connected_components(sps.coo_matrix((np.ones_like(edges[0]),(edges[0],edges[1])))))
            components=torch.LongTensor(components)
            for i in range(n_components):
                G_new=subgraph(components==i,G,relabel_nodes=True)[0]
                xy_new=xy[components==i]
                X_new=X[components==i]
                y_new=y[components==i]
                np.random.seed(42)
                idx=np.arange(X_new.shape[0])
                idx2=np.arange(X_new.shape[0])
                np.random.shuffle(idx)
                train_idx,val_idx,test_idx=torch.tensor(np.isin(idx2,idx[:int(0.8*len(idx))])),torch.tensor(np.isin(idx2,idx[int(0.8*len(idx)):int(0.9*len(idx))])),torch.tensor(np.isin(idx2,idx[int(0.9*len(idx)):]))
                dataset=Data(x=X_new, edge_index=G_new, edge_attr=None, y=y_new, pos=xy_new)
                dataset.train_mask=train_idx
                dataset.val_mask=val_idx
                dataset.test_mask=test_idx
                datasets.append(dataset)
            components=components.numpy()
                
        else:
            components=np.ones(X.shape[0])
            np.random.seed(42)
            idx=np.arange(X.shape[0])
            idx2=np.arange(X.shape[0])
            np.random.shuffle(idx)
            train_idx,val_idx,test_idx=torch.tensor(np.isin(idx2,idx[:int(0.8*len(idx))])),torch.tensor(np.isin(idx2,idx[int(0.8*len(idx)):int(0.9*len(idx))])),torch.tensor(np.isin(idx2,idx[int(0.9*len(idx)):]))
            dataset=Data(x=X, edge_index=G, edge_attr=None, y=y, pos=xy)
            dataset.train_mask=train_idx
            dataset.val_mask=val_idx
            dataset.test_mask=test_idx
            datasets.append(dataset)
        return datasets,components


    def extract_graphs(df,k=8,radius=0,build_connected_components=False):
        graphs=[]
        if build_connected_components: df['component']=-1
        for slide in df['ID'].unique():
            if df.loc[df['ID']==slide,'y_true'].sum():
                G,components=get_dataset(slide,k,radius,build_connected_components)
                graphs.extend(G) 
                if build_connected_components: df.loc[df['ID']==slide,"component"]=components
        return graphs,df
    
    graph_dataset,df=extract_graphs(df,k,radius,build_connected_components)
    return dict(df=df,weight=weights,graph_dataset=graph_dataset)

def graph_extraction(embedding_dir,save_file='graph_dataset_test.pkl',k=8,radius=0,build_connected_components=False):
    graph_datasets=get_graph_datasets(embedding_dir,k,radius,build_connected_components)
    pickle.dump(graph_datasets,open(save_file,'wb'))

In [None]:
# use pathflowai or https://github.com/jlevy44/PathPretrain to pretrain / extract image features first

In [None]:
graph_datasets={}
for k in ['your_data_set']:
    embedding_dir=f"{k}/imagenet_embeddings"
    out_dir=f"{k}/graph_datasets"
    os.makedirs(out_dir,exist_ok=True)
    graph_extraction(embedding_dir,save_file=f'{out_dir}/imagenet_graph_data.pkl')