In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

np.random.seed(1)

def read_ground_truth(filename):
      
    with open(filename,'r') as f:
        lines = f.readlines()

    return lines

In [None]:
#read groundtruth file
label_data = read_ground_truth(r"groundtruth-events.txt")

dict_labels = {}


for labels in label_data:
    
    columns = labels.split()
    labels_list = []
    for idx, value in enumerate(columns):
        
        if idx ==0:
            image_name = value.split("/")[1]

        else:
            values = float(value)
            labels_list.append(values)
        
    dict_labels[image_name] = labels_list



In [None]:
import os

def create_dataset(img_folder,dict_labels):
   
    img_data_array=[]
    class_name=[]
   
    for image_x in os.listdir(img_folder):

      
        if image_x in dict_labels.keys():
            class_name.append(dict_labels[image_x])
        image_path = os.path.join(img_folder, image_x)

        image = cv2.imread(image_path)
        gray= cv2.cvtColor(image,cv2.COLOR_BGR2RGB)    

        gray_image=np.array(gray)

        img_data_array.append(gray_image)
        

    return img_data_array, class_name

In [None]:
#Use images to create dataset
img_array, ground_labels = create_dataset(r"images",dict_labels)

In [None]:
short_img_array = []
new_ground_labels = []


for i in range(len(img_array)):
    short_img_array.append(img_array[i][:299])
    new_ground_labels.append(ground_labels[i])

In [None]:
all_images_five_d_list = []

for x in range(len(short_img_array)):
    thresh = 70    
    intensity = cv2.cvtColor(short_img_array[x], cv2.COLOR_BGR2GRAY)

    data = short_img_array[x][intensity>thresh]

    coords = np.argwhere(intensity>thresh)

    five_dim_list = []
    
    chosen_cords_x = np.random.choice(coords[:,0], size=53, replace=False)
    chosen_cords_y = np.random.choice(coords[:,1], size=53, replace=False)

    for each in range(len(chosen_cords_x)):

        r, g, b = short_img_array[x][chosen_cords_x[each]][chosen_cords_y[each]]
        five_d_list = [chosen_cords_x[each]/299,chosen_cords_y[each]/299,r/255,g/255,b/255]

        five_dim_list.append(np.array(five_d_list))
 
    all_images_five_d_list.append(np.array(five_dim_list))


In [None]:
all_images_five_d_list = np.array(all_images_five_d_list)
new_ground_labels = np.array(new_ground_labels)

In [None]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(all_images_five_d_list, new_ground_labels, train_size=0.8, random_state=1)


In [None]:
import torch

inputs = torch.from_numpy(X_train)
targets = torch.from_numpy(y_train)

In [None]:
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

dataset = TensorDataset(inputs, targets)


batch_size_training = 27
train_loader = DataLoader(dataset, batch_size=batch_size_training, shuffle=True)

In [None]:
inputs_test = torch.from_numpy(X_test)
targets_test = torch.from_numpy(y_test)

dataset_test = TensorDataset(inputs_test, targets_test)

batch_size = 17
test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)

In [None]:
alpha_list = [0,0.1,0.3,0.5,0.9]

epoch_list = [10, 134, 259, 383, 508, 632, 757, 882, 1006, 1131, 1255, 1380, 1505, 1629, 1754, 1878, 2003, 2127, 2252, 2377, 2501, 2626, 2750, 2875, 3000]


for alpha in alpha_list:
    testing_loss_tmd = []
    # for epochs in epoch_list:
    def mse_loss_withtmd(predictions, targets, tmd):
        difference = predictions - targets + (tmd * alpha)
        return torch.sum(difference * difference)/ difference.numel()


    def signal_to_noise_loss_withtmd(predictions, targets,tmd):
        difference = predictions - targets + (tmd * alpha)
        return torch.sum(difference * difference)/ torch.sum(targets * targets)

    w = torch.randn(7, 265, requires_grad=True)
    b = torch.randn(7, requires_grad=True)
    # alpha = torch.nn.Parameter(torch.FloatTensor([0.1]))


    D_epsilon_tilde_hyper = torch.nn.Parameter(torch.FloatTensor([0.1]))

    def model(X):
        return X @ w.t() + b


    # epochs = 50
    for i in range(3000):

        for x,y in train_loader:
    
            #Calculate Diffusion Operator
            epsilon = 0.25

            K_epsilon = torch.cdist(x,x).sum(dim=(2,3))

            q_epsilon_tilde = (K_epsilon).sum(dim=1)

            D_epsilon_tilde = torch.diag_embed(D_epsilon_tilde_hyper / q_epsilon_tilde)

            K_tilde  = torch.matmul(K_epsilon, D_epsilon_tilde)

            D_tilde = torch.diag_embed(K_tilde.sum(dim=1))
            
            L =  1 / epsilon * (torch.inverse(D_tilde).matmul(K_tilde)) - torch.eye(K_tilde.shape[1])      
            
            x_l = torch.matmul(torch.transpose(L.float(),0,1), x.reshape(27,265).float())

            xlw = torch.matmul(w, torch.transpose(x_l,0,1))

            #Diffusion Operator
            final_L = torch.div(torch.transpose(xlw,0,1),27)

            preds = model(x.reshape(27,265).float())

            # Get the loss and perform backpropagation

            #Correct features using TMD layer
            loss = mse_loss_withtmd(preds, y, final_L)
            loss.backward()


            with torch.no_grad():
                w -= w.grad *1e-2
                b -= b.grad * 1e-2
                # alpha -= alpha * 1e-3
                D_epsilon_tilde_hyper -= D_epsilon_tilde_hyper * 1e-2
                # Set the gradients to zero
                w.grad.zero_()
                b.grad.zero_()
                # alpha.grad.zero_()
                D_epsilon_tilde_hyper.grad.zero_()


        if i in epoch_list:

            epochs_test = 1
            batch_size_testing = 17
    
            for x_test,y_test in test_loader:

                epsilon = 0.25

                K_epsilon = torch.cdist(x,x).sum(dim=(2,3))

                q_epsilon_tilde = (K_epsilon).sum(dim=1)

                D_epsilon_tilde = torch.diag_embed(D_epsilon_tilde_hyper / q_epsilon_tilde)

                K_tilde  = torch.matmul(K_epsilon, D_epsilon_tilde)

                D_tilde = torch.diag_embed(K_tilde.sum(dim=1))

                L =  1 / epsilon * (torch.inverse(D_tilde).matmul(K_tilde)) - torch.eye(K_tilde.shape[1])      
                
                x_l = torch.matmul(torch.transpose(L.float(),0,1), x.reshape(17,265).float())

                xlw = torch.matmul(w, torch.transpose(x_l,0,1))
                
                #Diffusion Operator
                final_L = torch.div(torch.transpose(xlw,0,1),17)

                #make prediction
                preds = model(x_test.reshape(batch_size_testing,265).float())

                # Get the loss and perform backpropagation
                loss = mse_loss_withtmd(preds, y_test, final_L)

            testing_loss_tmd.append(loss.item())