In [None]:
# ! conda remove torchvision pytorch torchaudio -y

In [None]:
# ! conda install pytorch=2.0.1 torchvision=0.15.2 torchaudio pytorch::torchvision -c pytorch -y

In [None]:
import torch, torchvision
torchvision.disable_beta_transforms_warning()
import sys
import time
import os
import PIL
import pickle
import importlib
import numpy as np

In [None]:
#internal modules
MODULES = ['datasets', 'traininglib']


from modellib import DuckDetector

def save(self, destination):
    from torch import package
    if isinstance(destination, str):
        destination = time.strftime(destination)
        if not destination.endswith('.pt.zip'):
            destination += '.pt.zip'

    with package.PackageExporter(destination) as exp:
        interns = [__name__.split('.')[-1]]+MODULES
        exp.intern(interns)
        exp.extern('**', exclude=['torchvision.**'])
        externs = ['torchvision.ops.**', 'torchvision.datasets.**', 'torchvision.io.**']
        exp.intern('torchvision.**', exclude=externs)
        exp.extern(externs)
        exp.intern('torchvision.models.detection.**')
        # force inclusion of internal modules + re-save if importlib.reload'ed
        for m in MODULES:
            exp.save_module(m)
        exp.save_module('modellib')
        exp.save_pickle('model', 'model.pkl', self)
        exp.save_text('model', 'class_list.txt', '\n'.join(self.class_list))
    return destination

In [None]:
# training_label_dict = {4.0: 'MALL', 1.0: 'AMCO', 3.0: 'GWTE', 6.0: 'NSHO', 2.0: 'GADW', 8.0: 'RNDU', 5.0: 'NOPI', 7.0: 'REDH'}

basemodel_pt_zip = save(self=DuckDetector(classes_of_interest=[ # class order must match the label_dict from training
                                    'Fulica americana',
                                    'Mareca strepera',
                                    'Anas carolinensis',
                                    'Anas platyrhynchos',
                                    'Anas acuta',
                                    'Spatula clypeata',
                                    'Aythya americana',
                                    'Aythya collaris']), 
                                    destination="basemodel.pt.zip")

In [None]:
imp = torch.package.PackageImporter(basemodel_pt_zip)
print(imp.file_structure())

### <center>Test that packaged model opens</center>

In [None]:
def load_modelfile(file_path:str) -> "torch.nn.Module":
        if file_path.endswith('.pt.zip'):
            return torch.package.PackageImporter('basemodel.pt.zip').load_pickle('model', 'model.pkl', map_location='cpu')
        elif file_path.endswith('.pkl'):
            import pickle
            return pickle.load(open('basemodel.pt.zip', 'rb'))

In [None]:
file_path = 'basemodel.pt.zip'
model = load_modelfile(file_path=file_path)


In [None]:
print(model)

### <center> Run model on sample images </center>

In [None]:
print(model.class_list)

In [None]:
# print predicts so that numbers don't have more than 4 decimal places
# don't print with scientific notation
np.set_printoptions(suppress=True)
np.set_printoptions(precision=4)
prediction = model.process_image('C:/Users/zack/Desktop/DuckNet_Data_Test/Images/DJI_20221216101519_0005_Z.JPG')
print(prediction)

In [None]:
import matplotlib.pyplot as plt

In [None]:
# turn model.class_list into a dictionary
label_dict = {i+1: model.class_list[i] for i in range(len(model.class_list))}

# distinct colors 
distinct_colors = ['#f032e6', '#ffffff', '#ffe119', '#3cb44b', '#42d4f4',
                    '#f58231', '#e6194B', '#dcbeff', '#469990', '#4363d8']

# label color map for plotting color-coded boxes by class
label_color_map = {k: distinct_colors[i] for i, k in enumerate(label_dict.keys())}

# function for reshaping boxes 
def get_box(boxes):
    boxes = np.array(boxes)
    boxes = boxes.astype('float').reshape(-1, 4)
    if boxes.shape[0] == 1 : return boxes
    return np.squeeze(boxes)


# function for plotting image
def img_show(image, ax = None, figsize = (6, 9)):
    if ax is None:
        fig, ax = plt.subplots(figsize = figsize)
    ax.xaxis.tick_top()
    ax.imshow(image)
    return ax

def plot_bbox_predicted(ax, boxes, labels, scores): # modify plot_bbox to add confidence scores
    # add box to the image and use label_color_map to color-code by bounding box class if exists else 'black'
    ax.add_patch(plt.Rectangle((boxes[:, 0], boxes[:, 1]), boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1],
                    fill = False,
                    color = label_color_map[labels.item()] if labels.item() in label_color_map else 'black', 
                    linewidth = 1.5))
    
    # add label and score to the bounding box. concatenate label and score to one string. 
    # use label_dict to replace class numbers with class names
    ax.text(boxes[:, 0], boxes[:, 1] - 100,
        s = f"{label_dict[labels.item()]} {scores.item():.2f}",
        color = 'black',
        fontsize = 6,
        verticalalignment = 'top',
        bbox = {'color': label_color_map[labels.item()] if labels.item() in label_color_map else 'black', 'pad': 0})
    return ax


# function for plotting all predictions on images
def plot_predictions(image, boxes, labels, scores, ax = None):
    ax = img_show(image, ax = ax)
    for i in range(len(boxes)):
        box = get_box(boxes[i])
        plot_bbox_predicted(ax, box, labels[i], scores[i])

In [None]:
image = model.load_image('C:/Users/zack/Desktop/DuckNet_Data_Test/Images/DJI_20221216101519_0005_Z.JPG')

plot_predictions(image, 
                 prediction['boxes'], 
                 prediction['labels'], 
                 prediction['scores'])

In [None]:
model.start_training_detector(imagefiles_train='C:/Users/zack/Desktop/DuckNet_Data_Test/Images/',
                              jsonfiles_train='C:/Users/zack/Desktop/DuckNet_Data_Test/Annotations/',
                              )


In [None]:
# import datasets module

import datasets

In [None]:
# boxes = datasets.get_boxes_from_jsonfile('C:/Users/zack/Desktop/DuckNet_Data_Test/Annotations/DJI_20211215103949_0003_Z.json')
# labels = datasets.get_labels_from_jsonfile('C:/Users/zack/Desktop/DuckNet_Data_Test/Annotations/DJI_20211215103949_0003_Z.json')

In [None]:
# plot_predictions(image, boxes, labels)

In [None]:
# model = torchvision.models.detection.ssd300_vgg16()
# model.load_state_dict(torch.load('C:/Users/zack/Documents/GitHub/SSD_VGG_PyTorch/ssd300_vgg16_gradientAccumulation_noHen.pth', map_location=torch.device('cpu')))

In [None]:
# model

In [None]:
# label_dict = {4.0: 'MALL', 1.0: 'AMCO', 3.0: 'GWTE', 6.0: 'NSHO', 2.0: 'GADW', 8.0: 'RNDU', 5.0: 'NOPI', 7.0: 'REDH'}

# # distinct colors 
# distinct_colors = ['#f032e6', '#ffffff', '#ffe119', '#3cb44b', '#42d4f4',
#                     '#f58231', '#e6194B', '#dcbeff', '#469990', '#4363d8']

# # label color map for plotting color-coded boxes by class
# label_color_map = {k: distinct_colors[i] for i, k in enumerate(label_dict.keys())}

# # function for reshaping boxes 
# def get_box(boxes):
#     boxes = np.array(boxes)
#     boxes = boxes.astype('float').reshape(-1, 4)
#     if boxes.shape[0] == 1 : return boxes
#     return np.squeeze(boxes)


# # function for plotting image
# def img_show(image, ax = None, figsize = (6, 9)):
#     if ax is None:
#         fig, ax = plt.subplots(figsize = figsize)
#     ax.xaxis.tick_top()
#     ax.imshow(image)
#     return ax

# def plot_bbox_predicted(ax, boxes, labels, scores): # modify plot_bbox to add confidence scores
#     # add box to the image and use label_color_map to color-code by bounding box class if exists else 'black'
#     ax.add_patch(plt.Rectangle((boxes[:, 0], boxes[:, 1]), boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1],
#                     fill = False,
#                     color = label_color_map[labels.item()] if labels.item() in label_color_map else 'black', 
#                     linewidth = 1.5))
    
#     # add label and score to the bounding box. concatenate label and score to one string. 
#     # use label_dict to replace class numbers with class names
#     ax.text(boxes[:, 0], boxes[:, 1] - 100,
#         s = f"{label_dict[labels.item()]} {scores.item():.2f}",
#         color = 'black',
#         fontsize = 6,
#         verticalalignment = 'top',
#         bbox = {'color': label_color_map[labels.item()] if labels.item() in label_color_map else 'black', 'pad': 0})
#     return ax


# # function for plotting all predictions on images
# def plot_predictions(image, boxes, labels, scores, ax = None):
#     ax = img_show(image, ax = ax)
#     for i in range(len(boxes)):
#         box = get_box(boxes[i])
#         plot_bbox_predicted(ax, box, labels[i], scores[i])

In [None]:
# import torchvision.transforms.v2 as T

# def run_model(model, image_path):
#     # set model to evaluation
#     model.eval()

#     image = PIL.Image.open(image_path)
    
#     width, height = image.size

#     # convert image to tensor
#     image_tensor = T.ToImageTensor()(image)

#     # # add batch dimension
#     image_tensor = image_tensor.unsqueeze(0)

#     image_tensor = T.ConvertImageDtype(torch.float32)(image_tensor)

#     # resize to 300x300
#     image_tensor = T.Resize((300, 300), antialias=True)(image_tensor)

#     # normalize image
#     image_tensor = T.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])(image_tensor)
    
#     # run model
#     with torch.no_grad():
#         output = model(image_tensor)

#     # get boxes, labels, and scores
#     boxes = output[0]['boxes']
#     labels = output[0]['labels']
#     scores = output[0]['scores']

#     # filter out boxes with scores less than 0.5
#     boxes = boxes[scores > 0.5]
#     labels = labels[scores > 0.5]
#     scores = scores[scores > 0.5]

#     # rescale boxes to original image size
#     boxes[:, 0] *= width / 300
#     boxes[:, 1] *= height / 300
#     boxes[:, 2] *= width / 300
#     boxes[:, 3] *= height / 300

#     # plot predictions
#     plot_predictions(image, boxes, labels, scores)
#     plt.show()
#     return output

In [None]:
# image_path = 'C:/Users/zack/Desktop/DuckNet_Data_Test/Images/DJI_20211215103949_0003_Z.JPG'
# output = run_model(model, image_path)

In [None]:
# output

In [None]:
# print(type(output))
# print(type(output[0]))
# print(type(output[0]['boxes']))
# print(type(output[0]['labels']))
# print(type(output[0]['scores']))