In [0]:
'''

    Standard Modules

'''
import time
import os
import sys
import copy
import pickle as pkl
import collections
from functools import reduce
from functools import partial 


''' 

    External Modules

'''

import cv2
import numpy as np
import pandas as pd

from tqdm import tqdm, trange

''' 

    PyTorch Modules

'''

import torch
from torch.utils.data import Dataset
from sklearn.metrics import precision_score,recall_score,f1_score 


class Specimen_Dataset(Dataset):
    
    def __init__(self, dataset, set_value, transform):
        self.metadata = dataset
        self.set_value = set_value
        self.transform = transform
        
    def __len__(self):
        return self.metadata.shape[0]
    
    def __getitem__(self, idx):
        
        category_id = -1
        img = cv2.imread(self.metadata.loc[idx, 'file_name'])
        img_id = self.metadata.loc[idx, 'id']
        
        if self.set_value != 'test_submission':
            category_id = self.metadata.loc[idx, 'category_id_le_preprocessed']
        
        sample = {
            'img' : img,
            'category_id' : category_id,
            'id' : img_id
        }
        
            
        if self.transform:
            return self.transform(sample)
        
        return sample
    
    
class Data_Pipeline:
    
    def __init__(self, *args):
        self.content_pipeline = list(args)
    
    
    def __call__(self, obj):
        
        return reduce(lambda x,y : y(x), [obj] + self.content_pipeline)
    

class Resizer:
    
    def __init__(self, output_size):
        self.output_size = output_size
        
    def __call__(self, sample):
        
        image = sample['img']
        image = cv2.resize(src=image, dsize=self.output_size)
        
        sample['img'] = image
        return sample
    
    
class Normalizer:
    
    def __init__(self, mean, std):
        self.std = std
        self.mean = mean
        
    def __call__(self, sample):
        
        img = sample['img']
        img = (img - self.mean) / self.std
        
        sample['img'] = img
        
        return sample
    
class ToTensor:
    
    def __call__(self, sample):
        
        img, category_id = sample['img'], sample['category_id']
        
        img = torch.Tensor(img.transpose((2, 0, 1)))
        category_id = category_id
        
        sample['img'] = img
        sample['category_id'] = category_id
        
        return sample
    

class NN_Model_Trainer:
    
    def __init__(self, model, optimizer, scheduler, label_encoder, loss_func, loaders, **parametres):
        self.model = model
        self.optimizer = optimizer
        self.loss_func = loss_func
        self.scheduler = scheduler
        self.loaders = loaders
        self.label_encoder = label_encoder
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.statistics_df = pd.DataFrame(columns=['epoch', 
                                                   'precision', 
                                                   'recall', 
                                                   'f1_score',
                                                   'learning_rate', 
                                                   'loss', 
                                                   'case'])
        
        self.model_train_parametres = parametres['model_param']
        self.environment_parametres = parametres['environment_param']
        self.model_directory = self.init_home_directory()
        self.time_init = time.ctime().replace(' ', '_')
        self.filelogs_info = self.init_file_logs()
        
        
    def init_home_directory(self):
        model_directory = os.path.join(self.environment_parametres['abs_path'], 
                                       self.environment_parametres['title'], 
                                       self.environment_parametres['version']
        )

        if os.path.exists(model_directory): pass    
        else: os.makedirs(name = model_directory, exist_ok = True)

        return model_directory
    
    def init_file_logs(self,):
    
        filelogs_path = os.path.join(self.model_directory,
                                     self.environment_parametres['title'] + '_model_logs_' + self.time_init + '.txt')
        filelogs_file = open(filelogs_path, 'w') if self.environment_parametres['verbose_mode'] != sys.stdout \
                                                 else self.environment_parametres['verbose_mode']
        
        return (filelogs_file, filelogs_path)
    
    def get_results(self, y_true, y_pred):
        result_med = collections.namedtuple(
                                    'result_med', 
                                            [
                                                'precision', 
                                                'recall', 
                                                'f1_score'
                                            ]
        )
        
        y_true_cpu, y_pred_cpu = y_true.cpu().numpy(), y_pred.cpu().numpy()
        
        result_med.precision = precision_score(y_true_cpu, y_pred_cpu, average='macro')
        result_med.recall = recall_score(y_true_cpu, y_pred_cpu, average='macro')
        result_med.f1_score = f1_score(y_true_cpu, y_pred_cpu, average='macro')
        
        
        return result_med
    
    
    def result_wrapper(self, **kwags):
        
        results = {
            'epoch' : kwags['epoch'],
            'precision' : kwags['precision'],
            'recall' : kwags['epoch'],
            'f1_score' : kwags['f1_score'],
            'learning_rate' : kwags['learning_rate'],
            'loss' : kwags['loss'],
            'case' : kwags['case']
        }
        
        return results
    
    def parse_batch(self, batch, case):
        
        category_existance = 'category_id' if case != 'submission' else 'id'
        
        images, categories = batch['img'], batch[category_existance]
        
        if self.device == 'cpu':
            pass
        
        else:
            images, categories = images.to(self.device), \
                                 categories.to(self.device)
            
        return images, categories
    
    
    def test_model(self, loader_key, case):
        
        y_true, y_pred = torch.IntTensor().to(self.device), \
                         torch.IntTensor().to(self.device)
        
        self.model.eval()
        with torch.no_grad():
            for batch_index, batch in enumerate(self.loaders[loader_key], 1):
                
                images, categories = self.parse_batch(batch, case)
                
                outputs = self.model(images)
                _, preds = torch.max(outputs, 1)
                
                y_true = torch.cat((y_true, categories))
                y_pred = torch.cat((y_pred, preds))
                
#                 break
                
        results = self.get_results(y_true, y_pred)
        
        
        return results
    
    def train_model(self):
        
        def init_model_state_info():
            model_state_info = {
                    'model_state_dict' : None,
                    'optimizer_state_dict' : None,
                    'best_model' : {
                        'f1_score' : -10,
                        'model_state_dict' : None,
                        'optimizer_state_dict' : None
                    },
                    'img_size' : self.model_train_parametres['img_size'],
                    'epochs' : self.model_train_parametres['epochs'],
                    'epochs_left' : self.model_train_parametres['epochs'],
                    'learnin_rate' : self.model_train_parametres['learning_rate'],
                    'loss' : None,
                    'f1_score' : None,
                    'model_logs' : dict((zip(range(self.model_train_parametres['epochs']), 
                                             [None] * self.model_train_parametres['epochs'])))
            }

            return model_state_info
        
        
        self.model = self.model.to(self.device)
        model_state_info = init_model_state_info()
        
        
        if self.model_train_parametres['Retrain_path']:
            
            model_state_info = torch.load(self.model_train_parametres['Retrain_path'])
            self.model.load_state_dict(model_state_info['model_state_dict'])
            
            self.optimizer = self.optimizer(self.model.parameters(), 
                                            lr=self.model_train_parametres['learning_rate'], 
                                            momentum=self.model_train_parametres['momentum'])
            
            self.optimizer.load_state_dict(model_state_info['optimizer_state_dict'])
            
            self.scheduler = self.scheduler(self.optimizer, step_size=7, gamma=0.1)
            
            print(model_state_info['best_model']['f1_score'])
            
#             self.get_submission()
#             return
            
            
            
        else:
            self.optimizer = self.optimizer(self.model.parameters(), 
                                            lr=self.model_train_parametres['learning_rate'], 
                                            momentum=self.model_train_parametres['momentum'])
            
            self.scheduler = self.scheduler(self.optimizer, step_size=7, gamma=0.1)
            
            
        parametrs_string = 'learning_rate : {}, momentum : {}, epochs : {}, img_size : {}'.format(self.model_train_parametres['learning_rate'],
                                                                                                  self.model_train_parametres['momentum'],
                                                                                                  self.model_train_parametres['epochs'],
                                                                                                  self.model_train_parametres['img_size'],)

        cli_string = "optimizer : {}, loss_function : {}, num_labels : {}".format(
            self.model_train_parametres['optimizer'],
            self.model_train_parametres['loss_function'],
            self.model_train_parametres['num_of_classes']
        )

        print('CLI PARAMETRES : ', file = self.filelogs_info[0])
        print(cli_string, file = self.filelogs_info[0])
        print('HYPERPARAMETRS : ', file = self.filelogs_info[0])
        print(parametrs_string, file = self.filelogs_info[0])
        print('DEVICE : ',self.device, file = self.filelogs_info[0])

                
        filelogs_fmt = '{} -> Epoch {} -> Batch_index {} -> {} -> {}'
        epoch_logs_format = '{} -> Epoch_loss : {}'
        
        
        for epoch in trange(1, model_state_info['epochs_left'] + 1, desc = 'epochs'):
        
            epoch_loss = 0.0
            
            y_true, y_pred = torch.IntTensor().to(self.device), \
                             torch.IntTensor().to(self.device)
            
            self.model.train()
            
            for batch_index, batch in enumerate(tqdm(self.loaders['train'], leave=False), 1):
                
                images, categories = self.parse_batch(batch, case='train')
                
                self.optimizer.zero_grad()
                
                outputs = self.model(images)
                _, preds = torch.max(outputs, 1)
                
                y_pred = torch.cat((y_pred, preds))
                y_true = torch.cat((y_true, categories))
                
                loss = self.loss_func(outputs, categories)
                epoch_loss += loss.item() * images.size(0)
                
                loss.backward()
                self.optimizer.step()
                
                print(filelogs_fmt.format(time.ctime() ,epoch, batch_index,'loss', loss.item()), file = self.filelogs_info[0])
                
#                 if batch_index == 10:
#                     break
            

            epoch_loss = epoch_loss / (batch_index * self.model_train_parametres['batch_size'])
            
            res_train = self.get_results(y_true, y_pred)
            res_val = self.test_model('val', case='val')
            
            self.scheduler.step()
            
            if res_val.f1_score > model_state_info['best_model']['f1_score']:
                model_state_info['best_model']['f1_score'] = res_val.f1_score
                model_state_info['best_model']['model_state_dict'] = copy.deepcopy(self.model.state_dict())
                model_state_info['best_model']['optimizer_state_dict'] = copy.deepcopy(self.optimizer.state_dict())
                
                
                
            
            self.statistics_df = self.statistics_df.append(self.result_wrapper(
                                                            epoch=epoch,
                                                            precision=res_train.precision,
                                                            recall=res_train.recall,
                                                            f1_score=res_train.f1_score,
                                                            learning_rate=self.optimizer.param_groups[0]['lr'],
                                                            loss=epoch_loss,
                                                            case='train',
                            ),
                                                      ignore_index=True)
            
            self.statistics_df = self.statistics_df.append(self.result_wrapper(
                                                            epoch=epoch,
                                                            precision=res_val.precision,
                                                            recall=res_val.recall,
                                                            f1_score=res_val.f1_score,
                                                            learning_rate=self.optimizer.param_groups[0]['lr'],
                                                            loss=epoch_loss,
                                                            case='val',
                            ),
                                                      ignore_index=True)
            
            print(filelogs_fmt.format(time.ctime() ,epoch, batch_index,'loss', epoch_loss), file = self.filelogs_info[0])
            
            model_state_info['model_logs'][epoch] = res_val.f1_score
            
            model_state_info['model_state_dict'] = self.model.state_dict()
            model_state_info['optimizer_state_dict'] = self.optimizer.state_dict()
            model_state_info['epochs_left'] = model_state_info['epochs'] - epoch
            model_state_info['learning_rate'] = self.optimizer.param_groups[0]['lr']
            
            
            torch.save(model_state_info, os.path.join(self.model_directory,self.environment_parametres['title'] + '_model_state_info.pth'))
            self.statistics_df.to_csv(os.path.join(self.model_directory,self.environment_parametres['title'] + '_statistics.csv'))
#             break
    
        self.model.load_state_dict(model_state_info['best_model']['model_state_dict'])
        res_test = self.test_model('test', 'test')
        self.statistics_df = self.statistics_df.append(self.result_wrapper(
                                                        epoch=epoch,
                                                        precision=res_test.precision,
                                                        recall=res_test.recall,
                                                        f1_score=res_test.f1_score,
                                                        learning_rate=self.optimizer.param_groups[0]['lr'],
                                                        loss=epoch_loss,
                                                        case='test',
                        ),
                                                  ignore_index=True)
    
        self.get_submission()
    
    
    def get_submission(self):
        
        self.model.eval()
        
        y_pred, all_ids = torch.IntTensor().to(self.device), \
                          torch.IntTensor().to(self.device)
        
        with torch.no_grad():
            for batch_index, batch in enumerate(tqdm(self.loaders['submission']), 1):
                images, ids = self.parse_batch(batch, 'submission')
                
                outputs = self.model(images)
                _, preds = torch.max(outputs, 1)
                
                y_pred = torch.cat((y_pred, preds))
                all_ids = torch.cat((all_ids, ids))
                
#                 break
        
        y_pred_deprocessed = self.label_encoder.inverse_transform(y_pred.cpu().numpy().astype(int))
        submission_df = pd.DataFrame({'Id' : all_ids.cpu().numpy().astype(int), 'Predicted' : y_pred_deprocessed})
        
        
        self.statistics_df.to_csv(os.path.join(self.model_directory,self.environment_parametres['title'] + '_statistics.csv'))
        submission_df.to_csv(os.path.join(self.model_directory,self.environment_parametres['title'] + 'submission_df.csv'), index=False)


            
                
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
    

        