In [10]:
import sys
sys.path.insert(1, '/home/fishial/Fishial/Object-Detection-Model')
import yaml
import torch
import os
import cv2
import random
import argparse

import numpy as np
from pathlib import Path
import torchvision.models as models
import matplotlib.pyplot as plt
from module.classification_package.src.utils import read_json, save_json
from PIL import Image
from torchvision import transforms
# Change path specificly to your directories
from torch import nn
import fiftyone as fo
import fiftyone.zoo as foz

from module.classification_package.src.dataset import FishialDatasetFoOnlineCuting
from module.classification_package.src.utils import get_data_config, update_internal_id

from module.classification_package.src.model import init_model
from module.segmentation_package.src.utils import get_mask

def get_config(path):
    with open(path, "r") as stream:
        try:
            return yaml.safe_load(stream)
        except yaml.YAMLError as exc:
            print(exc)    

In [11]:
class EmbeddingModel(nn.Module):
    def __init__(self, backbone: nn.Module,last_layer = 512, emb_dim=128):
        super().__init__()
        self.backbone = backbone
        self.embeddings = nn.Linear(last_layer, emb_dim)
        self.softmax = nn.Softmax()

    def forward(self, x: torch.Tensor):
        return self.embeddings(self.backbone(x))

In [12]:
loader = transforms.Compose([
        transforms.Resize((224, 224), Image.BILINEAR),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])  


In [13]:
absolute_path = '/home/fishial/Fishial/output/classification/resnet_18_184_train_06_12'
model_name= 'model_184.ckpt'
device = 'cpu'
resnet18 = models.resnet18()
resnet18.fc = nn.Identity()

model = EmbeddingModel(resnet18, 512, 256)
model.load_state_dict(torch.load(os.path.join(absolute_path, model_name), map_location=torch.device(device)))
model.eval()
model.to(device)


EmbeddingModel(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, t

In [14]:
fo_dataset = fo.load_dataset("fish-classification-184")
fo_dataset = fo_dataset.match_tags(['val', 'train'])

In [15]:
list_numbers = random.choices([100,100], k=256) 
random_numbers = torch.Tensor(list_numbers)

data_set_ids = {}
embedding_tensor = []
data_labels = {}
for sample_id, sample in enumerate(fo_dataset):
    print(f"Left: {sample_id}/{len(fo_dataset)}", end='\r')
    img_path = sample['filepath']
    label = sample['polyline']['label']
    image_id, annotation_id, drawn_fish_id = sample['image_id'], sample['annotation_id'], sample['drawn_fish_id']
    width, height = sample['width'], sample['height']
    
    polyline = sample['polyline']['points'][0]
    polyline = [[int(point[0] * width), int(point[1] * height)] for point in polyline]
    
    if label not in data_labels:
        internal_id = str(len(data_labels))
        data_labels.update({label :internal_id})
        data_set_ids.update({internal_id: {
            'image_id':[],
            'annotation_id': [],
            'drawn_fish_id': [],
        }})
        embedding_tensor.append([])
    mask = cv2.imread(img_path)
    mask = get_mask(mask, np.array(polyline))
    mask = Image.fromarray(mask)
    mask = loader(mask)
    output = model(mask.unsqueeze(0))
    embedding_tensor[int(data_labels[label])].append(output[0].detach())
    
    data_set_ids[data_labels[label]]['image_id'].append(image_id)
    data_set_ids[data_labels[label]]['annotation_id'].append(annotation_id)
    data_set_ids[data_labels[label]]['drawn_fish_id'].append(drawn_fish_id)

Left: 30309/30310

In [16]:
max_val = max(len(i) for i in embedding_tensor)
for i in range(len(embedding_tensor)):
    if len(embedding_tensor[i]) < max_val:
        for _ in range(max_val - len(embedding_tensor[i])):
            embedding_tensor[i].append(random_numbers)

In [17]:

data_set = torch.stack ([torch.stack(i) for i in embedding_tensor] )
torch.save(data_set, os.path.join(absolute_path, 'embeddings.pt'))
data_labels = {data_labels[label]:label for label in data_labels}
save_json(data_labels, os.path.join(absolute_path, 'labels.json'))
save_json(data_set_ids, os.path.join(absolute_path, 'idx.json'))

In [9]:
fo.launch_app()

Dataset:     -
Session URL: http://localhost:5151/