## Imports

In [1]:
import torch
from torch import nn
import os
from os import path
import torchvision
import torchvision.transforms as T
from typing import Sequence
from torchvision.transforms import functional as F
import numbers
import random
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
import torchmetrics as TM
from dataclasses import dataclass
import dataclasses

  from .autonotebook import tqdm as notebook_tqdm


## Utils

In [2]:
def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    else:
        return torch.device("cpu")

# Send the Tensor or Model (input argument x) to the right device
# for this notebook. i.e. if GPU is enabled, then send to GPU/CUDA
# otherwise send to CPU.
def to_device(x):
    if torch.cuda.is_available():
        return x.cuda()
    else:
        return x.cpu()
    
def print_title(title):
    title_len = len(title)
    dashes = ''.join(["-"] * title_len)
    print(f"\n{title}\n{dashes}")
# end def

## Set up data

In [3]:
from pathlib import Path
from natsort import natsorted
import os
import numpy as np

def find_path_to_folder(tag):
    """
    Find folder that contains given tag.
    Returns the local path to this folder.
    """
    tag = str(Path(tag))
    workdir = os.getcwd()
    print(f'Looking for {tag} in: {workdir}')
    n = len(tag)
    found = False
    for (dir_path, dir_names, file_names) in os.walk(workdir , topdown=True):
      if tag == dir_path[-n:]:
        print(f'Found {dir_path}\n')
        found = dir_path
        break

    if not found:
        raise Exception("Couldn't find the folder")
    return Path(found)

def create_slice_matrix(path_to_image, fracture, label: str):
    """
    Read in all images of a given folder. Returns a matrix with all the images 
    stacked in the 3e dimension.
    """
    im_size = 64
    path_to_slices = os.path.join(path_to_image, fracture, label)
    slices = natsorted(os.listdir(path_to_slices))
    tmp_matrix = np.zeros((im_size, im_size, len(slices)))

    for i, s in enumerate(slices):
        pts = os.path.join(path_to_slices, s) # pts is path_to_slice  
        tmp_matrix[:, :, i] = np.load(pts)

    return tmp_matrix

def readin_slices(path_to_image_folder, image_inds: list = None) -> dict:
    """
    Creates and returns a dictionary with 3d numpy arrays of stacked images.
    Works as follows:

    dict["Index of image"]["Index of fracture]["neg/pos_image/pos_label"]
    
    Where the 3e dim (e.g: im[:,:,x]) is equal to the index of the slices.
    """
    # Locate image folder containing the fractures.
    all_paths = os.listdir(path_to_image_folder)
    filenames = [(p,ind) for p in all_paths for ind in image_inds if ind in p]

    d = {}
    for name, ind in filenames:
        d[ind] = {}
        path_to_image = path_to_image_folder.joinpath(name)

        # Iterate over fractures.
        fracs = os.listdir(path_to_image)
        for f in fracs:
            neg_matrix = create_slice_matrix(path_to_image, f, 'neg')
            pos_image_matrix = create_slice_matrix(path_to_image, f, 'pos_image')
            pos_label_matrix = create_slice_matrix(path_to_image, f, 'pos_label')

            d[ind][f] = {'neg': neg_matrix, 'pos_image': pos_image_matrix, 'pos_label': pos_label_matrix}
    
    return d

# path_to_image = find_path_to_folder('dataset')
d = readin_slices(Path('../dataset'), ['422', '423', '424'])

In [4]:
inputs, targets = [], []

for im_ind in d.keys():
    for frac in d[im_ind].keys():
        inputs.append(d[im_ind][frac]['pos_image'])
        targets.append(d[im_ind][frac]['pos_label'])

## ViT

### Patch models

In [5]:
# ImageToPatches returns multiple flattened square patches from an
# input image tensor.
class ImageToPatches(nn.Module):
    def __init__(self, image_size, patch_size):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.unfold = nn.Unfold(kernel_size=patch_size, stride=patch_size)
    # end def

    def forward(self, x):
        assert len(x.size()) == 4
        y = self.unfold(x)
        y = y.permute(0, 2, 1)
        return y
    # end def
# end class

print_title("ImageToPatches")
i2p = ImageToPatches(8, 4)


# The PatchEmbedding layer takes multiple image patches in (B,T,Cin) format
# and returns the embedded patches in (B,T,Cout) format.
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, embed_size):
        super().__init__()
        self.in_channels = in_channels
        self.embed_size = embed_size
        # A single Layer is used to map all input patches to the output embedding dimension.
        # i.e. each image patch will share the weights of this embedding layer.
        self.embed_layer = nn.Linear(in_features=in_channels, out_features=embed_size)
    # end def

    def forward(self, x):
        assert len(x.size()) == 3
        B, T, C = x.size()
        x = self.embed_layer(x)
        return x
    # end def
# end class

print_title("PatchEmbedding")
pe = PatchEmbedding(768, 256)


ImageToPatches
--------------

PatchEmbedding
--------------


### Vision transformer

In [6]:
class VisionTransformerInput(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_size):
        """in_channels is the number of input channels in the input that will be
        fed into this layer. For RGB images, this value would be 3.
        """
        super().__init__()
        self.i2p = ImageToPatches(image_size, patch_size)
        self.pe = PatchEmbedding(patch_size * patch_size * in_channels, embed_size)
        num_patches = (image_size // patch_size) ** 2
        # position_embed below is the learned embedding for the position of each patch
        # in the input image. They correspond to the cosine similarity of embeddings
        # visualized in the paper "An Image is Worth 16x16 Words"
        # https://arxiv.org/pdf/2010.11929.pdf (Figure 7, Center).
        self.position_embed = nn.Parameter(torch.randn(num_patches, embed_size))
    # end def

    def forward(self, x):
        x = self.i2p(x)
        # print(x.shape)
        x = self.pe(x)
        x = x + self.position_embed
        return x
    # end def
# end class

print_title("VisionTransformerInput")
# Original
# x = torch.randn(10, 3, 224, 224)
# Custom
# x = torch.randn(10, 1, 224, 224)

# Original
# vti = VisionTransformerInput(224, 16, 3, 256)
# Custom
vti = VisionTransformerInput(224, 16, 1, 256)
# y = vti(x)
# print(f"{x.shape} -> {y.shape}")



VisionTransformerInput
----------------------


### Vision transformer segmentation building blocks

In [7]:


# The MultiLayerPerceptron is a unit of computation. It expands the input
# to 4x the number of channels, and then contracts it back into the number
# of input channels. There's a GeLU activation in between, and the layer
# is followed by a droput layer.
class MultiLayerPerceptron(nn.Module):
    def __init__(self, embed_size, dropout):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(embed_size, embed_size * 4),
            nn.GELU(),
            nn.Linear(embed_size * 4, embed_size),
            nn.Dropout(p=dropout),
        )
    # end def

    def forward(self, x):
        return self.layers(x)
    # end def
# end class

print_title("MultiLayerPerceptron")
mlp = MultiLayerPerceptron(60, dropout=0.2)


# This is a single self-attention encoder block, which has a multi-head attention
# block within it. The MultiHeadAttention block performs communication, while the
# MultiLayerPerceptron performs computation.
class SelfAttentionEncoderBlock(nn.Module):
    def __init__(self, embed_size, num_heads, dropout):
        super().__init__()
        self.embed_size = embed_size
        self.ln1 = nn.LayerNorm(embed_size)
        # self.kqv = nn.Linear(embed_size, embed_size * 3)
        self.mha = nn.MultiheadAttention(embed_size, num_heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(embed_size)
        self.mlp = MultiLayerPerceptron(embed_size, dropout)
    # end def

    def forward(self, x):
        y = self.ln1(x)
        # y = self.kqv(x)
        # (q, k, v) = torch.split(y, self.embed_size, dim=2)
        x = x + self.mha(y, y, y, need_weights=False)[0]
        x = x + self.mlp(self.ln2(x))
        return x
    # end def
# end class

print_title("SelfAttentionEncoderBlock")
attention_block = SelfAttentionEncoderBlock(256, 8, dropout=0.2)


# Similar to the PatchEmbedding class, we need to un-embed the representation
# of each patch that has been produced by our transformer network. We project
# each patch (that has embed_size) dimensions into patch_size*patch_size*output_dims
# channels, and then fold all the patches back to make it look like an image.
class OutputProjection(nn.Module):
    def __init__(self, image_size, patch_size, embed_size, output_dims):
        super().__init__()
        self.patch_size = patch_size
        self.output_dims = output_dims
        self.projection = nn.Linear(embed_size, patch_size * patch_size * output_dims)
        self.fold = nn.Fold(output_size=(image_size, image_size), kernel_size=patch_size, stride=patch_size)
    # end def

    def forward(self, x):
        B, T, C = x.shape
        x = self.projection(x)
        # x will now have shape (B, T, PatchSize**2 * OutputDims). This can be folded into
        # the desired output shape.

        # To fold the patches back into an image-like form, we need to first
        # swap the T and C dimensions to make it a (B, C, T) tensor.
        x = x.permute(0, 2, 1)
        x = self.fold(x)
        return x
    # end def
# end class

print_title("OutputProjection")
# Original
# op = OutputProjection(224, 16, 256, 3)
# Custom
op = OutputProjection(224, 16, 256, 1)


class VisionTransformerForSegmentation(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, out_channels, embed_size, num_blocks, num_heads, dropout):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.embed_size = embed_size
        self.num_blocks = num_blocks
        self.num_heads = num_heads
        self.dropout = dropout

        heads = [ SelfAttentionEncoderBlock(embed_size, num_heads, dropout) for i in range(num_blocks) ]
        self.layers = nn.Sequential(
            nn.BatchNorm2d(num_features=in_channels),
            VisionTransformerInput(image_size, patch_size, in_channels, embed_size),
            nn.Sequential(*heads),
            OutputProjection(image_size, patch_size, embed_size, out_channels),
        )
    # end def

    def forward(self, x):
        x = self.layers(x)
        return x
    # end def
# end class

@dataclass
class VisionTransformerArgs:
    """Arguments to the VisionTransformerForSegmentation."""
    image_size: int = 128
    patch_size: int = 16
    in_channels: int = 1 # Number of layers in image (RGB=3), set to 1 for grayscale
    out_channels: int = 1 # N layers that should be put out eventually by the network, set to 1 for grayscale again
    embed_size: int = 768
    num_blocks: int = 12
    num_heads: int = 8
    dropout: float = 0.2
# end class

print_title("VisionTransformerForSegmentation")
# x = torch.randn(2, 1, 128, 128)
vit_args = dataclasses.asdict(VisionTransformerArgs())

vit = VisionTransformerForSegmentation(**vit_args)
# y = vit(x)
# print(f"{x.shape} -> {y.shape}")
# print_model_parameters(vit)


MultiLayerPerceptron
--------------------

SelfAttentionEncoderBlock
-------------------------

OutputProjection
----------------

VisionTransformerForSegmentation
--------------------------------


## Loss fuction

In [8]:
# Define a custom IoU Metric for validating the model.
def IoUMetric(pred, gt, softmax=False):
    # Run softmax if input is logits.
    if softmax is True:
        pred = nn.Softmax(dim=1)(pred)
    # end if

    # Add the one-hot encoded masks for all 3 output channels
    # (for all the classes) to a tensor named 'gt' (ground truth).
    gt = torch.cat([ (gt == i) for i in range(3) ], dim=1)
    # print(f"[2] Pred shape: {pred.shape}, gt shape: {gt.shape}")

    intersection = gt * pred
    union = gt + pred - intersection

    # Compute the sum over all the dimensions except for the batch dimension.
    iou = (intersection.sum(dim=(1, 2, 3)) + 0.001) / (union.sum(dim=(1, 2, 3)) + 0.001)

    # Compute the mean over the batch dimension.
    return iou.mean()

class IoULoss(nn.Module):
    def __init__(self, softmax=False):
        super().__init__()
        self.softmax = softmax

    # pred => Predictions (logits, B, 3, H, W)
    # gt => Ground Truth Labales (B, 1, H, W)
    def forward(self, pred, gt):
        # return 1.0 - IoUMetric(pred, gt, self.softmax)
        # Compute the negative log loss for stable training.
        return -(IoUMetric(pred, gt, self.softmax).log())
    # end def
# end class

def test_custom_iou_loss():
    #               B, C, H, W
    x = torch.rand((2, 3, 2, 2), requires_grad=True)
    y = torch.randint(0, 3, (2, 1, 2, 2), dtype=torch.long)
    z = IoULoss(softmax=True)(x, y)
    return z
# end def

test_custom_iou_loss()

tensor(1.6861, grad_fn=<NegBackward0>)

## Set up training

### Actual training

In [29]:
# Train the model for a single epoch
def train_model(model, loader, optimizer):
    to_device(model.train())
    cel = True
    if cel:
        criterion = nn.CrossEntropyLoss(reduction='mean')
    else:
        criterion = IoULoss(softmax=True) # Hasn't be adjusted for 2D images jochem
    # end if

    running_loss = 0.0
    running_samples = 0

    inputs, targets = loader
    for batch_idx, (inps, targs) in enumerate(zip(inputs, targets), 0):
        optimizer.zero_grad()
        inps = torch.from_numpy(inps)
        targs = torch.from_numpy(targs)
        # targets = targets.type(torch.long)

        inps = to_device(inps)
        targs = to_device(targs)
        outputs = model(inps)

        # The ground truth labels have a channel dimension (NCHW).
        # We need to remove it before passing it into
        # CrossEntropyLoss so that it has shape (NHW) and each element
        # is a value representing the class of the pixel.
        targs = targs.type(torch.long)
        if cel:
            targs = targs.squeeze(dim=1)
        # end if
        unqs = targs.unique()
        print("UNGKS:", unqs)
        print("Targests:", targs.shape)
        print("Outputs:", outputs.shape)
        value = outputs[0,0,0,0].item()
        print("Value in the output:", value)
        print("output type:", type(value), '\n')

        value = targs[0,0,0].item()
        print("Value in the target:", value)
        print("target shape:", targs.shape)
        print("Value type:", type(value))

        # sys.exit()
        loss = criterion(outputs, targs)
        loss.backward()
        optimizer.step()

        running_samples += targs.size(0)
        running_loss += loss.item()
    # end for

    print("Trained {} samples, Loss: {:.4f}".format(
        running_samples,
        running_loss / (batch_idx+1),
    ))
# end def

### Define training loop

In [30]:
#  Define training loop. This will train the model for multiple epochs.
#
# epochs: A tuple containing the start epoch (inclusive) and end epoch (exclusive).
#         The model is trained for [epoch[0] .. epoch[1]) epochs.
#
def train_loop(model, loader, test_data, epochs, optimizer, scheduler, save_path):
    test_inputs, test_targets = test_data
    epoch_i, epoch_j = epochs
    for i in range(epoch_i, epoch_j):
        epoch = i
        print(f"Epoch: {i:02d}, Learning Rate: {optimizer.param_groups[0]['lr']}")
        train_model(model, loader, optimizer)
        with torch.inference_mode():
        #     # Display the plt in the final training epoch.
        #     # (epoch == epoch_j-1)
            custom_test = True
            if i == 15:
                print("got here")

                to_device(model.eval())
                inputs, targets = loader
                for batch_idx, (inps, targs) in enumerate(zip(inputs, targets), 0):
                    inps = torch.from_numpy(inps)
                    inps = to_device(inps)
                    preds = model(inps)
                    pred = nn.Softmax(dim=1)(preds)
                    print("PREDS:", torch.unique(pred))
                    print("PREDS", pred.shape)
                break
        #     print_test_dataset_masks(model, test_inputs, test_targets, epoch=epoch, save_path=save_path, show_plot=True)
        # # end with

        

        if scheduler is not None:
            scheduler.step()
        # end if
        print("")
    # end for
# end def

### Run training

In [31]:
m = vit
images_folder_name = "vit_training_progress_images"
save_path = os.path.join('output', images_folder_name)

In [32]:
type(m)

__main__.VisionTransformerForSegmentation

In [33]:
# Optimizer and Learning Rate Scheduler.
to_device(m)
optimizer = torch.optim.Adam(m.parameters(), lr=0.0004)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=12, gamma=0.8)

### Call training function

In [34]:
# Original:
"""
train_loop(m, pets_train_loader, (test_pets_inputs, test_pets_targets), (1, 51), optimizer, scheduler, save_path=save_path)
m = model
pets_train_loader = torch dataloader object
We will split this into a inputs and targets vars for the sake of fast implementation.

test_pets_inputs = Used for unnecesary things and removed from training loop
test_pets_targets = Used for unnecesary things and removed from training loop
"""

test_pets_inputs = None
test_pets_targets = None

# Custom:
# train_loop(m, (inputs, targets), (test_pets_inputs, test_pets_targets), (1, 51), optimizer, scheduler, save_path=save_path)


In [35]:
def surround_the_ones(m_arr):
    updated_arr = m_arr.clone()

    # Get the indices of ones in the 'test' tensor
    ones_indices = torch.nonzero(m_arr == 1)

    # Define the neighborhood indices for the surrounding elements
    neighborhood_indices = torch.tensor([
        [1, 0], [-1, 0], [0, 1], [0, -1],
        [1, 1], [-1, -1], [1, -1], [-1, 1]
    ])
    # Place ones next to twos only if there is a zero
    for index in ones_indices:
        updated_arr[tuple(index)] = 2  # Change the one to two
        for offset in neighborhood_indices:
            neighbor_index = index + offset
            if (0 <= neighbor_index[0] < m_arr.shape[0] and
                0 <= neighbor_index[1] < m_arr.shape[1] and
                updated_arr[tuple(neighbor_index)] == 0):
                updated_arr[tuple(neighbor_index)] = 1  # Place ones only if the neighbor is zero

    # Print the updated 'test' tensor
    # print("Updated 'test' tensor:")
    # print(updated_arr)
    return updated_arr

    

# Testing

In [36]:
import torchvision.transforms as transforms 
from PIL import ImageOps

img_size_custom = 128
examples = 21

inputs = Image.open("../raw_data/test/dikke_kippen/kip2.jpg")

inputs = ImageOps.grayscale(inputs) 

segmask = np.load("../dataset/RibFrac421-image/frac_0/pos_label/pos-slice-0-label.npy")
segmask = torch.from_numpy(segmask)
img2tensor = transforms.Compose([transforms.PILToTensor()])

img_tensor = img2tensor(inputs)
img_tensor = img_tensor[:, 0:128, 0:128]

# large_tensor = torch.rand(100, 100)  # A large tensor for demonstration

# # Set the print options to display the entire tensor and make it wider
# torch.set_printoptions(threshold=large_tensor.numel(), linewidth=200)

segmask= surround_the_ones(segmask)

# print("UNIQUES:", torch.unique(segmask))
# # sys.exit()
# # Set the print options to display the entire tensor
# print(torch.unique(segmask_test))
# print(segmask_test.shape)
# print(segmask_test)
# for row in segmask_test:
    # print(row)
# sys.exit()

inputs_batch = torch.zeros((examples, 1, img_size_custom, img_size_custom))
for i in range(examples):
    inputs_batch[i,:,:,:] = img_tensor

tmp = torch.zeros((img_size_custom, img_size_custom))
tmp[0:64, 0:64] = segmask
segmask_batch = torch.zeros((examples, 1, img_size_custom, img_size_custom))
for i in range(examples):
    segmask_batch[i,:,:,:] = tmp

inputs_batch = inputs_batch.numpy()
segmask_batch = segmask_batch.numpy()

print(segmask_batch.shape)

print(inputs_batch.shape)
# print(inputs_batch.shape)
# sys.exit()


train_loop(m, ([inputs_batch], [segmask_batch]), (test_pets_inputs, test_pets_targets), (1, 51), optimizer, scheduler, save_path=save_path)



(21, 1, 128, 128)
(21, 1, 128, 128)
Epoch: 01, Learning Rate: 0.0004
UNGKS: tensor([0, 1, 2])
Targests: torch.Size([21, 128, 128])
Outputs: torch.Size([21, 1, 128, 128])
Value in the output: -0.7524198293685913
output type: <class 'float'> 

Value in the target: 0
target shape: torch.Size([21, 128, 128])
Value type: <class 'int'>


IndexError: Target 1 is out of bounds.