In [1]:
!pip install git+https://github.com/jordanIAxelrod/ShapeModel

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/jordanIAxelrod/ShapeModel
  Cloning https://github.com/jordanIAxelrod/ShapeModel to /tmp/pip-req-build-y68yu7_n
  Running command git clone --filter=blob:none --quiet https://github.com/jordanIAxelrod/ShapeModel /tmp/pip-req-build-y68yu7_n
  Resolved https://github.com/jordanIAxelrod/ShapeModel to commit 2ecb240a96d1453df187c07ba35853eb47970c63
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: ShapeModelIMCP
  Building wheel for ShapeModelIMCP (pyproject.toml) ... [?25l[?25hdone
  Created wheel for ShapeModelIMCP: filename=ShapeModelIMCP-0.0.1-py3-none-any.whl size=14497 sha256=2184cf3e68c6cfc4d867c29d944a7336ace53f23cace652c76ad185c99f626d0

In [5]:
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import shapeModel as ShapeModel
import torch
import nibabel as nib
import skimage.measure

Our data consists of the masks of several heart valves. This function extracts the boundary of these masks.

In [None]:
def get_outline(shape):
    point_list = []
    shape = skimage.measure.block_reduce(shape, (4,4,4))
    for i in range(shape.shape[0]):
        for j in range(shape.shape[1]):
            for k in range(shape.shape[2]):
                if np.any(shape[i - 1: i + 2, j - 1: j + 2, k - 1: k + 2] == 0) and shape[i, j, k] > 0:
                    point_list.append([i, j, k])
    return np.array(point_list)

In [None]:
# Reads the data from the file system. Get the outline

def read_data(folder, leave_out=1):
    dataframe = []
    cwd = os.getcwd()
    os.chdir(folder)
    curr_dir = os.listdir()
    curr_dir = curr_dir[:leave_out] + curr_dir[leave_out + 1:]

    for direct in curr_dir:
        os.chdir(direct)
        file = os.listdir()[0]

        shape_cloud = nib.load(file).get_fdata()
        os.chdir('..')
        shape_cloud = get_outline(shape_cloud)

        dataframe.append(shape_cloud)
    min_len = min(dataframe, key=lambda x: x.shape[0]).shape[0]
    for i, data in enumerate(dataframe):
        choice = np.random.choice(data.shape[0], size=(min_len,), replace=False)
        dataframe[i] = torch.Tensor(data[choice]).unsqueeze(0)

    os.chdir(cwd)
    return torch.cat(dataframe, dim=0)

In [None]:
# Create the model and fit it. Save if told to
def create_ICMP_Model(data, verbose=True, save=False):
    model = ShapeModel.ShapeModel()
    model(data, verbose=verbose)
    if save:
        model.save()
    return model

In [None]:
# We now run the model twenty times. one for each piece of data
# We test the generality of the model by predicting the left out shape on each 
# model.

# Expect this to take a few minutes

PATH = r"C:\Users\jda_s\Box\bone_project\heart_dataset\masks"
generality = {}
for i in range(20):
    ssm = create_ICMP_Model(read_data(PATH), i==0)
    # ssm = IO.load('hi', r"C:\Users\jda_s\OneDrive\Documents\Research\ShapeModel\model\20230209-121010 ICMP.pickle")
    ssm.get_explained_variance()
    print(ssm.eig_vecs)
    print(ssm.mean_shape, ssm.mean_shape.shape)
    cwd = os.getcwd()
    print(cwd)
    os.chdir(PATH)
    curr_dir = os.listdir()[i]
    os.chdir(curr_dir)
    shape = os.listdir()[0]
    new_shape = nib.load(shape).get_fdata()
    new_shape = get_outline(new_shape)
    choice = np.random.choice(new_shape.shape[0], size=(927,), replace=False)
    new_shape = torch.Tensor(new_shape[choice]).unsqueeze(0)
    reg_shape = ssm.register_new_shapes(torch.Tensor(new_shape))
    generality[i] = []
    for j in range(1, ssm.eig_vals.shape[0]):

        new_shape1 = ssm.create_shape_approx(reg_shape, j + 1)
        dist = torch.sqrt(torch.sum(torch.square(reg_shape - new_shape1)) / (927 * 3))
        generality[i].append(dist)
    ax = plt.axes(projection='3d')
    plt.title('Reconstructions')
    ax.scatter(new_shape1[0, :, 0], new_shape1[0, :, 1], new_shape1[0, :, 2])
    ax.scatter(reg_shape[0, :, 0], reg_shape[0, :, 1], reg_shape[0, :, 2])
    os.chdir(cwd)
    print(cwd)
    plt.savefig('../img/Reconstruction.png')
    plt.show()
averages = []
for i in range(len(generality[0])):
    average = sum([generality[k][i] for k in generality.keys()]) / len(generality)
    averages.append(average)
plt.title("Generality")
plt.xlabel("PC Number")
plt.plot(list(range(len(averages))), averages)
plt.savefig('../img/Generality.png')
plt.show()