In [4]:
import os
import cv2
import sys
#Change path specificly to your directories
sys.path.insert(1, '/home/codahead/Fishial/FishialReaserch')
import copy
import time
import json
import torch
import numpy as np
import pandas as pd
import torchvision.models as models
import matplotlib.pyplot as plt

from torch import nn
from torchvision import transforms


from os import listdir
from os.path import isfile, join

from module.segmentation_package.interpreter_segm import SegmentationInference
from module.classification_package.src.dataset import FishialDataset
from module.classification_package.src.model import EmbeddingModel, Backbone, Model, FcNet

from PIL import Image
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay


pd.options.display.max_rows = 999

In [5]:
def _get_scores(report):
    labels_row = []
    scores = []
    for i in report:
        labels_row.append(i)
        row = []
        if type(report[i]) != dict:
            continue
        for z in report[i]:
            row.append(report[i][z])
        scores.append(row)
    scores = np.array(scores[:61])
    
    return labels_row, scores

def init_model_class(ckp=None):
    resnet18 = models.resnet18(pretrained=True)
    resnet18.fc = nn.Identity()

    backbone = Backbone(resnet18)
    model = FcNet(backbone, 61)
    if ckp:
        model.load_state_dict(torch.load(ckp))
    model.eval()
    return model

def init_model_embed(ckp=None):
    resnet18 = models.resnet18(pretrained=True)
    resnet18.fc = nn.Identity()

    backbone = Backbone(resnet18)
    model = EmbeddingModel(backbone)
    if ckp:
        model.load_state_dict(torch.load(ckp))
    model.eval()
    return model

def remove_dupliceta(mylist):
    seen = set()
    newlist = []
    for item in mylist:
        t = tuple(item)
        if t not in seen:
            newlist.append(item)
            seen.add(t)
    return newlist





In [11]:
loader = transforms.Compose([
            transforms.Resize([224, 224]),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
data_set_train = FishialDataset(
        json_path="data_train.json",
        root_folder="/home/codahead/Fishial/FishialReaserch/datasets/cutted_fish",
        transform=loader
    )

# data_set_val = FishialDataset(
#         json_path="data_test.json",
#         root_folder="/home/codahead/Fishial/FishialReaserch/dataset/",
#         transform=loader
#     )

In [None]:
list_of_reports = ['output/cross_entropy_best.json','output/triplet_best.json']

In [None]:
labels = []
scores = []
for report in list_of_reports:
    basename = os.path.basename(report).split('.')[0]
    data_r = read_json(report)
    
    labels, scores_tmp = _get_scores(data_r)
    if len(scores) == 0:
        scores = scores_tmp
    else:
        scores = np.concatenate((scores, scores_tmp), axis=1)
        
    full_data = {}
    for model_score in range(int(len(scores[0])/4)):
        for idxx, name in enumerate(['Pre', 'TPR', 'F1']):
            name_metrics = "{}_{}".format(name, model_score)
            full_data.update({name_metrics: [round(iix, 2) for iix in scores[:, model_score * 4 + idxx]]})
        
    

In [None]:
full_data = {}
for model_score in range(int(len(scores[0])/4)):
    for idxx, name in enumerate(['Pre', 'TPR', 'F1']):
        name_metrics = "{}_{}".format(name, model_score)
        full_data.update({name_metrics: [round(iix, 2) for iix in scores[:, model_score * 4 + idxx]]})
df=pd.DataFrame(full_data)
df.index = labels[:61]
for column in range(int(len(df.columns)/3) - 1):
    for metric_n in range(3):
        name = df.columns[3 * (column+1) + metric_n]
        main_column = name[:len(name)-1]+"0"
        df[name] = df[name] - df[main_column]

In [None]:
print(sum(df['Pre_1']), sum(df['TPR_1']), sum(df['F1_1']))

In [None]:
model_class = init_model_class('output/ckpt_adam_cross_entropy_0.837245696400626_41800.0.ckpt')
model_embed = init_model_embed('output/ckpt_triplet_cross_entropy_0.87_42000.0.ckpt')

data_train = get_data_base(model_embed, data_set_train)
data_eval = get_data_base(model_embed, data_set_val)

In [None]:
lables = [data_set_train.library_name[xcx]['label'] for xcx in data_set_train.library_name][:61]

In [None]:
segm_class = SegmentationInference(model_path = '../../best_scores/model_0067499_amp_on.pth')
softmax = nn.Softmax(dim=None)

y_true = []
y_embed_61_pred = []
y_class_pred = []

dir_valid = '../self_cuted'

dirs = [os.path.join(dir_valid, o) for o in os.listdir(dir_valid) 
                    if os.path.isdir(os.path.join(dir_valid,o))]

for indiece, specie_dir in enumerate(dirs):
    print("Left: {}".format(len(dirs) - indiece), end='\r' )
    
    basename = os.path.basename(specie_dir)
    if basename not in labels:
        print(basename)
        
    imgs = [f for f in listdir(specie_dir) if isfile(join(specie_dir, f))]

    for img_name in imgs:
        path_img = os.path.join(specie_dir, img_name)
        try:
            mask = cv2.imread(path_img)
            mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
        except:
            continue
            
        image = Image.fromarray(mask)
        image = loader(image).float()
        image = torch.tensor(image)

        dump_embed = model_embed(image.unsqueeze(0)).detach().numpy()
        topest = classify(data_train, dump_embed)
        flatten = [iii[0] for iii in topest[:10]]
        my_dict = [[i, flatten.count(i)] for i in flatten]
        my_dict = remove_dupliceta(my_dict)
        my_dict = sorted(my_dict, key=lambda x: x[1], reverse=True)

        dump = softmax(model_class(image.unsqueeze(0)))
        output = torch.topk(dump, 3)
        y_true.append(basename)
        y_embed_61_pred.append(labels[topest[0][0]])
        y_class_pred.append(lables[int(output.indices[0][0])])

In [None]:
labels = [data_set_train.library_name[i]['label'] for i in data_set_train.library_name]
cm = confusion_matrix(y_true, y_embed_61_pred, normalize='true')
fig, ax = plt.subplots(figsize=(30, 30))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp = disp.plot(cmap=plt.cm.Blues, ax=ax, xticks_rotation=90)
plt.show()
print(classification_report(y_true, y_embed_61_pred, target_names=labels))

In [None]:
labels = [data_set_train.library_name[i]['label'] for i in data_set_train.library_name]
cm = confusion_matrix(y_true, y_class_pred, normalize='true')
fig, ax = plt.subplots(figsize=(30, 30))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp = disp.plot(cmap=plt.cm.Blues, ax=ax, xticks_rotation=90)
plt.show()
print(classification_report(y_true, y_class_pred, target_names=labels))

The script bellow make a cuted fishs from src dir to distanation by mask rcnn net

In [None]:
segm_class = SegmentationInference(model_path = '../../best_scores/model_0067499_amp_on.pth')

src_dir = '../self_validation'
new_dir = '../self_cuted'
dirs = [os.path.join(src_dir, o) for o in os.listdir(src_dir) 
                    if os.path.isdir(os.path.join(src_dir,o))]
for path in dirs:
    print(path)
    basename = os.path.basename(path)
    new_spieces_dir_path = os.path.join(new_dir, basename)
    
    os.makedirs(new_spieces_dir_path, exist_ok=True)
    imgs = [f for f in listdir(path) if isfile(join(path, f))]
    numss = 0
    for img_name in imgs:
        path_img = os.path.join(os.path.join(src_dir, basename), img_name)
        try:
            img = cv2.imread(path_img)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        except:
            print("Error: ", path_img)
            continue
        masks = segm_class.simple_inference(img, output='mask')
        for mask in masks:
            numss += 1
            mask_path = os.path.join(new_spieces_dir_path, "{}_img_{}.png".format(basename, numss))
            cv2.imwrite(mask_path, cv2.cvtColor(mask, cv2.COLOR_RGB2BGR)) 
    print("Total: ", numss)

In [13]:
test = {0: {'num': '336', 'label': 'Pomatomus saltatrix'},
 1: {'num': '157', 'label': 'Acanthocybium solandri'},
 2: {'num': '94', 'label': 'Carassius auratus'},
 3: {'num': '210', 'label': 'Caranx hippos'},
 4: {'num': '604', 'label': 'Thunnus atlanticus'},
 5: {'num': '436', 'label': 'Perca flavescens'},
 6: {'num': '20', 'label': 'Salvelinus fontinalis'},
 7: {'num': '18', 'label': 'Oncorhynchus mykiss'},
 8: {'num': '383', 'label': 'Esox lucius'},
 9: {'num': '11', 'label': 'Pterois volitans'},
 10: {'num': '10', 'label': 'Sphyraena barracuda'},
 11: {'num': '529', 'label': 'Lepomis macrochirus'},
 12: {'num': '703', 'label': 'Pogonias cromis'},
 13: {'num': '248', 'label': 'Sciaenops ocellatus'},
 14: {'num': '416', 'label': 'Lepomis gulosus'},
 15: {'num': '684', 'label': 'Esox masquinongy'},
 16: {'num': '23', 'label': 'Salmo trutta'},
 17: {'num': '15', 'label': 'Scomberomorus cavalla'},
 18: {'num': '13', 'label': 'Micropterus salmoides'},
 19: {'num': '5', 'label': 'Coryphaena hippurus'},
 20: {'num': '22', 'label': 'Micropterus dolomieu'},
 21: {'num': '91', 'label': 'Balistes capriscus'},
 22: {'num': '16', 'label': 'Thunnus albacares'},
 23: {'num': '247', 'label': 'Megalops atlanticus'},
 24: {'num': '235', 'label': 'Trachinotus falcatus'},
 25: {'num': '14', 'label': 'Morone saxatilis'},
 26: {'num': '237', 'label': 'Seriola dumerili'},
 27: {'num': '356', 'label': 'Caranx crysos'},
 28: {'num': '658', 'label': 'Lutjanus vivanus'},
 29: {'num': '234', 'label': 'Scomberomorus maculatus'},
 30: {'num': '162', 'label': 'Rachycentron canadum'},
 31: {'num': '29', 'label': 'Amphiprion percula'},
 32: {'num': '252', 'label': 'Haemulon sciurus'},
 33: {'num': '347', 'label': 'Lutjanus synagris'},
 34: {'num': '129', 'label': 'Lutjanus campechanus'},
 35: {'num': '24', 'label': 'Istiophorus albicans'},
 36: {'num': '388', 'label': 'Cynoscion nebulosus'},
 37: {'num': '217', 'label': 'Elops saurus'},
 38: {'num': '12', 'label': 'Carcharias taurus'},
 39: {'num': '394', 'label': 'Lutjanus griseus'},
 40: {'num': '159', 'label': 'Cyprinus carpio'},
 41: {'num': '483', 'label': 'Selene vomer'},
 42: {'num': '192', 'label': 'Centropristis striata'},
 43: {'num': '725', 'label': 'Caranx ruber'},
 44: {'num': '221', 'label': 'Epinephelus morio'},
 45: {'num': '245', 'label': 'Amphiprion ocellaris'},
 46: {'num': '17', 'label': 'Carcharodon carcharias'},
 47: {'num': '449', 'label': 'Chaetodipterus faber'},
 48: {'num': '481', 'label': 'Mycteroperca microlepis'},
 49: {'num': '676', 'label': 'Lagodon rhomboides'},
 50: {'num': '712', 'label': 'Archosargus probatocephalus'},
 51: {'num': '230', 'label': 'Lobotes surinamensis'},
 52: {'num': '276', 'label': 'Xiphias gladius'},
 53: {'num': '100', 'label': 'Pomoxis nigromaculatus'},
 54: {'num': '115', 'label': 'Sander vitreus'},
 55: {'num': '21', 'label': 'Rhincodon typus'},
 56: {'num': '696', 'label': 'Oncorhynchus kisutch'},
 57: {'num': '142', 'label': 'Katsuwonus pelamis'},
 58: {'num': '251', 'label': 'Euthynnus alletteratus'},
 59: {'num': '2', 'label': 'Lutjanus analis'},
 60: {'num': '51', 'label': 'Cyprinus rubrofuscus'}}

In [14]:
test[0]

{'num': '336', 'label': 'Pomatomus saltatrix'}