# Introduction

In this notebook, we take a look at:

* Embedded data: we run inference with pretrained encoders to construct the concatenation of chunk embeddings
* Map Graph construction: for imaging data, we need to also convert the embedded data to map graphs 

# Embedding data -- DO LATER

This is something we only have to perform for data that's not natively graphs. We use `networkx` to create our data structures. 

Let's now load and set up model $f_\theta$. Choices for Camelyon16 data include:
- `"tile2vec"`: an unsupervised learning model, ResNet-16 trained from scratch
- `"vit_iid"`: a (weakly) supervised learning model, ViT trained from scratch on IID fuzzy targets
- `"clip"`: a Foundation Model, specifically a Vision-Langauge Model (VLM), pre-trained and used out of the box
- `"plip"`: a Foundation Model/VLM, pre-trained and used out of the box; clip-style model that is fine-tuned on Patholgy chunks
- `None`: skip inference

In [1]:
%load_ext autoreload
%autoreload 2
import torch

In [2]:
modelstr = "plip" #"tile2vec", "vit_iid", "clip", "plip"

In [3]:
# if modelstr == "tile2vec":
#     from architectures import ResNet18 
#     model = ResNet18(n_classes=2, in_channels=3, z_dim=128, supervised=False, no_relu=False, loss_type='triplet', tile_size=224, activation='relu')
#     chkpt = "/home/lofi/lofi/models/cam/to-port/ResNet18-hdf5_triplets_random_loading-224-label_selfsup-custom_loss-on_cam-cam16-filtration_background.sd"
#     checkpoint = torch.load(chkpt, map_location=device)
#     model.load_state_dict(checkpoint['model_state_dict'])
#     model.to(device)
#     prev_epoch = checkpoint['epoch']
#     loss = checkpoint['loss']
# elif modelstr == "vit_iid":
#     from vit_pytorch import ViT
#     model = ViT(image_size = 224, patch_size=16, num_classes=2, dim=1024, depth=6, heads=16, mlp_dim=2048, dropout=0.1, emb_dropout=0.1)
#     chkpt = "/home/lofi/lofi/models/cam/ViT-hdf5_random_loading-224-label_inherit-bce_loss-on_cam-cam16-filtration_background.sd"
#     checkpoint = torch.load(chkpt, map_location=device)
#     model.load_state_dict(checkpoint['model_state_dict'])
#     model.to(device)
#     prev_epoch = checkpoint['epoch']
#     loss = checkpoint['loss']
# elif modelstr == "clip":
#     from transformers import CLIPProcessor, CLIPTokenizer, CLIPModel
#     tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
#     processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
#     model_clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
# elif modelstr == "plip":
#     from transformers import AutoProcessor, AutoTokenizer, AutoModelForZeroShotImageClassification
#     tokenizer = AutoTokenizer.from_pretrained("vinid/plip")
#     processor = AutoProcessor.from_pretrained("vinid/plip")
#     model_plip = AutoModelForZeroShotImageClassification.from_pretrained("vinid/plip")
# elif modelstr == None:
#     print("No model selected for inference! Skipping inference...")
# else:
#     print("Not yet supported for inference! Skipping inference...")

# Converting embedded data to map graphs

In [4]:
import os
import utils
import networkx as nx
import numpy as np

In [5]:
modelstr = "plip"

In [6]:
if modelstr == "tile2vec":
    Z_dir = "/home/data/tinycam/train/Zs_tile2vec"
    save_dir = "/home/data/tinycam/train/Gs_tile2vec"
elif modelstr == "vit_iid":
    Z_dir = "/home/data/tinycam/train/Zs_vit"
    save_dir = "/home/data/tinycam/train/Gs_vit_iid"
elif modelstr == "clip":
    Z_dir = "/home/data/tinycam/train/Zs_clip"
    save_dir = "/home/data/tinycam/train/Gs_clip"
elif modelstr == "plip":
    Z_dir = "/home/data/tinycam/train/Zs_plip"
    save_dir = "/home/data/tinycam/train/Gs_plip"

In [9]:
# for Z_obj in os.listdir(Z_dir):
#     sample_id = Z_obj.split(".npy")[0]
#     G_id = "G-" + sample_id.split("-")[1]

#     print("converting {}".format(Z_obj))
#     Z_path = str(os.path.join(Z_dir, Z_obj))
#     Z = np.load(Z_path)
#     G = utils.convert_arr2graph(Z)
#     save_path = os.path.join(save_dir, G_id)
#     utils.serialize(G, save_path)

In [10]:
label_dict = {}
save_dir = "/home/data/tinycam/train/"
G_dir = "/home/data/tinycam/train/Gs_" + modelstr

for G_obj in os.listdir(G_dir):
    sample_id = G_obj.split(".npy")[0]
    G_id = "G-" + sample_id.split("-")[1]
    if "normal" in G_id:
        label_dict[G_id] = 0
    else: # tumor
        label_dict[G_id] = 1
    
utils.serialize(label_dict, os.path.join(save_dir, modelstr+"-label_dict.pkl"))

# Clean test set embedded images
- do later, draw from previous repo

# Convert test images to graphs

In [11]:
if modelstr == "tile2vec":
    Z_dir = "/home/data/tinycam/test/clean_Zs_tile2vec"
    save_dir = "/home/data/tinycam/test/Gs_tile2vec"
elif modelstr == "vit_iid":
    Z_dir = "/home/data/tinycam/test/clean_Zs_vit_iid"
    save_dir = "/home/data/tinycam/test/Gs_vit_iid"
elif modelstr == "clip":
    Z_dir = "/home/data/tinycam/test/clean_Zs_clip"
    save_dir = "/home/data/tinycam/test/Gs_clip"
elif modelstr == "plip":
    Z_dir = "/home/data/tinycam/test/clean_Zs_plip"
    save_dir = "/home/data/tinycam/test/Gs_plip"

In [12]:
for Z_obj in os.listdir(Z_dir):
    sample_id = Z_obj.split(".npy")[0]
    G_id = "G-" + sample_id.split("-")[1]

    print("converting {}".format(Z_obj))
    Z_path = str(os.path.join(Z_dir, Z_obj))
    Z = np.load(Z_path)
    G = utils.convert_arr2graph(Z)
    save_path = os.path.join(save_dir, G_id)
    utils.serialize(G, save_path)

converting Z-test_129.npy
converting Z-test_128.npy
converting Z-test_001.npy
converting Z-test_087.npy
converting Z-test_067.npy
converting Z-test_009.npy
converting Z-test_100.npy
converting Z-test_091.npy
converting Z-test_040.npy
converting Z-test_003.npy
converting Z-test_076.npy
converting Z-test_022.npy
converting Z-test_035.npy
converting Z-test_017.npy
converting Z-test_093.npy
converting Z-test_039.npy
converting Z-test_066.npy
converting Z-test_122.npy
converting Z-test_012.npy
converting Z-test_057.npy
converting Z-test_032.npy
converting Z-test_061.npy
converting Z-test_082.npy
converting Z-test_108.npy
converting Z-test_038.npy
converting Z-test_106.npy
converting Z-test_030.npy
converting Z-test_080.npy
converting Z-test_058.npy
converting Z-test_075.npy
converting Z-test_096.npy
converting Z-test_127.npy
converting Z-test_004.npy
converting Z-test_056.npy
converting Z-test_068.npy
converting Z-test_054.npy
converting Z-test_110.npy
converting Z-test_036.npy
converting Z