# CAMP II Surgical Workflow Analysis Exercise

It is highly recommended to complete the CAMP II Exercises on Classification and Segmentation before this one!  
Feel free to reuse code you wrote for these other exercises here, it might be useful. Make sure to activate the GPU in colab before you start this exercise.

In this exercise, we will be making first steps towards workflow analysis from endoscopic data. Specifically, we will be extracting the surgical phases, such as "Preparation" or "ColotTriangleDissection". 

We will be working with an extended version of the dataset used in the exercise classification and segmentation. The extension allows us to observe the full workflow steps in every video.

This notebook has many codeblocks already in place to help you get started. Places where you have to add your own code are clearly marked with "TASK" and lines ("-----"). When a variable you have to implement is used later on, we placed a name and description in the task bracket (see example below). These markings are only there to guide you toward what you have to implement to complete the exercise, feel free to experiment beyond them.






In [None]:
# TASK: description of the task you need to do ---------------------------------
# my_variable_name: a variable that is used later on, so the name should be right

# ------------------------------------------------------------------------------

When looking for solutions to the tasks below, definitely consider online resources such as the ones linked below:

Pytorch (torch) documentation: <br>
https://pytorch.org/docs/stable/index.html

Pytorch Vision (torchvision) documentation: <br>
https://pytorch.org/vision/stable/index.html

In [1]:
# download the dataset to your notebook
# if you get access denied, retry after a minute
!gdown 1MwtcHceqj8FchmsIR92pRb0tR-QZ9uh4
!unzip -qq liver_endoscopy_dataset_workflow.zip 
!rm liver_endoscopy_dataset_workflow.zip
# this block should take 1 min

Access denied with the following error:

 	Too many users have viewed or downloaded this file recently. Please
	try accessing the file again later. If the file you are trying to
	access is particularly large or is shared with many people, it may
	take up to 24 hours to be able to view or download the file. If you
	still can't access a file after 24 hours, contact your domain
	administrator. 

You may still be able to access the file from the browser:

	 https://drive.google.com/uc?id=1MwtcHceqj8FchmsIR92pRb0tR-QZ9uh4 

unzip:  cannot find or open liver_endoscopy_dataset_workflow.zip, liver_endoscopy_dataset_workflow.zip.zip or liver_endoscopy_dataset_workflow.zip.ZIP.
rm: cannot remove ‘liver_endoscopy_dataset_workflow.zip’: No such file or directory


In [None]:
!pip -qq install pytorch_lightning==1.6.2

We import the libraries that will be useful

In [None]:
import json
import random
from collections import defaultdict
from pathlib import Path
from typing import Optional

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchvision
from PIL import Image
from sklearn.metrics import classification_report
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import BoundaryNorm, ListedColormap

## 0. Dataset
We load the liver endoscopy dataset here. This cell defines the video splits, correctly loads the dataset depending on the task.

In [None]:
video_splits = {'train': ['01', '02', '05', '13', '15', '18', '22'], 'val': ['08', '29', '50'], 'test': ['06', '10', '42']}


class LiverEndoscopy(Dataset):
    def __init__(self, split: str = 'train', balance_data: bool = False, temporal: bool = False,
                 pil_transform: Optional[transforms.Compose] = None, tensor_transform: Optional[transforms.Compose] = None, augmentation=None):
        assert split in ['train', 'val', 'test']
        self.split = split
        self.balance_data = balance_data
        self.temporal = temporal
        self.pil_transform = pil_transform
        self.tensor_transform = tensor_transform
        self.augmentation = augmentation

        self.phases_to_indices = {'Preparation': 0, 'CalotTriangleDissection': 1, 'ClippingCutting': 2, 'GallbladderDissection': 3, 'GallbladderPackaging': 4,
                                  'CleaningCoagulation': 5, 'GallbladderRetraction': 6}
        self.indices_to_phases = {value: key for key, value in self.phases_to_indices.items()}

        export_dataset_path = Path('data_workflow')
        self.images_path = export_dataset_path / 'images'
        with open(export_dataset_path / 'phase_annotations.json', 'r') as f:
            self.workflow_phase_annotations = json.load(f)

        if not temporal:
            self.image_names = []
            for image_path in sorted(self.images_path.glob('*.png')):
                video_id = image_path.name.split('_')[0].replace('video', '')
                if video_id in video_splits[split]:
                    self.image_names.append(image_path.name.replace('.png', ''))
            self.image_names = sorted(self.image_names)

        else:
            self.window_size = 8
            self.downsample_factor = 25
            all_image_names = sorted(self.images_path.glob('*.png'))
            all_video_ids = {image_name.name.split('_')[0].replace('video', '') for image_name in all_image_names}
            all_split_video_ids = {video_id for video_id in all_video_ids if video_id in video_splits[split]}
            self.windows = []
            for video_id in all_split_video_ids:
                sequence_images = sorted(self.images_path.glob(f'video{video_id}_*.png'))
                sequence_image_indices = [int(image_name.name.split('_')[1].replace('.png', '')) for image_name in sequence_images]
                for i in range(len(sequence_images) - self.window_size + 1):
                    self.windows.append((video_id, sequence_image_indices[i:i + self.window_size]))

        if balance_data:
            self.do_balance_data(temporal)

    def do_balance_data(self, temporal):
        print('Balancing data by oversampling under-represented classes...')
        class_to_samples = defaultdict(list)
        if not temporal:
            for image_name in self.image_names:
                label = self.workflow_phase_annotations[image_name]
                class_to_samples[label].append(image_name)
            max_number = max([len(elem) for elem in class_to_samples.values()])
            self.image_names = []
            for key, value in class_to_samples.items():
                if len(value) < max_number:
                    self.image_names += random.choices(value, k=max_number)
                else:
                    self.image_names += value
            random.shuffle(self.image_names)
        else:
            for video_id, window in self.windows:
                label = self.workflow_phase_annotations[f'video{video_id}_{str(window[-1]).zfill(6)}']
                class_to_samples[label].append((video_id, window))
            max_number = max([len(elem) for elem in class_to_samples.values()])
            self.windows = []
            for key, value in class_to_samples.items():
                if len(value) < max_number:
                    self.windows += random.choices(value, k=max_number)
                else:
                    self.windows += value
            random.shuffle(self.windows)

    def __len__(self):
        if self.temporal:
            return len(self.windows)
        else:
            return len(self.image_names)

    def phase_label_to_number(self, phase_label):
        return self.phases_to_indices[phase_label]

    def number_to_phase_label(self, phase_number):
        return self.indices_to_phases[phase_number]

    def __getitem__(self, index):
        if self.temporal:
            video_id, window = self.windows[index]
            image_names = []
            for frame_number in window:
                image_names.append(f'video{video_id}_{str(frame_number).zfill(6)}.png')

            phase = self.phase_label_to_number(self.workflow_phase_annotations[image_names[-1].replace('.png', '')])
            return {'image_names': image_names, 'phase': phase}
        else:
            image_name = self.image_names[index]
            image_path = self.images_path / f'{image_name}.png'
            image = Image.open(image_path)
            if self.pil_transform is not None:
                image = self.pil_transform(image)
            if self.augmentation is not None:
                image = self.augmentation(image)
            image_tensor = transforms.ToTensor()(image)
            if self.tensor_transform is not None:
                image_tensor = self.tensor_transform(image_tensor)

            phase = self.phase_label_to_number(self.workflow_phase_annotations[image_name])

            return {'image': image_tensor, 'phase': phase, 'image_name': image_name}


## A. Workflow Recognition without Temporal Modelling

First, we will try to recognize the different phases of the surgical workflow from each frame of the video. This is equivalent to a multi-class classification problem on each frame.

### A.1 Load the data as images with phase labels


In [None]:
# Fix pytorch_lightning seed
pl.seed_everything(42)

# load data
pil_transform = transforms.Compose([transforms.Resize((224, 224))])
augmentation = transforms.Compose([transforms.RandomHorizontalFlip(),
                                    transforms.ColorJitter(brightness=.1, hue=.1),
                                    transforms.RandomRotation(degrees=(0, 30)),
                                    transforms.RandomResizedCrop((224, 224), scale=(0.7, 1.0))])

train_dataset = LiverEndoscopy(split='train', balance_data=False, temporal=False, pil_transform=None, augmentation=augmentation)
val_dataset = LiverEndoscopy(split='val', balance_data=False, temporal=False, pil_transform=pil_transform)
test_dataset = LiverEndoscopy(split='test', balance_data=False, temporal=False, pil_transform=pil_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

#### Explore the data

In [None]:
# TASK: Load a random sample from the dataset, visualize the image and its phase label


# ------------------------------------------------------------------------------

### A.2 Load and train a workflow recognition model
We already provide most of the code skeleton in this part. To finish this model, you will need to make the following modifications.
1. An image processing model (resnet18, pretrained=True), which you can find in torchvision.
2. Disable the final layer (fc) of the image model
3. Define a new linear layer, that takes as input the output of the image model, and outputs 7 nodes. These will correspond to the 7 classes.
4. Define the forward function, which uses the image model and the linear layer
5. Complete the training, validation and test steps, where the models forward function is called to get a prediction, and then a cross entropy loss between the correct phase and the prediction is computed. Tip: F.cross_entropy()


In [None]:
class ModelWrapper(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # TASK: define image model, disable final layer, add linear layer ------


        # ----------------------------------------------------------------------       

        self.train_preds = []
        self.train_gts = []
        self.val_preds = []
        self.val_gts = []
        self.test_preds = []
        self.test_gts = []
        self.reset_metrics()

        self.train_loss = []
        self.val_loss = []
        self.test_loss = []

        self.phase_names = ['Preparation', 'ColotTriangleDissection', 'ClippingCutting', 'GallbladderDissection', 'GallbladderPackaging', 'CleaningCoagulation',
                            'GallbladderRetraction']

    def forward(self, x):
        #TASK use the image model and linear layer to get a prediction ---------


        # ----------------------------------------------------------------------
        return x

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # TASK: get a prediction from the model and calculate cross_entropy loss.
        # y_hat: prediction from the model
        # loss: calculated loss from the model

        # ----------------------------------------------------------------------       
        self.update_metrics(batch['phase'], y_hat, split='train')
        self.train_loss.append(loss.item())
        return {'loss': loss}

    def validation_step(self, batch, batch_idx):
        # validation_step defines the validation loop.
        # TASK: get a prediction from the model and calculate cross_entropy loss.
        # y_hat: prediction from the model
        # loss: calculated loss from the model

        # ----------------------------------------------------------------------
        self.update_metrics(batch['phase'], y_hat, split='val')
        self.val_loss.append(loss.item())
        return {'val_loss': loss}

    def test_step(self, batch, batch_idx):
        # test_step defines the test loop.
        # TASK: get a prediction from the model and calculate cross_entropy loss.
        # y_hat: prediction from the model
        # loss: calculated loss from the model

        # ----------------------------------------------------------------------
        self.update_metrics(batch['phase'], y_hat, split='test')
        self.test_loss.append(loss.item())
        return {'test_loss': loss}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
        return optimizer

    def reset_metrics(self, split=None):
        if split == 'train':
            self.train_preds = []
            self.train_gts = []
        elif split == 'val':
            self.val_preds = []
            self.val_gts = []
        elif split == 'test':
            self.test_preds = []
            self.test_gts = []
        else:
            self.train_preds = []
            self.train_gts = []
            self.val_preds = []
            self.val_gts = []
            self.test_preds = []
            self.test_gts = []

    def update_metrics(self, gt, pred, split='train'):
        if split == 'train':
            self.train_preds.extend(pred.detach().cpu().numpy().argmax(1))
            self.train_gts.extend(gt.detach().cpu().numpy())
        elif split == 'val':
            self.val_preds.extend(pred.detach().cpu().numpy().argmax(1))
            self.val_gts.extend(gt.detach().cpu().numpy())
        elif split == 'test':
            self.test_preds.extend(pred.detach().cpu().numpy().argmax(1))
            self.test_gts.extend(gt.detach().cpu().numpy())
        else:
            raise NotImplementedError()

    def training_epoch_end(self, outputs):
        self.evaluate_predictions(split='train')
        self.reset_metrics(split='train')

    def validation_epoch_end(self, outputs):
        self.evaluate_predictions(split='val')
        self.reset_metrics(split='val')
    
    def test_epoch_end(self, outputs):
        self.evaluate_predictions(split='test')
        self.reset_metrics(split='test')

    def evaluate_predictions(self, split):
        if split == 'train':
            preds = self.train_preds
            gts = self.train_gts
        elif split == 'val':
            preds = self.val_preds
            gts = self.val_gts
        elif split == 'test':
            preds = self.test_preds
            gts = self.test_gts
        else:
            raise NotImplementedError()

        cls_report = classification_report(gts, preds, labels=list(range(len(self.phase_names))),
                                           target_names=self.phase_names)
        print(split)
        print(cls_report)

#### Create the model

In [None]:
# TASK: create the model -------------------------------------------------------
# model: your model


# ------------------------------------------------------------------------------

#### Train the model

In [None]:
# TASK: train the model --------------------------------------------------------
# TIP: use pl.Trainer, make sure to use the gpu and train for 4 epochs, -> max_epochs=4. Make sure to call the fit function with both the train and validation loaders to get correct evaluations during training
# TIP: Training should take around 10 mins

# trainer: your trainer

# ------------------------------------------------------------------------------

#### Save the model!

Training a neural network takes some time. If you don't want to loose this progress, e.g. because you need to take a break from this exercise, make sure to save the model with the code below. Then download the created file '\<myModel\>.pt'!

If you want to continue with your model later, you can use the code provided below to load your model after uploading it to this notebook.

In [None]:
# save model
torch.save(model.state_dict(), 'model_state_dict.pt')

# load model (instead of training if you have a saved model)
# model.load_state_dict(torch.load('model_state_dict.pt'))

### A.3 Evaluate the results
#### Test the model on the unseen test set



In [None]:
# TASK: Test the model on the unseen test set ----------------------------------


# ------------------------------------------------------------------------------

#### Plot the loss to see the training progress

In [None]:
# TASK: plot loss --------------------------------------------------------------


# ------------------------------------------------------------------------------

#### Visualize Predictions

Similar to before, get a random sample from the test_dataset, run it through the model in evaluation mode to get a prediction, then visualize the image, the prediction and the ground truth. Get an example for every phase from the ground truth and look at the image and network prediction.

In [None]:
# TASK: visualize predition results --------------------------------------------


# ------------------------------------------------------------------------------

#### Visualize the predictions for a whole video

To see a complete predicted workflow we want to plot the predictions for a whole video as a sequence. To achieve that we first have to get the predictions for a whole video in sequence from the model and then plot them with the ground truth for comparison.

In [None]:
device = torch.device("cuda")
# TASK: set up the model for evaluation ----------------------------------------
# TIP: freeze the model, set to eval mode, transfer to gpu


# ------------------------------------------------------------------------------

video_name = "video06" # "video10" , "video42"
video_preds = []
video_labels = []
for batch in tqdm(test_loader):
  if batch['image_name'][0].startswith(video_name):
    with torch.no_grad():
      # TASK: transfer the input to the gpu and evaluate with your model -------
      # batch_preds: Predictions of you model for the current batch


      # ------------------------------------------------------------------------
      for image_name, pred, label in zip(batch['image_name'], batch_preds, batch['phase']):
          if image_name.startswith(video_name):
            video_preds.append(pred.detach().cpu().numpy().argmax())
            video_labels.append(label.detach().cpu().numpy())

In [None]:
# Plot the predicted workflow
fig = plt.figure()
cmap= ListedColormap(sns.color_palette("muted", as_cmap=True))
levels= list(range(len(model.phase_names) + 1))
norm = BoundaryNorm(levels, ncolors=cmap.N, clip=True)

ax1 = fig.add_axes([0, 0.4, 2, 0.3])
barprops1 = dict(aspect='auto', cmap=cmap, norm=norm, interpolation='nearest')
im1 = ax1.imshow(np.array(video_preds).reshape(1,-1), **barprops1)
ax1.set_axis_off()
ax1.set_title('Prediction')

ax2 = fig.add_axes([0, 0, 2, 0.3])
barprops2 = dict(aspect='auto', cmap=cmap, norm=norm, interpolation='nearest')
im2 = ax2.imshow(np.array(video_labels).reshape(1,-1), **barprops2)
ax2.set_axis_off()
ax2.set_title('Ground Truth')

cbar_ax = fig.add_axes([2.1, 0, 0.05, 0.7])
cbar = fig.colorbar(im1, cax=cbar_ax)
cbar.set_ticks([x + 0.5 for x in range(7)])
cbar.set_ticklabels(model.phase_names)
cbar.ax.invert_yaxis()

## B. Workflow Recognition with Temporal Modelling

Recognizing the current phase of the surgery from a single image can be very difficult, even for a human expert. Including the temporal context in your phase prediction can be very helpful, so we will try to create a machine learning model that can do just that.

### B.1 Load the data as video clips with a phase label

For this exercise we want to include the temporal context in each prediction. We accomplish this by looking at the surgical videos contained in the dataset not as individual frames, but as short video clips.
Load the data so that you have sequences of 8 <b>consecutive</b> frames and the phase label of the last frame of each such sequence.

In [None]:
# load data
train_dataset_temp = LiverEndoscopy(split='train', balance_data=False, temporal=True)
val_dataset_temp = LiverEndoscopy(split='val', balance_data=False, temporal=True)
test_dataset_temp = LiverEndoscopy(split='test', balance_data=False, temporal=True)
train_loader_temp = DataLoader(train_dataset_temp, batch_size=32, shuffle=True, num_workers=2)
val_loader_temp = DataLoader(val_dataset_temp, batch_size=32, shuffle=False, num_workers=2)
test_loader_temp = DataLoader(test_dataset_temp, batch_size=32, shuffle=False, num_workers=2)

### B.2 Extract Image Features

We will use the model you previously trained as a feature extractor. This means, we will run every image in all the datasets through the model, and get the image features (not the phase predictions).
TASK: To this end, you first need to disable the linear layer of the model.
Make sure to set it  to evaluate model, transfer it to gpu, and freeze it as well.

Then we create a dictionary called all_features, where the keys will be the image names, and the values will be the corresponding image features.

In [None]:
# TASK: set model up to extract features ---------------------------------------
# TIP: freeze the model, set to eval mode, transfer to gpu and disable the linear layer


# ------------------------------------------------------------------------------

# recreate the train dataset without augmentions
train_dataset = LiverEndoscopy(split='train', balance_data=False, temporal=False, pil_transform=pil_transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)

# TASK: iterate through the data, extract features with your model, write to all_features
# TIP: feature extraction should take 5 min to run
# TIP: extract features for all sets (train, val, test)
all_features = {} # Keys: image names, Features: corresponding image features



# ------------------------------------------------------------------------------

In [None]:
#Simple test for checking the validity of all features
assert all_features['video01_000276'].shape[0] == 512

### B.3 Simple Temporal model

Now you will design a temporal model, that will take the extracted image features as input and predict the surgical phase of the last frame in the sequence.

Many types of models can be used here. We will go for a very simple option. We will concatenate the image features from 8 neigbouring images, and feed it through 2 linear layers to make a final prediction of 7 phases again.

In [None]:
# build mlp
class TemporalModelWrapper(pl.LightningModule):
    def __init__(self, features):
        super().__init__()
        # TASK: Desfine you model layers ---------------------------------------
        # TIP: First layer: Input size: 512x8 -> output 256. 
        # TIP: Second Layer: Input size: 256 -> output 7


        # ----------------------------------------------------------------------
        self.features = features

        self.train_preds = []
        self.train_gts = []
        self.val_preds = []
        self.val_gts = []
        self.test_preds = []
        self.test_gts = []
        self.reset_metrics()

        self.train_loss = []
        self.val_loss = []
        self.test_loss = []

        self.phase_names = ['Preparation', 'ColotTriangleDissection', 'ClippingCutting', 'GallbladderDissection', 'GallbladderPackaging', 'CleaningCoagulation',
                            'GallbladderRetraction']

    def forward(self, x):
        # TASK: define your network --------------------------------------------
        # TIP: x has the shape Batch_size x 8 x 512 -> You need to reshape it to batch_size x 4096 and then pass it through the fully connected layers
        

        # ----------------------------------------------------------------------
        return x

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # We get the pre computed image features
        all_image_features = []
        for image_names in batch['image_names']:
            batch_image_features = [self.features[image_name.replace('.png', '')] for image_name in image_names]
            all_image_features.append(torch.stack(batch_image_features))
        all_image_features = torch.stack(all_image_features).transpose(0, 1)
        all_image_features = all_image_features.to(self.device)
        # TASK: get a prediction from the model and calculate cross_entropy loss.
        # y_hat: prediction from the model
        # loss: calculated loss from the model

        # ----------------------------------------------------------------------
        self.update_metrics(batch['phase'], y_hat, split='train')
        self.train_loss.append(loss.item())
        return {'loss': loss}

    def validation_step(self, batch, batch_idx):
        # validation_step defines the validation loop.
        all_image_features = []
        for image_names in batch['image_names']:
            batch_image_features = [self.features[image_name.replace('.png', '')] for image_name in image_names]
            all_image_features.append(torch.stack(batch_image_features))
        all_image_features = torch.stack(all_image_features).transpose(0, 1)
        all_image_features = all_image_features.to(self.device)
        # TASK: get a prediction from the model and calculate cross_entropy loss.
        # y_hat: prediction from the model
        # loss: calculated loss from the model

        # ----------------------------------------------------------------------
        self.update_metrics(batch['phase'], y_hat, split='val')
        self.val_loss.append(loss.item())
        return {'val_loss': loss}

    def test_step(self, batch, batch_idx):
        # test_step defines the test loop.
        all_image_features = []
        for image_names in batch['image_names']:
            batch_image_features = [self.features[image_name.replace('.png', '')] for image_name in image_names]
            all_image_features.append(torch.stack(batch_image_features))
        all_image_features = torch.stack(all_image_features).transpose(0, 1)
        all_image_features = all_image_features.to(self.device)
        # TASK: get a prediction from the model and calculate cross_entropy loss.
        # y_hat: prediction from the model
        # loss: calculated loss from the model

        # ----------------------------------------------------------------------
        self.update_metrics(batch['phase'], y_hat, split='test')
        self.test_loss.append(loss.item())
        return {'test_loss': loss}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-5)
        return optimizer

    def reset_metrics(self, split=None):
        if split == 'train':
            self.train_preds = []
            self.train_gts = []
        elif split == 'val':
            self.val_preds = []
            self.val_gts = []
        elif split == 'test':
            self.test_preds = []
            self.test_gts = []
        else:
            self.train_preds = []
            self.train_gts = []
            self.val_preds = []
            self.val_gts = []
            self.test_preds = []
            self.test_gts = []

    def update_metrics(self, gt, pred, split='train'):
        if split == 'train':
            self.train_preds.extend(pred.detach().cpu().numpy().argmax(1))
            self.train_gts.extend(gt.detach().cpu().numpy())
        elif split == 'val':
            self.val_preds.extend(pred.detach().cpu().numpy().argmax(1))
            self.val_gts.extend(gt.detach().cpu().numpy())
        elif split == 'test':
            self.test_preds.extend(pred.detach().cpu().numpy().argmax(1))
            self.test_gts.extend(gt.detach().cpu().numpy())
        else:
            raise NotImplementedError()

    def training_epoch_end(self, outputs):
        self.evaluate_predictions(split='train')
        self.reset_metrics(split='train')

    def validation_epoch_end(self, outputs):
        self.evaluate_predictions(split='val')
        self.reset_metrics(split='val')

    def test_epoch_end(self, outputs):
        self.evaluate_predictions(split='test')
        self.reset_metrics(split='test')

    def evaluate_predictions(self, split):
        if split == 'train':
            preds = self.train_preds
            gts = self.train_gts
        elif split == 'val':
            preds = self.val_preds
            gts = self.val_gts
        elif split == 'test':
            preds = self.test_preds
            gts = self.test_gts
        else:
            raise NotImplementedError()

        cls_report = classification_report(gts, preds, labels=list(range(len(self.phase_names))),
                                           target_names=self.phase_names)
        print(split)
        print(cls_report)



#### Create and Train the Model

In [None]:
# TASK: create the temporal model and train it ---------------------------------
# TIP: again use pl.Trainer. Train for max_epochs=10 this time.
# temporal_model: your temporal model

# ------------------------------------------------------------------------------

### B.4 Evaluate the results

#### Test the model

If everything worked correctly, you should observe an improved performance when using temporal modeling

In [None]:
# TASK: Test the model on the unseen test set ----------------------------------


# ------------------------------------------------------------------------------

#### Plot the loss


In [None]:
# TASK: plot loss --------------------------------------------------------------


# ------------------------------------------------------------------------------

#### Visualize the predictions for a whole video

To see a complete predicted workflow we want to plot the predictions for a whole video as a sequence like we did above. We can also plot the results from our normal model to compare against the temporal model.

In [None]:
device = torch.device("cuda")
# TASK: set up the temporal model for evaluation -------------------------------
# TIP: freeze the model, set to eval mode, transfer to gpu


# ------------------------------------------------------------------------------
video_name = "video06" # "video10" , "video42"
video_preds_temp = []
video_labels = []
for batch in tqdm(test_loader_temp):
  if batch['image_names'][0][-1].startswith(video_name):
    with torch.no_grad():
      # Get the pre computed image features
      batch_image_features = []
      for image_names in batch['image_names']:
          image_features = [temporal_model.features[image_name.replace('.png', '')] for image_name in image_names]
          batch_image_features.append(torch.stack(image_features))
      batch_image_features = torch.stack(batch_image_features).transpose(0, 1)
      # TASK: transfer the input to the gpu and evaluate with your model -------
      # batch_preds: Predictions of you model for the current batch


      # ------------------------------------------------------------------------
      for pred, label in zip(batch_preds, batch['phase']):
        video_preds_temp.append(pred.detach().cpu().numpy().argmax())
        video_labels.append(label.detach().cpu().numpy())

In [None]:
fig = plt.figure()
cmap= ListedColormap(sns.color_palette("muted", as_cmap=True))
levels= list(range(len(model.phase_names) + 1))
norm = BoundaryNorm(levels, ncolors=cmap.N, clip=True)
barprops = dict(aspect='auto', cmap=cmap, norm=norm, interpolation='nearest')

ax1 = fig.add_axes([0, 0.8, 2, 0.3])
# predictions for the normal model are not computed again
im1 = ax1.imshow(np.array(video_preds).reshape(1,-1), **barprops)
ax1.set_axis_off()
ax1.set_title('Prediction')

ax2 = fig.add_axes([0, 0.4, 2, 0.3])
im2 = ax2.imshow(np.array(video_preds_temp).reshape(1,-1), **barprops)
ax2.set_axis_off()
ax2.set_title('Prediction with Temporal Model')

ax3 = fig.add_axes([0, 0, 2, 0.3])
im3 = ax3.imshow(np.array(video_labels).reshape(1,-1), **barprops)
ax3.set_axis_off()
ax3.set_title('Ground Truth')

cbar_ax = fig.add_axes([2.1, 0.25, 0.05, 0.7])
cbar = fig.colorbar(im1, cax=cbar_ax)
cbar.set_ticks([x + 0.5 for x in range(7)])
cbar.set_ticklabels(model.phase_names)
cbar.ax.invert_yaxis()