# Mount Google Drive

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Copy Files

In [3]:
import shutil
import os

os.makedirs('data', exist_ok=True)

####For PointCloud Files#####
shutil.copyfile('/content/drive/My Drive/AlzheimerStallCatcher3DConvPointCloud/point_cloud.zip' , './data/point_cloud.zip')
shutil.copyfile('/content/drive/My Drive/AlzheimerStallCatcher3DConvPointCloud/traintestlist.zip' , './data/traintestlist.zip')


!unzip data/point_cloud.zip;
!unzip data/traintestlist.zip;

os.remove("./data/point_cloud.zip")
os.remove('./data/traintestlist.zip')

shutil.move('./point_cloud', './data/point_cloud')

Archive:  data/point_cloud.zip
   creating: point_cloud/
  inflating: point_cloud/187558.h5   
  inflating: point_cloud/280599.h5   
  inflating: point_cloud/105452.h5   
  inflating: point_cloud/105668.h5   
  inflating: point_cloud/110123.h5   
  inflating: point_cloud/110210.h5   
  inflating: point_cloud/110488.h5   
  inflating: point_cloud/110498.h5   
  inflating: point_cloud/110544.h5   
  inflating: point_cloud/110787.h5   
  inflating: point_cloud/111150.h5   
  inflating: point_cloud/111350.h5   
  inflating: point_cloud/112237.h5   
  inflating: point_cloud/112324.h5   
  inflating: point_cloud/112532.h5   
  inflating: point_cloud/112558.h5   
  inflating: point_cloud/112736.h5   
  inflating: point_cloud/112804.h5   
  inflating: point_cloud/113202.h5   
  inflating: point_cloud/113351.h5   
  inflating: point_cloud/113382.h5   
  inflating: point_cloud/113389.h5   
  inflating: point_cloud/113850.h5   
  inflating: point_cloud/114732.h5   
  inflating: point_cloud/114843

'./data/point_cloud'

In [4]:
####For images files######

shutil.copyfile('/content/drive/My Drive/AlzheimerStallCatcher3DConvPointCloud/micro_frames_gray.zip' , './data/micro_frames.zip')

!unzip data/micro_frames.zip;

os.remove("./data/micro_frames.zip")

shutil.move('./micro_frames_gray', './data/micro_frames')

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 extracting: micro_frames_gray/667449/68.jpg  
 extracting: micro_frames_gray/667449/69.jpg  
 extracting: micro_frames_gray/667449/7.jpg  
 extracting: micro_frames_gray/667449/70.jpg  
 extracting: micro_frames_gray/667449/71.jpg  
 extracting: micro_frames_gray/667449/72.jpg  
 extracting: micro_frames_gray/667449/73.jpg  
 extracting: micro_frames_gray/667449/74.jpg  
 extracting: micro_frames_gray/667449/75.jpg  
 extracting: micro_frames_gray/667449/76.jpg  
 extracting: micro_frames_gray/667449/77.jpg  
 extracting: micro_frames_gray/667449/78.jpg  
 extracting: micro_frames_gray/667449/79.jpg  
 extracting: micro_frames_gray/667449/8.jpg  
 extracting: micro_frames_gray/667449/80.jpg  
 extracting: micro_frames_gray/667449/81.jpg  
 extracting: micro_frames_gray/667449/82.jpg  
 extracting: micro_frames_gray/667449/83.jpg  
 extracting: micro_frames_gray/667449/84.jpg  
 extracting: micro_frames_gray/667449/85.jpg

'./data/micro_frames'

# Import Libraries

In [3]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
import glob
import gc

from sklearn.metrics import matthews_corrcoef as mcc

# PyTorch libraries and modules
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.utils.data


torch.manual_seed(100)


from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import itertools
import datetime

# Balanced Batch Sampler

In [4]:
import torchvision
import torch.utils.data
import random


class BalancedBatchSampler(torch.utils.data.sampler.Sampler):
    def __init__(self, dataset):
        self.dataset = {}
        self.balanced_max = 0
        # Save all the indices for all the classes
        for idx in range(0, len(dataset)):
            label = self._get_label(dataset, idx)
            if label not in self.dataset:
                self.dataset[label] = []
            self.dataset[label].append(idx)
            self.balanced_max = len(self.dataset[label]) \
                if len(self.dataset[label]) > self.balanced_max else self.balanced_max
        
        # Oversample the classes with fewer elements than the max
        for label in self.dataset:
            while len(self.dataset[label]) < self.balanced_max:
                self.dataset[label].append(random.choice(self.dataset[label]))
    
        self.keys = list(self.dataset.keys())
        self.currentkey = 0

    def __iter__(self):
        while len(self.dataset[self.keys[self.currentkey]]) > 0:
            yield self.dataset[self.keys[self.currentkey]].pop()
            self.currentkey = (self.currentkey + 1) % len(self.keys)

    
    def _get_label(self, dataset, idx):
        dataset_type = type(dataset)
        if dataset_type is torchvision.datasets.MNIST:
            return dataset.train_labels[idx].item()
        elif dataset_type is torchvision.datasets.ImageFolder:
            return dataset.imgs[idx][1]
        else:
            (image_sequence, target) = dataset.__getitem__(idx)
            return target

    def __len__(self):
        return self.balanced_max*len(self.keys)

In [5]:
# Some data

batch_size = 32

split_number = 1
#num_epochs = 50

# **Dataset for PointCloud-Voxel**

### Voxel Dataset Class

In [6]:
import h5py
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd

class VoxelDataset(Dataset):
    def __init__(self, dataset_path, split_path, split_number, input_shape, training):
        self.training = training
        self.sequences, self.labels = self._extract_sequence_paths_and_labels(dataset_path, split_path, split_number,
                                                                              training)  # creating a list of directories where the extracted frames are saved
        self.label_names = ["Non-stalled", "Stalled"]  # Getting the label names or name of the class
        self.num_classes = len(self.label_names)  # Getting the number of class
        self.input_shape = input_shape

    def _extract_sequence_paths_and_labels(
            self, dataset_path, split_path="data/traintestlist", split_number=0, training=True
    ):
        """ Extracts paths to sequences given the specified train / test split """
        fn = f"fold_{split_number}_train.csv" if training else f"fold_{split_number}_test.csv"
        split_path = os.path.join(split_path, fn)
        df = pd.read_csv(split_path)
        file_name = df['filename'].values
        all_labels = df['class'].values
        sequence_paths = []
        classes = []
        for i, video_name in enumerate(file_name):
            seq_name = video_name.split(".mp4")[0]
            sequence_paths += [os.path.join(dataset_path, seq_name).replace('\\', '/')]
            classes += [all_labels[i]]
        return sequence_paths, classes

    def pc2voxel(self, cloud0, cloud1, cloud2, depth=32, height=64, width=64):

        voxel_grid = np.zeros((3, depth, height, width), dtype=np.float16)

        in_depth = max(np.max(cloud0[:, 0]), np.max(cloud1[:, 0]), np.max(cloud2[:, 0]))
        in_height = max(np.max(cloud0[:, 1]), np.max(cloud1[:, 1]), np.max(cloud2[:, 1]))
        in_width = max(np.max(cloud0[:, 2]), np.max(cloud1[:, 2]), np.max(cloud2[:, 2]))

        if in_depth >= depth:
            depth_ratio = depth / (in_depth + 1)
            cloud0[:, 0] = np.uint32(cloud0[:, 0].astype(float) * depth_ratio)
            cloud1[:, 0] = np.uint32(cloud1[:, 0].astype(float) * depth_ratio)
            cloud2[:, 0] = np.uint32(cloud2[:, 0].astype(float) * depth_ratio)
        if in_height >= height:
            height_ratio = height / (in_height + 1)
            cloud0[:, 1] = np.uint32(cloud0[:, 1].astype(float) * height_ratio)
            cloud1[:, 1] = np.uint32(cloud1[:, 1].astype(float) * height_ratio)
            cloud2[:, 1] = np.uint32(cloud2[:, 1].astype(float) * height_ratio)
        if in_width >= width:
            width_ratio = width / (in_width + 1)
            cloud0[:, 2] = np.uint32(cloud0[:, 2].astype(float) * width_ratio)
            cloud1[:, 2] = np.uint32(cloud1[:, 2].astype(float) * width_ratio)
            cloud2[:, 2] = np.uint32(cloud2[:, 2].astype(float) * width_ratio)

        voxel_grid[0, cloud0[:, 0], cloud0[:, 1], cloud0[:, 2]] = 1.0
        voxel_grid[1, cloud1[:, 0], cloud1[:, 1], cloud1[:, 2]] = 1.0
        voxel_grid[2, cloud2[:, 0], cloud2[:, 1], cloud2[:, 2]] = 1.0

        return voxel_grid

    def get_cloud(self, filename):
        depth = self.input_shape[0]
        height = self.input_shape[1]
        width = self.input_shape[2]

        hf = h5py.File(filename, 'r')
        c1 = hf['cloud1'][:]
        c2 = hf['cloud2'][:]
        c3 = hf['cloud3'][:]
        hf.close()

        X = self.pc2voxel(c1, c2, c3, depth=depth, height=height, width=width)
        X = torch.from_numpy(X).float()
        return X

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, index):
        sequence_path = self.sequences[index % len(self)]
        target = self.labels[index % len(self)]

        voxels = self.get_cloud(sequence_path + ".h5")

        return voxels, target

### Create Train and Test Datasets of Pt Cloud

In [9]:
import gc
import time

start = time.time()

dataset_path = 'data/point_cloud'
split_path = 'traintestlist'

checkpoint_model = ''

voxel_shape = [32, 64, 64]


# Define training set
train_dataset_vox = VoxelDataset(
    dataset_path=dataset_path,
    split_path=split_path,
    split_number=split_number,
    input_shape=voxel_shape,
    training=True,
)
train_dataloader_vox = DataLoader(train_dataset_vox, batch_size= batch_size,sampler=BalancedBatchSampler(train_dataset_vox),shuffle=False, num_workers=4)
# Define test set
test_dataset_vox = VoxelDataset(
    dataset_path=dataset_path,
    split_path=split_path,
    split_number=split_number,
    input_shape=voxel_shape,
    training=False,
)
#test_dataloader_vox = DataLoader(test_dataset_vox, sampler=BalancedBatchSampler(test_dataset_vox), batch_size=batch_size, shuffle=False, num_workers=4)
test_dataloader_vox = DataLoader(test_dataset_vox, batch_size=batch_size, shuffle=False, num_workers=4)

endtime = time.time()

print("Elapsed time : " + str(endtime-start))

gc.collect() 

Elapsed time : 5.681003093719482


24

# **Dataset for Images**

### Image Dataset Class

In [10]:
import glob
import random
import os
import numpy as np
import torch
import pandas as pd

from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms


class ImageDataset(Dataset):
    def __init__(self, dataset_path, split_path, split_number, input_shape, sequence_length, training):
        self.training = training
        self.sequences, self.labels = self._extract_sequence_paths_and_labels(dataset_path, split_path, split_number,
                                                                              training)  # creating a list of directories where the extracted frames are saved
        self.sequence_length = int(
            sequence_length)  # Defining how many frames should be taken per video for training and testing
        self.label_names = ["Non-stalled", "Stalled"]  # Getting the label names or name of the class
        self.num_classes = len(self.label_names)  # Getting the number of class
        self.input_shape = input_shape
        self.transform = transforms.Compose(
            [
                transforms.Grayscale(num_output_channels=3),
                transforms.Resize(input_shape[-2:], Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]
        )  # This is to transform the datasets to same sizes, it's basically resizing -> converting the image to Tensor image -> then normalizing the image -> composing all the transformation in a single image

    def _extract_sequence_paths_and_labels(
            self, dataset_path, split_path="data/traintestlist", split_number=0, training=True
    ):
        """ Extracts paths to sequences given the specified train / test split """
        fn = f"fold_{split_number}_train.csv" if training else f"fold_{split_number}_test.csv"
        split_path = os.path.join(split_path, fn)
        df = pd.read_csv(split_path)
        file_name = df['filename'].values
        all_labels = df['class'].values
        sequence_paths = []
        classes = []
        for i, video_name in enumerate(file_name):
            seq_name = video_name.split(".mp4")[0]
            sequence_paths += [os.path.join(dataset_path, seq_name).replace('\\', '/')]
            classes += [all_labels[i]]
        return sequence_paths, classes

    def _frame_number(self, image_path):
        """ Extracts frame number from filepath """
        image_path = image_path.replace('\\', '/')
        try:
            return int(image_path.split('/')[-1].split('.jpg')[0])
        except:
            print("Got error while getting image number ....")
            exit()

    def _pad_to_length(self, sequence, path):
        """ Pads the video frames to the required sequence length for small videos"""
        try:
            left_pad = sequence[0]
        except:
            print("Got error while padding ....")
            exit()
        if self.sequence_length is not None:
            while len(sequence) < self.sequence_length:
                sequence.insert(0, left_pad)
        return sequence

    def __getitem__(self, index):
        sequence_path = self.sequences[index % len(self)]
        target = self.labels[index % len(self)]
        # Sort frame sequence based on frame number 
        image_paths = sorted(glob.glob(sequence_path + '/*.jpg'), key=lambda path: self._frame_number(path))

        # Pad frames of videos shorter than `self.sequence_length` to length

        image_paths = self._pad_to_length(image_paths, sequence_path)
        total_image = len(image_paths)
        if total_image >= self.sequence_length and total_image < (
                self.sequence_length + int(self.sequence_length // 2)):
            midpoint = (total_image // 2)
            sample_interval = 1
            start_i = (midpoint - (self.sequence_length // 2))
            end_i = start_i + self.sequence_length
        elif total_image >= (self.sequence_length + int(self.sequence_length // 2)):
            midpoint = (total_image // 2)
            sample_interval = 1
            start_i = (midpoint - (self.sequence_length // 2) + int((self.sequence_length // 2) // 2)) - 1
            end_i = start_i + self.sequence_length
        else:
            start_i = 0
            end_i = total_image
            sample_interval = 1
        # flip = np.random.random() < 0.5
        # Extract frames as tensors
        image_sequence = []
        for i in range(start_i, end_i, sample_interval):
            if self.sequence_length is None or len(image_sequence) < self.sequence_length:
                img = Image.open(image_paths[i])
                image_tensor = self.transform(img)
                # if flip:
                #     image_tensor = torch.flip(image_tensor, (-1,))
                image_sequence.append(image_tensor)
        image_sequence = torch.stack(image_sequence)
        image_sequence = image_sequence.view(3, self.sequence_length, self.input_shape[-2], self.input_shape[-2])
        return image_sequence, target

    def __len__(self):
        return len(self.sequences)
        

### Create Train and Test Datasets of Images

In [11]:
import gc
import time


start = time.time()

dataset_path = 'data/micro_frames'
split_path = 'traintestlist'
sequence_length=40
img_dim = 112
channels = 3
latent_dim = 512
checkpoint_model = ''

image_shape = (channels, img_dim, img_dim)


# Define training set
train_dataset_img = ImageDataset(
    dataset_path=dataset_path,
    split_path=split_path,
    sequence_length=sequence_length,
    split_number=split_number,
    input_shape=image_shape,
    training=True,
)
train_dataloader_img = DataLoader(train_dataset_img, batch_size= batch_size,sampler=BalancedBatchSampler(train_dataset_img),shuffle=False, num_workers=4)

# Define test set
test_dataset_img = ImageDataset(
    dataset_path=dataset_path,
    split_path=split_path,
    split_number=split_number,
    sequence_length=sequence_length,
    input_shape=image_shape,
    training=False,
)

test_dataloader_img = DataLoader(test_dataset_img, batch_size=batch_size, shuffle=False, num_workers=4)

endtime = time.time()

print("Elapsed time : " + str(endtime-start))

gc.collect() 

Elapsed time : 66.65416693687439


0

In [12]:
#length should be same for image and voxel data
print(f"Length of image loader {len(train_dataloader_img)}, Length of pt cld loader {len(train_dataloader_vox)}")
print(f"Length of test image loader {len(test_dataloader_img)}, Length of test pt cld loader {len(test_dataloader_vox)}")

Length of image loader 80, Length of pt cld loader 80
Length of test image loader 19, Length of test pt cld loader 19


# **Multimodal Model**

In [13]:
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.video import r2plus1d_18
from torchvision.models.video import r3d_18

import os
import sys



##############################
#     Encoder for Image
##############################

class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        resnet = r2plus1d_18(pretrained=True)
        self.dropout1 = nn.Dropout(0.2)
        self.feature_extractor = nn.Sequential(*list(resnet.children())[0:3])
        self.feature_extractor_new1 = nn.Sequential(*list(resnet.children())[3:4])
        self.feature_extractor_new2 = nn.Sequential(*list(resnet.children())[4:5])
        self.feature_extractor_new3 = nn.Sequential(*list(resnet.children())[5:6])
        self.final = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(resnet.fc.in_features, latent_dim),
            nn.BatchNorm1d(latent_dim, momentum=0.01)
        )
    

    def forward(self, x):
        with torch.no_grad():
           x = self.feature_extractor(x)
        x = self.dropout1(x)  
        x = self.feature_extractor_new1(x)
        x = self.dropout1(x)
        x = self.feature_extractor_new2(x)
        x = self.dropout1(x)
        x = self.feature_extractor_new3(x)
        x = x.view(x.size(0), -1)
        #print(self.final(x))
        
        return self.final(x)


##############################
#   Encoder For Point Cloud
##############################

class Encoder_pt(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder_pt, self).__init__()
        self.feature_extractor_pt = r3d_18(pretrained = True)
        self.final_pt = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(self.feature_extractor_pt.fc.out_features, latent_dim),
            nn.BatchNorm1d(latent_dim, momentum=0.01)
        )
    

    def forward(self, x):
        #with torch.no_grad():
        x = self.feature_extractor_pt(x)
        x = x.view(x.size(0), -1)

        return self.final_pt(x)


##############################
#      MultiModal Model
##############################

#dim=-1 is the right most dimension

class Multimodal(nn.Module):
    def __init__(
        self, num_classes, latent_dim=512, hidden_dim=1024
    ):
        super(Multimodal, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.encoder_pt = Encoder_pt(latent_dim)
        hidden_dim2 = int(hidden_dim/2)      #if hidden_dim = 512 then dense layers : 1024->512->256->1
        hidden_dim3 = int(hidden_dim2/4)
        hidden_dim4 = int(hidden_dim3/4)
        hidden_dim5 = int(hidden_dim4/4)
        self.output_layers_final = nn.Sequential(
            nn.Linear(1024, hidden_dim),
            nn.BatchNorm1d(hidden_dim, momentum=0.01),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(hidden_dim, hidden_dim2),
            nn.BatchNorm1d(hidden_dim2, momentum=0.01),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(hidden_dim2, 2),
            #nn.Sigmoid()
        )


    ##Concatenating two stream
    
    def forward(self, x, y):
        
        x = self.encoder(x)
        y = self.encoder_pt(y)
        #print(f"img tensor size {x.shape} and pt tensor size {y.shape}")
        x = torch.cat((x, y), 1)
        #print(f"concatenated size is {x.shape}")
        return self.output_layers_final(x)

# **Model Parameters for Image model**

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Classification criterion
#cls_criterion = nn.BCEWithLogitsLoss().to(device)
cls_criterion = nn.CrossEntropyLoss().to(device) 

# Define network
model_img = Multimodal(
    num_classes=2,
    latent_dim=latent_dim,
    hidden_dim=512
)

model_img = model_img.to(device)

#for half precision

model_img.half()  
for layer in model_img.modules():
  if isinstance(layer, nn.BatchNorm2d):
    layer.float()

# Add weights from checkpoint model if specified
if checkpoint_model:
    model_img.load_state_dict(torch.load(checkpoint_model))

learning_rate = 2e-4

optimizer = torch.optim.Adam(model_img.parameters(), lr= learning_rate, eps=1e-04, weight_decay=1e-04)

step_size = 11
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.75, last_epoch=-1)


# Confusion Function 

In [15]:
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.utils.multiclass import unique_labels
from sklearn.metrics import matthews_corrcoef as mcc


MCC_SCORE = 0


def plot_confusion_matrix(y_true, y_pred, classes,
                          normalize=False,
                          title=None,
                          cmap=plt.cm.RdYlGn):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if not title:
        if normalize:
            title = 'Normalized confusion matrix'
        else:
            title = 'Confusion matrix, without normalization'

    # Compute confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    # Only use the labels that appear in the data
    classes = classes[unique_labels(y_true, y_pred)]

    del(y_pred)
    del(y_true)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    fig, ax = plt.subplots(figsize=(12, 12))
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.figure.colorbar(im, ax=ax)
    # We want to show all ticks...
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           # ... and label them with the respective list entries
           xticklabels=classes, yticklabels=classes,
           title=title,
           ylabel='True label',
           xlabel='Predicted label')

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt),
                    ha="center", va="center",
                    color="white" if cm[i, j] > thresh else "black")
    fig.tight_layout()
    return ax

In [16]:

def test_model(epoch):
    """ Evaluate the model on the test set """
    print("")
    #defining numpy array to store the true labels and predicted labels
    y_true = np.array([])
    y_pred = np.array([])

    # Preparing the model for evaluation
    model_img.eval()
    test_metrics = {"loss": [], "acc": []}
    for batch_i, ((X1, y), (X2,y2)) in enumerate(zip(test_dataloader_img, test_dataloader_vox)):
        image_sequences = Variable(X1.to(device), requires_grad=False)
        pt_cloud = Variable(X2.to(device), requires_grad=False)

        #print(f"From image loader {y} and pt loader {y2}")  
        #print(f"From image loader {y} and pt loader {y2}")
      
        labels = Variable(y.to(device), requires_grad=False)
        
        
        image_sequences = image_sequences.half()
        pt_cloud = pt_cloud.half()
        
        # Get sequence predictions
        
        y_true = np.append(y_true, labels.cpu().numpy())
        
        with torch.no_grad():
            # Get sequence predictions
            predictions = model_img(image_sequences, pt_cloud)

            predicted = torch.max(predictions.data, 1)[1]   

            y_pred = np.append(y_pred, predicted.cpu().numpy())
        
        # Compute metrics
        acc = 100 * (predicted == labels).cpu().numpy().mean()
        
        #for bcewithlogits loss
        #acc = 100 * (acc.detach() == labels).cpu().numpy().mean()
        #labels = labels.view(-1,1)
        
        loss =   cls_criterion(predictions, labels).item()
        
        # Keep track of loss and accuracy
        test_metrics["loss"].append(loss)
        test_metrics["acc"].append(acc)
        
        # Log test performance
        sys.stdout.write(
            "\rTesting -- [Batch %d/%d] [Loss: %f (%f), Acc: %.2f%% (%.2f%%)]"
            % (
                batch_i,
                len(test_dataloader_img),
                loss,
                np.mean(test_metrics["loss"]),
                acc,
                np.mean(test_metrics["acc"]),
            )
        )
        
    
    final_acc=round(np.mean(test_metrics["acc"]),3)


    # Save model checkpoint
    
    # Using the learning rate scheduler while monitoring Loss and printing the learning rate after every epoch
    for param_group in optimizer.param_groups:
        print("\nCurrent Learning Rate is : " + str(param_group['lr']))
    
    model_img.train()
    print("")
    # Getting the MCC score for evaluation and plotting the confusion matrix and saving that matrix
    mcc_score = round(mcc(y_true.astype(int), y_pred.astype(int)), 5)

    print(f"MCC score : {mcc_score}")
    mcc_score_str = "MCC Score: " + str(mcc_score) + "\n\n"

    plot_title = mcc_score_str + "Confusion matrix, Without Normalization\n"

    
    class_names = ['Non-stalled', 'Stalled']
    class_names = np.array(class_names)
    plot_confusion_matrix(y_true.astype(int), y_pred.astype(int), classes=class_names, title=plot_title)

    # Save model checkpoint
    global MCC_SCORE

    if MCC_SCORE < mcc_score or int(epoch) % 2 == 0:
        if MCC_SCORE < mcc_score:
            MCC_SCORE = mcc_score
        os.makedirs("model_checkpoints", exist_ok=True)
        torch.save(model_img.state_dict(), f"model_checkpoints/Multimodal_catcrosent_spl1_epoch_{epoch}_acc_{final_acc}_mcc_{mcc_score}.pth")

    os.makedirs("confusion_matrix", exist_ok=True)
    plt.savefig(f"confusion_matrix/epoch{epoch}with_accuracy{final_acc}_mcc_{mcc_score}.png")


# **Multimodal Training**

In [None]:
num_epochs = 60
for epoch in range(num_epochs):
        epoch_metrics = {"loss": [], "acc": []}
        prev_time = time.time()
        print(f"--- Epoch {epoch}---")
        for batch_i, ((X1, y), (X2,y2)) in enumerate(zip(train_dataloader_img, train_dataloader_vox)):
            if X1.size(0) == 1:
                continue
            image_sequences = Variable(X1.to(device), requires_grad=True)
            pt_cloud = Variable(X2.to(device), requires_grad=True)

             
            #print(f"From image loader {y} and pt loader {y2}")
            
            labels = Variable(y.to(device), requires_grad=False)
            
            image_sequences = image_sequences.half()
            pt_cloud = pt_cloud.half()
            
            optimizer.zero_grad()
            
            # Get sequence predictions
            predictions = model_img(image_sequences, pt_cloud)
            
            
            
            predicted = torch.max(predictions.data, 1)[1]   
            acc = 100 * (predicted == labels).cpu().numpy().mean()
            
            #for bcewithlogits loss
            #acc = torch.tensor([0 if i<=0.5 else 1 for i in predictions]).to(device)
            #acc = 100 * (acc.detach() == labels).cpu().numpy().mean()
            # Compute metrics
            #labels = labels.view(-1, 1)
            
            loss = cls_criterion.forward(predictions, labels)
            
            loss.backward()
            optimizer.step()

            # Keep track of epoch metrics
            epoch_metrics["loss"].append(loss.item())
            epoch_metrics["acc"].append(acc)

            # Determine approximate time left
            batches_done = epoch * len(train_dataloader_img) + batch_i
            batches_left = num_epochs * len(train_dataloader_img) - batches_done
            time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
            prev_time = time.time()

            # Print log
            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [Loss: %f (%f), Acc: %.2f%% (%.2f%%)] ETA: %s"
                % (
                    epoch,
                    num_epochs,
                    batch_i,
                    len(train_dataloader_img),
                    loss.item(),
                    np.mean(epoch_metrics["loss"]),
                    acc,
                    np.mean(epoch_metrics["acc"]),
                    time_left,
                )
            )
            #deleting variable to free up memory
            del(X1)
            del(y)
            del(X2)
            del(y2)
            del(image_sequences)
            del(pt_cloud)
            del(labels)
            del(predictions)
            del(loss)


            # Empty cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        # Evaluate the model on the test set
        test_model(epoch)
        
        train_dataloader_img = DataLoader(train_dataset_img, batch_size= batch_size,sampler=BalancedBatchSampler(train_dataset_img),shuffle=False, num_workers=4)
        train_dataloader_vox = DataLoader(train_dataset_vox, batch_size= batch_size,sampler=BalancedBatchSampler(train_dataset_vox),shuffle=False, num_workers=4)
        
        scheduler.step()

--- Epoch 0---
Current Learning Rate is : 0.0002

MCC score : 0.38822
Confusion matrix, without normalization
[[334  89]
 [ 69 108]]
--- Epoch 1---
Current Learning Rate is : 0.0002

MCC score : 0.4723
Confusion matrix, without normalization
[[251 172]
 [ 14 163]]
--- Epoch 2---