In [None]:
!pip install datasets

In [None]:
from datasets import load_dataset

dataset = load_dataset("mllab/alfafood")

In [None]:
import os
import cv2
import random
import pathlib
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from typing import Tuple, Dict, List

import torch
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import Dataset, DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

from sklearn.model_selection import train_test_split

import albumentations
from albumentations.pytorch.transforms import ToTensorV2

from PIL import Image, ImageFile, ImageFont, ImageDraw, ImageEnhance
ImageFile.LOAD_TRUNCATED_IMAGES = True

import copy
from time import time

import warnings
warnings.filterwarnings('ignore')

FUSED_SHAPE = (640, 480)
ORIGINAL_SHAPE = (4000, 3000)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def fused_bbox(bboxes, original_shape, fused_shape):

    x_fused = fused_shape[0] / original_shape[0]
    y_fused = fused_shape[1] / original_shape[1]

    for i in range(len(bboxes)):
        bboxes[i][0] = bboxes[i][0] * x_fused
        bboxes[i][1] = bboxes[i][1] * y_fused
        bboxes[i][2] = bboxes[i][2] * x_fused
        bboxes[i][3] = bboxes[i][3] * y_fused

    return bboxes

In [None]:
images = [dataset['train'][i]['image'].resize(FUSED_SHAPE) for i in range(len(dataset['train']))]
objects = [dataset['train'][i]['objects'] for i in range(len(dataset['train']))]

for i in range(len(objects)):
    if objects[i]['bbox'] != []:
        objects[i]['bbox'] = fused_bbox(objects[i]['bbox'], dataset['train'][i]['image'].size, FUSED_SHAPE)

In [None]:
test_transform = albumentations.Compose(
    [
        # albumentations.Resize(height=FUSED_SHAPE[1], width=FUSED_SHAPE[0]),
        albumentations.pytorch.transforms.ToTensorV2()
    ],
    bbox_params=albumentations.BboxParams(format='pascal_voc', label_fields=['labels'])
)

In [None]:
class AlfaFoodDataset(Dataset):
    def __init__(self, images: List, objects: List[Dict[str, List]], transform: torchvision.transforms=None) -> None:
        super().__init__()
        self.images = images
        self.annotations = copy.deepcopy(objects)
        self.transform = transform
        self.num_classes = len(set(i for ob in objects for i in ob['categories']))

        for i in range(len(self.annotations)):
            self.bboxes = self.annotations[i]['bbox']
            for bbox in self.bboxes:
                bbox[2] += bbox[0]
                bbox[3] += bbox[1]


    def __getitem__(self, index: int) -> Tuple[torch.Tensor, Tuple[Tuple[int]], Tuple[int]]:
        "Returns one sample of data: image, labels, bboxes"

        image = np.array(self.images[index].convert('RGB'))
        bboxes = self.annotations[index]['bbox']
        labels = self.annotations[index]['categories']

        if self.transform:
            transformed = self.transform(image = image, bboxes = bboxes, labels = labels)
            image = np.array(transformed['image']).transpose(1, 2, 0)
            bboxes = transformed['bboxes']
            labels = transformed['labels']

        image = image.transpose(2, 0, 1)
        target = dict()
        target['boxes'] = torch.as_tensor(bboxes, dtype=torch.float)
        target['labels'] = torch.as_tensor(labels, dtype=torch.int64)

        if target['boxes'].shape == torch.Size([0]):
            target['boxes'] = torch.Tensor([0, 0, 1e-10, 1e-10]).unsqueeze(dim=0)
        if target['labels'].shape == torch.Size([0]):
            target['labels']= torch.zeros(size=(1, ), dtype=torch.int64)
        return image, target


    def __len__(self) -> int:
        "Returns the total number of samples."
        return len(self.images)

In [None]:
data = AlfaFoodDataset(images=images,
                       objects=objects,
                       transform=None)

In [None]:
class FasterRCNN_ResNet50(torch.nn.Module):
  def __init__(self, num_classes: int=127) -> None:
    super().__init__()

    self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True, pretrained_backbone=True)
    num_classes = num_classes + 2
    in_features = self.model.roi_heads.box_predictor.cls_score.in_features
    self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    for child in list(self.model.children())[:-1]:
      for param in child.parameters():
          param.requires_grad = False

  def predict(self, X: torch.Tensor) -> torch.Tensor:
    '''
    For predict bboxes and labels
    '''
    return self.model(X)

  # To calculate the loss function
  def forward(self, images: List[torch.Tensor], annotation: List[Dict[str, torch.Tensor]]) -> Dict[str, int]:
    return self.model(images, annotation)

model = FasterRCNN_ResNet50(num_classes=data.num_classes)

In [None]:
import pathlib

path_to_weights_model = pathlib.Path("/content/best_model (2).pth")

model.load_state_dict(torch.load(path_to_weights_model, map_location=torch.device('cpu')))

In [None]:
import random

def show_image_with_objects(image, bboxes, labels=None):

    image = Image.fromarray(image.transpose(1, 2, 0))

    color = list((random.randint(40, 240), random.randint(40, 255), random.randint(60, 255)) for i in range(100))

    random.shuffle(color)

    # if bboxes.shape
    for i in range(len(bboxes)):
        draw = ImageDraw.Draw(image)
        draw.rectangle(bboxes[i].numpy(), outline = color[i], width=2)

        bbox = draw.textbbox((bboxes[i][0], bboxes[i][1]), f"{labels[i]}")
        draw.rectangle((bbox[0]-2, bbox[1]-2, bbox[2]+2, bbox[3]+2), fill=(30, 20, 20))
        draw.text((bboxes[i][0], bboxes[i][1]), f"{labels[i]}", color[i])
    return image

image, target = random.choice(data)
# bb = random.choice(dataset['train'])
print(image.shape)
show_image_with_objects(image, target['boxes'], target['labels'])

In [None]:
with torch.no_grad():
  model.eval()
  outputs = model.predict(torch.Tensor(image).unsqueeze(dim=0))

In [None]:
outputs

In [None]:
show_image_with_objects(image, outputs[0]['boxes'], outputs[0]['labels'])