In [None]:
%matplotlib inline
# import nnunet
import numpy as np
import matplotlib.pyplot as plt 

In [None]:
def extract(dp, gt, keep_dims=True):
    vpos_sets, vneg_sets = [],[]
    for i in range(dp.shape[0]):
        pos_voxels = np.argwhere(gt[i, 0] > 0)
        neg_voxels = np.argwhere(gt[i, 0] == 0)
        
        pos_set = dp[i, :, pos_voxels[:, 0], pos_voxels[:, 1], pos_voxels[:, 2]]
        neg_set = dp[i, :, neg_voxels[:, 0], neg_voxels[:, 1], neg_voxels[:, 2]]

        vpos_sets.append(pos_set)
        vneg_sets.append(neg_set)

    vpos_sets = np.vstack(vpos_sets)    
    vneg_sets = np.vstack(vneg_sets)

    if keep_dims: 
        return vpos_sets, vneg_sets
    
    return vpos_sets.mean(axis=1), vneg_sets.mean(axis=1)

In [None]:
from sklearn import mixture 

In [None]:
def fit_sep(pos, neg, n_components=2):
    if isinstance(n_components, str):
        # infer the number of components
        pass 

    if len(pos.shape) < 2:
        pos = np.expand_dims(pos, 1)
        neg = np.expand_dims(neg, 1)

    clf = mixture.GaussianMixture(n_components=n_components, covariance_type="full")
    clf.fit(np.concatenate((pos, neg)))

    return clf, clf.means_, clf.covariances_

In [None]:
def evaluate_sep(clf, preds):
    pass

In [None]:
from scipy.interpolate import UnivariateSpline

def visualize_scalar_dist(s, nbins=30, c='b', ax=None):
    if ax is None:
        fig, ax = plt.subplots()
    p, x = np.histogram(s, bins=nbins) # bin it into n = N//10 bins
    x = x[:-1] + (x[1] - x[0])/2   # convert bin edges to centers
    f = UnivariateSpline(x, p, s=nbins)
    ax.plot(x, f(x), c=c)

    return fig, ax

def visualize_multivar_dist():
    pass

def visualize_dist(s, *args, **kwargs):
    if len(s.shape) > 1 or (len(s.shape) == 2 and s.shape[1] != 1):
        visualize_scalar_dist(s, *args, **kwargs)
    else:
        visualize_multivar_dist(s, *args, **kwargs)

In [None]:
import torch
import torch.nn.functional as F

In [None]:
class LinearClassifier(torch.nn.Module):
    def __init__(self, in_channels, num_classes, w, h, d):
        super(LinearClassifier, self).__init__()

        self.in_channels = in_channels
        self.num_classes = num_classes 
        self.width = w
        self.height = h
        self.depth = d
        self.classifier = torch.nn.Conv3d(in_channels, num_classes, (1,1,1))

    def forward(self, embeddings):
        embeddings = embeddings.reshape(-1, self.height, self.width, self.depth, self.in_channels)
        embeddings = embeddings.permute(0,4,1,2, 3)
        logits = self.classifier(embeddings)
        

        return logits


def train(model, train_dataloader, epochs, optimizer, criterion):
  # put model in training mode
  model.train()

  for epoch in range(epochs):
    print("Epoch:", epoch)
    for idx, batch in enumerate(train_dataloader):
        # pixel_values = batch["pixel_values"].to(device)
        # labels = batch["labels"].to(device)
        dp, gt = batch

        # forward pass
        out_logits = model(dp)
        if out_logits.size() != gt.size():
            while out_logits.size() != gt.size():
                out_logits = F.interpolate(out_logits, scale_factor=(2, 2, 2), mode='trilinear')
            
        dice_loss = criterion(out_logits, gt)
        dice_loss.backward()
        optimizer.step()

        # zero the parameter gradients
        optimizer.zero_grad()

In [None]:
def test(model, test_dataloader):
    pass

In [None]:
from torch.optim import AdamW
from Upstream.nnunet.training.loss_functions.dice_loss import SoftDiceLoss

model = LinearClassifier()

train_dataloader = None


# training hyperparameters
# NOTE: I've just put some random ones here, not optimized at all
# feel free to experiment, see also DINOv2 paper
learning_rate = 5e-5
epochs = 10

optimizer = AdamW(model.parameters(), lr=learning_rate)
criterion = SoftDiceLoss()

# put model on GPU (set runtime to GPU in Google Colab)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

train(model, train_dataloader, epochs=epochs, optimizer=optimizer)