In [None]:
import os
base = "/data/bionets" if "ramses" in os.uname()[1] else "/data_nfs/"

import cv2 
import sys
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import pickle
import pandas as pd

os.environ["CUDA_VISIBLE_DEVICES"]="1"
sys.path.append("..")
from src import *

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
markers = ['ADAM10', 'Bcl-2', 'CD10', 'CD107a', 'CD13', 'CD138', 'CD14', 'CD1a', 'CD2', 'CD25', 'CD271', 'CD3', 'CD36', 'CD4', 'CD44', 'CD45', 'CD45RA', 'CD45RO', 'CD5', 'CD56', 'CD6', 'CD63', 'CD66abce', 'CD7', 'CD71', 'CD8', 'CD9', 'CD95', 'Collagen IV', 'Cytokeratin-14', 'EBF-P', 'EGFR', 'EGFR-AF488', 'HLA-ABC', 'HLA-DR', 'KIP1', 'Ki67', 'L302', 'MCSP', 'Melan-A', 'Nestin-AF488', 'Notch-1', 'Notch-3', 'PPARgamma', 'PPB', 'RIM3', 'TAP73', 'Vimentin', 'p63', 'phospho-Connexin']    

In [None]:
data = get_data_csv()
data = data[data["Group"] == "Melanoma"].reset_index()
data["Coarse tumor stage"] = data["Float tumor stage"] > 0.5

In [None]:
rois = [os.path.splitext(f)[0][7:] for f in os.listdir("/data_nfs/je30bery/melanoma_data/MAGICAL/data/ROIs/ground_truth_ROI/filled") if "filled" in f]
data = data[data["file_path"].isin(rois)]

In [None]:
import torch as t
sys.path.append("../model")
from model import EfficientnetWithFinetuning, VGGWithFinetuning, ResNetWithFinetuning, EfficientnetWithFinetuningWithVGGClassifier

effnet = EfficientnetWithFinetuning(indim=len(markers))
effnet.load_state_dict(t.load("../model/finetuned_effnet_with_LR_reduction_on_plateau.pt", map_location="cpu"))
#resnet = ResNetWithFinetuning(indim=len(markers))
#resnet.load_state_dict(t.load("../model/finetuned_resnet_with_LR_reduction_on_pleateau.pt", map_location="cpu"))
#vgg = VGGWithFinetuning(indim=len(markers))
#vgg.load_state_dict(t.load("../model/finetuned_vgg_with_LR_reduction_on_plateau.pt", map_location="cpu"))
#effnet_regression = t.load("../data/models/regression_effnet.pt", map_location="cpu") 
effnet2 = EfficientnetWithFinetuningWithVGGClassifier(indim=len(markers))
effnet2.load_state_dict(t.load("../data/models/model_2024-03-25 11:01:19.161783_f1=0.749165120593692_acc=0.75_12.pt", map_location="cpu"))

for i, model in enumerate([effnet, effnet2]):# , vggresnet, 
    maps = get_smooth_grad(data, model)
    np.save(f"model{i}_maps.npy", maps)

In [None]:
maps1 = np.load("model0_maps.npy")
maps2 = np.load("model1_maps.npy")

mean1 = np.mean(maps1, axis=2)
mean2 = np.mean(maps2, axis=2)

for d in tqdm(range(len(data))):
    for i, (gradcam, mean) in enumerate([(maps1, mean1), (maps2, mean2)]):
        fov = data.iloc[d]["file_path"] 
        segmented = os.path.join(base, "datasets/melc/melanoma/", "segmented", f'{fov}_cells.npy')
        with open(segmented, "rb") as openfile:
            seg_file = np.load(openfile)
        gc = gradcam[:,:,d] - mean
        gc *= (gc > 0)
        gc[0,:] = 0
        gc[:,-1] = 0
        gc[-1,:] = 0
        gc[:,0] = 0
        gc = gc / (np.max(gc) + 1e-6)
        
        binary, roi_cells = get_binary(gc, seg_file)
        #path = f"../data/ROIs/model_{i}"
        #os.makedirs(path, exist_ok=True)
        #with open(os.path.join(path, fov + "_idxs.pkl"), "wb") as fp:
        #    pickle.dump(roi_cells, fp)
        #rgb[i] = cv2.resize((gc * 255).as

In [None]:
data = data.drop([36, 40])

In [None]:
rgbs = list()
mean1 = np.mean(maps1, axis=2)
mean2 = np.mean(maps2, axis=2)

for d in tqdm(range(len(data))):
    rgb = np.zeros([3, 512, 512])
    for i, (gradcam, mean) in enumerate([(maps1, mean1), (maps2, mean2)]):
        if i == 0:
            continue
        fov = data.iloc[d]["file_path"] 
        segmented = os.path.join(base, "datasets/melc/melanoma/", "segmented", f'{fov}_cells.npy')
        with open(segmented, "rb") as openfile:
            seg_file = np.load(openfile)
        gc = gradcam[:,:,d] - mean
        gc *= (gc > 0)
        gc = gc / (np.max(gc) + 1e-6)
        
        binary, roi_cells = get_binary(gc, seg_file)
        
          
        #path = f"../data/ROIs/riprip_{i}"
        #os.makedirs(path, exist_ok=True)
        #with open(os.path.join(path, fov + "_idxs.pkl"), "wb") as fp:
        #    pickle.dump(roi_cells, fp)
        rgb[i] = cv2.resize((gc * 255).astype(np.uint8), (512, 512), interpolation=cv2.INTER_AREA) #binary * 255
        rgb[i] = binary * 255
    
    #roi = cv2.imread(os.path.join(base, f"je30bery/melanoma_data/MAGICAL/data/ROIs/ground_truth_ROI/filled/filled_{fov}.tif"), cv2.IMREAD_GRAYSCALE)
    #roi = cv2.resize(roi, (512, 512))
    #if d == 5:
    #    roi = (roi > 15).astype(np.uint8)
    #else:
    #    roi = (roi > 0).astype(np.uint8)
    #rgb[-1] = roi * 255
    rgbs.append(rgb.astype(np.uint8)) 