In [1]:
import utils.dataset_utils as dataset
import utils.train_utils as train

import numpy as np
import copy
import pickle

from sklearn.base import BaseEstimator, ClassifierMixin
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from datetime import datetime
from dateutil.relativedelta import relativedelta

In [2]:
def unit_vector_norm(X):
    return (X.T / np.sqrt((X**2).sum(axis=1))).T

In [3]:
file_location = "../data/Raman_Mouse/preprocess/"
filenames = np.load(f"{file_location}FileNames.npy")
with open(f'{file_location}Sample_labels.pickle', 'rb') as f:
    labels = pickle.load(f)

data = []
for f in filenames:
    x = np.load(f"{file_location}{f.split('.')[0]}_raman.npy")
    data.append((unit_vector_norm(x.reshape(-1,x.shape[-1])).reshape(x.shape), labels[f]))


In [4]:
class MLP(nn.Module):
    def __init__(self, size=1300*7*7 , output=10, depth=2, neurons=1300, bias=True, **kwargs):
        super().__init__()

        self.layers = nn.Sequential( 
            nn.Dropout3d(0.25),
            nn.Flatten(),
            nn.Linear(size, neurons, bias=bias),
            nn.ReLU(True),
            *((nn.Linear(neurons, neurons, bias=bias),
            nn.ReLU(True)) * (depth-1)),
            nn.Linear(neurons, 100, bias=bias),
            nn.ReLU(True),
            nn.Linear(100, output, bias=bias)
        )
  
    def forward(self, x):
        return self.layers(x)

In [5]:
class SupervisedClassifier(BaseEstimator):
    def __init__(self, **kwargs):
        self.kwargs = kwargs

        _use_cuda = torch.cuda.is_available() and kwargs['cuda']
        if _use_cuda:
            torch.backends.cudnn.enabled = True
            torch.backends.cudnn.benchmark = True
        self.device = torch.device('cuda' if _use_cuda else 'cpu')        
        
    def fit(self, data):
        self.model = MLP(**self.kwargs).to(self.device)

        parameters = filter(lambda x: x.requires_grad, self.model.parameters())
        self.optimizer = optim.Adam(parameters, lr=0.00001)        
        train_loader, test_loader = dataset.load_liver(data, self.kwargs['batch_size'], self.kwargs['patch_size'])
        
        for epoch in range(self.kwargs['epochs']):
            print('-'*50)
            print('Epoch {:3d}/{:3d}'.format(epoch+1, self.kwargs['epochs']))
            start_time = datetime.now()
            train.train(self.model, self.optimizer, train_loader, self.kwargs['loss_func'], self.kwargs['log_step'], self.device)
            end_time = datetime.now()
            time_diff = relativedelta(end_time, start_time)
            print('Elapsed time: {}h {}m {}s'.format(time_diff.hours, time_diff.minutes, time_diff.seconds))
            loss = train.test(self.model, test_loader, self.kwargs['loss_func'], self.device)
            print('Validation| bits: {:2.2f}'.format(loss), flush=True)    
                         
        return self
    
    def predict(self, X):
        """
        predict transforms the data into the reference space. Min weight should be 0 or higher then 'min_weight'
        The error is the NMSE, where the MSE is normalised by the signal strength. 
        error.shape = X.shape[0], so for each data point the error is calculated.
        """
        # Check is fit had been called
        check_is_fitted(self)

        # Input validation
        X = check_array(X)
#         self.model(X)
        
#         return RCA_vector

In [6]:
BATCH_SIZE = 256
EPOCHS = 10
PATCH_SIZE = 7

kwargs = {'size': PATCH_SIZE*PATCH_SIZE*data[0][0].shape[-1],
          'batch_size': BATCH_SIZE,
          'patch_size': PATCH_SIZE,
          'cuda': True,
          'log_step': 50,
          'epochs': EPOCHS,
          'depth': 4,
          'output': 4,
          'loss_func': nn.BCEWithLogitsLoss(),
          'neurons': data[0][0].shape[-1],
          'bias': True
         }

In [None]:
rvc = SupervisedClassifier(**kwargs)
rvc.fit(data)


--------------------------------------------------
Epoch   1/ 10
  2021-12-01 14:47:10|     0/  453| bits: 0.55
  2021-12-01 14:47:22|    50/  453| bits: 0.58
  2021-12-01 14:47:34|   100/  453| bits: 0.55
  2021-12-01 14:47:47|   150/  453| bits: 0.53
  2021-12-01 14:48:00|   200/  453| bits: 0.59
  2021-12-01 14:48:13|   250/  453| bits: 0.54
  2021-12-01 14:48:25|   300/  453| bits: 0.58
  2021-12-01 14:48:38|   350/  453| bits: 0.58
  2021-12-01 14:48:52|   400/  453| bits: 0.57
  2021-12-01 14:49:04|   450/  453| bits: 0.55
Elapsed time: 0h 2m 1s
Validation| bits: 0.67
--------------------------------------------------
Epoch   2/ 10
  2021-12-01 14:49:31|     0/  453| bits: 0.60
  2021-12-01 14:49:44|    50/  453| bits: 0.55
  2021-12-01 14:49:58|   100/  453| bits: 0.54
  2021-12-01 14:50:11|   150/  453| bits: 0.56
  2021-12-01 14:50:23|   200/  453| bits: 0.53
  2021-12-01 14:50:38|   250/  453| bits: 0.57
  2021-12-01 14:50:53|   300/  453| bits: 0.56
  2021-12-01 14:51:08|   