In [None]:
%matplotlib inline
import os
import torch
import numpy as np
import matplotlib.pyplot as plt 

In [None]:
import torch.nn.functional as F

def extract(dp, gt, keep_dims=True):

    while dp.size() != gt.size():
        dp = F.interpolate(dp, scale_factor=(2, 2, 2), mode='trilinear')

    print(dp.size())
    vpos_sets, vneg_sets = [],[]
    for i in range(dp.size(0)):
        pos_voxels = torch.argwhere(gt[i, 0] > 0)
        neg_voxels = torch.argwhere(gt[i, 0] == 0)
        
        pos_set = dp[i, :, pos_voxels[:, 0], pos_voxels[:, 1], pos_voxels[:, 2]]
        neg_set = dp[i, :, neg_voxels[:, 0], neg_voxels[:, 1], neg_voxels[:, 2]]

        vpos_sets.append(pos_set)
        vneg_sets.append(neg_set)

    print([v.size() for v in vpos_sets])
    vpos_sets = torch.concat(vpos_sets, -1)    
    vneg_sets = torch.concat(vneg_sets, -1)

    if keep_dims: 
        return vpos_sets, vneg_sets
    
    return vpos_sets.mean(dim=0), vneg_sets.mean(dim=0)

In [None]:
from sklearn import mixture 

In [None]:
def fit_gmm(pos, neg, n_components=2):
    if isinstance(n_components, str):
        # infer the number of components
        pass 

    if not isinstance(pos, np.ndarray):
        pos = pos.detach().cpu().numpy()

    if not isinstance(neg, np.ndarray):
        neg = neg.detach().cpu().numpy()

    if len(pos.shape) < 2:
        pos = np.expand_dims(pos, 1)
        neg = np.expand_dims(neg, 1)

    clf = mixture.GaussianMixture(n_components=n_components, covariance_type="full")
    clf.fit(np.concatenate((pos, neg)))

    return clf, clf.means_, clf.covariances_

In [None]:
from sklearn import metrics

def evaluate_gmm(clf: mixture.GaussianMixture, preds, labels_true):
    if len(preds.shape) != 1 :
        preds = np.squeeze(preds, 1)

    if len(labels_true.shape) != 1:
        labels_true = np.squeeze(labels_true, 1)

    print("Homogeneity: %0.3f" % metrics.homogeneity_score(labels_true, preds))
    print("Completeness: %0.3f" % metrics.completeness_score(labels_true, preds))
    print("V-measure: %0.3f" % metrics.v_measure_score(labels_true, preds))
    
    # sample_scores = clf.score_samples(preds)
    # print(sample_scores)

In [None]:
from scipy.interpolate import UnivariateSpline

def visualize_scalar_dist(s, nbins=30, c='b', ax=None):
    if ax is None:
        fig, ax = plt.subplots()
    p, x = np.histogram(s, bins=nbins) # bin it into n = N//10 bins
    x = x[:-1] + (x[1] - x[0])/2   # convert bin edges to centers
    f = UnivariateSpline(x, p, s=nbins)
    ax.plot(x, f(x), c=c)

    return fig, ax

def visualize_multivar_dist():
    pass

def visualize_dist(s, *args, **kwargs):
    if len(s.shape) > 1 or (len(s.shape) == 2 and s.shape[1] != 1):
        visualize_scalar_dist(s, *args, **kwargs)
    else:
        visualize_multivar_dist(s, *args, **kwargs)

In [None]:
import torch
import torch.nn.functional as F

In [None]:
class LinearClassifier(torch.nn.Module):
    def __init__(self, in_channels, num_classes, w, h, d):
        super(LinearClassifier, self).__init__()

        self.in_channels = in_channels
        self.num_classes = num_classes 
        self.width = w
        self.height = h
        self.depth = d
        self.classifier = torch.nn.Conv3d(in_channels, num_classes, (1,1,1))

    def forward(self, embeddings):
        embeddings = embeddings.reshape(-1, self.height, self.width, self.depth, self.in_channels)
        embeddings = embeddings.permute(0,4,1,2, 3)
        logits = self.classifier(embeddings)
        

        return logits


def train(model, train_dataloader, epochs, optimizer, criterion):
  # put model in training mode
  model.train()

  for epoch in range(epochs):
    print("Epoch:", epoch)
    for idx, batch in enumerate(train_dataloader):
        # pixel_values = batch["pixel_values"].to(device)
        # labels = batch["labels"].to(device)
        dp, gt = batch

        # forward pass
        out_logits = model(dp)
        if out_logits.size() != gt.size():
            while out_logits.size() != gt.size():
                out_logits = F.interpolate(out_logits, scale_factor=(2, 2, 2), mode='trilinear')
            
        dice_loss = criterion(out_logits, gt)
        dice_loss.backward()
        optimizer.step()

        # zero the parameter gradients
        optimizer.zero_grad()

def test(model, test_dataloader, criterion):
    total_loss = 0 
    for batch_idx, batch in enumerate(test_dataloader):

        dp, gt = batch 

        out_logits = model(dp)
        if out_logits.size() != gt.size():
            while out_logits.size() != gt.size():
                out_logits = F.interpolate(out_logits, scale_factor=(2, 2, 2), mode='trilinear')
            
        dice_loss = criterion(out_logits, gt)


def run_linear_clf():
    pass
    # from torch.optim import AdamW
    # from Upstream.nnunet.training.loss_functions.dice_loss import SoftDiceLoss

    # model = LinearClassifier()

    # train_dataloader = None


    # # training hyperparameters
    # # NOTE: I've just put some random ones here, not optimized at all
    # # feel free to experiment, see also DINOv2 paper
    # learning_rate = 5e-5
    # epochs = 10

    # optimizer = AdamW(model.parameters(), lr=learning_rate)
    # criterion = SoftDiceLoss()

    # # put model on GPU (set runtime to GPU in Google Colab)
    # device = "cuda" if torch.cuda.is_available() else "cpu"
    # model.to(device)

    # train(model, train_dataloader, epochs=epochs, optimizer=optimizer)


In [None]:
# Linear Probing
# train_features == dynamic prompts
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.metrics import f1_score, roc_auc_score, matthews_corrcoef

def cast_and_format(*args):
    fmt = []
    for arg in args:
        if not isinstance(arg, np.ndarray):
            arg = arg.detach().cpu().numpy()

        bsize = arg.shape[0]

        # if len(arg.shape) > 2:
        #     arg = np.reshape(arg, (bsize, -1))

        if len(arg.shape) > 2:
            arg = arg.reshape([-1, 1])

        fmt.append(arg)

    return tuple(fmt)

def linear_probing(train_features, train_labels):
    while train_features.size() != train_labels.size():
        train_features = F.interpolate(train_features, scale_factor=(2, 2, 2), mode='trilinear')

    train_features, train_labels = cast_and_format(train_features, train_labels)

    classifier = LogisticRegression(random_state=0, C=0.3, max_iter=1000, verbose=1)
    # classifier = LinearRegression()
    classifier.fit(train_features, train_labels)

    return classifier

def test_linear_classifier(classifier, features, labels):
    # match size first

    while features.size() != labels.size():
        features = F.interpolate(features, scale_factor=(2, 2, 2), mode='trilinear')
    
    features, labels = cast_and_format(features, labels)

    predictions = classifier.predict(features)

    accuracy = np.mean((labels == predictions).astype(float)) * 100
    f1 = f1_score(labels, predictions)
    auc_score = roc_auc_score(labels, predictions)

    return predictions, {"Accuracy": accuracy, "F1": f1, "AUC": auc_score}


In [None]:
from nnunet.training.network_training.UniSeg_Trainer import UniSeg_Trainer
from nnunet.paths import network_training_output_dir
from nnunet.run.default_configuration import get_default_configuration

def collect_model_info_and_evaluate(model_checkpoint, exp_name = "UniSeg_Trainer", network = "3d_fullres", task = "Task097_11task", network_trainer = "UniSeg_Trainer", plans_identifier = "DoDNetPlans", fold=0):
    # Get the main plans file
    plans_file, output_folder_name, dataset_directory, batch_dice, stage, \
        trainer_class = get_default_configuration(exp_name, network, task, network_trainer, plans_identifier)
    trainer = trainer_class(plans_file, fold, output_folder=output_folder_name, dataset_directory=dataset_directory,
                            batch_dice=batch_dice, stage=stage, unpack_data=True, deterministic=True, fp16=True)
    
    # Extract the model name from the checkpoint file name
    path_checkpoint = os.path.join(output_folder_name, f"fold_{fold}",model_checkpoint)        
    trainer.load_checkpoint(path_checkpoint)
    
    # Get the data loaders
    tr_gen = trainer.tr_gen
    val_gen = trainer.val_gen
    len_data = len(tr_gen.generator._data)
    outputs, inter_mediate_prompts, dynamic_prompts, task_prompts, features_xs,targets = [], [], [], [], [], []
    for i in range(len_data):
        output, inter_mediate_prompt, dynamic_prompt, task_prompt, features_x, target = trainer.run_iteration(tr_gen, False,False, True)
        # output, inter_mediate_prompt, dynamic_prompt, task_prompt, features_x, target = trainer.run_iteration(tr_gen, False,False)
        # metric = get_metric(output, target, inter_mediate_prompt, dynamic_prompt, task_prompt, features_x)
        # Possibly move to CPU
        outputs.append(output)
        inter_mediate_prompts.append(inter_mediate_prompt)
        dynamic_prompts.append(dynamic_prompt)
        task_prompts.append(task_prompt)
        features_xs.append(features_x)
        # target is a list
        targets.append(target)

        if i == 2:
            break

    return tr_gen, trainer, outputs, inter_mediate_prompts, dynamic_prompts, task_prompts, features_xs,targets

In [None]:
model_checkpoint = "/data/nnUNet_trained_models/UniSeg_Trainer/3d_fullres/Task091_MOTS/UniSeg_Trainer__DoDNetPlans/fold_0/model_latest.model"
tr_gen, trainer, outputs, inter_mediate_prompts, dynamic_prompts, task_prompts, features_xs,targets = collect_model_info_and_evaluate(model_checkpoint, exp_name = "UniSeg_Trainer", network = "3d_fullres", task = "Task097_11task", network_trainer = "UniSeg_Trainer", plans_identifier = "DoDNetPlans")


In [None]:
def plot_dist(s, nbins=30, c='b'):
    p, x = np.histogram(s, bins=nbins) # bin it into n = N//10 bins
    x = x[:-1] + (x[1] - x[0])/2   # convert bin edges to centers
    f = UnivariateSpline(x, p, s=nbins)
    plt.plot(x, f(x), c=c)

In [None]:
target_lst = [t[-2] for t in targets]
gts = torch.vstack(target_lst)
cast_and_format(gts)[0].shape

In [None]:
target_lst = [t[-2] for t in targets]

tps = torch.vstack(task_prompts)
gts = torch.vstack(target_lst)
print(gts.size())
classifier = linear_probing(tps, gts)

In [None]:
predictions, metric_res = test_linear_classifier(classifier, tps, gts)

for metric_name, res in metric_res.items():
    print(f"Linear Probing {metric_name}: {res}")   

In [None]:
pos_set, neg_set = extract(tps, gts)
pos_set = pos_set.detach().cpu().numpy().reshape(-1, 1)
neg_set = neg_set.detach().cpu().numpy().reshape(-1, 1)

In [None]:
pos_set.shape
# neg_set.shape

In [None]:
# plot_dist(np.concatenate((pos_set, neg_set)))
plt.close()
plot_dist(pos_set, c='b')
plot_dist(neg_set, c='r')
plot_dist(np.concatenate((pos_set, neg_set)), c='purple')
plt.show()
plt.savefig("/workspace/dist.png")

In [None]:
# GMM
gmm, means, covariances = fit_gmm(pos_set, neg_set)

print(means)
print(covariances)

In [None]:
# infer the assignment gmm made by visualization of distribution
pos_gt = np.zeros_like(pos_set)#.squeeze(1)
neg_gt = np.ones_like(neg_set)#.squeeze(1)

pos_gmm_acc = evaluate_gmm(gmm, pos_set, pos_gt)
neg_gmm_acc = evaluate_gmm(gmm, neg_set, neg_gt)


In [None]:
from sklearn.cluster import KMeans

kmeans = KMeans(n_clusters=2, random_state=0, n_init="auto").fit(*cast_and_format(tps))

print(f"cluster centers: {kmeans.cluster_centers_}")
pos_pred = kmeans.predict(pos_set)
print(f"{np.mean(1-pos_pred)}")

neg_pred = kmeans.predict(neg_set)
print(f"{np.mean(neg_pred)}")

In [None]:
from sklearn.neighbors import KNeighborsClassifier

neigh = KNeighborsClassifier(n_neighbors=3)
X = np.concatenate((pos_set, neg_set), 0)
y = np.concatenate((np.ones_like(pos_set), np.zeros_like(neg_set)), 0)
neigh.fit(X, y)
print(neigh.score(X, y))